Compare commits

..

9 Commits

Author SHA1 Message Date
wxy
cef33fce0d fix(workflow): sanitize condition expression building and cache assigner node inputs
- Sanitize condition expression construction in graph_builder.py using json.dumps to prevent potential injection vulnerabilities.
- Cache input data prior to assigner node execution to ensure variable values are correctly captured before processing.
2026-05-07 16:26:47 +08:00
wxy
d9f08860bc feat(LLM node): integrate exception handling and enable branch routing
- Integrate exception handling configuration into LLM nodes, supporting three strategies: throw exception, return default value, or trigger exception branch.
- Modify execution logic to return a result structure containing a branch signal, enabling routing to designated branches upon failure.
- Update graph_builder to support LLM node branch routing logic using the branch_signal field for conditional judgment.
- Implement backward compatibility to support both legacy and new result formats.
2026-05-07 11:43:24 +08:00
wxy
461674c8d8 feat(workflow): parse and substitute template variables in node configurations
- Implement regex matching for {{xxx}} template variable format.
- Enable recursive parsing of all string template variables within node configurations.
- Resolve and substitute template variables with runtime values during input data extraction.
- Support dynamic parsing and substitution of file selector variables in the document extraction node.
- Make strict template variable mode optional and introduce support for default values.
2026-04-29 14:10:02 +08:00
wxy
c59e179cc2 feat(workflow): incorporate model references and streamline parsing logic
- Incorporate model reference metadata (name, provider, type) into workflow nodes and refactor parsing logic to support the new format.
- Streamline code structure by removing redundant model_id fields to enhance maintainability.
2026-04-28 11:18:06 +08:00
Mark
a5670bfff6 Merge branch 'feature/rag2' into develop 2026-04-27 18:17:49 +08:00
Mark
4bef9b578b [fix] document file delete 2026-04-27 17:35:13 +08:00
Mark
c53fcf3981 [fix] old code file_path 2026-04-27 17:10:00 +08:00
Mark
2997558bc8 Merge branch 'release/v0.3.2' into feature/rag2
* release/v0.3.2: (245 commits)
  fix(conversation_schema): refine citations field type to Dict[str, Any]
  fix(tool_controller): re-raise HTTPException to preserve original status codes
  fix(workflow): add reasoning content, suggested questions, citations and audio status support
  feat(workflow): augment logging queries and ameliorate error handling
  fix(api_key): bypass publication check for SERVICE type API keys
  fix(multimodal_service): add '文档内容:' prefix to document text and simplify image placeholder text
  fix(api): convert config_id to string in write_router
  fix(api): convert end_user_id to string in write_router
  fix(multimodal_service): refactor image processing to use intermediate list before extending result
  fix(web): node status ui
  fix(api): correct import paths in memory_read and celery task command
  fix(api): correct import paths in memory_read and celery task command
  refactor(tool): flatten request body parameters for model exposure
  fix(api): correct import paths in memory_read and celery task command
  refactor(workflow): streamline node execution handling and log service logic
  feat(web): http request add process
  feat(web): workflow app logs
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  fix(app_chat_service,draft_run_service): move system_prompt augmentation before LangChainAgent instantiation
  refactor(http_request): simplify request handling and remove unused fields
  ...

# Conflicts:
#	api/app/controllers/file_controller.py
#	api/app/tasks.py
2026-04-27 16:13:57 +08:00
Mark
30cdf229de [modify] rag file system 2026-04-27 16:05:27 +08:00
48 changed files with 1052 additions and 1159 deletions

View File

@@ -3,9 +3,12 @@ name: Sync to Gitee
on: on:
push: push:
branches: branches:
- '**' # All branchs - main # Production
- develop # Integration
- 'release/*' # Release preparation
- 'hotfix/*' # Urgent fixes
tags: tags:
- '**' # All version tags (v1.0.0, etc.) - '*' # All version tags (v1.0.0, etc.)
jobs: jobs:
sync: sync:

View File

