Compare commits

..

9 Commits

Author SHA1 Message Date
wxy
cef33fce0d fix(workflow): sanitize condition expression building and cache assigner node inputs
- Sanitize condition expression construction in graph_builder.py using json.dumps to prevent potential injection vulnerabilities.
- Cache input data prior to assigner node execution to ensure variable values are correctly captured before processing.
2026-05-07 16:26:47 +08:00
wxy
d9f08860bc feat(LLM node): integrate exception handling and enable branch routing
- Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch.
- Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure.
- Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment.
- Implement backward compatibility to support both legacy and new result formats.
2026-05-07 11:43:24 +08:00
wxy
461674c8d8 feat(workflow): parse and substitute template variables in node configurations
- Implement regex matching for {{xxx}} template variable format.
- Enable recursive parsing of all string template variables within node configurations.
- Resolve and substitute template variables with runtime values during input data extraction.
- Support dynamic parsing and substitution of file selector variables in the document extraction node.
- Make strict template variable mode optional and introduce support for default values.
2026-04-29 14:10:02 +08:00
wxy
c59e179cc2 feat(workflow): incorporate model references and streamline parsing logic
- Incorporate model reference metadata (name, provider, type) into workflow nodes and refactor parsing logic to support the new format.
- Streamline code structure by removing redundant model_id fields to enhance maintainability.
2026-04-28 11:18:06 +08:00
Mark
a5670bfff6 Merge branch 'feature/rag2' into develop 2026-04-27 18:17:49 +08:00
Mark
4bef9b578b [fix] document file delete 2026-04-27 17:35:13 +08:00
Mark
c53fcf3981 [fix] old code file_path 2026-04-27 17:10:00 +08:00
Mark
2997558bc8 Merge branch 'release/v0.3.2' into feature/rag2
* release/v0.3.2: (245 commits)
  fix(conversation_schema): refine citations field type to Dict[str, Any]
  fix(tool_controller): re-raise HTTPException to preserve original status codes
  fix(workflow): add reasoning content, suggested questions, citations and audio status support
  feat(workflow): augment logging queries and ameliorate error handling
  fix(api_key): bypass publication check for SERVICE type API keys
  fix(multimodal_service): add '文档内容:' prefix to document text and simplify image placeholder text
  fix(api): convert config_id to string in write_router
  fix(api): convert end_user_id to string in write_router
  fix(multimodal_service): refactor image processing to use intermediate list before extending result
  fix(web): node status ui
  fix(api): correct import paths in memory_read and celery task command
  fix(api): correct import paths in memory_read and celery task command
  refactor(tool): flatten request body parameters for model exposure
  fix(api): correct import paths in memory_read and celery task command
  refactor(workflow): streamline node execution handling and log service logic
  feat(web): http request add process
  feat(web): workflow app logs
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  refactor(http_request): simplify request handling and remove unused fields
  ...

# Conflicts:
#	api/app/controllers/file_controller.py
#	api/app/tasks.py
2026-04-27 16:13:57 +08:00
Mark
30cdf229de [modify] rag file system 2026-04-27 16:05:27 +08:00
48 changed files with 1052 additions and 1159 deletions

View File

@@ -3,9 +3,12 @@ name: Sync to Gitee
on:
push:
branches:
- '**' # All branchs
- main # Production
- develop # Integration
- 'release/*' # Release preparation
- 'hotfix/*' # Urgent fixes
tags:
- '**' # All version tags (v1.0.0, etc.)
- '*' # All version tags (v1.0.0, etc.)
jobs:
sync:

View File