@@ -82,19 +82,32 @@ async def get_preview_chunks(
detail="The file does not exist or you do not have permission to access it" detail="The file does not exist or you do not have permission to access it"
) )
# 5. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} # 5. Get file content from storage backend
file_path = os.path.join( if not db_file.file_key:
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 6. Check if the file exists
if not os.path.exists(file_path):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)" detail="File has no storage key (legacy data not migrated)"
)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(db_file.file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
except Exception as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"File not found in storage: {e}"
) )
# 7. Document parsing & segmentation # 7. Document parsing & segmentation
@@ -104,11 +117,12 @@ async def get_preview_chunks(
vision_model = QWenCV( vision_model = QWenCV(
key=db_knowledge.image2text.api_keys[0].api_key, key=db_knowledge.image2text.api_keys[0].api_key,
model_name=db_knowledge.image2text.api_keys[0].model_name, model_name=db_knowledge.image2text.api_keys[0].model_name,
lang="Chinese", # Default to Chinese lang="Chinese",
base_url=db_knowledge.image2text.api_keys[0].api_base base_url=db_knowledge.image2text.api_keys[0].api_base
) )
from app.core.rag.app.naive import chunk from app.core.rag.app.naive import chunk
res = chunk(filename=file_path, res = chunk(filename=db_file.file_name,
binary=file_binary,
from_page=0, from_page=0,
to_page=5, to_page=5,
callback=progress_callback, callback=progress_callback,

View File

@@ -20,6 +20,7 @@ from app.models.user_model import User
from app.schemas import document_schema from app.schemas import document_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import document_service, file_service, knowledge_service from app.services import document_service, file_service, knowledge_service
from app.services.file_storage_service import FileStorageService, get_file_storage_service
# Obtain a dedicated API logger # Obtain a dedicated API logger
@@ -231,7 +232,8 @@ async def update_document(
async def delete_document( async def delete_document(
document_id: uuid.UUID, document_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """
Delete document Delete document
@@ -257,7 +259,7 @@ async def delete_document(
db.commit() db.commit()
# 3. Delete file # 3. Delete file
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user) await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
# 4. Delete vector index # 4. Delete vector index
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
@@ -305,38 +307,25 @@ async def parse_documents(
detail="The file does not exist or you do not have permission to access it" detail="The file does not exist or you do not have permission to access it"
) )
# 3. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} # 3. Get file_key for storage backend
file_path = os.path.join( if not db_file.file_key:
settings.FILE_PATH, api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. Check if the file exists
api_logger.debug(f"Constructed file path: {file_path}")
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
if not os.path.exists(file_path):
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)" detail="File has no storage key (legacy data not migrated)"
) )
# 5. Obtain knowledge base information # 4. Obtain knowledge base information
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}") api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user) db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
if not db_knowledge: if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 6. Task: Document parsing, vectorization, and storage # 5. Dispatch parse task with file_key (not file_path)
# from app.tasks import parse_document task = celery_app.send_task(
# parse_document(file_path, document_id) "app.core.rag.tasks.parse_document",
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id]) args=[db_file.file_key, document_id, db_file.file_name]
)
result = { result = {
"task_id": task.id "task_id": task.id
} }

View File

@@ -1,12 +1,10 @@
import os import os
from pathlib import Path
import shutil
from typing import Any, Optional from typing import Any, Optional
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import FileResponse from fastapi.responses import Response
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.config import settings from app.core.config import settings
@@ -19,10 +17,14 @@ from app.models.user_model import User
from app.schemas import file_schema, document_schema from app.schemas import file_schema, document_schema
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import file_service, document_service from app.services import file_service, document_service
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
from app.services.file_storage_service import (
FileStorageService,
generate_kb_file_key,
get_file_storage_service,
)
from app.core.quota_stub import check_knowledge_capacity_quota from app.core.quota_stub import check_knowledge_capacity_quota
# Obtain a dedicated API logger
api_logger = get_api_logger() api_logger = get_api_logger()
router = APIRouter( router = APIRouter(
@@ -35,67 +37,37 @@ router = APIRouter(
async def get_files( async def get_files(
kb_id: uuid.UUID, kb_id: uuid.UUID,
parent_id: uuid.UUID, parent_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0 page: int = Query(1, gt=0),
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items pagesize: int = Query(20, gt=0, le=100),
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"), orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"), desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (file name)"), keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """Paged query file list"""
Paged query file list api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
- Support filtering by kb_id and parent_id
- Support keyword search for file names
- Support dynamic sorting
- Return paging metadata + file list
"""
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions if page < 1 or pagesize < 1:
filters = [ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
file_model.File.kb_id == kb_id
] filters = [file_model.File.kb_id == kb_id]
if parent_id: if parent_id:
filters.append(file_model.File.parent_id == parent_id) filters.append(file_model.File.parent_id == parent_id)
# Keyword search (fuzzy matching of file name)
if keywords: if keywords:
filters.append(file_model.File.file_name.ilike(f"%{keywords}%")) filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
# 3. Execute paged query
try: try:
api_logger.debug("Start executing file paging query")
total, items = file_service.get_files_paginated( total, items = file_service.get_files_paginated(
db=db, db=db, filters=filters, page=page, pagesize=pagesize,
filters=filters, orderby=orderby, desc=desc, current_user=current_user
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
) )
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = { result = {
"items": items, "items": items,
"page": { "page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
} }
return success(data=jsonable_encoder(result), msg="Query of file list succeeded") return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
@@ -108,23 +80,14 @@ async def create_folder(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """Create a new folder"""
Create a new folder api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
"""
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
try: try:
api_logger.debug(f"Start creating a folder: {folder_name}") create_folder_data = file_schema.FileCreate(
create_folder = file_schema.FileCreate( kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
kb_id=kb_id, file_name=folder_name, file_ext='folder', file_size=0,
created_by=current_user.id,
parent_id=parent_id,
file_name=folder_name,
file_ext='folder',
file_size=0,
) )
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user) db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful") return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
except Exception as e: except Exception as e:
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}") api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
@@ -138,76 +101,58 @@ async def upload_file(
parent_id: uuid.UUID, parent_id: uuid.UUID,
file: UploadFile = File(...), file: UploadFile = File(...),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Upload file to storage backend"""
upload file api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
"""
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
# Read the contents of the file
contents = await file.read() contents = await file.read()
# Check file size
file_size = len(contents) file_size = len(contents)
print(f"file size: {file_size} byte")
if file_size == 0: if file_size == 0:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE: if file_size > settings.MAX_FILE_SIZE:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
# Extract the extension using `os.path.splitext`
_, file_extension = os.path.splitext(file.filename) _, file_extension = os.path.splitext(file.filename)
upload_file = file_schema.FileCreate( file_ext = file_extension.lower()
kb_id=kb_id,
created_by=current_user.id, # Create File record
parent_id=parent_id, upload_file_data = file_schema.FileCreate(
file_name=file.filename, kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
file_ext=file_extension.lower(), file_name=file.filename, file_ext=file_ext, file_size=file_size,
file_size=file_size,
) )
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension} # Upload to storage backend
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists try:
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file # Save file_key
with open(save_path, "wb") as f: db_file.file_key = file_key
f.write(contents) db.commit()
db.refresh(db_file)
# Verify whether the file has been saved successfully # Create document (inherit parser_config from knowledge base)
if not os.path.exists(save_path): default_parser_config = {
raise HTTPException( "layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
detail="File save failed" }
) try:
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
if db_knowledge and db_knowledge.parser_config:
default_parser_config.update(dict(db_knowledge.parser_config))
except Exception:
pass
# Create a document
create_data = document_schema.DocumentCreate( create_data = document_schema.DocumentCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
created_by=current_user.id, file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_id=db_file.id, file_meta={}, parser_id="naive", parser_config=default_parser_config
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
) )
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user) db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
@@ -221,123 +166,73 @@ async def custom_text(
parent_id: uuid.UUID, parent_id: uuid.UUID,
create_data: file_schema.CustomTextFileCreate, create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Custom text upload"""
custom text
"""
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
# Check file content size
# 将内容编码为字节UTF-8
content_bytes = create_data.content.encode('utf-8') content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes) file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0: if file_size == 0:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE: if file_size > settings.MAX_FILE_SIZE:
raise HTTPException( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
upload_file = file_schema.FileCreate( upload_file_data = file_schema.FileCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
created_by=current_user.id, file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
) )
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user) db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension} # Upload to storage backend
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id)) file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists try:
save_path = os.path.join(save_dir, f"{db_file.id}.txt") await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
except Exception as e:
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
# Save file db_file.file_key = file_key
with open(save_path, "wb") as f: db.commit()
f.write(content_bytes) db.refresh(db_file)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate( create_document_data = document_schema.DocumentCreate(
kb_id=kb_id, kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
created_by=current_user.id, file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
file_id=db_file.id, file_meta={}, parser_id="naive",
file_name=db_file.file_name, parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
file_ext=db_file.file_ext, "auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
) )
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user) db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful") return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
@router.get("/{file_id}", response_model=Any) @router.get("/{file_id}", response_model=Any)
async def get_file( async def get_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db) db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any: ) -> Any:
""" """Download file by file_id"""
Download the file based on the file_id
- Query file information from the database
- Construct the file path and check if it exists
- Return a FileResponse to download the file
"""
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
# 1. Query file information from the database
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct file path/files/{kb_id}/{parent_id}/{file.id}{file.file_ext} if not db_file.file_key:
file_path = os.path.join( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Check if the file exists try:
if not os.path.exists(file_path): content = await storage_service.download_file(db_file.file_key)
raise HTTPException( except Exception as e:
status_code=status.HTTP_404_NOT_FOUND, api_logger.error(f"Storage download failed: {e}")
detail="File not found (possibly deleted)" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
)
# 4.Return FileResponse (automatically handle download) import mimetypes
return FileResponse( media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
path=file_path, return Response(
filename=db_file.file_name, # Use original file name content=content,
media_type="application/octet-stream" # Universal binary stream type media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
) )
@@ -348,50 +243,22 @@ async def update_file(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """Update file information (such as file name)"""
Update file information (such as file name)
- Only specified fields such as file_name are allowed to be modified
"""
api_logger.debug(f"Query the file to be updated: {file_id}")
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the file fields: {file_id}")
updated_fields = []
for field, value in update_data.dict(exclude_unset=True).items(): for field, value in update_data.dict(exclude_unset=True).items():
if hasattr(db_file, field): if hasattr(db_file, field):
old_value = getattr(db_file, field) setattr(db_file, field, value)
if old_value != value:
# update value
setattr(db_file, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try: try:
db.commit() db.commit()
db.refresh(db_file) db.refresh(db_file)
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
except Exception as e: except Exception as e:
db.rollback() db.rollback()
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File update failed: {str(e)}"
)
# 4. Return the updated file
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully") return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
@@ -399,60 +266,43 @@ async def update_file(
async def delete_file( async def delete_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
): ):
""" """Delete a file or folder"""
Delete a file or folder api_logger.info(f"Request to delete file: file_id={file_id}")
""" await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
await _delete_file(db=db, file_id=file_id, current_user=current_user)
return success(msg="File deleted successfully") return success(msg="File deleted successfully")
async def _delete_file( async def _delete_file(
file_id: uuid.UUID, file_id: uuid.UUID,
db: Session = Depends(get_db), db: Session,
current_user: User = Depends(get_current_user) current_user: User,
storage_service: FileStorageService,
) -> None: ) -> None:
""" """Delete a file or folder from storage and database"""
Delete a file or folder
"""
# 1. Check if the file exists
db_file = file_service.get_file_by_id(db, file_id=file_id) db_file = file_service.get_file_by_id(db, file_id=file_id)
if not db_file: if not db_file:
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist or you do not have permission to access it"
)
# 2. Construct physical path # Delete from storage backend
file_path = Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.id)
) if db_file.file_ext == 'folder' else Path(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 3. Delete physical files/folders
try:
if file_path.exists():
if db_file.file_ext == 'folder':
shutil.rmtree(file_path) # Recursively delete folders
else:
file_path.unlink() # Delete a single file
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete physical file/folder: {str(e)}"
)
# 4.Delete db_file
if db_file.file_ext == 'folder': if db_file.file_ext == 'folder':
# For folders, delete all child files from storage first
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
for child in child_files:
if child.file_key:
try:
await storage_service.delete_file(child.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete() db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
else:
if db_file.file_key:
try:
await storage_service.delete_file(db_file.file_key)
except Exception as e:
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
db.delete(db_file) db.delete(db_file)
db.commit() db.commit()

View File

@@ -296,7 +296,7 @@ async def chat(
} }
) )
# workflow 非流式返回 # 多 Agent 非流式返回
result = await app_chat_service.workflow_chat( result = await app_chat_service.workflow_chat(
message=payload.message, message=payload.message,

View File

@@ -221,7 +221,7 @@ def update_workspace_members(
@router.delete("/members/{member_id}", response_model=ApiResponse) @router.delete("/members/{member_id}", response_model=ApiResponse)
@cur_workspace_access_guard() @cur_workspace_access_guard()
async def delete_workspace_member( def delete_workspace_member(
member_id: uuid.UUID, member_id: uuid.UUID,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
@@ -230,7 +230,7 @@ async def delete_workspace_member(
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
await workspace_service.delete_workspace_member( workspace_service.delete_workspace_member(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
member_id=member_id, member_id=member_id,

View File

@@ -241,8 +241,6 @@ class Settings:
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587")) SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "") SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "") SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))

View File

@@ -216,7 +216,7 @@ class RedBearModelFactory:
# 深度思考模式Claude 3.7 Sonnet 等支持思考的模型 # 深度思考模式Claude 3.7 Sonnet 等支持思考的模型
# 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项) # 通过 additional_model_request_fields 传递 thinking 块关闭时不传Bedrock 无 disabled 选项)
if config.deep_thinking: if config.deep_thinking:
budget = config.thinking_budget_tokens or 1024 budget = config.thinking_budget_tokens or 10000
params["additional_model_request_fields"] = { params["additional_model_request_fields"] = {
"thinking": {"type": "enabled", "budget_tokens": budget} "thinking": {"type": "enabled", "budget_tokens": budget}
} }

View File

@@ -14,6 +14,7 @@ Transcribe the content from the provided PDF page image into clean Markdown form
6. Do NOT wrap the output in ```markdown or ``` blocks. 6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image. 7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image. 8. Preserve the original language, information, and order exactly as shown in the image.
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
{% if page %} {% if page %}
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`. At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.

View File

@@ -2,6 +2,7 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
import json
import logging import logging
import re import re
import uuid import uuid
@@ -141,9 +142,10 @@ class GraphBuilder:
for node_info in source_nodes: for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES: if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append( if node_info.get("branch") is not None:
(node_info["id"], node_info["branch"]) branch_nodes.append(
) (node_info["id"], node_info["branch"])
)
else: else:
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT): if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
output_nodes.append(node_info["id"]) output_nodes.append(node_info["id"])
@@ -314,9 +316,12 @@ class GraphBuilder:
for idx in range(len(related_edge)): for idx in range(len(related_edge)):
# Generate a condition expression for each edge # Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output # Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label # For LLM nodes, use branch_signal field for routing (output is dynamic text)
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' # For other branch nodes (e.g. HTTP), use output field
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" route_field = "branch_signal" if node_type == NodeType.LLM else "output"
related_edge[idx]['condition'] = (
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
)
if node_instance: if node_instance:
# Wrap node's run method to avoid closure issues # Wrap node's run method to avoid closure issues

View File

@@ -18,10 +18,17 @@ class AssignerNode(BaseNode):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None
self._input_data: dict[str, Any] | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {} return {}
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
if self._input_data is not None:
return self._input_data
return {"config": self._resolve_config(self.config, variable_pool)}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
Execute the assignment operation defined by this node. Execute the assignment operation defined by this node.
@@ -34,6 +41,9 @@ class AssignerNode(BaseNode):
Returns: Returns:
None or the result of the assignment operation. None or the result of the assignment operation.
""" """
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
# Initialize a variable pool for accessing conversation, node, and system variables # Initialize a variable pool for accessing conversation, node, and system variables
self.typed_config = AssignerNodeConfig(**self.config) self.typed_config = AssignerNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} 开始执行") logger.info(f"节点 {self.node_id} 开始执行")

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
import re
import time import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@@ -22,6 +23,9 @@ from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 匹配模板变量 {{xxx}} 的正则
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class NodeExecutionError(Exception): class NodeExecutionError(Exception):
"""节点执行失败异常。 """节点执行失败异常。
@@ -503,10 +507,29 @@ class BaseNode(ABC):
variable_pool: The variable pool used for reading and writing variables. variable_pool: The variable pool used for reading and writing variables.
Returns: Returns:
A dictionary containing the node's input data. A dictionary containing the node's input data with all template
variables resolved to their actual runtime values.
""" """
# Default implementation returns the node configuration return {"config": self._resolve_config(self.config, variable_pool)}
return {"config": self.config}
@staticmethod
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
Args:
config: 节点的原始配置(可能包含模板变量)。
variable_pool: 变量池,用于解析模板变量。
Returns:
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
"""
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
return BaseNode._render_template(config, variable_pool, strict=False)
elif isinstance(config, dict):
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
elif isinstance(config, list):
return [BaseNode._resolve_config(item, variable_pool) for item in config]
return config
def _extract_output(self, business_result: Any) -> Any: def _extract_output(self, business_result: Any) -> Any:
"""Extracts the actual output from the business result. """Extracts the actual output from the business result.

View File

@@ -14,7 +14,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import BaseNode from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -132,7 +131,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient(timeout=60) as client:
response = await client.post( response = await client.post(
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run", "http://sandbox:8194/v1/sandbox/run",
headers={ headers={
"x-api-key": 'redbear-sandbox' "x-api-key": 'redbear-sandbox'
}, },

View File

@@ -121,7 +121,10 @@ class DocExtractorNode(BaseNode):
return business_result return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")} file_selector = self.config.get("file_selector", "")
# 将变量选择器(如 sys.files解析为实际值
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
return {"file_selector": resolved}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config) config = DocExtractorNodeConfig(**self.config)
@@ -182,7 +185,7 @@ class DocExtractorNode(BaseNode):
mime_type=f"image/{ext}", mime_type=f"image/{ext}",
is_file=True, is_file=True,
).model_dump()) ).model_dump())
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">" text = text + f"\n{placeholder}: {url}"
except Exception as e: except Exception as e:
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}") logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")

View File

@@ -31,7 +31,7 @@ class NodeType(StrEnum):
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -6,6 +6,7 @@ import uuid
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
@@ -49,6 +50,20 @@ class MemoryWindowSetting(BaseModel):
) )
class LLMErrorHandleConfig(BaseModel):
"""LLM 异常处理配置"""
method: HttpErrorHandle = Field(
default=HttpErrorHandle.NONE,
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
)
output: str = Field(
default="",
description="LLM 异常时返回的默认输出文本method=default 时生效)",
)
class LLMNodeConfig(BaseNodeConfig): class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置 """LLM 节点配置
@@ -152,6 +167,11 @@ class LLMNodeConfig(BaseNodeConfig):
description="输出变量定义(自动生成,通常不需要修改)" description="输出变量定义(自动生成,通常不需要修改)"
) )
error_handle: LLMErrorHandleConfig = Field(
default_factory=LLMErrorHandleConfig,
description="LLM 异常处理配置",
)
@field_validator("messages", "prompt") @field_validator("messages", "prompt")
@classmethod @classmethod
def validate_input_mode(cls, v): def validate_input_mode(cls, v):

View File

@@ -15,6 +15,7 @@ from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpErrorHandle
from app.core.workflow.nodes.llm.config import LLMNodeConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_context from app.db import get_db_context
@@ -76,7 +77,7 @@ class LLMNode(BaseNode):
self.messages = [] self.messages = []
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
def _render_context(self, message: str, variable_pool: VariablePool): def _render_context(self, message: str, variable_pool: VariablePool):
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>" context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
@@ -239,7 +240,7 @@ class LLMNode(BaseNode):
return llm return llm
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: async def execute(self, state: WorkflowState, variable_pool: VariablePool):
"""非流式执行 LLM 调用 """非流式执行 LLM 调用
Args: Args:
@@ -247,28 +248,36 @@ class LLMNode(BaseNode):
variable_pool: 变量池 variable_pool: 变量池
Returns: Returns:
LLM 响应消息 dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
{"llm_result": None, "branch_signal": "ERROR"} on branch error
""" """
# self.typed_config = LLMNodeConfig(**self.config) try:
llm = await self._prepare_llm(state, variable_pool, False) # self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, False)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表 # 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(self.messages) response = await llm.ainvoke(self.messages)
# 提取内容 # 提取内容
if hasattr(response, 'content'): if hasattr(response, 'content'):
content = self.process_model_output(response.content) content = self.process_model_output(response.content)
else: else:
content = str(response) content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}") logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据 # 返回 AIMessage包含响应元数据
return AIMessage(content=content, response_metadata={ return {
**response.response_metadata, "llm_result": AIMessage(content=content, response_metadata={
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage') **response.response_metadata,
}) "token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
}),
"branch_signal": "SUCCESS",
}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
return self._handle_llm_error(e)
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""提取输入数据(用于记录)""" """提取输入数据(用于记录)"""
@@ -286,16 +295,36 @@ class LLMNode(BaseNode):
} }
} }
def _extract_output(self, business_result: Any) -> str: def _extract_output(self, business_result: Any) -> dict:
""" AIMessage 中提取文本内容""" """业务结果中提取输出变量
支持新旧两种格式:
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
- 旧格式AIMessage向后兼容
"""
if isinstance(business_result, dict) and "branch_signal" in business_result:
llm_result = business_result.get("llm_result")
if isinstance(llm_result, AIMessage):
return {
"output": llm_result.content,
"branch_signal": business_result["branch_signal"],
}
return {
"output": str(llm_result) if llm_result else "",
"branch_signal": business_result["branch_signal"],
}
# 旧格式向后兼容
if isinstance(business_result, AIMessage): if isinstance(business_result, AIMessage):
return business_result.content return {"output": business_result.content, "branch_signal": "SUCCESS"}
return str(business_result) return {"output": str(business_result), "branch_signal": "SUCCESS"}
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
""" AIMessage 中提取 token 使用情况""" """业务结果中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'): llm_result = business_result
usage = business_result.response_metadata.get('token_usage') if isinstance(business_result, dict):
llm_result = business_result.get("llm_result", business_result)
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
usage = llm_result.response_metadata.get('token_usage')
if usage: if usage:
return { return {
"prompt_tokens": usage.get('input_tokens', 0), "prompt_tokens": usage.get('input_tokens', 0),
@@ -304,6 +333,44 @@ class LLMNode(BaseNode):
} }
return None return None
def _handle_llm_error(self, error: Exception) -> dict:
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
Args:
error: LLM 调用中捕获的异常
Returns:
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
or default output for default mode
Raises:
原异常(当 error_handle.method 为 NONE 时)
"""
if self.typed_config is None:
raise error
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
raise error
case HttpErrorHandle.DEFAULT:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
)
default_output = self.typed_config.error_handle.output or ""
return {
"llm_result": AIMessage(content=default_output, response_metadata={}),
"branch_signal": "SUCCESS",
}
case HttpErrorHandle.BRANCH:
logger.warning(
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
)
return {
"llm_result": None,
"branch_signal": "ERROR",
}
raise error
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
"""流式执行 LLM 调用 """流式执行 LLM 调用
@@ -316,54 +383,58 @@ class LLMNode(BaseNode):
""" """
self.typed_config = LLMNodeConfig(**self.config) self.typed_config = LLMNodeConfig(**self.config)
llm = await self._prepare_llm(state, variable_pool, True) try:
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 累积完整响应 # 累积完整响应
full_response = "" full_response = ""
chunk_count = 0 chunk_count = 0
# 调用 LLM流式支持字符串或消息列表 # 调用 LLM流式支持字符串或消息列表
last_meta_data = {} last_meta_data = {}
last_usage_metadata = {} last_usage_metadata = {}
async for chunk in llm.astream(self.messages): async for chunk in llm.astream(self.messages):
if hasattr(chunk, 'content'): if hasattr(chunk, 'content'):
content = self.process_model_output(chunk.content) content = self.process_model_output(chunk.content)
else: else:
content = str(chunk) content = str(chunk)
if hasattr(chunk, 'response_metadata') and chunk.response_metadata: if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
last_meta_data = chunk.response_metadata last_meta_data = chunk.response_metadata
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata: if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
last_usage_metadata = chunk.usage_metadata last_usage_metadata = chunk.usage_metadata
# 只有当内容不为空时才处理 # 只有当内容不为空时才处理
if content: if content:
full_response += content full_response += content
chunk_count += 1 chunk_count += 1
# 流式返回每个文本片段 # 流式返回每个文本片段
yield { yield {
"__final__": False, "__final__": False,
"chunk": content "chunk": content
} }
yield { yield {
"__final__": False, "__final__": False,
"chunk": "", "chunk": "",
"done": True "done": True
}
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
} }
) logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# yield 完成标记 # 构建完整的 AIMessage包含元数据
yield {"__final__": True, "result": final_message} final_message = AIMessage(
content=full_response,
response_metadata={
**last_meta_data,
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
}
)
# yield 完成标记
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
except Exception as e:
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
error_result = self._handle_llm_error(e)
yield {"__final__": True, "result": error_result}

View File

@@ -15,4 +15,5 @@ class File(Base):
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf") file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
file_size = Column(Integer, default=0, comment="file size(byte)") file_size = Column(Integer, default=0, comment="file size(byte)")
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url") file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
created_at = Column(DateTime, default=datetime.datetime.now) created_at = Column(DateTime, default=datetime.datetime.now)

View File

@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量") n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
stop: Optional[List[str]] = Field(default=None, description="停止序列") stop: Optional[List[str]] = Field(default=None, description="停止序列")
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)") deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)") thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)") json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")

View File

@@ -11,6 +11,7 @@ class FileBase(BaseModel):
file_ext: str file_ext: str
file_size: int file_size: int
file_url: str | None = None file_url: str | None = None
file_key: str | None = None
created_at: datetime.datetime | None = None created_at: datetime.datetime | None = None

View File

@@ -161,10 +161,7 @@ class AppChatService:
f.type == FileType.DOCUMENT for f in files f.type == FileType.DOCUMENT for f in files
): ):
system_prompt += ( system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
) )
# 创建 LangChain Agent # 创建 LangChain Agent
@@ -451,10 +448,7 @@ class AppChatService:
): ):
from langchain.agents import create_agent from langchain.agents import create_agent
system_prompt += ( system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
) )
# 创建 LangChain Agent # 创建 LangChain Agent

View File

@@ -102,6 +102,11 @@ class AppDslService:
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or []) {**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
] ]
return enriched return enriched
if app_type == AppType.WORKFLOW:
enriched = {**cfg}
if "nodes" in cfg:
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
return enriched
return cfg return cfg
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]: def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
@@ -110,7 +115,7 @@ class AppDslService:
config_data = { config_data = {
"variables": config.variables if config else [], "variables": config.variables if config else [],
"edges": config.edges if config else [], "edges": config.edges if config else [],
"nodes": config.nodes if config else [], "nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
"features": config.features if config else {}, "features": config.features if config else {},
"execution_config": config.execution_config if config else {}, "execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [], "triggers": config.triggers if config else [],
@@ -190,6 +195,23 @@ class AppDslService:
def _enrich_tools(self, tools: list) -> list: def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])] return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _enrich_workflow_nodes(self, nodes: list) -> list:
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
from app.core.workflow.nodes.enums import NodeType
enriched_nodes = []
for node in (nodes or []):
node_type = node.get("type")
config = dict(node.get("config") or {})
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_id = config.get("model_id")
if model_id:
config["model_ref"] = self._model_ref(model_id)
del config["model_id"]
enriched_nodes.append({**node, "config": config})
return enriched_nodes
def _skill_ref(self, skill_id) -> Optional[dict]: def _skill_ref(self, skill_id) -> Optional[dict]:
if not skill_id: if not skill_id:
return None return None
@@ -620,16 +642,16 @@ class AppDslService:
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置") warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
config["knowledge_bases"] = resolved_kbs config["knowledge_bases"] = resolved_kbs
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value): elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
model_ref = config.get("model_id") model_ref = config.get("model_ref") or config.get("model_id")
if model_ref: if model_ref:
ref_dict = None ref_dict = None
if isinstance(model_ref, dict): if isinstance(model_ref, dict):
ref_id = model_ref.get("id") ref_dict = {
ref_name = model_ref.get("name") "id": model_ref.get("id"),
if ref_id: "name": model_ref.get("name"),
ref_dict = {"id": ref_id} "provider": model_ref.get("provider"),
elif ref_name is not None: "type": model_ref.get("type")
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")} }
elif isinstance(model_ref, str): elif isinstance(model_ref, str):
try: try:
uuid.UUID(model_ref) uuid.UUID(model_ref)
@@ -640,12 +662,18 @@ class AppDslService:
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings) resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
if resolved_model_id: if resolved_model_id:
config["model_id"] = resolved_model_id config["model_id"] = resolved_model_id
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
else: else:
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置") warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
config["model_id"] = None config["model_id"] = None
if "model_ref" in config:
del config["model_ref"]
resolved_nodes.append({**node, "config": config}) resolved_nodes.append({**node, "config": config})
return resolved_nodes return resolved_nodes

View File

@@ -650,10 +650,7 @@ class AgentRunService:
) )
if has_doc_with_images: if has_doc_with_images:
system_prompt += ( system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
) )
agent = LangChainAgent( agent = LangChainAgent(
@@ -927,10 +924,7 @@ class AgentRunService:
) )
if has_doc_with_images: if has_doc_with_images:
system_prompt += ( system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>" "\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
) )
# 创建 LangChain Agent # 创建 LangChain Agent

View File

@@ -34,26 +34,7 @@ def generate_file_key(
Generate a unique file key for storage. Generate a unique file key for storage.
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext} The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
Args:
tenant_id: The tenant UUID.
workspace_id: The workspace UUID.
file_id: The file UUID.
file_ext: The file extension (e.g., '.pdf', '.txt').
Returns:
A unique file key string.
Example:
>>> generate_file_key(
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
... '.pdf'
... )
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
""" """
# Ensure file_ext starts with a dot
if file_ext and not file_ext.startswith('.'): if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}' file_ext = f'.{file_ext}'
if workspace_id: if workspace_id:
@@ -61,6 +42,21 @@ def generate_file_key(
return f"{tenant_id}/{file_id}{file_ext}" return f"{tenant_id}/{file_id}{file_ext}"
def generate_kb_file_key(
kb_id: uuid.UUID,
file_id: uuid.UUID,
file_ext: str,
) -> str:
"""
Generate a file key for knowledge base files.
Format: kb/{kb_id}/{file_id}{file_ext}
"""
if file_ext and not file_ext.startswith('.'):
file_ext = f'.{file_ext}'
return f"kb/{kb_id}/{file_id}{file_ext}"
class FileStorageService: class FileStorageService:
""" """
High-level service for file storage operations. High-level service for file storage operations.

View File

@@ -400,7 +400,7 @@ class MultimodalService:
# 在文本内容中追加图片位置标记 # 在文本内容中追加图片位置标记
if result and result[-1].get("type") in ("text", "document"): if result and result[-1].get("type") in ("text", "document"):
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1] key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">" result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
# 将图片以视觉格式追加到消息内容中 # 将图片以视觉格式追加到消息内容中
img_file = FileInput( img_file = FileInput(
type=FileType.IMAGE, type=FileType.IMAGE,

View File

@@ -554,16 +554,13 @@ class WorkflowService:
} }
} }
case "workflow_end": case "workflow_end":
data = {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
if "citations" in payload and payload["citations"]:
data["citations"] = payload["citations"]
return { return {
"event": "end", "event": "end",
"data": data "data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
} }
case "node_start" | "node_end" | "node_error" | "cycle_item": case "node_start" | "node_end" | "node_error" | "cycle_item":
return None return None

View File

@@ -20,7 +20,6 @@ from app.models.workspace_model import (
) )
from app.repositories import workspace_repository from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.services.session_service import SessionService
from app.schemas.workspace_schema import ( from app.schemas.workspace_schema import (
InviteAcceptRequest, InviteAcceptRequest,
InviteValidateResponse, InviteValidateResponse,
@@ -59,7 +58,7 @@ def switch_workspace(
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR) raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
async def delete_workspace_member( def delete_workspace_member(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
member_id: uuid.UUID, member_id: uuid.UUID,
@@ -77,29 +76,10 @@ async def delete_workspace_member(
BizCode.WORKSPACE_NOT_FOUND) BizCode.WORKSPACE_NOT_FOUND)
try: try:
deleted_user = workspace_member.user
workspace_member.is_active = False workspace_member.is_active = False
deleted_user.current_workspace_id = None workspace_member.user.current_workspace_id = None
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
if not deleted_user.is_superuser:
remaining = (
db.query(WorkspaceMember)
.filter(
WorkspaceMember.user_id == deleted_user.id,
WorkspaceMember.workspace_id != workspace_id,
WorkspaceMember.is_active.is_(True),
)
.count()
)
if remaining == 0:
deleted_user.is_active = False
db.commit() db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}") business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
# 使被删除成员的所有 token 立即失效
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
except Exception as e: except Exception as e:
db.rollback() db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")

View File

@@ -210,9 +210,14 @@ def _build_vision_model(file_path: str, db_knowledge):
@celery_app.task(name="app.core.rag.tasks.parse_document") @celery_app.task(name="app.core.rag.tasks.parse_document")
def parse_document(file_path: str, document_id: uuid.UUID): def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
""" """
Document parsing, vectorization, and storage Document parsing, vectorization, and storage.
Args:
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
document_id: Document UUID
file_name: Original file name (used for extension detection in chunk())
""" """
db_document = None db_document = None
@@ -223,7 +228,6 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with get_db_context() as db: with get_db_context() as db:
try: try:
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
if not isinstance(document_id, uuid.UUID): if not isinstance(document_id, uuid.UUID):
document_id = uuid.UUID(str(document_id)) document_id = uuid.UUID(str(document_id))
@@ -234,7 +238,11 @@ def parse_document(file_path: str, document_id: uuid.UUID):
if db_knowledge is None: if db_knowledge is None:
raise ValueError(f"Knowledge {db_document.kb_id} not found") raise ValueError(f"Knowledge {db_document.kb_id} not found")
# 1. Document parsing & segmentation # Use file_name from argument or fall back to document record
if not file_name:
file_name = db_document.file_name
# 1. Download file from storage backend
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.") progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
start_time = time.time() start_time = time.time()
db_document.progress = 0.0 db_document.progress = 0.0
@@ -245,45 +253,36 @@ def parse_document(file_path: str, document_id: uuid.UUID):
db.commit() db.commit()
db.refresh(db_document) db.refresh(db_document)
# Read file content from storage backend (no NFS dependency)
from app.services.file_storage_service import FileStorageService
import asyncio
storage_service = FileStorageService()
async def _download():
return await storage_service.download_file(file_key)
try:
file_binary = asyncio.run(_download())
except RuntimeError:
# If there's already a running loop (e.g. in some worker configurations)
loop = asyncio.new_event_loop()
try:
file_binary = loop.run_until_complete(_download())
finally:
loop.close()
if not file_binary:
raise IOError(f"Downloaded empty file from storage: {file_key}")
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
def progress_callback(prog=None, msg=None): def progress_callback(prog=None, msg=None):
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.") progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
# Prepare vision_model for parsing # Prepare vision_model for parsing
vision_model = _build_vision_model(file_path, db_knowledge) vision_model = _build_vision_model(file_name, db_knowledge)
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
# python-docx 等库在 binary=None 时会用路径直接打开文件,
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
max_wait_seconds = 30
wait_interval = 2
waited = 0
file_binary = None
while waited <= max_wait_seconds:
# os.listdir 强制 NFS 客户端刷新目录缓存
parent_dir = os.path.dirname(file_path)
try:
os.listdir(parent_dir)
except OSError:
pass
try:
with open(file_path, "rb") as f:
file_binary = f.read()
if not file_binary:
# NFS 上文件存在但内容为空(可能还在同步中)
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
break
except (FileNotFoundError, IOError) as e:
if waited >= max_wait_seconds:
raise type(e)(
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
)
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
time.sleep(wait_interval)
waited += wait_interval
from app.core.rag.app.naive import chunk from app.core.rag.app.naive import chunk
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}") logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
res = chunk(filename=file_path, res = chunk(filename=file_name,
binary=file_binary, binary=file_binary,
from_page=0, from_page=0,
to_page=DEFAULT_PARSE_TO_PAGE, to_page=DEFAULT_PARSE_TO_PAGE,

View File

@@ -8,11 +8,12 @@ import { type FC, useRef, useEffect, useState } from 'react'
import clsx from 'clsx' import clsx from 'clsx'
import Markdown from '@/components/Markdown' import Markdown from '@/components/Markdown'
import type { ChatContentProps } from './types' import type { ChatContentProps } from './types'
import { Spin, Flex, Button } from 'antd' import { Spin, Image, Flex, Button } from 'antd'
import { SoundOutlined } from '@ant-design/icons' import { SoundOutlined } from '@ant-design/icons'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import MessageFiles from './MessageFiles' import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) => { const getFileUrl = (file: any) => {
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined) return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
@@ -148,7 +149,72 @@ const ChatContent: FC<ChatContentProps> = ({
{labelFormat(item)} {labelFormat(item)}
</div> </div>
} }
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} /> {item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
{item.meta_data?.files?.map((file) => {
if (file.type.includes('image')) {
return (
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={file.url || file.uid} className="rb:w-50">
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/')
return (
<Flex
key={file.url || file.uid}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => handleDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
file.type?.includes('pdf')
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
: file.type?.includes('csv')
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
: file.type?.includes('html')
? "rb:bg-[url('@/assets/images/file/html.svg')]"
: file.type?.includes('json')
? "rb:bg-[url('@/assets/images/file/json.svg')]"
: file.type?.includes('ppt')
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
: file.type?.includes('markdown')
? "rb:bg-[url('@/assets/images/file/md.svg')]"
: file.type?.includes('text')
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
? "rb:bg-[url('@/assets/images/file/word.svg')]"
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
)}
></div>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
</div>
</Flex>
)
})}
</Flex>}
{/* Message bubble */} {/* Message bubble */}
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', { <div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
// Error message style (content is null and not assistant message) // Error message style (content is null and not assistant message)

View File

@@ -1,87 +0,0 @@
import { Image, Flex } from 'antd'
import clsx from 'clsx'
import AudioPlayer from './AudioPlayer'
import VideoPlayer from './VideoPlayer'
const getFileUrl = (file: any) =>
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
const DOC_ICONS: [string[], string][] = [
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
]
const getDocIcon = (parts: string[]) => {
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
}
interface MessageFilesProps {
files: any[]
contentClassNames?: string | Record<string, boolean>
onDownload: (file: any) => void
}
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
if (!files?.length) return null
return (
<Flex gap={8} vertical align="end" className="rb:mb-2!">
{files.map((file) => {
const key = file.url || file.uid
if (file.type.includes('image')) {
return (
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={key} className="rb:w-50">
<VideoPlayer src={getFileUrl(file)} />
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={key} className="rb:w-50">
<AudioPlayer src={getFileUrl(file)} />
</div>
)
}
const documentType = (file.file_type || file.type)?.split('/') ?? []
return (
<Flex
key={key}
align="center"
gap={10}
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
onClick={() => onDownload(file)}
>
<div
className={clsx(
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
getDocIcon(documentType)
)}
/>
<div className="rb:flex-1 rb:w-32.5">
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
{documentType?.[documentType.length - 1]} · {file.size}
</div>
</div>
</Flex>
)
})}
</Flex>
)
}
export default MessageFiles

View File

@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
import Tag, { type TagProps } from '@/components/Tag' import Tag, { type TagProps } from '@/components/Tag'
interface OverflowTagsProps { interface OverflowTagsProps {
items?: ReactNode[]; items: ReactNode[];
gap?: number; gap?: number;
numTagColor?: TagProps['color']; numTagColor?: TagProps['color'];
numTag?: (num?: number) => ReactNode; numTag?: (num?: number) => ReactNode;
popoverProps?: PopoverProps | false; popoverProps?: PopoverProps | false;
} }
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => { const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
const containerRef = useRef<HTMLDivElement>(null) const containerRef = useRef<HTMLDivElement>(null)
const measureRef = useRef<HTMLDivElement>(null) const measureRef = useRef<HTMLDivElement>(null)
const [visibleCount, setVisibleCount] = useState(items.length) const [visibleCount, setVisibleCount] = useState(items.length)
@@ -20,7 +20,7 @@ const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, po
if (!measure || containerWidth === 0) return if (!measure || containerWidth === 0) return
const children = Array.from(measure.children) as HTMLElement[] const children = Array.from(measure.children) as HTMLElement[]
if (!children.length) { setVisibleCount(0); return } if (!children.length) return
// last child is the sample +N tag // last child is the sample +N tag
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth

View File

@@ -399,7 +399,7 @@ const Menu: FC<{
className="rb:overflow-y-auto rb:flex-1!" className="rb:overflow-y-auto rb:flex-1!"
/> />
{/* Return to space button for superusers */} {/* Return to space button for superusers */}
{source === 'space' && {user?.is_superuser && source === 'space' &&
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!"> <Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" /> <Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
<Flex <Flex
@@ -412,18 +412,16 @@ const Menu: FC<{
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div> <div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
{collapsed ? null : t('common.switchSpace')} {collapsed ? null : t('common.switchSpace')}
</Flex> </Flex>
{user?.is_superuser && <Flex
<Flex gap={8}
gap={8} align="center"
align="center" justify="start"
justify="start" onClick={goToSpace}
onClick={goToSpace} className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer" >
> <div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div> {collapsed ? null : t('common.returnToSpace')}
{collapsed ? null : t('common.returnToSpace')} </Flex>
</Flex>
}
</Flex> </Flex>
} }
{source === 'manage' && subscription && !collapsed && {source === 'manage' && subscription && !collapsed &&

View File

@@ -1538,7 +1538,6 @@ export const en = {
json_output: 'Support JSON formatted output', json_output: 'Support JSON formatted output',
thinking_budget_tokens: 'thinking budget tokens', thinking_budget_tokens: 'thinking budget tokens',
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})", thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
logSearchPlaceholder: 'Search log content', logSearchPlaceholder: 'Search log content',
}, },
userMemory: { userMemory: {

View File

@@ -868,7 +868,6 @@ export const zh = {
json_output: '支持JSON格式化输出', json_output: '支持JSON格式化输出',
thinking_budget_tokens: '深度思考预算Token数', thinking_budget_tokens: '深度思考预算Token数',
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})", thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
thinking_budget_tokens_min_error: "不能小于 {{min}}",
logSearchPlaceholder: '搜索日志内容', logSearchPlaceholder: '搜索日志内容',
}, },
table: { table: {

View File

@@ -49,8 +49,6 @@ const configFields = [
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 }, { key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
] ]
const minThinkingBudgetTokens = 128;
const defaultThinkingBudgetTokens = 1000;
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
refresh, refresh,
data, data,
@@ -110,7 +108,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
const newValues: ModelConfig = { const newValues: ModelConfig = {
capability: (option as Model).capability, capability: (option as Model).capability,
deep_thinking: false, deep_thinking: false,
thinking_budget_tokens: defaultThinkingBudgetTokens, thinking_budget_tokens: undefined,
json_output: false, json_output: false,
} }
if (source === 'chat') { if (source === 'chat') {
@@ -130,12 +128,6 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
form.setFieldsValue({ ...rest }) form.setFieldsValue({ ...rest })
}, [data?.default_model_config_id]) }, [data?.default_model_config_id])
useEffect(() => {
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
}
}, [values?.deep_thinking])
const handleReset = () => { const handleReset = () => {
if (!id) return if (!id) return
resetAppModelConfig(id).then((res) => { resetAppModelConfig(id).then((res) => {
@@ -186,20 +178,15 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
name="thinking_budget_tokens" name="thinking_budget_tokens"
label={t('application.thinking_budget_tokens')} label={t('application.thinking_budget_tokens')}
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))} hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>} extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
rules={[ rules={[
{ required: values?.deep_thinking, message: t('common.pleaseEnter') }, { required: values?.deep_thinking, message: t('common.pleaseEnter') },
{ {
validator: (_, value) => { validator: (_, value) => {
const maxTokens = values?.max_tokens const maxTokens = values?.max_tokens
const deep_thinking = values?.deep_thinking; const deep_thinking = values?.deep_thinking;
if (deep_thinking && value !== undefined) { if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
if (value < minThinkingBudgetTokens) { return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
}
if (maxTokens !== undefined && value > maxTokens) {
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
}
} }
return Promise.resolve() return Promise.resolve()
} }
@@ -208,7 +195,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
> >
<RbSlider <RbSlider
step={1} step={1}
min={minThinkingBudgetTokens} min={0}
max={32000} max={32000}
isInput={true} isInput={true}
disabled={!values?.deep_thinking} disabled={!values?.deep_thinking}

View File

@@ -166,10 +166,10 @@ const Ontology: FC = () => {
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div> <div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
</Tooltip> </Tooltip>
<div className="rb:mt-2 rb:h-5.5"> <div className="rb:mt-2">
<OverflowTags <OverflowTags
popoverProps={false} popoverProps={false}
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []} items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>} numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
/> />
</div> </div>

View File

@@ -101,7 +101,6 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
}); });
}; };
const formatSchema = (value: string) => { const formatSchema = (value: string) => {
if (!value || value.trim() === '') return
setParseSchemaData({} as ParseSchemaData) setParseSchemaData({} as ParseSchemaData)
parseSchema({ schema_content: value }) parseSchema({ schema_content: value })
.then(res => { .then(res => {

View File

@@ -57,6 +57,7 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
} }
}} }}
labelRender={(props) => { labelRender={(props) => {
console.log('props', props)
return `${props.value}%` return `${props.value}%`
}} }}
className="rb:w-20 rb:h-4!" className="rb:w-20 rb:h-4!"

View File

@@ -66,6 +66,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
const [fileList, setFileList] = useState<any[]>([]) const [fileList, setFileList] = useState<any[]>([])
const [message, setMessage] = useState<string | undefined>(undefined) const [message, setMessage] = useState<string | undefined>(undefined)
console.log('abortRef', abortRef, chatList)
/** /**
* Opens the chat drawer and loads workflow variables from the start node * Opens the chat drawer and loads workflow variables from the start node
*/ */

View File

@@ -18,7 +18,6 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
// Handle node selection from popover and create new node replacing the add-node placeholder // Handle node selection from popover and create new node replacing the add-node placeholder
const handleNodeSelect = (selectedNodeType: any) => { const handleNodeSelect = (selectedNodeType: any) => {
graph.startBatch('add-node');
const parentBBox = node.getBBox(); const parentBBox = node.getBBox();
const cycleId = data.cycle; const cycleId = data.cycle;
const horizontalSpacing = 0; const horizontalSpacing = 0;
@@ -44,7 +43,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
if (cycleId) { if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId); const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) { if (parentNode) {
parentNode.addChild(newNode, { silent: true }); parentNode.addChild(newNode);
} }
} }
@@ -77,40 +76,55 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
} }
}); });
setTimeout(() => {
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}, 50);
// Automatically adjust loop node size // Automatically adjust loop node size
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId); const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (loopNode) { if (loopNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc, child) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = loopNode.getPorts();
ports.forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId); const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) { childNodes.forEach((childNode: any) => {
const bounds = childNodes.reduce((acc, child) => { childNode.on('change:position', adjustLoopSize);
const bbox = child.getBBox(); });
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
loopNode.prop('size', { width: newWidth, height: newHeight });
loopNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
loopNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
} }
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
graph.stopBatch('add-node');
setOpen(false); setOpen(false);
}; };

View File

@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
{data.type === 'if-else' && {data.type === 'if-else' &&
<Flex vertical gap={4} className="rb:mt-3!"> <Flex vertical gap={4} className="rb:mt-3!">
{data.config?.cases?.defaultValue.map((item: any, index: number) => ( {data.config?.cases?.defaultValue.map((item: any, index: number) => (
<div key={index}> <div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4"> <Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>} {item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span> <span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>

View File

@@ -1,15 +1,134 @@
import { useEffect } from 'react';
import { useTranslation } from 'react-i18next'
import clsx from 'clsx'; import clsx from 'clsx';
import type { ReactShapeConfig } from '@antv/x6-react-shape'; import type { ReactShapeConfig } from '@antv/x6-react-shape';
import { Flex } from 'antd'; import { Flex } from 'antd';
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons'; import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next'
import { graphNodeLibrary, edgeAttrs } from '../../constant';
import NodeTools from './NodeTools' import NodeTools from './NodeTools'
const LoopNode: ReactShapeConfig['component'] = ({ node }) => { const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
const data = node.getData() || {}; const data = node.getData() || {};
const { t } = useTranslation() const { t } = useTranslation()
useEffect(() => {
// 使用setTimeout确保在所有节点都添加完成后再创建连线
const timer = setTimeout(() => {
initNodes()
checkAndAddAddNode()
}, 50)
return () => clearTimeout(timer)
}, [graph])
const checkAndAddAddNode = () => {
if (!graph) return;
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
// 如果只有一个cycle-start节点且没有其他类型的子节点则添加add-node
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
const cycleStartNode = cycleStartNodes[0];
const cycleStartBBox = cycleStartNode.getBBox();
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: cycleStartBBox.x + 84,
y: cycleStartBBox.y + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(addNode);
// 连接cycle-start和add-node
const sourcePorts = cycleStartNode.getPorts();
const targetPorts = addNode.getPorts();
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
// 然后创建连线
graph.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
cycleStartNode.toFront()
addNode.toFront()
}
}
const initNodes = () => {
// 检查是否存在cycle为当前节点ID的子节点若存在则不调用initNodes避免重复创建
const existingCycleNodes = graph.getNodes().filter((n: any) =>
n.getData()?.cycle === data.id
);
if (existingCycleNodes.length > 0) return;
// 添加默认子节点
const parentBBox = node.getBBox();
const centerX = parentBBox.x + 24;
const centerY = parentBBox.y + 70;
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: centerX,
y: centerY,
id: cycleStartNodeId,
data: {
id: cycleStartNodeId,
type: 'cycle-start',
parentId: node.id,
isDefault: true, // 标记为默认节点,不可删除
cycle: data.id,
},
});
const addNode = graph.addNode({
...graphNodeLibrary.addStart,
x: centerX + 84,
y: centerY + 4,
data: {
type: 'add-node',
label: t('workflow.addNode'),
icon: '+',
parentId: node.id,
cycle: data.id,
},
});
node.addChild(cycleStartNode)
node.addChild(addNode)
const sourcePorts = cycleStartNode.getPorts()
const targetPorts = addNode.getPorts()
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
const edgeConfig = {
source: {
cell: cycleStartNode.id,
port: sourcePort
},
target: {
cell: addNode.id,
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
},
...edgeAttrs
}
graph.addEdge(edgeConfig)
setTimeout(() => {
cycleStartNode.toFront()
addNode.toFront()
}, 0)
}
return ( return (
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', { <div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
'rb:border-[#171719]!': data.isSelected && !data.executionStatus, 'rb:border-[#171719]!': data.isSelected && !data.executionStatus,

View File

@@ -43,52 +43,70 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
}; };
}, []); }, []);
// Handle node selection from popover menu and create new node with edge connection
const handleNodeSelect = (selectedNodeType: any) => { const handleNodeSelect = (selectedNodeType: any) => {
if (!sourceNode || !graph) return; if (!sourceNode || !graph) return;
const sourceNodeData = sourceNode.getData(); const sourceNodeData = sourceNode.getData();
const sourceNodeType = sourceNodeData?.type; const sourceNodeType = sourceNodeData?.type;
const isCycleSubNode = !!sourceNodeData.cycle;
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration'; // If it's a cycle-start node, handle the add-node placeholder
const newNodeType = selectedNodeType.type;
// Save add-node placeholder position before disabling history
let addNodePosition = null; let addNodePosition = null;
const isCycleSubNode = sourceNodeData.cycle
if (isCycleSubNode && sourceNodeType === 'cycle-start') { if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle; const cycleId = sourceNodeData.cycle;
const addNodes = graph.getNodes().filter((n: any) => const addNodes = graph.getNodes().filter((n: any) =>
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
); );
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
if (addNodes.length > 0) {
const addNode = addNodes[0];
addNodePosition = addNode.getBBox();
addNode.remove();
}
} }
// Calculate position // Calculate new node position to avoid overlapping
const sourceBBox = sourceNode.getBBox(); const sourceBBox = sourceNode.getBBox();
const nw = graphNodeLibrary[newNodeType]?.width || 120; const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
const nh = graphNodeLibrary[newNodeType]?.height || 88; const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
const hSpacing = isCycleSubNode ? 48 : 80; const horizontalSpacing = isCycleSubNode ? 48 : 80;
const vSpacing = 10; const verticalSpacing = 10;
// Get source port group information
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort); const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
const sourcePortGroup = sourcePortInfo?.group || sourcePort; const sourcePortGroup = sourcePortInfo?.group || sourcePort;
let newX: number, newY: number; // Calculate new node position
let newX, newY;
if (edgeInsertion) { if (edgeInsertion) {
// Edge insertion: place new node on the same row as target, between source and target
const targetBBox = edgeInsertion.targetCell.getBBox(); const targetBBox = edgeInsertion.targetCell.getBBox();
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width); const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
const requiredSpace = nw + hSpacing * 4; const requiredSpace = nodeWidth + horizontalSpacing * 4;
newX = sourceBBox.x + sourceBBox.width + hSpacing;
newY = targetBBox.y + (targetBBox.height - nh) / 2; // New node x: right after source + spacing
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
// Same row as target node
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
// If not enough space, shift target and all downstream nodes to the right
if (gap < requiredSpace) { if (gap < requiredSpace) {
const shiftX = requiredSpace - gap; const shiftX = requiredSpace - gap;
const visited = new Set<string>(); const visited = new Set<string>();
const shiftDownstream = (cell: any) => { const shiftDownstream = (cell: any) => {
if (visited.has(cell.id)) return; const cellId = cell.id;
visited.add(cell.id); if (visited.has(cellId)) return;
visited.add(cellId);
const pos = cell.getPosition(); const pos = cell.getPosition();
cell.setPosition(pos.x + shiftX, pos.y); cell.setPosition(pos.x + shiftX, pos.y);
// Recursively shift nodes connected from right ports
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => { graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
const tCell = graph.getCellById(e.getTargetCellId()); const tId = e.getTargetCellId();
if (tCell?.isNode()) shiftDownstream(tCell); if (tId && !visited.has(tId)) {
const tCell = graph.getCellById(tId);
if (tCell?.isNode()) shiftDownstream(tCell);
}
}); });
}; };
shiftDownstream(edgeInsertion.targetCell); shiftDownstream(edgeInsertion.targetCell);
@@ -96,170 +114,208 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
} else if (addNodePosition) { } else if (addNodePosition) {
newX = addNodePosition.x; newX = addNodePosition.x;
newY = addNodePosition.y; newY = addNodePosition.y;
} else if (sourcePortGroup === 'left') {
newX = sourceBBox.x - nw * 2 - hSpacing;
newY = sourceBBox.y;
} else { } else {
newX = sourceBBox.x + sourceBBox.width + hSpacing; // Determine node placement direction based on port position
newY = sourceBBox.y; if (sourcePortGroup === 'left') {
const connectedNodes = new Set<string>(); // Left port: add node to the left
graph.getConnectedEdges(sourceNode).forEach((e: any) => { newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => { newY = sourceBBox.y;
if (cid !== sourceNode.id) connectedNodes.add(cid); } else {
// Right port: add node to the right
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
newY = sourceBBox.y;
}
// Check if position overlaps with existing nodes (only consider connected nodes)
const checkOverlap = (x: number, y: number) => {
// Get nodes connected to the source node
const connectedNodes = new Set();
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
const sourceId = edge.getSourceCellId();
const targetId = edge.getTargetCellId();
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
}); });
});
const checkOverlap = (x: number, y: number) => return graph.getNodes().some((node: any) => {
graph.getNodes().some((n: any) => { if (node.id === sourceNode.id) return false;
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false; if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
const b = n.getBBox(); const bbox = node.getBBox();
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height); return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
}); });
while (checkOverlap(newX, newY)) newY += nh + vSpacing; };
// If position is occupied, search downward for empty space
while (checkOverlap(newX, newY)) {
newY += nodeHeight + verticalSpacing;
}
} }
// Disable history for all graph mutations // Create new node
graph.disableHistory(); const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
// Remove add-node placeholder
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
const cycleId = sourceNodeData.cycle;
graph.getNodes()
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
.forEach((n: any) => n.remove());
}
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const newNode = graph.addNode({ const newNode = graph.addNode({
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default), ...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
x: newX, x: newX,
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0), y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
id, id,
data: { data: {
id, id,
type: newNodeType, type: selectedNodeType.type,
icon: selectedNodeType.icon, icon: selectedNodeType.icon,
name: t(`workflow.${newNodeType}`), name: t(`workflow.${selectedNodeType.type}`),
cycle: sourceNodeData.cycle, cycle: sourceNodeData.cycle, // Inherit cycle from source node
config: selectedNodeType.config || {} config: selectedNodeType.config || {}
}, },
}); });
// Add new node as child of parent node
if (sourceNodeData.cycle) { if (sourceNodeData.cycle) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle); const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
if (parentNode) parentNode.addChild(newNode, { silent: true });
}
if (edgeInsertion) {
const { edge: oldEdge } = edgeInsertion;
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
else graph.removeEdge(oldEdge);
}
const newPorts = newNode.getPorts();
const addedCells: any[] = [newNode];
if (edgeInsertion) {
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
} else {
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
}
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
if (isCycleContainer(newNodeType)) {
const parentBBox = newNode.getBBox();
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
const cycleStartNode = graph.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
});
const addNodePlaceholder = graph.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
});
newNode.addChild(cycleStartNode, { silent: true });
newNode.addChild(addNodePlaceholder, { silent: true });
const innerEdge = graph.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
...edgeAttrs,
});
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
}
// Adjust parent size if adding inside a cycle container
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) { if (parentNode) {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId); parentNode.addChild(newNode);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const b = child.getBBox();
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
parentNode.getPorts().forEach((port: any) => {
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
});
}
} }
} }
// toFront // Edge insertion: remove old edge immediately before creating new edges
const bringCycleChildrenToFront = (cycleContainerId: string) => { if (edgeInsertion) {
graph.getEdges().forEach((e: any) => { const { edge: oldEdge } = edgeInsertion;
const src = graph.getCellById(e.getSourceCellId()); if (oldEdge.id && graph.getCellById(oldEdge.id)) {
const tgt = graph.getCellById(e.getTargetCellId()); graph.removeCell(oldEdge.id);
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront(); } else {
}); graph.removeEdge(oldEdge);
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); }); }
};
if (isCycleContainer(sourceNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
} else if (isCycleContainer(newNodeType)) {
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
} else {
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
} }
// Re-enable history and manually push one batch frame for all added cells // Create edge connection
graph.enableHistory(); setTimeout(() => {
const history = graph.getPlugin('history') as any; const newPorts = newNode.getPorts();
if (history) {
const batchFrame = addedCells.map((cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
}));
history.undoStack.push(batchFrame);
history.redoStack = [];
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
}
const addedEdges: any[] = [];
if (edgeInsertion) {
// Edge insertion: create source→new and new→target edges
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: newLeftPort },
...edgeAttrs
}));
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: newRightPort },
target: { cell: targetCell.id, port: origTargetPort },
...edgeAttrs
}));
setEdgeInsertion(null);
} else if (sourcePortGroup === 'left') {
// Connect from left port to new node's right side
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
addedEdges.push(graph.addEdge({
source: { cell: newNode.id, port: targetPort },
target: { cell: sourceNode.id, port: sourcePort },
...edgeAttrs
}));
} else {
// Connect from right port to new node's left side
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
addedEdges.push(graph.addEdge({
source: { cell: sourceNode.id, port: sourcePort },
target: { cell: newNode.id, port: targetPort },
...edgeAttrs
}));
}
// Adjust loop node size when child node is added via port within loop node
const cycleId = sourceNodeData.cycle;
if (cycleId) {
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
if (parentNode) {
const adjustLoopSize = () => {
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
if (childNodes.length > 0) {
const bounds = childNodes.reduce((acc: any, child: any) => {
const bbox = child.getBBox();
return {
minX: Math.min(acc.minX, bbox.x),
minY: Math.min(acc.minY, bbox.y),
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
};
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
const padding = 50;
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
parentNode.prop('size', { width: newWidth, height: newHeight });
// Update right port x position
const ports = parentNode.getPorts();
ports.forEach((port: any) => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth);
}
});
}
};
adjustLoopSize();
// Listen to child node movement events
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
childNodes.forEach((childNode: any) => {
childNode.on('change:position', adjustLoopSize);
});
}
}
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
const newNodeType = selectedNodeType.type;
// Helper: bring all child nodes and their edges of a cycle container to front
const bringCycleChildrenToFront = (cycleContainerId: string) => {
graph.getEdges().forEach((e: any) => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
});
graph.getNodes().forEach((n: any) => {
if (n.getData()?.cycle === cycleContainerId) n.toFront();
});
};
if (isCycleContainer(sourceNodeType)) {
console.log('isCycleContainer(sourceNodeType)')
// Case 4: source is a loop/iteration node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront();
bringCycleChildrenToFront(sourceNodeData.id);
} else if (isCycleContainer(newNodeType)) {
console.log('isCycleContainer(newNodeType)')
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
newNode.toFront();
sourceNode.toFront()
bringCycleChildrenToFront(id);
} else {
// Case 2: normal node → normal node
addedEdges.forEach(e => {
const src = graph.getCellById(e.getSourceCellId());
const tgt = graph.getCellById(e.getTargetCellId());
if (src?.isNode()) src.toFront();
if (tgt?.isNode()) tgt.toFront();
});
}
}, 50);
// Clean up temporary element
if (tempElement) { if (tempElement) {
document.body.removeChild(tempElement); document.body.removeChild(tempElement);
setTempElement(null); setTempElement(null);
} }
setPopoverVisible(false); setPopoverVisible(false);
}; };
@@ -335,4 +391,4 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
); );
}; };
export default PortClickHandler; export default PortClickHandler;

View File

@@ -242,11 +242,10 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''} className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
> >
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0 {parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} /> ? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
: parameter.type === 'boolean' : parameter.type === 'boolean'
? <Switch key={values.tool_id} size="small" /> ? <Switch size="small" />
: <Editor : <Editor
key={values.tool_id}
variant="outlined" variant="outlined"
type="input" type="input"
size="small" size="small"

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 15:06:18 * @Date: 2026-02-03 15:06:18
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-27 14:07:14 * @Last Modified time: 2026-04-21 18:23:31
*/ */
import type { ReactShapeConfig } from '@antv/x6-react-shape'; import type { ReactShapeConfig } from '@antv/x6-react-shape';
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port'; import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
@@ -948,15 +948,6 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
width: nodeWidth, width: nodeWidth,
height: 120, height: 120,
shape: 'notes-node', shape: 'notes-node',
},
output: {
width: nodeWidth,
height: 76,
shape: 'normal-node',
ports: {
groups: { left: defaultPortGroup },
items: [defaultPortItems[0]],
},
} }
} }

View File

@@ -2,9 +2,10 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 15:17:48 * @Date: 2026-02-03 15:17:48
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-04-28 13:49:11 * @Last Modified time: 2026-04-24 17:21:09
*/ */
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6'; import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
import { register } from '@antv/x6-react-shape'; import { register } from '@antv/x6-react-shape';
import type { PortMetadata } from '@antv/x6/lib/model/port'; import type { PortMetadata } from '@antv/x6/lib/model/port';
import { App } from 'antd'; import { App } from 'antd';
@@ -16,7 +17,7 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
import { useUser } from '@/store/user'; import { useUser } from '@/store/user';
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types'; import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant'; import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types'; import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils'; import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
import { useWorkflowStore } from '@/store/workflow'; import { useWorkflowStore } from '@/store/workflow';
@@ -85,10 +86,6 @@ export interface UseWorkflowGraphReturn {
/** Get start node output variable list (user-defined + system variables) */ /** Get start node output variable list (user-defined + system variables) */
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>; getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
nodeClick: ({ node }: { node: Node }) => void; nodeClick: ({ node }: { node: Node }) => void;
/** All recorded history operations */
historyRecords: HistoryRecord[];
/** Clear history records */
clearHistoryRecords: () => void;
} }
/** /**
@@ -122,19 +119,14 @@ export const useWorkflowGraph = ({
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined) const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
const [canUndo, setCanUndo] = useState(false) const [canUndo, setCanUndo] = useState(false)
const [canRedo, setCanRedo] = useState(false) const [canRedo, setCanRedo] = useState(false)
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
const undoRef = useRef<() => void>(() => {})
const redoRef = useRef<() => void>(() => {})
const syncChildRelationshipsRef = useRef<() => void>(() => {})
const isSyncingRef = useRef(false)
useEffect(() => { useEffect(() => {
if (!graphRef.current) return if (!graphRef.current) return
graphRef.current.getNodes().forEach(node => { graphRef.current.getNodes().forEach(node => {
const data = node.getData() const data = node.getData()
if (data?.type === 'if-else' || data?.type === 'question-classifier') { if (data?.type === 'if-else' || data?.type === 'question-classifier') {
console.log('chatVariables', chatVariables) console.log('chatVariables', chatVariables)
node.setData({ ...data, chatVariables }) node.setData({ ...data, chatVariables }, { silent: true })
} }
}) })
}, [chatVariables]) }, [chatVariables])
@@ -351,7 +343,7 @@ export const useWorkflowGraph = ({
if (parentNode) { if (parentNode) {
const addedChild = graphRef.current?.addNode(childNode) const addedChild = graphRef.current?.addNode(childNode)
if (addedChild) { if (addedChild) {
parentNode.addChild(addedChild, { silent: true }) parentNode.addChild(addedChild)
} }
} }
} }
@@ -382,6 +374,8 @@ export const useWorkflowGraph = ({
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2) const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight) const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
console.log('newWidth', newHeight, newWidth)
parentNode.prop('size', { width: newWidth, height: newHeight }) parentNode.prop('size', { width: newWidth, height: newHeight })
// Update x position of right group ports // Update x position of right group ports
@@ -494,77 +488,8 @@ export const useWorkflowGraph = ({
graphRef.current.cleanHistory() graphRef.current.cleanHistory()
} }
}, 200) }, 200)
} else {
graphRef.current.enableHistory()
graphRef.current.cleanHistory()
} }
} }
const resizeGroupNodes = (graph: Graph) => {
graph.getNodes().forEach(parentNode => {
const parentType = parentNode.getData()?.type
if (parentType !== 'loop' && parentType !== 'iteration') return
const children = graph.getNodes().filter(
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
)
if (!children.length) return
const padding = 24
const headerHeight = 50
const childBounds = children.map(c => c.getBBox())
const minX = Math.min(...childBounds.map(b => b.x))
const minY = Math.min(...childBounds.map(b => b.y))
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
const parentBBox = parentNode.getBBox()
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
parentNode.prop('size', { width: newWidth, height: newHeight })
parentNode.getPorts().forEach(port => {
if (port.group === 'right' && port.args) {
parentNode.portProp(port.id!, 'args/x', newWidth)
}
})
})
}
const syncChildRelationships = () => {
if (!graphRef.current) return
const graph = graphRef.current
graph.disableHistory()
graph.getNodes().forEach(node => {
const cycleId = node.getData()?.cycle
if (!cycleId) return
const parentNode = graph.getCellById(cycleId) as Node | null
if (!parentNode) return
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
parentNode.addChild(node, { silent: true })
}
})
graph.getNodes().forEach(node => {
const children = node.getChildren()
if (!children?.length) return
children.forEach(child => {
if (!child.isNode()) return
const childCycleId = (child as Node).getData?.()?.cycle
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
node.removeChild(child, { silent: true })
}
})
})
resizeGroupNodes(graph)
graph.getEdges().forEach(edge => {
const src = graph.getCellById(edge.getSourceCellId())
const tgt = graph.getCellById(edge.getTargetCellId())
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
edge.toFront()
}
})
graph.getNodes().forEach(node => {
if (node.getData()?.cycle) node.toFront()
})
graph.enableHistory()
}
syncChildRelationshipsRef.current = syncChildRelationships
/** /**
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard) * Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
*/ */
@@ -600,44 +525,18 @@ export const useWorkflowGraph = ({
new History({ new History({
enabled: false, enabled: false,
beforeAddCommand(_event, args: any) { beforeAddCommand(_event, args: any) {
const key = args?.key const event = args?.key ? `cell:change:${args.key}` : _event;
if (key === 'attrs' || key === 'tools') return false if (event.startsWith('cell:change:') &&
event !== 'cell:change:position' &&
event !== 'cell:change:source' &&
event !== 'cell:change:target') return false;
}, },
}), }),
); );
const MERGE_INTERVAL = 1000 graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
setCanUndo(graphRef.current?.canUndo() ?? false) setCanUndo(graphRef.current?.canUndo() ?? false)
setCanRedo(graphRef.current?.canRedo() ?? false) setCanRedo(graphRef.current?.canRedo() ?? false)
console.log('history:change', cmds, options)
const batchName: string | undefined = options?.name
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
const now = Date.now()
const last = lastHistoryRef.current
const canMerge =
actionType === 'change' &&
last?.type === 'change' &&
now - last.timestamp < MERGE_INTERVAL &&
cellIds.length > 0 &&
cellIds.length === last.cellIds.length &&
cellIds.every((id, i) => id === last.cellIds[i])
if (canMerge) {
lastHistoryRef.current!.timestamp = now
setHistoryRecords(prev => {
const next = [...prev]
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
return next
})
} else {
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
setHistoryRecords(prev => [...prev, record])
}
}) })
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
}; };
// 显示/隐藏连接桩 // 显示/隐藏连接桩
// const showPorts = (show: boolean) => { // const showPorts = (show: boolean) => {
@@ -670,13 +569,13 @@ export const useWorkflowGraph = ({
vo.setData({ vo.setData({
...data, ...data,
isSelected: false, isSelected: false,
}, { silent: true }); });
} }
}); });
node.setData({ node.setData({
...nodeData, ...nodeData,
isSelected: true, isSelected: true,
}, { silent: true }); });
clearEdgeSelect() clearEdgeSelect()
if (nodeData.type !== 'notes') { if (nodeData.type !== 'notes') {
setSelectedNode(node); setSelectedNode(node);
@@ -690,7 +589,7 @@ export const useWorkflowGraph = ({
const edgeClick = ({ edge }: { edge: Edge }) => { const edgeClick = ({ edge }: { edge: Edge }) => {
clearEdgeSelect(); clearEdgeSelect();
edge.setAttrByPath('line/stroke', edge_selected_color); edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true }); edge.setData({ ...edge.getData(), isSelected: true });
clearNodeSelect(); clearNodeSelect();
}; };
/** /**
@@ -705,7 +604,7 @@ export const useWorkflowGraph = ({
node.setData({ node.setData({
...data, ...data,
isSelected: false, isSelected: false,
}, { silent: true }); });
} }
}); });
setSelectedNode(null); setSelectedNode(null);
@@ -715,7 +614,7 @@ export const useWorkflowGraph = ({
*/ */
const clearEdgeSelect = () => { const clearEdgeSelect = () => {
graphRef.current?.getEdges().forEach(e => { graphRef.current?.getEdges().forEach(e => {
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true }); e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
e.setAttrByPath('line/stroke', edge_color); e.setAttrByPath('line/stroke', edge_color);
e.setAttrByPath('line/strokeWidth', edge_width); e.setAttrByPath('line/strokeWidth', edge_width);
}); });
@@ -854,6 +753,8 @@ export const useWorkflowGraph = ({
// Find corresponding parent node // Find corresponding parent node
const parentNode = nodes?.find(n => n.id === nodeData.cycle); const parentNode = nodes?.find(n => n.id === nodeData.cycle);
if (parentNode) { if (parentNode) {
// Use removeChild method to delete child node
parentNode.removeChild(nodeToDelete);
parentNodesToUpdate.push(parentNode); parentNodesToUpdate.push(parentNode);
} }
// Add child node to deletion list // Add child node to deletion list
@@ -881,51 +782,42 @@ export const useWorkflowGraph = ({
// Delete all collected nodes and edges // Delete all collected nodes and edges
if (cells.length > 0) { if (cells.length > 0) {
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
const parentsNeedingAddNode = parentNodesToUpdate
.filter(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
const parentData = parentNode.getData();
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
// After deletion, only cycle-start will remain
const nonCycleStartToDelete = cells.filter(c =>
c.isNode() &&
(c as Node).getData()?.cycle === parentData.id &&
(c as Node).getData()?.type !== 'cycle-start'
);
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
})
.map(parentNode => ({
parentNode,
cycleStartNode: graphRef.current!.getNodes().find(
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
)!
}))
.filter(({ cycleStartNode }) => !!cycleStartNode);
graphRef.current?.startBatch('delete');
graphRef.current?.removeCells(cells); graphRef.current?.removeCells(cells);
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => { // If parent is iteration/loop and only cycle-start remains, add add-node connected to it
parentNodesToUpdate.forEach(parentNode => {
const parentShape = parentNode.shape;
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
const parentData = parentNode.getData(); const parentData = parentNode.getData();
const bbox = cycleStartNode.getBBox(); const remainingChildren = graphRef.current!.getNodes().filter(
const addNode = graphRef.current!.addNode({ n => n.getData()?.cycle === parentData.id
...graphNodeLibrary.addStart, );
x: bbox.x + 84, const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
y: bbox.y + 4, if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' }, const cycleStartNode = cycleStartNodes[0];
}); const bbox = cycleStartNode.getBBox();
parentNode.addChild(addNode, { silent: true }); const addNode = graphRef.current!.addNode({
graphRef.current!.addEdge({ ...graphNodeLibrary.addStart,
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' }, x: bbox.x + 84,
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' }, y: bbox.y + 4,
...edgeAttrs, data: {
}); type: 'add-node',
parentId: parentNode.id,
cycle: parentData.id,
label: t('workflow.addNode'),
icon: '+',
},
});
parentNode.addChild(addNode);
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
graphRef.current!.addEdge({
source: { cell: cycleStartNode.id, port: sourcePort },
target: { cell: addNode.id, port: targetPort },
...edgeAttrs,
});
}
}); });
graphRef.current?.stopBatch('delete');
} }
return false; return false;
}; };
@@ -1144,7 +1036,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => { graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) { if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_selected_color); edge.setAttrByPath('line/stroke', edge_selected_color);
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true }); edge.setData({ ...edge.getData(), isNodeHover: true });
} }
}); });
}); });
@@ -1152,7 +1044,7 @@ export const useWorkflowGraph = ({
graphRef.current?.getConnectedEdges(node).forEach(edge => { graphRef.current?.getConnectedEdges(node).forEach(edge => {
if (!edge.getData()?.isSelected) { if (!edge.getData()?.isSelected) {
edge.setAttrByPath('line/stroke', edge_color); edge.setAttrByPath('line/stroke', edge_color);
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true }); edge.setData({ ...edge.getData(), isNodeHover: false });
} }
}); });
}); });
@@ -1234,8 +1126,8 @@ export const useWorkflowGraph = ({
// Delete selected nodes and edges // Delete selected nodes and edges
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent); graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
// Undo / Redo // Undo / Redo
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; }); graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; }); graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
}; };
@@ -1301,51 +1193,13 @@ export const useWorkflowGraph = ({
}; };
if (dragData.type === 'loop' || dragData.type === 'iteration') { if (dragData.type === 'loop' || dragData.type === 'iteration') {
graph.disableHistory() graphRef.current.addNode({
const parentNode = graphRef.current.addNode({
...graphNodeLibrary[dragData.type], ...graphNodeLibrary[dragData.type],
x: point.x - 150, x: point.x - 150,
y: point.y - 100, y: point.y - 100,
id: cleanNodeData.id, id: cleanNodeData.id,
data: { ...cleanNodeData, isGroup: true }, data: { ...cleanNodeData, isGroup: true },
}) });
const parentBBox = parentNode.getBBox()
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
const cycleStartNode = graphRef.current.addNode({
...graphNodeLibrary.cycleStart,
x: parentBBox.x + 24,
y: parentBBox.y + 70,
id: cycleStartId,
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
})
const addNode = graphRef.current.addNode({
...graphNodeLibrary.addStart,
x: parentBBox.x + 24 + 84,
y: parentBBox.y + 70 + 4,
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
})
parentNode.addChild(cycleStartNode, { silent: true })
parentNode.addChild(addNode, { silent: true })
const newEdge = graphRef.current.addEdge({
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
...edgeAttrs,
})
cycleStartNode.toFront()
addNode.toFront()
graph.enableHistory()
// Manually push a single batch frame covering all 4 cells into undoStack
const history = graph.getPlugin('history') as History
const makeBatchCmd = (cell: any) => ({
batch: true,
event: 'cell:added',
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
options: {},
})
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
;(history as any).undoStack.push(batchFrame)
;(history as any).redoStack = []
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
} else if (dragData.type === 'if-else') { } else if (dragData.type === 'if-else') {
// Create condition node // Create condition node
graphRef.current.addNode({ graphRef.current.addNode({
@@ -1592,80 +1446,8 @@ export const useWorkflowGraph = ({
return userVars return userVars
} }
const clearHistoryRecords = () => { const undo = () => graphRef.current?.undo()
setHistoryRecords([]) const redo = () => graphRef.current?.redo()
lastHistoryRef.current = null
}
const getStackCellIds = (cmds: any): string[] => {
const arr = Array.isArray(cmds) ? cmds : [cmds]
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
}
const isSkippableFrame = (frame: any): boolean => {
const arr = Array.isArray(frame) ? frame : [frame]
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
}
const undo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getUndoSize() === 0) return
const undoStack = (history as any).undoStack as any[]
isSyncingRef.current = true
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
}
if (undoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
graphRef.current!.undo()
while (undoStack.length > 0) {
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
graphRef.current!.undo()
continue
}
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.undo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const redo = () => {
const history = graphRef.current?.getPlugin('history') as History | undefined
if (!history || history.getRedoSize() === 0) return
const redoStack = (history as any).redoStack as any[]
isSyncingRef.current = true
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
}
if (redoStack.length === 0) {
isSyncingRef.current = false
return
}
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
graphRef.current!.redo()
while (redoStack.length > 0) {
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
graphRef.current!.redo()
continue
}
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
graphRef.current!.redo()
} else {
break
}
}
isSyncingRef.current = false
syncChildRelationships()
}
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => { const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
const { statement = '' } = value?.opening_statement || {} const { statement = '' } = value?.opening_statement || {}
@@ -1706,16 +1488,20 @@ export const useWorkflowGraph = ({
if (!graphRef.current) return; if (!graphRef.current) return;
const nodes = graphRef.current.getNodes(); const nodes = graphRef.current.getNodes();
// Reset all node execution status on every chatHistory change const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
// Reset all node execution status first
nodes.forEach(node => { nodes.forEach(node => {
const data = node.getData(); const data = node.getData();
node.setData({ ...data, executionStatus: '' }); if (typeof data.executionStatus === 'string') {
node.setData({ ...data, executionStatus: undefined });
}
}); });
if (!lastWithSub?.subContent) return;
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant'); // Build a nodeId -> status map first
if (!lastAssistant?.subContent?.length) return; const statusMap: Record<string, string> = {};
lastAssistant.subContent.forEach(sub => { lastWithSub.subContent.forEach(sub => {
if (typeof sub.status === 'string') { if (typeof sub.status === 'string') {
statusMap[sub.node_id] = sub.status;
const node = nodes.find(n => n.getData()?.id === sub.node_id); const node = nodes.find(n => n.getData()?.id === sub.node_id);
if (node) { if (node) {
node.setData({ ...node.getData(), executionStatus: sub.status }); node.setData({ ...node.getData(), executionStatus: sub.status });
@@ -1751,7 +1537,5 @@ export const useWorkflowGraph = ({
canRedo, canRedo,
undo, undo,
redo, redo,
historyRecords,
clearHistoryRecords,
}; };
}; };

View File

@@ -113,13 +113,4 @@ export interface ChatVariable {
} }
export interface AddChatVariableRef { export interface AddChatVariableRef {
handleOpen: (value?: ChatVariable) => void; handleOpen: (value?: ChatVariable) => void;
}
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
export interface HistoryRecord {
type: HistoryActionType;
timestamp: number;
batchName?: string;
cellIds?: string[];
} }

View File

@@ -17,7 +17,6 @@ export const isSubExprSet = (sub: any) => {
* Uses the same per-expression height logic as getConditionNodeCasePortY. * Uses the same per-expression height logic as getConditionNodeCasePortY.
*/ */
export const calcConditionNodeTotalHeight = (cases: any[]) => { export const calcConditionNodeTotalHeight = (cases: any[]) => {
if (!cases?.length) return conditionNodeHeight;
const casesHeight = cases.reduce((acc: number, c: any) => { const casesHeight = cases.reduce((acc: number, c: any) => {
const exprs = c?.expressions ?? []; const exprs = c?.expressions ?? [];
const n = exprs.length; const n = exprs.length;