@@ -82,19 +82,32 @@ async def get_preview_chunks(
detail="The file does not exist or you do not have permission to access it"
)
# 5. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 6. Check if the file exists
if not os.path.exists(file_path):
# 5. Get file content from storage backend
if not db_file.file_key:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
detail="File has no storage key (legacy data not migrated)"
)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(db_file.file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File not found in storage: {e}"
)
# 7. Document parsing & segmentation
@@ -104,11 +117,12 @@ async def get_preview_chunks(
vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese", # Default to Chinese
lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base
)
from app.core.rag.app.naive import chunk
res = chunk(filename=file_path,
res = chunk(filename=db_file.file_name,
binary=file_binary,
from_page=0,
to_page=5,
callback=progress_callback,

View File

@@ -20,6 +20,7 @@ from app.models.user_model import User
from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse
from app.services import document_service, file_service, knowledge_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service
# Obtain a dedicated API logger
@@ -231,7 +232,8 @@ async def update_document(
async def delete_document(
document_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Delete document
@@ -257,7 +259,7 @@ async def delete_document(
db.commit()
# 3. Delete file
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
# 4. Delete vector index
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
@@ -305,38 +307,25 @@ async def parse_documents(
detail="The file does not exist or you do not have permission to access it"
)
# 3. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path):
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
# 3. Get file_key for storage backend
if not db_file.file_key:
api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
detail="File has no storage key (legacy data not migrated)"
)
# 5. Obtain knowledge base information
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
# 4. Obtain knowledge base information
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
# 6. Task: Document parsing, vectorization, and storage
# from app.tasks import parse_document
# parse_document(file_path, document_id)
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
# 5. Dispatch parse task with file_key (not file_path)
task = celery_app.send_task(
"app.core.rag.tasks.parse_document",
args=[db_file.file_key, document_id, db_file.file_name]
)
result = {
"task_id": task.id
}

View File

@@ -1,12 +1,10 @@
import os
from pathlib import Path
import shutil
from typing import Any, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse
from fastapi.responses import Response
from sqlalchemy.orm import Session
from app.core.config import settings
@@ -19,10 +17,14 @@ from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse
from app.services import file_service, document_service
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
from app.services.file_storage_service import (
FileStorageService,
generate_kb_file_key,
get_file_storage_service,
)
from app.core.quota_stub import check_knowledge_capacity_quota
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
@@ -35,67 +37,37 @@ router = APIRouter(
async def get_files(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
page: int = Query(1, gt=0),
pagesize: int = Query(20, gt=0, le=100),
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Paged query file list
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
"""Paged query file list"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
# 2. Construct query conditions
filters = [
file_model.File.kb_id == kb_id
]
if page < 1 or pagesize < 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
filters = [file_model.File.kb_id == kb_id]
if parent_id:
filters.append(file_model.File.parent_id == parent_id)
# Keyword search (fuzzy matching of file name)
if keywords:
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
# 3. Execute paged query
try:
api_logger.debug("Start executing file paging query")
total, items = file_service.get_files_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
db=db, filters=filters, page=page, pagesize=pagesize,
orderby=orderby, desc=desc, current_user=current_user
)
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
"page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
}
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
@@ -108,23 +80,14 @@ async def create_folder(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
Create a new folder
"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
"""Create a new folder"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
try:
api_logger.debug(f"Start creating a folder: {folder_name}")
create_folder = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=folder_name,
file_ext='folder',
file_size=0,
create_folder_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
file_name=folder_name, file_ext='folder', file_size=0,
)
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
@@ -138,76 +101,58 @@ async def upload_file(
parent_id: uuid.UUID,
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
upload file
"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
"""Upload file to storage backend"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
# Read the contents of the file
contents = await file.read()
# Check file size
file_size = len(contents)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
# Extract the extension using `os.path.splitext`
_, file_extension = os.path.splitext(file.filename)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=file.filename,
file_ext=file_extension.lower(),
file_size=file_size,
file_ext = file_extension.lower()
# Create File record
upload_file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
file_name=file.filename, file_ext=file_ext, file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
# Upload to storage backend
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
try:
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file
with open(save_path, "wb") as f:
f.write(contents)
# Save file_key
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create document (inherit parser_config from knowledge base)
default_parser_config = {
"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
}
try:
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
if db_knowledge and db_knowledge.parser_config:
default_parser_config.update(dict(db_knowledge.parser_config))
except Exception:
pass
# Create a document
create_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_meta={}, parser_id="naive", parser_config=default_parser_config
)
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
@@ -221,123 +166,73 @@ async def custom_text(
parent_id: uuid.UUID,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
custom text
"""
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
# Check file content size
# 将内容编码为字节UTF-8
"""Custom text upload"""
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id,
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
upload_file_data = file_schema.FileCreate(
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
# Upload to storage backend
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
try:
await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file
with open(save_path, "wb") as f:
f.write(content_bytes)
db_file.file_key = file_key
db.commit()
db.refresh(db_file)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_meta={}, parser_id="naive",
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any)
async def get_file(
file_id: uuid.UUID,
db: Session = Depends(get_db)
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
# 1. Query file information from the database
"""Download file by file_id"""
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
# 2. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
if not db_file.file_key:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
# 3. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
try:
content = await storage_service.download_file(db_file.file_key)
except Exception as e:
api_logger.error(f"Storage download failed: {e}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
# 4.Return FileResponse (automatically handle download)
return FileResponse(
path=file_path,
filename=db_file.file_name, # Use original file name
media_type="application/octet-stream" # Universal binary stream type
import mimetypes
media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
return Response(
content=content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
)
@@ -348,50 +243,22 @@ async def update_file(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
api_logger.debug(f"Query the file to be updated: {file_id}")
# 1. Check if the file exists
"""Update file information (such as file name)"""
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items():
if hasattr(db_file, field):
old_value = getattr(db_file, field)
if old_value != value:
# update value
setattr(db_file, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
setattr(db_file, field, value)
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try:
db.commit()
db.refresh(db_file)
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
except Exception as e:
db.rollback()
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File update failed: {str(e)}"
)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
# 4. Return the updated file
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
@@ -399,60 +266,43 @@ async def update_file(
async def delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Delete a file or folder
"""
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
await _delete_file(db=db, file_id=file_id, current_user=current_user)
"""Delete a file or folder"""
api_logger.info(f"Request to delete file: file_id={file_id}")
await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
return success(msg="File deleted successfully")
async def _delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
db: Session,
current_user: User,
storage_service: FileStorageService,
) -> None:
"""
Delete a file or folder
"""
# 1. Check if the file exists
"""Delete a file or folder from storage and database"""
db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
# 2. Construct physical path
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.id)
) if db_file.file_ext == 'folder' else Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Delete physical files/folders
try:
if file_path.exists():
if db_file.file_ext == 'folder':
shutil.rmtree(file_path) # Recursively delete folders
else:
file_path.unlink() # Delete a single file
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete physical file/folder: {str(e)}"
)
# 4.Delete db_file
# Delete from storage backend
if db_file.file_ext == 'folder':
# For folders, delete all child files from storage first
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
for child in child_files:
if child.file_key:
try:
await storage_service.delete_file(child.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
else:
if db_file.file_key:
try:
await storage_service.delete_file(db_file.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
db.delete(db_file)
db.commit()

View File

@@ -296,7 +296,7 @@ async def chat(
}
)
# workflow 非流式返回
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,

View File

@@ -221,7 +221,7 @@ def update_workspace_members(
@router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def delete_workspace_member(
def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
@@ -230,7 +230,7 @@ async def delete_workspace_member(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
await workspace_service.delete_workspace_member(
workspace_service.delete_workspace_member(
db=db,
workspace_id=workspace_id,
member_id=member_id,

View File

@@ -241,8 +241,6 @@ class Settings:
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))

View File

@@ -216,7 +216,7 @@ class RedBearModelFactory:
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking:
budget = config.thinking_budget_tokens or 1024
budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget}
}

View File

@@ -14,6 +14,7 @@ Transcribe the content from the provided PDF page image into clean Markdown form
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
{% if page %}
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.

View File

@@ -2,6 +2,7 @@
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33
import json
import logging
import re
import uuid
@@ -141,9 +142,10 @@ class GraphBuilder:
for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
if node_info.get("branch") is not None:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
else:
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
output_nodes.append(node_info["id"])
@@ -314,9 +316,12 @@ class GraphBuilder:
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
# For other branch nodes (e.g. HTTP), use output field
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
related_edge[idx]['condition'] = (
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
)
if node_instance:
# Wrap node's run method to avoid closure issues

View File

@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None
self._input_data: dict[str, Any] | None = None
def _output_types(self) -> dict[str, VariableType]:
return {}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
if self._input_data is not None:
return self._input_data
return {"config": self._resolve_config(self.config, variable_pool)}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the assignment operation defined by this node.
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
Returns:
None or the result of the assignment operation.
"""
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
# Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行")

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import re
import time
import uuid
from abc import ABC, abstractmethod
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
# 匹配模板变量 {{xxx}} 的正则
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class NodeExecutionError(Exception):
"""节点执行失败异常。
@@ -503,10 +507,29 @@ class BaseNode(ABC):
variable_pool: The variable pool used for reading and writing variables.
Returns:
A dictionary containing the node's input data.
A dictionary containing the node's input data with all template
variables resolved to their actual runtime values.
"""
# Default implementation returns the node configuration
return {"config": self.config}
return {"config": self._resolve_config(self.config, variable_pool)}
@staticmethod
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
Args:
config: 节点的原始配置(可能包含模板变量)。
variable_pool: 变量池,用于解析模板变量。
Returns:
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
"""
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
return BaseNode._render_template(config, variable_pool, strict=False)
elif isinstance(config, dict):
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
elif isinstance(config, list):
return [BaseNode._resolve_config(item, variable_pool) for item in config]
return config
def _extract_output(self, business_result: Any) -> Any:
"""Extracts the actual output from the business result.

View File

@@ -14,7 +14,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -132,7 +131,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
"http://sandbox:8194/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},

View File

@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")}
file_selector = self.config.get("file_selector", "")
# 将变量选择器(如 sys.files解析为实际值
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
return {"file_selector": resolved}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config)
@@ -182,7 +185,7 @@ class DocExtractorNode(BaseNode):
mime_type=f"image/{ext}",
is_file=True,
).model_dump())
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">"
text = text + f"\n{placeholder}: {url}"
except Exception as e:
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")

View File

@@ -31,7 +31,7 @@ class NodeType(StrEnum):
NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
class ComparisonOperator(StrEnum):

View File

@@ -6,6 +6,7 @@ import uuid
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.variable.base_variable import VariableType
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
)
class LLMErrorHandleConfig(BaseModel):
"""LLM 异常处理配置"""
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
)
output: str = Field(
default="",
description="LLM 异常时返回的默认输出文本method=default 时生效)",
)
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="输出变量定义(自动生成,通常不需要修改)"
)
error_handle: LLMErrorHandleConfig = Field(
default_factory=LLMErrorHandleConfig,
description="LLM 异常处理配置",
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v):

View File

@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
self.messages = []
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
"""非流式执行 LLM 调用
Args:
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
variable_pool: 变量池
Returns:
LLM 响应消息
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
{"llm_result": None, "branch_signal": "ERROR"} on branch error
"""
# self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
try:
# self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages)
# 提取内容
if hasattr(response, 'content'):
content = self.process_model_output(response.content)
else:
content = str(response)
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages)
# 提取内容
if hasattr(response, 'content'):
content = self.process_model_output(response.content)
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
})
# 返回 AIMessage包含响应元数据
return {
"llm_result": AIMessage(content=content, response_metadata={
**response.response_metadata,
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}),
"branch_signal": "SUCCESS",
}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
return self._handle_llm_error(e)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
}
}
def _extract_output(self, business_result: Any) -> str:
""" AIMessage 中提取文本内容"""
def _extract_output(self, business_result: Any) -> dict:
"""业务结果中提取输出变量
支持新旧两种格式:
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
- 旧格式AIMessage向后兼容
"""
if isinstance(business_result, dict) and "branch_signal" in business_result:
llm_result = business_result.get("llm_result")
if isinstance(llm_result, AIMessage):
return {
"output": llm_result.content,
"branch_signal": business_result["branch_signal"],
}
return {
"output": str(llm_result) if llm_result else "",
"branch_signal": business_result["branch_signal"],
}
# 旧格式向后兼容
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
return {"output": business_result.content, "branch_signal": "SUCCESS"}
return {"output": str(business_result), "branch_signal": "SUCCESS"}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
""" AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
"""业务结果中提取 token 使用情况"""
llm_result = business_result
if isinstance(business_result, dict):
llm_result = business_result.get("llm_result", business_result)
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
usage = llm_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('input_tokens', 0),
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
}
return None
def _handle_llm_error(self, error: Exception) -> dict:
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
Args:
error: LLM 调用中捕获的异常
Returns:
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
or default output for default mode
Raises:
原异常(当 error_handle.method 为 NONE 时)
"""
if self.typed_config is None:
raise error
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
raise error
case HttpErrorHandle.DEFAULT:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
)
default_output = self.typed_config.error_handle.output or ""
return {
"llm_result": AIMessage(content=default_output, response_metadata={}),
"branch_signal": "SUCCESS",
}
case HttpErrorHandle.BRANCH:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
)
return {
"llm_result": None,
"branch_signal": "ERROR",
}
raise error
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
"""
self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, True)
try:
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
chunk_count = 0
# 累积完整响应
full_response = ""
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 调用 LLM流式支持字符串或消息列表
last_meta_data = {}
last_usage_metadata = {}
async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content)
else:
content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理
if content:
full_response += content
chunk_count += 1
# 只有当内容不为空时才处理
if content:
full_response += content
chunk_count += 1
# 流式返回每个文本片段
yield {
"__final__": False,
"chunk": content
}
# 流式返回每个文本片段
yield {
"__final__": False,
"chunk": content
}
yield {
"__final__": False,
"chunk": "",
"done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
yield {
"__final__": False,
"chunk": "",
"done": True
}
)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# yield 完成标记
yield {"__final__": True, "result": final_message}
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
error_result = self._handle_llm_error(e)
yield {"__final__": True, "result": error_result}

View File

@@ -15,4 +15,5 @@ class File(Base):
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
file_size = Column(Integer, default=0, comment="file size(byte)")
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
stop: Optional[List[str]] = Field(default=None, description="停止序列")
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")

View File

@@ -11,6 +11,7 @@ class FileBase(BaseModel):
file_ext: str
file_size: int
file_url: str | None = None
file_key: str | None = None
created_at: datetime.datetime | None = None

View File

@@ -161,10 +161,7 @@ class AppChatService:
f.type == FileType.DOCUMENT for f in files
):
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent
@@ -451,10 +448,7 @@ class AppChatService:
):
from langchain.agents import create_agent
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent

View File

@@ -102,6 +102,11 @@ class AppDslService:
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
]
return enriched
if app_type == AppType.WORKFLOW:
enriched = {**cfg}
if "nodes" in cfg:
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
return enriched
return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
@@ -110,7 +115,7 @@ class AppDslService:
config_data = {
"variables": config.variables if config else [],
"edges": config.edges if config else [],
"nodes": config.nodes if config else [],
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
"features": config.features if config else {},
"execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [],
@@ -190,6 +195,23 @@ class AppDslService:
def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _enrich_workflow_nodes(self, nodes: list) -> list:
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
from app.core.workflow.nodes.enums import NodeType
enriched_nodes = []
for node in (nodes or []):
node_type = node.get("type")
config = dict(node.get("config") or {})
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_id = config.get("model_id")
if model_id:
config["model_ref"] = self._model_ref(model_id)
del config["model_id"]
enriched_nodes.append({**node, "config": config})
return enriched_nodes
def _skill_ref(self, skill_id) -> Optional[dict]:
if not skill_id:
return None
@@ -620,16 +642,16 @@ class AppDslService:
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
config["knowledge_bases"] = resolved_kbs
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_ref = config.get("model_id")
model_ref = config.get("model_ref") or config.get("model_id")
if model_ref:
ref_dict = None
if isinstance(model_ref, dict):
ref_id = model_ref.get("id")
ref_name = model_ref.get("name")
if ref_id:
ref_dict = {"id": ref_id}
elif ref_name is not None:
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
ref_dict = {
"id": model_ref.get("id"),
"name": model_ref.get("name"),
"provider": model_ref.get("provider"),
"type": model_ref.get("type")
}
elif isinstance(model_ref, str):
try:
uuid.UUID(model_ref)
@@ -640,12 +662,18 @@ class AppDslService:
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
if resolved_model_id:
config["model_id"] = resolved_model_id
if "model_ref" in config:
del config["model_ref"]
else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
resolved_nodes.append({**node, "config": config})
return resolved_nodes

View File

@@ -650,10 +650,7 @@ class AgentRunService:
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
agent = LangChainAgent(
@@ -927,10 +924,7 @@ class AgentRunService:
)
if has_doc_with_images:
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
)
# 创建 LangChain Agent

View File

@@ -34,26 +34,7 @@ def generate_file_key(
Generate a unique file key for storage.
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
Args:
tenant_id: The tenant UUID.
workspace_id: The workspace UUID.
file_id: The file UUID.
file_ext: The file extension (e.g., '.pdf', '.txt').
Returns:
A unique file key string.
Example:
>>> generate_file_key(
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
... '.pdf'
... )
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
"""
# Ensure file_ext starts with a dot
if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}'
if workspace_id:
@@ -61,6 +42,21 @@ def generate_file_key(
return f"{tenant_id}/{file_id}{file_ext}"
def generate_kb_file_key(
kb_id: uuid.UUID,
file_id: uuid.UUID,
file_ext: str,
) -> str:
"""
Generate a file key for knowledge base files.
Format: kb/{kb_id}/{file_id}{file_ext}
"""
if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}'
return f"kb/{kb_id}/{file_id}{file_ext}"
class FileStorageService:
"""
High-level service for file storage operations.

View File

@@ -400,7 +400,7 @@ class MultimodalService:
# 在文本内容中追加图片位置标记
if result and result[-1].get("type") in ("text", "document"):
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">"
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
# 将图片以视觉格式追加到消息内容中
img_file = FileInput(
type=FileType.IMAGE,

View File

@@ -554,16 +554,13 @@ class WorkflowService:
}
}
case "workflow_end":
data = {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
if "citations" in payload and payload["citations"]:
data["citations"] = payload["citations"]
return {
"event": "end",
"data": data
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None

View File

@@ -20,7 +20,6 @@ from app.models.workspace_model import (
)
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.services.session_service import SessionService
from app.schemas.workspace_schema import (
InviteAcceptRequest,
InviteValidateResponse,
@@ -59,7 +58,7 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
async def delete_workspace_member(
def delete_workspace_member(
db: Session,
workspace_id: uuid.UUID,
member_id: uuid.UUID,
@@ -77,29 +76,10 @@ async def delete_workspace_member(
BizCode.WORKSPACE_NOT_FOUND)
try:
deleted_user = workspace_member.user
workspace_member.is_active = False
deleted_user.current_workspace_id = None
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
if not deleted_user.is_superuser:
remaining = (
db.query(WorkspaceMember)
.filter(
WorkspaceMember.user_id == deleted_user.id,
WorkspaceMember.workspace_id != workspace_id,
WorkspaceMember.is_active.is_(True),
)
.count()
)
if remaining == 0:
deleted_user.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
# 使被删除成员的所有 token 立即失效
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")

View File

@@ -210,9 +210,14 @@ def _build_vision_model(file_path: str, db_knowledge):
@celery_app.task(name="app.core.rag.tasks.parse_document")
def parse_document(file_path: str, document_id: uuid.UUID):
def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
"""
Document parsing, vectorization, and storage
Document parsing, vectorization, and storage.
Args:
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
document_id: Document UUID
file_name: Original file name (used for extension detection in chunk())
"""
db_document = None
@@ -223,7 +228,6 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with get_db_context() as db:
try:
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
if not isinstance(document_id, uuid.UUID):
document_id = uuid.UUID(str(document_id))
@@ -234,7 +238,11 @@ def parse_document(file_path: str, document_id: uuid.UUID):
if db_knowledge is None:
raise ValueError(f"Knowledge {db_document.kb_id} not found")
# 1. Document parsing & segmentation
# Use file_name from argument or fall back to document record
if not file_name:
file_name = db_document.file_name
# 1. Download file from storage backend
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
start_time = time.time()
db_document.progress = 0.0
@@ -245,45 +253,36 @@ def parse_document(file_path: str, document_id: uuid.UUID):
db.commit()
db.refresh(db_document)
# Read file content from storage backend (no NFS dependency)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
# If there's already a running loop (e.g. in some worker configurations)
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
if not file_binary:
raise IOError(f"Downloaded empty file from storage: {file_key}")
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
def progress_callback(prog=None, msg=None):
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
# Prepare vision_model for parsing
vision_model = _build_vision_model(file_path, db_knowledge)
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
# python-docx 等库在 binary=None 时会用路径直接打开文件,
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
max_wait_seconds = 30
wait_interval = 2
waited = 0
file_binary = None
while waited <= max_wait_seconds:
# os.listdir 强制 NFS 客户端刷新目录缓存
parent_dir = os.path.dirname(file_path)
try:
os.listdir(parent_dir)
except OSError:
pass
try:
with open(file_path, "rb") as f:
file_binary = f.read()
if not file_binary:
# NFS 上文件存在但内容为空(可能还在同步中)
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
break
except (FileNotFoundError, IOError) as e:
if waited >= max_wait_seconds:
raise type(e)(
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
)
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
time.sleep(wait_interval)
waited += wait_interval
vision_model = _build_vision_model(file_name, db_knowledge)
from app.core.rag.app.naive import chunk
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
res = chunk(filename=file_path,
res = chunk(filename=file_name,
binary=file_binary,
from_page=0,
to_page=DEFAULT_PARSE_TO_PAGE,

View File

@@ -8,11 +8,12 @@ import { type FC, useRef, useEffect, useState } from 'react'
import clsx from 'clsx'
import Markdown from '@/components/Markdown'
import type { ChatContentProps } from './types'
import { Spin, Flex, Button } from 'antd'
import { Spin, Image, Flex, Button } from 'antd'
import { SoundOutlined } from '@ant-design/icons'
import { useTranslation } from 'react-i18next'
import MessageFiles from './MessageFiles'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) => {
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
@@ -148,7 +149,72 @@ const ChatContent: FC<ChatContentProps> = ({
{labelFormat(item)}
</div>
}
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} />
{item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
{item.meta_data?.files?.map((file) => {
if (file.type.includes('image')) {
return (
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/')
return (
<Flex
key={file.url || file.uid}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => handleDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
file.type?.includes('pdf')
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
: file.type?.includes('csv')
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
: file.type?.includes('html')
? "rb:bg-[url('@/assets/images/file/html.svg')]"
: file.type?.includes('json')
? "rb:bg-[url('@/assets/images/file/json.svg')]"
: file.type?.includes('ppt')
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
: file.type?.includes('markdown')
? "rb:bg-[url('@/assets/images/file/md.svg')]"
: file.type?.includes('text')
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
? "rb:bg-[url('@/assets/images/file/word.svg')]"
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
)}
></div>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
</div>
</Flex>
)
})}
</Flex>}
{/* Message bubble */}
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
// Error message style (content is null and not assistant message)

View File

@@ -1,87 +0,0 @@
import { Image, Flex } from 'antd'
import clsx from 'clsx'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) =>
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
const DOC_ICONS: [string[], string][] = [
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
]
const getDocIcon = (parts: string[]) => {
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
}
interface MessageFilesProps {
files: any[]
contentClassNames?: string | Record<string, boolean>
onDownload: (file: any) => void
}
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
if (!files?.length) return null
return (
<Flex gap={8} vertical align="end" className="rb:mb-2!">
{files.map((file) => {
const key = file.url || file.uid
if (file.type.includes('image')) {
return (
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={key} className="rb:w-50">
<VideoPlayer src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={key} className="rb:w-50">
<AudioPlayer src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/') ?? []
return (
<Flex
key={key}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => onDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
getDocIcon(documentType)
)}
/>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
{documentType?.[documentType.length - 1]} · {file.size}
</div>
</div>
</Flex>
)
})}
</Flex>
)
}
export default MessageFiles

View File

@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
import Tag, { type TagProps } from '@/components/Tag'
interface OverflowTagsProps {
items?: ReactNode[];
items: ReactNode[];
gap?: number;
numTagColor?: TagProps['color'];
numTag?: (num?: number) => ReactNode;
popoverProps?: PopoverProps | false;
}
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const containerRef = useRef<HTMLDivElement>(null)
const measureRef = useRef<HTMLDivElement>(null)
const [visibleCount, setVisibleCount] = useState(items.length)
@@ -20,7 +20,7 @@ const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, po
if (!measure || containerWidth === 0) return
const children = Array.from(measure.children) as HTMLElement[]
if (!children.length) { setVisibleCount(0); return }
if (!children.length) return
// last child is the sample +N tag
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth

View File

@@ -399,7 +399,7 @@ const Menu: FC<{
className="rb:overflow-y-auto rb:flex-1!"
/>
{/* Return to space button for superusers */}
{source === 'space' &&
{user?.is_superuser && source === 'space' &&
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
<Flex
@@ -412,18 +412,16 @@ const Menu: FC<{
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
{collapsed ? null : t('common.switchSpace')}
</Flex>
{user?.is_superuser &&
<Flex
gap={8}
align="center"
justify="start"
onClick={goToSpace}
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
{collapsed ? null : t('common.returnToSpace')}
</Flex>
}
<Flex
gap={8}
align="center"
justify="start"
onClick={goToSpace}
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
{collapsed ? null : t('common.returnToSpace')}
</Flex>
</Flex>
}
{source === 'manage' && subscription && !collapsed &&

View File

@@ -1538,7 +1538,6 @@ export const en = {
json_output: 'Support JSON formatted output',
thinking_budget_tokens: 'thinking budget tokens',
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
logSearchPlaceholder: 'Search log content',
},
userMemory: {

View File

@@ -868,7 +868,6 @@ export const zh = {
json_output: '支持JSON格式化输出',
thinking_budget_tokens: '深度思考预算Token数',
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
thinking_budget_tokens_min_error: "不能小于 {{min}}",
logSearchPlaceholder: '搜索日志内容',
},
table: {

View File

@@ -49,8 +49,6 @@ const configFields = [
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
]
const minThinkingBudgetTokens = 128;
const defaultThinkingBudgetTokens = 1000;
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
refresh,
data,
@@ -110,7 +108,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
const newValues: ModelConfig = {
capability: (option as Model).capability,
deep_thinking: false,
thinking_budget_tokens: defaultThinkingBudgetTokens,
thinking_budget_tokens: undefined,
json_output: false,
}
if (source === 'chat') {
@@ -130,12 +128,6 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
form.setFieldsValue({ ...rest })
}, [data?.default_model_config_id])
useEffect(() => {
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
}
}, [values?.deep_thinking])
const handleReset = () => {
if (!id) return
resetAppModelConfig(id).then((res) => {
@@ -186,20 +178,15 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
name="thinking_budget_tokens"
label={t('application.thinking_budget_tokens')}
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
rules={[
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
{
validator: (_, value) => {
const maxTokens = values?.max_tokens
const deep_thinking = values?.deep_thinking;
if (deep_thinking && value !== undefined) {
if (value < minThinkingBudgetTokens) {
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
}
if (maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
return Promise.resolve()
}
@@ -208,7 +195,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
>
<RbSlider
step={1}
min={minThinkingBudgetTokens}
min={0}
max={32000}
isInput={true}
disabled={!values?.deep_thinking}

View File

@@ -166,10 +166,10 @@ const Ontology: FC = () => {
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
</Tooltip>
<div className="rb:mt-2 rb:h-5.5">
<div className="rb:mt-2">
<OverflowTags
popoverProps={false}
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []}
items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
/>
</div>

View File

@@ -101,7 +101,6 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
});
};
const formatSchema = (value: string) => {
if (!value || value.trim() === '') return
setParseSchemaData({} as ParseSchemaData)
parseSchema({ schema_content: value })
.then(res => {

View File

@@ -57,6 +57,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
}
}}
labelRender={(props) => {
console.log('props', props)
return `${props.value}%`
}}
className="rb:w-20 rb:h-4!"

View File

@@ -66,6 +66,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined)
console.log('abortRef', abortRef, chatList)
/**
* Opens the chat drawer and loads workflow variables from the start node
*/

View File

@@ -18,7 +18,6 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
// Handle node selection from popover and create new node replacing the add-node placeholder
const handleNodeSelect = (selectedNodeType: any) => {
graph.startBatch('add-node');
const parentBBox = node.getBBox();
const cycleId = data.cycle;
const horizontalSpacing = 0;
@@ -44,7 +43,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
parentNode.addChild(newNode, { silent: true });
parentNode.addChild(newNode);
}
}
@@ -77,40 +76,55 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
}
});
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
// Automatically adjust loop node size
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (loopNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = loopNode.getPorts();
ports.forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
loopNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
graph.stopBatch('add-node');
setOpen(false);
};

View File

@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
{data.type === 'if-else' &&
<Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index}>
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>

View File

@@ -1,15 +1,134 @@
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'
import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { Flex } from 'antd';
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'
import { graphNodeLibrary, edgeAttrs } from '../../constant';
import NodeTools from './NodeTools'
const LoopNode: ReactShapeConfig['component'] = ({ node }) => {
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const data = node.getData() || {};
const { t } = useTranslation()
useEffect(() => {
// 使用setTimeout确保在所有节点都添加完成后再创建连线
const timer = setTimeout(() => {
initNodes()
checkAndAddAddNode()
}, 50)
return () => clearTimeout(timer)
}, [graph])
const checkAndAddAddNode = () => {
if (!graph) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
// 如果只有一个cycle-start节点且没有其他类型的子节点则添加add-node
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const cycleStartBBox = cycleStartNode.getBBox();
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: cycleStartBBox.x + 84,
y: cycleStartBBox.y + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(addNode);
// 连接cycle-start和add-node
const sourcePorts = cycleStartNode.getPorts();
const targetPorts = addNode.getPorts();
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
// 然后创建连线
graph.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
cycleStartNode.toFront()
addNode.toFront()
}
}
const initNodes = () => {
// 检查是否存在cycle为当前节点ID的子节点若存在则不调用initNodes避免重复创建
const existingCycleNodes = graph.getNodes().filter((n: any) =>
n.getData()?.cycle === data.id
);
if (existingCycleNodes.length > 0) return;
// 添加默认子节点
const parentBBox = node.getBBox();
const centerX = parentBBox.x + 24;
const centerY = parentBBox.y + 70;
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: centerX,
y: centerY,
id: cycleStartNodeId,
data: {
id: cycleStartNodeId,
type: 'cycle-start',
parentId: node.id,
isDefault: true, // 标记为默认节点,不可删除
cycle: data.id,
},
});
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: centerX + 84,
y: centerY + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(cycleStartNode)
node.addChild(addNode)
const sourcePorts = cycleStartNode.getPorts()
const targetPorts = addNode.getPorts()
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const edgeConfig = {
source: {
cell: cycleStartNode.id,
port: sourcePort
},
target: {
cell: addNode.id,
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
...edgeAttrs
}
graph.addEdge(edgeConfig)
setTimeout(() => {
cycleStartNode.toFront()
addNode.toFront()
}, 0)
}
return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,

View File

@@ -43,52 +43,70 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
};
}, []);
// Handle node selection from popover menu and create new node with edge connection
const handleNodeSelect = (selectedNodeType: any) => {
if (!sourceNode || !graph) return;
const sourceNodeData = sourceNode.getData();
const sourceNodeType = sourceNodeData?.type;
const isCycleSubNode = !!sourceNodeData.cycle;
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Save add-node placeholder position before disabling history
// If it's a cycle-start node, handle the add-node placeholder
let addNodePosition = null;
const isCycleSubNode = sourceNodeData.cycle
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
const addNodes = graph.getNodes().filter((n: any) =>
const addNodes = graph.getNodes().filter((n: any) =>
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
);
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
if (addNodes.length > 0) {
const addNode = addNodes[0];
addNodePosition = addNode.getBBox();
addNode.remove();
}
}
// Calculate position
// Calculate new node position to avoid overlapping
const sourceBBox = sourceNode.getBBox();
const nw = graphNodeLibrary[newNodeType]?.width || 120;
const nh = graphNodeLibrary[newNodeType]?.height || 88;
const hSpacing = isCycleSubNode ? 48 : 80;
const vSpacing = 10;
const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
const horizontalSpacing = isCycleSubNode ? 48 : 80;
const verticalSpacing = 10;
// Get source port group information
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
let newX: number, newY: number;
// Calculate new node position
let newX, newY;
if (edgeInsertion) {
// Edge insertion: place new node on the same row as target, between source and target
const targetBBox = edgeInsertion.targetCell.getBBox();
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
const requiredSpace = nw + hSpacing * 4;
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = targetBBox.y + (targetBBox.height - nh) / 2;
const requiredSpace = nodeWidth + horizontalSpacing * 4;
// New node x: right after source + spacing
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
// Same row as target node
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
// If not enough space, shift target and all downstream nodes to the right
if (gap < requiredSpace) {
const shiftX = requiredSpace - gap;
const visited = new Set<string>();
const shiftDownstream = (cell: any) => {
if (visited.has(cell.id)) return;
visited.add(cell.id);
const cellId = cell.id;
if (visited.has(cellId)) return;
visited.add(cellId);
const pos = cell.getPosition();
cell.setPosition(pos.x + shiftX, pos.y);
// Recursively shift nodes connected from right ports
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
const tCell = graph.getCellById(e.getTargetCellId());
if (tCell?.isNode()) shiftDownstream(tCell);
const tId = e.getTargetCellId();
if (tId && !visited.has(tId)) {
const tCell = graph.getCellById(tId);
if (tCell?.isNode()) shiftDownstream(tCell);
}
});
};
shiftDownstream(edgeInsertion.targetCell);
@@ -96,170 +114,208 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
} else if (addNodePosition) {
newX = addNodePosition.x;
newY = addNodePosition.y;
} else if (sourcePortGroup === 'left') {
newX = sourceBBox.x - nw * 2 - hSpacing;
newY = sourceBBox.y;
} else {
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = sourceBBox.y;
const connectedNodes = new Set<string>();
graph.getConnectedEdges(sourceNode).forEach((e: any) => {
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => {
if (cid !== sourceNode.id) connectedNodes.add(cid);
// Determine node placement direction based on port position
if (sourcePortGroup === 'left') {
// Left port: add node to the left
newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
newY = sourceBBox.y;
} else {
// Right port: add node to the right
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
newY = sourceBBox.y;
}
// Check if position overlaps with existing nodes (only consider connected nodes)
const checkOverlap = (x: number, y: number) => {
// Get nodes connected to the source node
const connectedNodes = new Set();
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
const sourceId = edge.getSourceCellId();
const targetId = edge.getTargetCellId();
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
});
});
const checkOverlap = (x: number, y: number) =>
graph.getNodes().some((n: any) => {
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false;
const b = n.getBBox();
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height);
return graph.getNodes().some((node: any) => {
if (node.id === sourceNode.id) return false;
if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
const bbox = node.getBBox();
return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
});
while (checkOverlap(newX, newY)) newY += nh + vSpacing;
};
// If position is occupied, search downward for empty space
while (checkOverlap(newX, newY)) {
newY += nodeHeight + verticalSpacing;
}
}
// Disable history for all graph mutations
graph.disableHistory();
// Remove add-node placeholder
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
graph.getNodes()
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
.forEach((n: any) => n.remove());
}
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
// Create new node
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const newNode = graph.addNode({
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default),
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: newX,
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
id,
data: {
id,
type: newNodeType,
type: selectedNodeType.type,
icon: selectedNodeType.icon,
name: t(`workflow.${newNodeType}`),
cycle: sourceNodeData.cycle,
name: t(`workflow.${selectedNodeType.type}`),
cycle: sourceNodeData.cycle, // Inherit cycle from source node
config: selectedNodeType.config || {}
},
});
// Add new node as child of parent node
if (sourceNodeData.cycle) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
if (parentNode) parentNode.addChild(newNode, { silent: true });
}
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
else graph.removeEdge(oldEdge);
}
const newPorts = newNode.getPorts();
const addedCells: any[] = [newNode];
if (edgeInsertion) {
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
} else {
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
}
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
if (isCycleContainer(newNodeType)) {
const parentBBox = newNode.getBBox();
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
});
const addNodePlaceholder = graph.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
});
newNode.addChild(cycleStartNode, { silent: true });
newNode.addChild(addNodePlaceholder, { silent: true });
const innerEdge = graph.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
}
// Adjust parent size if adding inside a cycle container
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const b = child.getBBox();
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
});
}
parentNode.addChild(newNode);
}
}
// toFront
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); });
};
if (isCycleContainer(sourceNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
} else if (isCycleContainer(newNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
} else {
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
// Edge insertion: remove old edge immediately before creating new edges
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) {
graph.removeCell(oldEdge.id);
} else {
graph.removeEdge(oldEdge);
}
}
// Re-enable history and manually push one batch frame for all added cells
graph.enableHistory();
const history = graph.getPlugin('history') as any;
if (history) {
const batchFrame = addedCells.map((cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
}));
history.undoStack.push(batchFrame);
history.redoStack = [];
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
}
// Create edge connection
setTimeout(() => {
const newPorts = newNode.getPorts();
const addedEdges: any[] = [];
if (edgeInsertion) {
// Edge insertion: create source→new and new→target edges
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: newLeftPort },
...edgeAttrs
}));
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newRightPort },
target: { cell: targetCell.id, port: origTargetPort },
...edgeAttrs
}));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
// Connect from left port to new node's right side
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: targetPort },
target: { cell: sourceNode.id, port: sourcePort },
...edgeAttrs
}));
} else {
// Connect from right port to new node's left side
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: targetPort },
...edgeAttrs
}));
}
// Adjust loop node size when child node is added via port within loop node
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = parentNode.getPorts();
ports.forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
}
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Helper: bring all child nodes and their edges of a cycle container to front
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.cycle === cycleContainerId) n.toFront();
});
};
if (isCycleContainer(sourceNodeType)) {
console.log('isCycleContainer(sourceNodeType)')
// Case 4: source is a loop/iteration node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront();
bringCycleChildrenToFront(sourceNodeData.id);
} else if (isCycleContainer(newNodeType)) {
console.log('isCycleContainer(newNodeType)')
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront()
bringCycleChildrenToFront(id);
} else {
// Case 2: normal node → normal node
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}
}, 50);
// Clean up temporary element
if (tempElement) {
document.body.removeChild(tempElement);
setTempElement(null);
}
setPopoverVisible(false);
};
@@ -335,4 +391,4 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
);
};
export default PortClickHandler;
export default PortClickHandler;

View File

@@ -242,11 +242,10 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
>
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
: parameter.type === 'boolean'
? <Switch key={values.tool_id} size="small" />
? <Switch size="small" />
: <Editor
key={values.tool_id}
variant="outlined"
type="input"
size="small"

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-27 14:07:14
* @Last Modified time: 2026-04-21 18:23:31
*/
import type { ReactShapeConfig } from '@antv/x6-react-shape';
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
@@ -948,15 +948,6 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
width: nodeWidth,
height: 120,
shape: 'notes-node',
},
output: {
width: nodeWidth,
height: 76,
shape: 'normal-node',
ports: {
groups: { left: defaultPortGroup },
items: [defaultPortItems[0]],
},
}
}

View File

@@ -2,9 +2,10 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 15:17:48
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-28 13:49:11
* @Last Modified time: 2026-04-24 17:21:09
*/
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
import { register } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port';
import { App } from 'antd';
@@ -16,7 +17,7 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
import { useUser } from '@/store/user';
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types';
import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
import { useWorkflowStore } from '@/store/workflow';
@@ -85,10 +86,6 @@ export interface UseWorkflowGraphReturn {
/** Get start node output variable list (user-defined + system variables) */
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
nodeClick: ({ node }: { node: Node }) => void;
/** All recorded history operations */
historyRecords: HistoryRecord[];
/** Clear history records */
clearHistoryRecords: () => void;
}
/**
@@ -122,19 +119,14 @@ export const useWorkflowGraph = ({
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
const [canUndo, setCanUndo] = useState(false)
const [canRedo, setCanRedo] = useState(false)
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
const undoRef = useRef<() => void>(() => {})
const redoRef = useRef<() => void>(() => {})
const syncChildRelationshipsRef = useRef<() => void>(() => {})
const isSyncingRef = useRef(false)
useEffect(() => {
if (!graphRef.current) return
graphRef.current.getNodes().forEach(node => {
const data = node.getData()
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
console.log('chatVariables', chatVariables)
node.setData({ ...data, chatVariables })
node.setData({ ...data, chatVariables }, { silent: true })
}
})
}, [chatVariables])
@@ -351,7 +343,7 @@ export const useWorkflowGraph = ({
if (parentNode) {
const addedChild = graphRef.current?.addNode(childNode)
if (addedChild) {
parentNode.addChild(addedChild, { silent: true })
parentNode.addChild(addedChild)
}
}
}
@@ -382,6 +374,8 @@ export const useWorkflowGraph = ({
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
console.log('newWidth', newHeight, newWidth)
parentNode.prop('size', { width: newWidth, height: newHeight })
// Update x position of right group ports
@@ -494,77 +488,8 @@ export const useWorkflowGraph = ({
graphRef.current.cleanHistory()
}
}, 200)
} else {
graphRef.current.enableHistory()
graphRef.current.cleanHistory()
}
}
const resizeGroupNodes = (graph: Graph) => {
graph.getNodes().forEach(parentNode => {
const parentType = parentNode.getData()?.type
if (parentType !== 'loop' && parentType !== 'iteration') return
const children = graph.getNodes().filter(
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
)
if (!children.length) return
const padding = 24
const headerHeight = 50
const childBounds = children.map(c => c.getBBox())
const minX = Math.min(...childBounds.map(b => b.x))
const minY = Math.min(...childBounds.map(b => b.y))
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
const parentBBox = parentNode.getBBox()
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
parentNode.prop('size', { width: newWidth, height: newHeight })
parentNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth)
}
})
})
}
const syncChildRelationships = () => {
if (!graphRef.current) return
const graph = graphRef.current
graph.disableHistory()
graph.getNodes().forEach(node => {
const cycleId = node.getData()?.cycle
if (!cycleId) return
const parentNode = graph.getCellById(cycleId) as Node | null
if (!parentNode) return
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
parentNode.addChild(node, { silent: true })
}
})
graph.getNodes().forEach(node => {
const children = node.getChildren()
if (!children?.length) return
children.forEach(child => {
if (!child.isNode()) return
const childCycleId = (child as Node).getData?.()?.cycle
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
node.removeChild(child, { silent: true })
}
})
})
resizeGroupNodes(graph)
graph.getEdges().forEach(edge => {
const src = graph.getCellById(edge.getSourceCellId())
const tgt = graph.getCellById(edge.getTargetCellId())
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
edge.toFront()
}
})
graph.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront()
})
graph.enableHistory()
}
syncChildRelationshipsRef.current = syncChildRelationships
/**
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
*/
@@ -600,44 +525,18 @@ export const useWorkflowGraph = ({
new History({
enabled: false,
beforeAddCommand(_event, args: any) {
const key = args?.key
if (key === 'attrs' || key === 'tools') return false
const event = args?.key ? `cell:change:${args.key}` : _event;
if (event.startsWith('cell:change:') &&
event !== 'cell:change:position' &&
event !== 'cell:change:source' &&
event !== 'cell:change:target') return false;
},
}),
);
const MERGE_INTERVAL = 1000
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
setCanUndo(graphRef.current?.canUndo() ?? false)
setCanRedo(graphRef.current?.canRedo() ?? false)
console.log('history:change', cmds, options)
const batchName: string | undefined = options?.name
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
const now = Date.now()
const last = lastHistoryRef.current
const canMerge =
actionType === 'change' &&
last?.type === 'change' &&
now - last.timestamp < MERGE_INTERVAL &&
cellIds.length > 0 &&
cellIds.length === last.cellIds.length &&
cellIds.every((id, i) => id === last.cellIds[i])
if (canMerge) {
lastHistoryRef.current!.timestamp = now
setHistoryRecords(prev => {
const next = [...prev]
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
return next
})
} else {
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
setHistoryRecords(prev => [...prev, record])
}
})
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
};
// 显示/隐藏连接桩
// const showPorts = (show: boolean) => {
@@ -670,13 +569,13 @@ export const useWorkflowGraph = ({
vo.setData({
...data,
isSelected: false,
}, { silent: true });
});
}
});
node.setData({
...nodeData,
isSelected: true,
}, { silent: true });
});
clearEdgeSelect()
if (nodeData.type !== 'notes') {
setSelectedNode(node);
@@ -690,7 +589,7 @@ export const useWorkflowGraph = ({
const edgeClick = ({ edge }: { edge: Edge }) => {
clearEdgeSelect();
edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true });
edge.setData({ ...edge.getData(), isSelected: true });
clearNodeSelect();
};
/**
@@ -705,7 +604,7 @@ export const useWorkflowGraph = ({
node.setData({
...data,
isSelected: false,
}, { silent: true });
});
}
});
setSelectedNode(null);
@@ -715,7 +614,7 @@ export const useWorkflowGraph = ({
*/
const clearEdgeSelect = () => {
graphRef.current?.getEdges().forEach(e => {
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true });
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
e.setAttrByPath('line/stroke', edge_color);
e.setAttrByPath('line/strokeWidth', edge_width);
});
@@ -854,6 +753,8 @@ export const useWorkflowGraph = ({
// Find corresponding parent node
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
if (parentNode) {
// Use removeChild method to delete child node
parentNode.removeChild(nodeToDelete);
parentNodesToUpdate.push(parentNode);
}
// Add child node to deletion list
@@ -881,51 +782,42 @@ export const useWorkflowGraph = ({
// Delete all collected nodes and edges
if (cells.length > 0) {
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
const parentsNeedingAddNode = parentNodesToUpdate
.filter(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
const parentData = parentNode.getData();
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
// After deletion, only cycle-start will remain
const nonCycleStartToDelete = cells.filter(c =>
c.isNode() &&
(c as Node).getData()?.cycle === parentData.id &&
(c as Node).getData()?.type !== 'cycle-start'
);
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
})
.map(parentNode => ({
parentNode,
cycleStartNode: graphRef.current!.getNodes().find(
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
)!
}))
.filter(({ cycleStartNode }) => !!cycleStartNode);
graphRef.current?.startBatch('delete');
graphRef.current?.removeCells(cells);
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => {
// If parent is iteration/loop and only cycle-start remains, add add-node connected to it
parentNodesToUpdate.forEach(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
const parentData = parentNode.getData();
const bbox = cycleStartNode.getBBox();
const addNode = graphRef.current!.addNode({
...graphNodeLibrary.addStart,
x: bbox.x + 84,
y: bbox.y + 4,
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' },
});
parentNode.addChild(addNode, { silent: true });
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
const remainingChildren = graphRef.current!.getNodes().filter(
n => n.getData()?.cycle === parentData.id
);
const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const bbox = cycleStartNode.getBBox();
const addNode = graphRef.current!.addNode({
...graphNodeLibrary.addStart,
x: bbox.x + 84,
y: bbox.y + 4,
data: {
type: 'add-node',
parentId: parentNode.id,
cycle: parentData.id,
label: t('workflow.addNode'),
icon: '+',
},
});
parentNode.addChild(addNode);
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
}
});
graphRef.current?.stopBatch('delete');
}
return false;
};
@@ -1144,7 +1036,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true });
edge.setData({ ...edge.getData(), isNodeHover: true });
}
});
});
@@ -1152,7 +1044,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_color);
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true });
edge.setData({ ...edge.getData(), isNodeHover: false });
}
});
});
@@ -1234,8 +1126,8 @@ export const useWorkflowGraph = ({
// Delete selected nodes and edges
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
// Undo / Redo
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; });
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
};
@@ -1301,51 +1193,13 @@ export const useWorkflowGraph = ({
};
if (dragData.type === 'loop' || dragData.type === 'iteration') {
graph.disableHistory()
const parentNode = graphRef.current.addNode({
graphRef.current.addNode({
...graphNodeLibrary[dragData.type],
x: point.x - 150,
y: point.y - 100,
id: cleanNodeData.id,
data: { ...cleanNodeData, isGroup: true },
})
const parentBBox = parentNode.getBBox()
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graphRef.current.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
})
const addNode = graphRef.current.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
})
parentNode.addChild(cycleStartNode, { silent: true })
parentNode.addChild(addNode, { silent: true })
const newEdge = graphRef.current.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
})
cycleStartNode.toFront()
addNode.toFront()
graph.enableHistory()
// Manually push a single batch frame covering all 4 cells into undoStack
const history = graph.getPlugin('history') as History
const makeBatchCmd = (cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
})
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
;(history as any).undoStack.push(batchFrame)
;(history as any).redoStack = []
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
});
} else if (dragData.type === 'if-else') {
// Create condition node
graphRef.current.addNode({
@@ -1592,80 +1446,8 @@ export const useWorkflowGraph = ({
return userVars
}
const clearHistoryRecords = () => {
setHistoryRecords([])
lastHistoryRef.current = null
}
const getStackCellIds = (cmds: any): string[] => {
const arr = Array.isArray(cmds) ? cmds : [cmds]
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
}
const isSkippableFrame = (frame: any): boolean => {
const arr = Array.isArray(frame) ? frame : [frame]
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
}
const undo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getUndoSize() === 0) return
const undoStack = (history as any).undoStack as any[]
isSyncingRef.current = true
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
}
if (undoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
graphRef.current!.undo()
while (undoStack.length > 0) {
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
continue
}
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.undo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const redo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getRedoSize() === 0) return
const redoStack = (history as any).redoStack as any[]
isSyncingRef.current = true
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
}
if (redoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
graphRef.current!.redo()
while (redoStack.length > 0) {
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
continue
}
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.redo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const undo = () => graphRef.current?.undo()
const redo = () => graphRef.current?.redo()
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
const { statement = '' } = value?.opening_statement || {}
@@ -1706,16 +1488,20 @@ export const useWorkflowGraph = ({
if (!graphRef.current) return;
const nodes = graphRef.current.getNodes();
// Reset all node execution status on every chatHistory change
const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
// Reset all node execution status first
nodes.forEach(node => {
const data = node.getData();
node.setData({ ...data, executionStatus: '' });
if (typeof data.executionStatus === 'string') {
node.setData({ ...data, executionStatus: undefined });
}
});
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant');
if (!lastAssistant?.subContent?.length) return;
lastAssistant.subContent.forEach(sub => {
if (!lastWithSub?.subContent) return;
// Build a nodeId -> status map first
const statusMap: Record<string, string> = {};
lastWithSub.subContent.forEach(sub => {
if (typeof sub.status === 'string') {
statusMap[sub.node_id] = sub.status;
const node = nodes.find(n => n.getData()?.id === sub.node_id);
if (node) {
node.setData({ ...node.getData(), executionStatus: sub.status });
@@ -1751,7 +1537,5 @@ export const useWorkflowGraph = ({
canRedo,
undo,
redo,
historyRecords,
clearHistoryRecords,
};
};

View File

@@ -113,13 +113,4 @@ export interface ChatVariable {
}
export interface AddChatVariableRef {
handleOpen: (value?: ChatVariable) => void;
}
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
export interface HistoryRecord {
type: HistoryActionType;
timestamp: number;
batchName?: string;
cellIds?: string[];
}

View File

@@ -17,7 +17,6 @@ export const isSubExprSet = (sub: any) => {
* Uses the same per-expression height logic as getConditionNodeCasePortY.
*/
export const calcConditionNodeTotalHeight = (cases: any[]) => {
if (!cases?.length) return conditionNodeHeight;
const casesHeight = cases.reduce((acc: number, c: any) => {
const exprs = c?.expressions ?? [];
const n = exprs.length;