Compare commits
39 Commits
feat/wxy-d
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
feae2f2e1e | ||
|
|
415234d4c8 | ||
|
|
e38a60e107 | ||
|
|
86eb08c73f | ||
|
|
53f1b0e586 | ||
|
|
49cc47a79a | ||
|
|
1817f52edf | ||
|
|
40633d72c3 | ||
|
|
6f10296969 | ||
|
|
89228825cf | ||
|
|
cab4deb2ff | ||
|
|
4048a10858 | ||
|
|
d6ef0f4923 | ||
|
|
75fbe44839 | ||
|
|
06597c567b | ||
|
|
28694fefb0 | ||
|
|
7a0f08148e | ||
|
|
d3058ce379 | ||
|
|
8d88df391d | ||
|
|
7621321d1b | ||
|
|
0e29b0b2a5 | ||
|
|
2fa4d29548 | ||
|
|
7bb181c1c7 | ||
|
|
a9c87b03ff | ||
|
|
720af8d261 | ||
|
|
09d32ed446 | ||
|
|
9a5ce7f7c6 | ||
|
|
531d785629 | ||
|
|
6d80d74f4a | ||
|
|
3d9882643e | ||
|
|
b4e4be1133 | ||
|
|
16926d9db5 | ||
|
|
f369a63c8d | ||
|
|
1861b0fbc9 | ||
|
|
750d4ca841 | ||
|
|
8baa466b31 | ||
|
|
dd7f9f6cee | ||
|
|
d5d81f0c4f | ||
|
|
610ae27cf9 |
7
.github/workflows/sync-to-gitee.yml
vendored
7
.github/workflows/sync-to-gitee.yml
vendored
@@ -3,12 +3,9 @@ name: Sync to Gitee
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main # Production
|
- '**' # All branchs
|
||||||
- develop # Integration
|
|
||||||
- 'release/*' # Release preparation
|
|
||||||
- 'hotfix/*' # Urgent fixes
|
|
||||||
tags:
|
tags:
|
||||||
- '*' # All version tags (v1.0.0, etc.)
|
- '**' # All version tags (v1.0.0, etc.)
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
sync:
|
sync:
|
||||||
|
|||||||
@@ -82,32 +82,19 @@ async def get_preview_chunks(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Get file content from storage backend
|
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
if not db_file.file_key:
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Check if the file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File has no storage key (legacy data not migrated)"
|
detail="File not found (possibly deleted)"
|
||||||
)
|
|
||||||
|
|
||||||
from app.services.file_storage_service import FileStorageService
|
|
||||||
import asyncio
|
|
||||||
storage_service = FileStorageService()
|
|
||||||
|
|
||||||
async def _download():
|
|
||||||
return await storage_service.download_file(db_file.file_key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_binary = asyncio.run(_download())
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
file_binary = loop.run_until_complete(_download())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"File not found in storage: {e}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. Document parsing & segmentation
|
# 7. Document parsing & segmentation
|
||||||
@@ -117,12 +104,11 @@ async def get_preview_chunks(
|
|||||||
vision_model = QWenCV(
|
vision_model = QWenCV(
|
||||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||||
lang="Chinese",
|
lang="Chinese", # Default to Chinese
|
||||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||||
)
|
)
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
res = chunk(filename=db_file.file_name,
|
res = chunk(filename=file_path,
|
||||||
binary=file_binary,
|
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=5,
|
to_page=5,
|
||||||
callback=progress_callback,
|
callback=progress_callback,
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import document_schema
|
from app.schemas import document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service
|
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -232,8 +231,7 @@ async def update_document(
|
|||||||
async def delete_document(
|
async def delete_document(
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete document
|
Delete document
|
||||||
@@ -259,7 +257,7 @@ async def delete_document(
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 3. Delete file
|
# 3. Delete file
|
||||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
|
|
||||||
# 4. Delete vector index
|
# 4. Delete vector index
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
@@ -307,25 +305,38 @@ async def parse_documents(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Get file_key for storage backend
|
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
if not db_file.file_key:
|
file_path = os.path.join(
|
||||||
api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Check if the file exists
|
||||||
|
api_logger.debug(f"Constructed file path: {file_path}")
|
||||||
|
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File has no storage key (legacy data not migrated)"
|
detail="File not found (possibly deleted)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Obtain knowledge base information
|
# 5. Obtain knowledge base information
|
||||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
# 5. Dispatch parse task with file_key (not file_path)
|
# 6. Task: Document parsing, vectorization, and storage
|
||||||
task = celery_app.send_task(
|
# from app.tasks import parse_document
|
||||||
"app.core.rag.tasks.parse_document",
|
# parse_document(file_path, document_id)
|
||||||
args=[db_file.file_key, document_id, db_file.file_name]
|
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||||
)
|
|
||||||
result = {
|
result = {
|
||||||
"task_id": task.id
|
"task_id": task.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -17,14 +19,10 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
|
|
||||||
from app.services.file_storage_service import (
|
|
||||||
FileStorageService,
|
|
||||||
generate_kb_file_key,
|
|
||||||
get_file_storage_service,
|
|
||||||
)
|
|
||||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -37,37 +35,67 @@ router = APIRouter(
|
|||||||
async def get_files(
|
async def get_files(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
page: int = Query(1, gt=0),
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
pagesize: int = Query(20, gt=0, le=100),
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Paged query file list"""
|
"""
|
||||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
|
Paged query file list
|
||||||
|
- Support filtering by kb_id and parent_id
|
||||||
|
- Support keyword search for file names
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
if page < 1 or pagesize < 1:
|
if page < 1 or pagesize < 1:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
filters = [file_model.File.kb_id == kb_id]
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
file_model.File.kb_id == kb_id
|
||||||
|
]
|
||||||
if parent_id:
|
if parent_id:
|
||||||
filters.append(file_model.File.parent_id == parent_id)
|
filters.append(file_model.File.parent_id == parent_id)
|
||||||
|
# Keyword search (fuzzy matching of file name)
|
||||||
if keywords:
|
if keywords:
|
||||||
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
try:
|
try:
|
||||||
|
api_logger.debug("Start executing file paging query")
|
||||||
total, items = file_service.get_files_paginated(
|
total, items = file_service.get_files_paginated(
|
||||||
db=db, filters=filters, page=page, pagesize=pagesize,
|
db=db,
|
||||||
orderby=orderby, desc=desc, current_user=current_user
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
result = {
|
result = {
|
||||||
"items": items,
|
"items": items,
|
||||||
"page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||||
|
|
||||||
@@ -80,14 +108,23 @@ async def create_folder(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Create a new folder"""
|
"""
|
||||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
|
Create a new folder
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_folder_data = file_schema.FileCreate(
|
api_logger.debug(f"Start creating a folder: {folder_name}")
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
create_folder = file_schema.FileCreate(
|
||||||
file_name=folder_name, file_ext='folder', file_size=0,
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=folder_name,
|
||||||
|
file_ext='folder',
|
||||||
|
file_size=0,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||||
|
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||||
@@ -101,58 +138,76 @@ async def upload_file(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Upload file to storage backend"""
|
"""
|
||||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
|
upload file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Read the contents of the file
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
# Check file size
|
||||||
file_size = len(contents)
|
file_size = len(contents)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The file is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the extension using `os.path.splitext`
|
||||||
_, file_extension = os.path.splitext(file.filename)
|
_, file_extension = os.path.splitext(file.filename)
|
||||||
file_ext = file_extension.lower()
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
# Create File record
|
created_by=current_user.id,
|
||||||
upload_file_data = file_schema.FileCreate(
|
parent_id=parent_id,
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
file_name=file.filename,
|
||||||
file_name=file.filename, file_ext=file_ext, file_size=file_size,
|
file_ext=file_extension.lower(),
|
||||||
|
file_size=file_size,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
# Upload to storage backend
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
try:
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
|
||||||
|
|
||||||
# Save file_key
|
# Save file
|
||||||
db_file.file_key = file_key
|
with open(save_path, "wb") as f:
|
||||||
db.commit()
|
f.write(contents)
|
||||||
db.refresh(db_file)
|
|
||||||
|
|
||||||
# Create document (inherit parser_config from knowledge base)
|
# Verify whether the file has been saved successfully
|
||||||
default_parser_config = {
|
if not os.path.exists(save_path):
|
||||||
"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
raise HTTPException(
|
||||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
}
|
detail="File save failed"
|
||||||
try:
|
)
|
||||||
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if db_knowledge and db_knowledge.parser_config:
|
|
||||||
default_parser_config.update(dict(db_knowledge.parser_config))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
# Create a document
|
||||||
create_data = document_schema.DocumentCreate(
|
create_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
kb_id=kb_id,
|
||||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
created_by=current_user.id,
|
||||||
file_meta={}, parser_id="naive", parser_config=default_parser_config
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||||
|
|
||||||
@@ -166,73 +221,123 @@ async def custom_text(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Custom text upload"""
|
"""
|
||||||
|
custom text
|
||||||
|
"""
|
||||||
|
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Check file content size
|
||||||
|
# 将内容编码为字节(UTF-8)
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The content is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
upload_file_data = file_schema.FileCreate(
|
upload_file = file_schema.FileCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
kb_id=kb_id,
|
||||||
file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=f"{create_data.title}.txt",
|
||||||
|
file_ext=".txt",
|
||||||
|
file_size=file_size,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
# Upload to storage backend
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
try:
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
|
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
|
||||||
|
|
||||||
db_file.file_key = file_key
|
# Save file
|
||||||
db.commit()
|
with open(save_path, "wb") as f:
|
||||||
db.refresh(db_file)
|
f.write(content_bytes)
|
||||||
|
|
||||||
|
# Verify whether the file has been saved successfully
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="File save failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a document
|
||||||
create_document_data = document_schema.DocumentCreate(
|
create_document_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
kb_id=kb_id,
|
||||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
created_by=current_user.id,
|
||||||
file_meta={}, parser_id="naive",
|
file_id=db_file.id,
|
||||||
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
file_name=db_file.file_name,
|
||||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{file_id}", response_model=Any)
|
@router.get("/{file_id}", response_model=Any)
|
||||||
async def get_file(
|
async def get_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Download file by file_id"""
|
"""
|
||||||
|
Download the file based on the file_id
|
||||||
|
- Query file information from the database
|
||||||
|
- Construct the file path and check if it exists
|
||||||
|
- Return a FileResponse to download the file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
||||||
|
|
||||||
|
# 1. Query file information from the database
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
if not db_file:
|
if not db_file:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
if not db_file.file_key:
|
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# 3. Check if the file exists
|
||||||
content = await storage_service.download_file(db_file.file_key)
|
if not os.path.exists(file_path):
|
||||||
except Exception as e:
|
raise HTTPException(
|
||||||
api_logger.error(f"Storage download failed: {e}")
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
|
detail="File not found (possibly deleted)"
|
||||||
|
)
|
||||||
|
|
||||||
import mimetypes
|
# 4.Return FileResponse (automatically handle download)
|
||||||
media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
|
return FileResponse(
|
||||||
return Response(
|
path=file_path,
|
||||||
content=content,
|
filename=db_file.file_name, # Use original file name
|
||||||
media_type=media_type,
|
media_type="application/octet-stream" # Universal binary stream type
|
||||||
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -243,22 +348,50 @@ async def update_file(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Update file information (such as file name)"""
|
"""
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
Update file information (such as file name)
|
||||||
if not db_file:
|
- Only specified fields such as file_name are allowed to be modified
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
"""
|
||||||
|
api_logger.debug(f"Query the file to be updated: {file_id}")
|
||||||
|
|
||||||
|
# 1. Check if the file exists
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Update fields (only update non-null fields)
|
||||||
|
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||||
|
updated_fields = []
|
||||||
for field, value in update_data.dict(exclude_unset=True).items():
|
for field, value in update_data.dict(exclude_unset=True).items():
|
||||||
if hasattr(db_file, field):
|
if hasattr(db_file, field):
|
||||||
setattr(db_file, field, value)
|
old_value = getattr(db_file, field)
|
||||||
|
if old_value != value:
|
||||||
|
# update value
|
||||||
|
setattr(db_file, field, value)
|
||||||
|
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||||
|
|
||||||
|
if updated_fields:
|
||||||
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
|
# 3. Save to database
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_file)
|
db.refresh(db_file)
|
||||||
|
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
|
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"File update failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return the updated file
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||||
|
|
||||||
|
|
||||||
@@ -266,43 +399,60 @@ async def update_file(
|
|||||||
async def delete_file(
|
async def delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Delete a file or folder"""
|
"""
|
||||||
api_logger.info(f"Request to delete file: file_id={file_id}")
|
Delete a file or folder
|
||||||
await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
"""
|
||||||
|
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
||||||
|
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
return success(msg="File deleted successfully")
|
return success(msg="File deleted successfully")
|
||||||
|
|
||||||
|
|
||||||
async def _delete_file(
|
async def _delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session,
|
db: Session = Depends(get_db),
|
||||||
current_user: User,
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete a file or folder from storage and database"""
|
"""
|
||||||
|
Delete a file or folder
|
||||||
|
"""
|
||||||
|
# 1. Check if the file exists
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
if not db_file:
|
if not db_file:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
# Delete from storage backend
|
# 2. Construct physical path
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.id)
|
||||||
|
) if db_file.file_ext == 'folder' else Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Delete physical files/folders
|
||||||
|
try:
|
||||||
|
if file_path.exists():
|
||||||
|
if db_file.file_ext == 'folder':
|
||||||
|
shutil.rmtree(file_path) # Recursively delete folders
|
||||||
|
else:
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete physical file/folder: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4.Delete db_file
|
||||||
if db_file.file_ext == 'folder':
|
if db_file.file_ext == 'folder':
|
||||||
# For folders, delete all child files from storage first
|
|
||||||
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
|
|
||||||
for child in child_files:
|
|
||||||
if child.file_key:
|
|
||||||
try:
|
|
||||||
await storage_service.delete_file(child.file_key)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
|
|
||||||
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||||
else:
|
|
||||||
if db_file.file_key:
|
|
||||||
try:
|
|
||||||
await storage_service.delete_file(db_file.file_key)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
|
|
||||||
|
|
||||||
db.delete(db_file)
|
db.delete(db_file)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 多 Agent 非流式返回
|
# workflow 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def update_workspace_members(
|
|||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def delete_workspace_member(
|
async def delete_workspace_member(
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -230,7 +230,7 @@ def delete_workspace_member(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
workspace_service.delete_workspace_member(
|
await workspace_service.delete_workspace_member(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
member_id=member_id,
|
member_id=member_id,
|
||||||
|
|||||||
@@ -241,6 +241,8 @@ class Settings:
|
|||||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||||
|
|
||||||
|
SANDBOX_URL: str = os.getenv("SANDBOX_URL", "")
|
||||||
|
|
||||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class RedBearModelFactory:
|
|||||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||||
if config.deep_thinking:
|
if config.deep_thinking:
|
||||||
budget = config.thinking_budget_tokens or 10000
|
budget = config.thinking_budget_tokens or 1024
|
||||||
params["additional_model_request_fields"] = {
|
params["additional_model_request_fields"] = {
|
||||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ Transcribe the content from the provided PDF page image into clean Markdown form
|
|||||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||||
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
|
|
||||||
|
|
||||||
{% if page %}
|
{% if page %}
|
||||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
# Author: Eternity
|
# Author: Eternity
|
||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/10 13:33
|
# @Time : 2026/2/10 13:33
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
@@ -142,10 +141,9 @@ class GraphBuilder:
|
|||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||||
if node_info.get("branch") is not None:
|
branch_nodes.append(
|
||||||
branch_nodes.append(
|
(node_info["id"], node_info["branch"])
|
||||||
(node_info["id"], node_info["branch"])
|
)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||||
output_nodes.append(node_info["id"])
|
output_nodes.append(node_info["id"])
|
||||||
@@ -316,12 +314,9 @@ class GraphBuilder:
|
|||||||
for idx in range(len(related_edge)):
|
for idx in range(len(related_edge)):
|
||||||
# Generate a condition expression for each edge
|
# Generate a condition expression for each edge
|
||||||
# Used later to determine which branch to take based on the node's output
|
# Used later to determine which branch to take based on the node's output
|
||||||
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
# For other branch nodes (e.g. HTTP), use output field
|
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||||
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
|
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||||
related_edge[idx]['condition'] = (
|
|
||||||
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# Wrap node's run method to avoid closure issues
|
# Wrap node's run method to avoid closure issues
|
||||||
|
|||||||
@@ -18,17 +18,10 @@ class AssignerNode(BaseNode):
|
|||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
self._input_data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
|
||||||
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
|
|
||||||
if self._input_data is not None:
|
|
||||||
return self._input_data
|
|
||||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the assignment operation defined by this node.
|
Execute the assignment operation defined by this node.
|
||||||
@@ -41,9 +34,6 @@ class AssignerNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
None or the result of the assignment operation.
|
None or the result of the assignment operation.
|
||||||
"""
|
"""
|
||||||
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
|
|
||||||
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
|
|
||||||
|
|
||||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||||
self.typed_config = AssignerNodeConfig(**self.config)
|
self.typed_config = AssignerNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} 开始执行")
|
logger.info(f"节点 {self.node_id} 开始执行")
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -23,9 +22,6 @@ from app.services.multimodal_service import MultimodalService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 匹配模板变量 {{xxx}} 的正则
|
|
||||||
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
class NodeExecutionError(Exception):
|
class NodeExecutionError(Exception):
|
||||||
"""节点执行失败异常。
|
"""节点执行失败异常。
|
||||||
@@ -507,29 +503,10 @@ class BaseNode(ABC):
|
|||||||
variable_pool: The variable pool used for reading and writing variables.
|
variable_pool: The variable pool used for reading and writing variables.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary containing the node's input data with all template
|
A dictionary containing the node's input data.
|
||||||
variables resolved to their actual runtime values.
|
|
||||||
"""
|
"""
|
||||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
# Default implementation returns the node configuration
|
||||||
|
return {"config": self.config}
|
||||||
@staticmethod
|
|
||||||
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
|
||||||
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: 节点的原始配置(可能包含模板变量)。
|
|
||||||
variable_pool: 变量池,用于解析模板变量。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
|
||||||
"""
|
|
||||||
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
|
||||||
return BaseNode._render_template(config, variable_pool, strict=False)
|
|
||||||
elif isinstance(config, dict):
|
|
||||||
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
|
||||||
elif isinstance(config, list):
|
|
||||||
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
|
||||||
return config
|
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> Any:
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
"""Extracts the actual output from the business result.
|
"""Extracts the actual output from the business result.
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import BaseNode
|
from app.core.workflow.nodes import BaseNode
|
||||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -131,7 +132,7 @@ class CodeNode(BaseNode):
|
|||||||
|
|
||||||
async with httpx.AsyncClient(timeout=60) as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"http://sandbox:8194/v1/sandbox/run",
|
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
"x-api-key": 'redbear-sandbox'
|
"x-api-key": 'redbear-sandbox'
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -121,10 +121,7 @@ class DocExtractorNode(BaseNode):
|
|||||||
return business_result
|
return business_result
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
file_selector = self.config.get("file_selector", "")
|
return {"file_selector": self.config.get("file_selector")}
|
||||||
# 将变量选择器(如 sys.files)解析为实际值
|
|
||||||
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
|
||||||
return {"file_selector": resolved}
|
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
config = DocExtractorNodeConfig(**self.config)
|
config = DocExtractorNodeConfig(**self.config)
|
||||||
@@ -185,7 +182,7 @@ class DocExtractorNode(BaseNode):
|
|||||||
mime_type=f"image/{ext}",
|
mime_type=f"image/{ext}",
|
||||||
is_file=True,
|
is_file=True,
|
||||||
).model_dump())
|
).model_dump())
|
||||||
text = text + f"\n{placeholder}: {url}"
|
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class NodeType(StrEnum):
|
|||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
|
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import uuid
|
|||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
|
|
||||||
@@ -50,20 +49,6 @@ class MemoryWindowSetting(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMErrorHandleConfig(BaseModel):
|
|
||||||
"""LLM 异常处理配置"""
|
|
||||||
|
|
||||||
method: HttpErrorHandle = Field(
|
|
||||||
default=HttpErrorHandle.NONE,
|
|
||||||
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
|
|
||||||
)
|
|
||||||
|
|
||||||
output: str = Field(
|
|
||||||
default="",
|
|
||||||
description="LLM 异常时返回的默认输出文本(method=default 时生效)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMNodeConfig(BaseNodeConfig):
|
class LLMNodeConfig(BaseNodeConfig):
|
||||||
"""LLM 节点配置
|
"""LLM 节点配置
|
||||||
|
|
||||||
@@ -167,11 +152,6 @@ class LLMNodeConfig(BaseNodeConfig):
|
|||||||
description="输出变量定义(自动生成,通常不需要修改)"
|
description="输出变量定义(自动生成,通常不需要修改)"
|
||||||
)
|
)
|
||||||
|
|
||||||
error_handle: LLMErrorHandleConfig = Field(
|
|
||||||
default_factory=LLMErrorHandleConfig,
|
|
||||||
description="LLM 异常处理配置",
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("messages", "prompt")
|
@field_validator("messages", "prompt")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_input_mode(cls, v):
|
def validate_input_mode(cls, v):
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
|
||||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
@@ -77,7 +76,7 @@ class LLMNode(BaseNode):
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||||
@@ -240,7 +239,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||||
"""非流式执行 LLM 调用
|
"""非流式执行 LLM 调用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -248,36 +247,28 @@ class LLMNode(BaseNode):
|
|||||||
variable_pool: 变量池
|
variable_pool: 变量池
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
|
LLM 响应消息
|
||||||
{"llm_result": None, "branch_signal": "ERROR"} on branch error
|
|
||||||
"""
|
"""
|
||||||
try:
|
# self.typed_config = LLMNodeConfig(**self.config)
|
||||||
# self.typed_config = LLMNodeConfig(**self.config)
|
llm = await self._prepare_llm(state, variable_pool, False)
|
||||||
llm = await self._prepare_llm(state, variable_pool, False)
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||||
|
|
||||||
# 调用 LLM(支持字符串或消息列表)
|
# 调用 LLM(支持字符串或消息列表)
|
||||||
response = await llm.ainvoke(self.messages)
|
response = await llm.ainvoke(self.messages)
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(response, 'content'):
|
if hasattr(response, 'content'):
|
||||||
content = self.process_model_output(response.content)
|
content = self.process_model_output(response.content)
|
||||||
else:
|
else:
|
||||||
content = str(response)
|
content = str(response)
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||||
|
|
||||||
# 返回 AIMessage(包含响应元数据)
|
# 返回 AIMessage(包含响应元数据)
|
||||||
return {
|
return AIMessage(content=content, response_metadata={
|
||||||
"llm_result": AIMessage(content=content, response_metadata={
|
**response.response_metadata,
|
||||||
**response.response_metadata,
|
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
})
|
||||||
}),
|
|
||||||
"branch_signal": "SUCCESS",
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
|
|
||||||
return self._handle_llm_error(e)
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""提取输入数据(用于记录)"""
|
"""提取输入数据(用于记录)"""
|
||||||
@@ -295,36 +286,16 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _extract_output(self, business_result: Any) -> dict:
|
def _extract_output(self, business_result: Any) -> str:
|
||||||
"""从业务结果中提取输出变量
|
"""从 AIMessage 中提取文本内容"""
|
||||||
|
|
||||||
支持新旧两种格式:
|
|
||||||
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
|
|
||||||
- 旧格式:AIMessage(向后兼容)
|
|
||||||
"""
|
|
||||||
if isinstance(business_result, dict) and "branch_signal" in business_result:
|
|
||||||
llm_result = business_result.get("llm_result")
|
|
||||||
if isinstance(llm_result, AIMessage):
|
|
||||||
return {
|
|
||||||
"output": llm_result.content,
|
|
||||||
"branch_signal": business_result["branch_signal"],
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"output": str(llm_result) if llm_result else "",
|
|
||||||
"branch_signal": business_result["branch_signal"],
|
|
||||||
}
|
|
||||||
# 旧格式向后兼容
|
|
||||||
if isinstance(business_result, AIMessage):
|
if isinstance(business_result, AIMessage):
|
||||||
return {"output": business_result.content, "branch_signal": "SUCCESS"}
|
return business_result.content
|
||||||
return {"output": str(business_result), "branch_signal": "SUCCESS"}
|
return str(business_result)
|
||||||
|
|
||||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||||
"""从业务结果中提取 token 使用情况"""
|
"""从 AIMessage 中提取 token 使用情况"""
|
||||||
llm_result = business_result
|
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||||
if isinstance(business_result, dict):
|
usage = business_result.response_metadata.get('token_usage')
|
||||||
llm_result = business_result.get("llm_result", business_result)
|
|
||||||
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
|
|
||||||
usage = llm_result.response_metadata.get('token_usage')
|
|
||||||
if usage:
|
if usage:
|
||||||
return {
|
return {
|
||||||
"prompt_tokens": usage.get('input_tokens', 0),
|
"prompt_tokens": usage.get('input_tokens', 0),
|
||||||
@@ -333,44 +304,6 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _handle_llm_error(self, error: Exception) -> dict:
|
|
||||||
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
|
|
||||||
|
|
||||||
Args:
|
|
||||||
error: LLM 调用中捕获的异常
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
|
|
||||||
or default output for default mode
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
原异常(当 error_handle.method 为 NONE 时)
|
|
||||||
"""
|
|
||||||
if self.typed_config is None:
|
|
||||||
raise error
|
|
||||||
|
|
||||||
match self.typed_config.error_handle.method:
|
|
||||||
case HttpErrorHandle.NONE:
|
|
||||||
raise error
|
|
||||||
case HttpErrorHandle.DEFAULT:
|
|
||||||
logger.warning(
|
|
||||||
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
|
|
||||||
)
|
|
||||||
default_output = self.typed_config.error_handle.output or ""
|
|
||||||
return {
|
|
||||||
"llm_result": AIMessage(content=default_output, response_metadata={}),
|
|
||||||
"branch_signal": "SUCCESS",
|
|
||||||
}
|
|
||||||
case HttpErrorHandle.BRANCH:
|
|
||||||
logger.warning(
|
|
||||||
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"llm_result": None,
|
|
||||||
"branch_signal": "ERROR",
|
|
||||||
}
|
|
||||||
raise error
|
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
"""流式执行 LLM 调用
|
"""流式执行 LLM 调用
|
||||||
|
|
||||||
@@ -383,58 +316,54 @@ class LLMNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
self.typed_config = LLMNodeConfig(**self.config)
|
self.typed_config = LLMNodeConfig(**self.config)
|
||||||
|
|
||||||
try:
|
llm = await self._prepare_llm(state, variable_pool, True)
|
||||||
llm = await self._prepare_llm(state, variable_pool, True)
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||||
|
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||||
|
|
||||||
# 累积完整响应
|
# 累积完整响应
|
||||||
full_response = ""
|
full_response = ""
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
last_meta_data = {}
|
last_meta_data = {}
|
||||||
last_usage_metadata = {}
|
last_usage_metadata = {}
|
||||||
async for chunk in llm.astream(self.messages):
|
async for chunk in llm.astream(self.messages):
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
content = self.process_model_output(chunk.content)
|
content = self.process_model_output(chunk.content)
|
||||||
else:
|
else:
|
||||||
content = str(chunk)
|
content = str(chunk)
|
||||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||||
last_meta_data = chunk.response_metadata
|
last_meta_data = chunk.response_metadata
|
||||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||||
last_usage_metadata = chunk.usage_metadata
|
last_usage_metadata = chunk.usage_metadata
|
||||||
|
|
||||||
# 只有当内容不为空时才处理
|
# 只有当内容不为空时才处理
|
||||||
if content:
|
if content:
|
||||||
full_response += content
|
full_response += content
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
|
|
||||||
# 流式返回每个文本片段
|
# 流式返回每个文本片段
|
||||||
yield {
|
yield {
|
||||||
"__final__": False,
|
"__final__": False,
|
||||||
"chunk": content
|
"chunk": content
|
||||||
}
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"__final__": False,
|
|
||||||
"chunk": "",
|
|
||||||
"done": True
|
|
||||||
}
|
|
||||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
|
||||||
|
|
||||||
# 构建完整的 AIMessage(包含元数据)
|
|
||||||
final_message = AIMessage(
|
|
||||||
content=full_response,
|
|
||||||
response_metadata={
|
|
||||||
**last_meta_data,
|
|
||||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
|
||||||
# yield 完成标记
|
yield {
|
||||||
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
|
"__final__": False,
|
||||||
except Exception as e:
|
"chunk": "",
|
||||||
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
|
"done": True
|
||||||
error_result = self._handle_llm_error(e)
|
}
|
||||||
yield {"__final__": True, "result": error_result}
|
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||||
|
|
||||||
|
# 构建完整的 AIMessage(包含元数据)
|
||||||
|
final_message = AIMessage(
|
||||||
|
content=full_response,
|
||||||
|
response_metadata={
|
||||||
|
**last_meta_data,
|
||||||
|
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# yield 完成标记
|
||||||
|
yield {"__final__": True, "result": final_message}
|
||||||
|
|||||||
@@ -15,5 +15,4 @@ class File(Base):
|
|||||||
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
||||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||||
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
||||||
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||||
@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
|
|||||||
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
||||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||||
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
||||||
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
||||||
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ class FileBase(BaseModel):
|
|||||||
file_ext: str
|
file_ext: str
|
||||||
file_size: int
|
file_size: int
|
||||||
file_url: str | None = None
|
file_url: str | None = None
|
||||||
file_key: str | None = None
|
|
||||||
created_at: datetime.datetime | None = None
|
created_at: datetime.datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -161,7 +161,10 @@ class AppChatService:
|
|||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
@@ -448,7 +451,10 @@ class AppChatService:
|
|||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
|
|||||||
@@ -102,11 +102,6 @@ class AppDslService:
|
|||||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||||
]
|
]
|
||||||
return enriched
|
return enriched
|
||||||
if app_type == AppType.WORKFLOW:
|
|
||||||
enriched = {**cfg}
|
|
||||||
if "nodes" in cfg:
|
|
||||||
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
|
||||||
return enriched
|
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||||
@@ -115,7 +110,7 @@ class AppDslService:
|
|||||||
config_data = {
|
config_data = {
|
||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"edges": config.edges if config else [],
|
"edges": config.edges if config else [],
|
||||||
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
"nodes": config.nodes if config else [],
|
||||||
"features": config.features if config else {},
|
"features": config.features if config else {},
|
||||||
"execution_config": config.execution_config if config else {},
|
"execution_config": config.execution_config if config else {},
|
||||||
"triggers": config.triggers if config else [],
|
"triggers": config.triggers if config else [],
|
||||||
@@ -195,23 +190,6 @@ class AppDslService:
|
|||||||
def _enrich_tools(self, tools: list) -> list:
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
|
||||||
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
|
||||||
enriched_nodes = []
|
|
||||||
for node in (nodes or []):
|
|
||||||
node_type = node.get("type")
|
|
||||||
config = dict(node.get("config") or {})
|
|
||||||
|
|
||||||
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
|
||||||
model_id = config.get("model_id")
|
|
||||||
if model_id:
|
|
||||||
config["model_ref"] = self._model_ref(model_id)
|
|
||||||
del config["model_id"]
|
|
||||||
|
|
||||||
enriched_nodes.append({**node, "config": config})
|
|
||||||
return enriched_nodes
|
|
||||||
|
|
||||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||||
if not skill_id:
|
if not skill_id:
|
||||||
return None
|
return None
|
||||||
@@ -642,16 +620,16 @@ class AppDslService:
|
|||||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||||
config["knowledge_bases"] = resolved_kbs
|
config["knowledge_bases"] = resolved_kbs
|
||||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||||
model_ref = config.get("model_ref") or config.get("model_id")
|
model_ref = config.get("model_id")
|
||||||
if model_ref:
|
if model_ref:
|
||||||
ref_dict = None
|
ref_dict = None
|
||||||
if isinstance(model_ref, dict):
|
if isinstance(model_ref, dict):
|
||||||
ref_dict = {
|
ref_id = model_ref.get("id")
|
||||||
"id": model_ref.get("id"),
|
ref_name = model_ref.get("name")
|
||||||
"name": model_ref.get("name"),
|
if ref_id:
|
||||||
"provider": model_ref.get("provider"),
|
ref_dict = {"id": ref_id}
|
||||||
"type": model_ref.get("type")
|
elif ref_name is not None:
|
||||||
}
|
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
||||||
elif isinstance(model_ref, str):
|
elif isinstance(model_ref, str):
|
||||||
try:
|
try:
|
||||||
uuid.UUID(model_ref)
|
uuid.UUID(model_ref)
|
||||||
@@ -662,18 +640,12 @@ class AppDslService:
|
|||||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||||
if resolved_model_id:
|
if resolved_model_id:
|
||||||
config["model_id"] = resolved_model_id
|
config["model_id"] = resolved_model_id
|
||||||
if "model_ref" in config:
|
|
||||||
del config["model_ref"]
|
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
if "model_ref" in config:
|
|
||||||
del config["model_ref"]
|
|
||||||
else:
|
else:
|
||||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||||
config["model_id"] = None
|
config["model_id"] = None
|
||||||
if "model_ref" in config:
|
|
||||||
del config["model_ref"]
|
|
||||||
resolved_nodes.append({**node, "config": config})
|
resolved_nodes.append({**node, "config": config})
|
||||||
return resolved_nodes
|
return resolved_nodes
|
||||||
|
|
||||||
|
|||||||
@@ -650,7 +650,10 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -924,7 +927,10 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
|
|||||||
@@ -34,7 +34,26 @@ def generate_file_key(
|
|||||||
Generate a unique file key for storage.
|
Generate a unique file key for storage.
|
||||||
|
|
||||||
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant UUID.
|
||||||
|
workspace_id: The workspace UUID.
|
||||||
|
file_id: The file UUID.
|
||||||
|
file_ext: The file extension (e.g., '.pdf', '.txt').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A unique file key string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> generate_file_key(
|
||||||
|
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
|
||||||
|
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
|
||||||
|
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
|
||||||
|
... '.pdf'
|
||||||
|
... )
|
||||||
|
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
|
||||||
"""
|
"""
|
||||||
|
# Ensure file_ext starts with a dot
|
||||||
if file_ext and not file_ext.startswith('.'):
|
if file_ext and not file_ext.startswith('.'):
|
||||||
file_ext = f'.{file_ext}'
|
file_ext = f'.{file_ext}'
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
@@ -42,21 +61,6 @@ def generate_file_key(
|
|||||||
return f"{tenant_id}/{file_id}{file_ext}"
|
return f"{tenant_id}/{file_id}{file_ext}"
|
||||||
|
|
||||||
|
|
||||||
def generate_kb_file_key(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
file_id: uuid.UUID,
|
|
||||||
file_ext: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate a file key for knowledge base files.
|
|
||||||
|
|
||||||
Format: kb/{kb_id}/{file_id}{file_ext}
|
|
||||||
"""
|
|
||||||
if file_ext and not file_ext.startswith('.'):
|
|
||||||
file_ext = f'.{file_ext}'
|
|
||||||
return f"kb/{kb_id}/{file_id}{file_ext}"
|
|
||||||
|
|
||||||
|
|
||||||
class FileStorageService:
|
class FileStorageService:
|
||||||
"""
|
"""
|
||||||
High-level service for file storage operations.
|
High-level service for file storage operations.
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ class MultimodalService:
|
|||||||
# 在文本内容中追加图片位置标记
|
# 在文本内容中追加图片位置标记
|
||||||
if result and result[-1].get("type") in ("text", "document"):
|
if result and result[-1].get("type") in ("text", "document"):
|
||||||
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
||||||
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
|
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">"
|
||||||
# 将图片以视觉格式追加到消息内容中
|
# 将图片以视觉格式追加到消息内容中
|
||||||
img_file = FileInput(
|
img_file = FileInput(
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
|
|||||||
@@ -554,13 +554,16 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "workflow_end":
|
case "workflow_end":
|
||||||
|
data = {
|
||||||
|
"elapsed_time": payload.get("elapsed_time"),
|
||||||
|
"message_length": len(payload.get("output", "")),
|
||||||
|
"error": payload.get("error", "")
|
||||||
|
}
|
||||||
|
if "citations" in payload and payload["citations"]:
|
||||||
|
data["citations"] = payload["citations"]
|
||||||
return {
|
return {
|
||||||
"event": "end",
|
"event": "end",
|
||||||
"data": {
|
"data": data
|
||||||
"elapsed_time": payload.get("elapsed_time"),
|
|
||||||
"message_length": len(payload.get("output", "")),
|
|
||||||
"error": payload.get("error", "")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.models.workspace_model import (
|
|||||||
)
|
)
|
||||||
from app.repositories import workspace_repository
|
from app.repositories import workspace_repository
|
||||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||||
|
from app.services.session_service import SessionService
|
||||||
from app.schemas.workspace_schema import (
|
from app.schemas.workspace_schema import (
|
||||||
InviteAcceptRequest,
|
InviteAcceptRequest,
|
||||||
InviteValidateResponse,
|
InviteValidateResponse,
|
||||||
@@ -58,7 +59,7 @@ def switch_workspace(
|
|||||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def delete_workspace_member(
|
async def delete_workspace_member(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
@@ -76,10 +77,29 @@ def delete_workspace_member(
|
|||||||
BizCode.WORKSPACE_NOT_FOUND)
|
BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
deleted_user = workspace_member.user
|
||||||
workspace_member.is_active = False
|
workspace_member.is_active = False
|
||||||
workspace_member.user.current_workspace_id = None
|
deleted_user.current_workspace_id = None
|
||||||
|
|
||||||
|
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
|
||||||
|
if not deleted_user.is_superuser:
|
||||||
|
remaining = (
|
||||||
|
db.query(WorkspaceMember)
|
||||||
|
.filter(
|
||||||
|
WorkspaceMember.user_id == deleted_user.id,
|
||||||
|
WorkspaceMember.workspace_id != workspace_id,
|
||||||
|
WorkspaceMember.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
if remaining == 0:
|
||||||
|
deleted_user.is_active = False
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
|
# 使被删除成员的所有 token 立即失效
|
||||||
|
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||||
|
|||||||
@@ -210,14 +210,9 @@ def _build_vision_model(file_path: str, db_knowledge):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
||||||
def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
def parse_document(file_path: str, document_id: uuid.UUID):
|
||||||
"""
|
"""
|
||||||
Document parsing, vectorization, and storage.
|
Document parsing, vectorization, and storage
|
||||||
|
|
||||||
Args:
|
|
||||||
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
|
|
||||||
document_id: Document UUID
|
|
||||||
file_name: Original file name (used for extension detection in chunk())
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_document = None
|
db_document = None
|
||||||
@@ -228,6 +223,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
|
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
||||||
if not isinstance(document_id, uuid.UUID):
|
if not isinstance(document_id, uuid.UUID):
|
||||||
document_id = uuid.UUID(str(document_id))
|
document_id = uuid.UUID(str(document_id))
|
||||||
|
|
||||||
@@ -238,11 +234,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
||||||
|
|
||||||
# Use file_name from argument or fall back to document record
|
# 1. Document parsing & segmentation
|
||||||
if not file_name:
|
|
||||||
file_name = db_document.file_name
|
|
||||||
|
|
||||||
# 1. Download file from storage backend
|
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
db_document.progress = 0.0
|
db_document.progress = 0.0
|
||||||
@@ -253,36 +245,45 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_document)
|
db.refresh(db_document)
|
||||||
|
|
||||||
# Read file content from storage backend (no NFS dependency)
|
|
||||||
from app.services.file_storage_service import FileStorageService
|
|
||||||
import asyncio
|
|
||||||
storage_service = FileStorageService()
|
|
||||||
|
|
||||||
async def _download():
|
|
||||||
return await storage_service.download_file(file_key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_binary = asyncio.run(_download())
|
|
||||||
except RuntimeError:
|
|
||||||
# If there's already a running loop (e.g. in some worker configurations)
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
file_binary = loop.run_until_complete(_download())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
if not file_binary:
|
|
||||||
raise IOError(f"Downloaded empty file from storage: {file_key}")
|
|
||||||
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
|
|
||||||
|
|
||||||
def progress_callback(prog=None, msg=None):
|
def progress_callback(prog=None, msg=None):
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
||||||
|
|
||||||
# Prepare vision_model for parsing
|
# Prepare vision_model for parsing
|
||||||
vision_model = _build_vision_model(file_name, db_knowledge)
|
vision_model = _build_vision_model(file_path, db_knowledge)
|
||||||
|
|
||||||
|
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
|
||||||
|
# python-docx 等库在 binary=None 时会用路径直接打开文件,
|
||||||
|
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
|
||||||
|
max_wait_seconds = 30
|
||||||
|
wait_interval = 2
|
||||||
|
waited = 0
|
||||||
|
file_binary = None
|
||||||
|
while waited <= max_wait_seconds:
|
||||||
|
# os.listdir 强制 NFS 客户端刷新目录缓存
|
||||||
|
parent_dir = os.path.dirname(file_path)
|
||||||
|
try:
|
||||||
|
os.listdir(parent_dir)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_binary = f.read()
|
||||||
|
if not file_binary:
|
||||||
|
# NFS 上文件存在但内容为空(可能还在同步中)
|
||||||
|
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
|
||||||
|
break
|
||||||
|
except (FileNotFoundError, IOError) as e:
|
||||||
|
if waited >= max_wait_seconds:
|
||||||
|
raise type(e)(
|
||||||
|
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
|
||||||
|
)
|
||||||
|
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
waited += wait_interval
|
||||||
|
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
||||||
res = chunk(filename=file_name,
|
res = chunk(filename=file_path,
|
||||||
binary=file_binary,
|
binary=file_binary,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=DEFAULT_PARSE_TO_PAGE,
|
to_page=DEFAULT_PARSE_TO_PAGE,
|
||||||
|
|||||||
@@ -8,12 +8,11 @@ import { type FC, useRef, useEffect, useState } from 'react'
|
|||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import Markdown from '@/components/Markdown'
|
import Markdown from '@/components/Markdown'
|
||||||
import type { ChatContentProps } from './types'
|
import type { ChatContentProps } from './types'
|
||||||
import { Spin, Image, Flex, Button } from 'antd'
|
import { Spin, Flex, Button } from 'antd'
|
||||||
import { SoundOutlined } from '@ant-design/icons'
|
import { SoundOutlined } from '@ant-design/icons'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import AudioPlayer from './AudioPlayer'
|
import MessageFiles from './MessageFiles'
|
||||||
import VideoPlayer from './VideoPlayer'
|
|
||||||
|
|
||||||
const getFileUrl = (file: any) => {
|
const getFileUrl = (file: any) => {
|
||||||
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
@@ -149,72 +148,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
{labelFormat(item)}
|
{labelFormat(item)}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
{item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
|
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} />
|
||||||
{item.meta_data?.files?.map((file) => {
|
|
||||||
if (file.type.includes('image')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
|
|
||||||
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('video')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className="rb:w-50">
|
|
||||||
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
|
|
||||||
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('audio')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className="rb:w-50">
|
|
||||||
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const documentType = (file.file_type || file.type)?.split('/')
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
key={file.url || file.uid}
|
|
||||||
align="center"
|
|
||||||
gap={10}
|
|
||||||
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
|
||||||
onClick={() => handleDownload(file)}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className={clsx(
|
|
||||||
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
|
||||||
file.type?.includes('pdf')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
|
|
||||||
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
|
|
||||||
: file.type?.includes('csv')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
|
|
||||||
: file.type?.includes('html')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/html.svg')]"
|
|
||||||
: file.type?.includes('json')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/json.svg')]"
|
|
||||||
: file.type?.includes('ppt')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
|
|
||||||
: file.type?.includes('markdown')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/md.svg')]"
|
|
||||||
: file.type?.includes('text')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
|
||||||
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
|
|
||||||
? "rb:bg-[url('@/assets/images/file/word.svg')]"
|
|
||||||
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
|
||||||
)}
|
|
||||||
></div>
|
|
||||||
<div className="rb:flex-1 rb:w-32.5">
|
|
||||||
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
|
||||||
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
|
|
||||||
</div>
|
|
||||||
</Flex>
|
|
||||||
)
|
|
||||||
})}
|
|
||||||
</Flex>}
|
|
||||||
{/* Message bubble */}
|
{/* Message bubble */}
|
||||||
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
||||||
// Error message style (content is null and not assistant message)
|
// Error message style (content is null and not assistant message)
|
||||||
|
|||||||
87
web/src/components/Chat/MessageFiles.tsx
Normal file
87
web/src/components/Chat/MessageFiles.tsx
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import { Image, Flex } from 'antd'
|
||||||
|
import clsx from 'clsx'
|
||||||
|
import AudioPlayer from './AudioPlayer'
|
||||||
|
import VideoPlayer from './VideoPlayer'
|
||||||
|
|
||||||
|
const getFileUrl = (file: any) =>
|
||||||
|
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
|
|
||||||
|
const DOC_ICONS: [string[], string][] = [
|
||||||
|
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
|
||||||
|
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
|
||||||
|
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
|
||||||
|
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
|
||||||
|
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
|
||||||
|
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
|
||||||
|
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
|
||||||
|
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
|
||||||
|
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
|
||||||
|
]
|
||||||
|
|
||||||
|
const getDocIcon = (parts: string[]) => {
|
||||||
|
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
|
||||||
|
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MessageFilesProps {
|
||||||
|
files: any[]
|
||||||
|
contentClassNames?: string | Record<string, boolean>
|
||||||
|
onDownload: (file: any) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
|
||||||
|
if (!files?.length) return null
|
||||||
|
return (
|
||||||
|
<Flex gap={8} vertical align="end" className="rb:mb-2!">
|
||||||
|
{files.map((file) => {
|
||||||
|
const key = file.url || file.uid
|
||||||
|
if (file.type.includes('image')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
|
||||||
|
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('video')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className="rb:w-50">
|
||||||
|
<VideoPlayer src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('audio')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className="rb:w-50">
|
||||||
|
<AudioPlayer src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
const documentType = (file.file_type || file.type)?.split('/') ?? []
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
key={key}
|
||||||
|
align="center"
|
||||||
|
gap={10}
|
||||||
|
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
||||||
|
onClick={() => onDownload(file)}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={clsx(
|
||||||
|
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
||||||
|
getDocIcon(documentType)
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<div className="rb:flex-1 rb:w-32.5">
|
||||||
|
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
||||||
|
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
|
||||||
|
{documentType?.[documentType.length - 1]} · {file.size}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Flex>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</Flex>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MessageFiles
|
||||||
@@ -3,14 +3,14 @@ import { Popover, type PopoverProps } from 'antd'
|
|||||||
import Tag, { type TagProps } from '@/components/Tag'
|
import Tag, { type TagProps } from '@/components/Tag'
|
||||||
|
|
||||||
interface OverflowTagsProps {
|
interface OverflowTagsProps {
|
||||||
items: ReactNode[];
|
items?: ReactNode[];
|
||||||
gap?: number;
|
gap?: number;
|
||||||
numTagColor?: TagProps['color'];
|
numTagColor?: TagProps['color'];
|
||||||
numTag?: (num?: number) => ReactNode;
|
numTag?: (num?: number) => ReactNode;
|
||||||
popoverProps?: PopoverProps | false;
|
popoverProps?: PopoverProps | false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
|
const OverflowTags = ({ items = [], gap = 8, numTagColor = 'default', numTag, popoverProps }: OverflowTagsProps) => {
|
||||||
const containerRef = useRef<HTMLDivElement>(null)
|
const containerRef = useRef<HTMLDivElement>(null)
|
||||||
const measureRef = useRef<HTMLDivElement>(null)
|
const measureRef = useRef<HTMLDivElement>(null)
|
||||||
const [visibleCount, setVisibleCount] = useState(items.length)
|
const [visibleCount, setVisibleCount] = useState(items.length)
|
||||||
@@ -20,7 +20,7 @@ const OverflowTags = ({ items, gap = 8, numTagColor = 'default', numTag, popover
|
|||||||
if (!measure || containerWidth === 0) return
|
if (!measure || containerWidth === 0) return
|
||||||
|
|
||||||
const children = Array.from(measure.children) as HTMLElement[]
|
const children = Array.from(measure.children) as HTMLElement[]
|
||||||
if (!children.length) return
|
if (!children.length) { setVisibleCount(0); return }
|
||||||
|
|
||||||
// last child is the sample +N tag
|
// last child is the sample +N tag
|
||||||
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth
|
const extraTagWidth = (children[children.length - 1] as HTMLElement).offsetWidth
|
||||||
|
|||||||
@@ -399,7 +399,7 @@ const Menu: FC<{
|
|||||||
className="rb:overflow-y-auto rb:flex-1!"
|
className="rb:overflow-y-auto rb:flex-1!"
|
||||||
/>
|
/>
|
||||||
{/* Return to space button for superusers */}
|
{/* Return to space button for superusers */}
|
||||||
{user?.is_superuser && source === 'space' &&
|
{source === 'space' &&
|
||||||
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
||||||
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
||||||
<Flex
|
<Flex
|
||||||
@@ -412,16 +412,18 @@ const Menu: FC<{
|
|||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
||||||
{collapsed ? null : t('common.switchSpace')}
|
{collapsed ? null : t('common.switchSpace')}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex
|
{user?.is_superuser &&
|
||||||
gap={8}
|
<Flex
|
||||||
align="center"
|
gap={8}
|
||||||
justify="start"
|
align="center"
|
||||||
onClick={goToSpace}
|
justify="start"
|
||||||
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
onClick={goToSpace}
|
||||||
>
|
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
>
|
||||||
{collapsed ? null : t('common.returnToSpace')}
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
||||||
</Flex>
|
{collapsed ? null : t('common.returnToSpace')}
|
||||||
|
</Flex>
|
||||||
|
}
|
||||||
</Flex>
|
</Flex>
|
||||||
}
|
}
|
||||||
{source === 'manage' && subscription && !collapsed &&
|
{source === 'manage' && subscription && !collapsed &&
|
||||||
|
|||||||
@@ -1538,6 +1538,7 @@ export const en = {
|
|||||||
json_output: 'Support JSON formatted output',
|
json_output: 'Support JSON formatted output',
|
||||||
thinking_budget_tokens: 'thinking budget tokens',
|
thinking_budget_tokens: 'thinking budget tokens',
|
||||||
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
||||||
|
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
|
||||||
logSearchPlaceholder: 'Search log content',
|
logSearchPlaceholder: 'Search log content',
|
||||||
},
|
},
|
||||||
userMemory: {
|
userMemory: {
|
||||||
|
|||||||
@@ -868,6 +868,7 @@ export const zh = {
|
|||||||
json_output: '支持JSON格式化输出',
|
json_output: '支持JSON格式化输出',
|
||||||
thinking_budget_tokens: '深度思考预算Token数',
|
thinking_budget_tokens: '深度思考预算Token数',
|
||||||
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
||||||
|
thinking_budget_tokens_min_error: "不能小于 {{min}}",
|
||||||
logSearchPlaceholder: '搜索日志内容',
|
logSearchPlaceholder: '搜索日志内容',
|
||||||
},
|
},
|
||||||
table: {
|
table: {
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ const configFields = [
|
|||||||
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
const minThinkingBudgetTokens = 128;
|
||||||
|
const defaultThinkingBudgetTokens = 1000;
|
||||||
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
||||||
refresh,
|
refresh,
|
||||||
data,
|
data,
|
||||||
@@ -108,7 +110,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
const newValues: ModelConfig = {
|
const newValues: ModelConfig = {
|
||||||
capability: (option as Model).capability,
|
capability: (option as Model).capability,
|
||||||
deep_thinking: false,
|
deep_thinking: false,
|
||||||
thinking_budget_tokens: undefined,
|
thinking_budget_tokens: defaultThinkingBudgetTokens,
|
||||||
json_output: false,
|
json_output: false,
|
||||||
}
|
}
|
||||||
if (source === 'chat') {
|
if (source === 'chat') {
|
||||||
@@ -128,6 +130,12 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
form.setFieldsValue({ ...rest })
|
form.setFieldsValue({ ...rest })
|
||||||
}, [data?.default_model_config_id])
|
}, [data?.default_model_config_id])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
|
||||||
|
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
|
||||||
|
}
|
||||||
|
}, [values?.deep_thinking])
|
||||||
|
|
||||||
const handleReset = () => {
|
const handleReset = () => {
|
||||||
if (!id) return
|
if (!id) return
|
||||||
resetAppModelConfig(id).then((res) => {
|
resetAppModelConfig(id).then((res) => {
|
||||||
@@ -178,15 +186,20 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
name="thinking_budget_tokens"
|
name="thinking_budget_tokens"
|
||||||
label={t('application.thinking_budget_tokens')}
|
label={t('application.thinking_budget_tokens')}
|
||||||
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
||||||
extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
||||||
rules={[
|
rules={[
|
||||||
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
||||||
{
|
{
|
||||||
validator: (_, value) => {
|
validator: (_, value) => {
|
||||||
const maxTokens = values?.max_tokens
|
const maxTokens = values?.max_tokens
|
||||||
const deep_thinking = values?.deep_thinking;
|
const deep_thinking = values?.deep_thinking;
|
||||||
if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
|
if (deep_thinking && value !== undefined) {
|
||||||
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
if (value < minThinkingBudgetTokens) {
|
||||||
|
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
|
||||||
|
}
|
||||||
|
if (maxTokens !== undefined && value > maxTokens) {
|
||||||
|
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
@@ -195,7 +208,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
>
|
>
|
||||||
<RbSlider
|
<RbSlider
|
||||||
step={1}
|
step={1}
|
||||||
min={0}
|
min={minThinkingBudgetTokens}
|
||||||
max={32000}
|
max={32000}
|
||||||
isInput={true}
|
isInput={true}
|
||||||
disabled={!values?.deep_thinking}
|
disabled={!values?.deep_thinking}
|
||||||
|
|||||||
@@ -166,10 +166,10 @@ const Ontology: FC = () => {
|
|||||||
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
|
<div className="rb:h-10 rb:wrap-break-word rb:line-clamp-2 rb:leading-5">{item.scene_description}</div>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
|
|
||||||
<div className="rb:mt-2">
|
<div className="rb:mt-2 rb:h-5.5">
|
||||||
<OverflowTags
|
<OverflowTags
|
||||||
popoverProps={false}
|
popoverProps={false}
|
||||||
items={[...item.entity_type?.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>]}
|
items={item.entity_type ? [...item.entity_type.map((type, i) => <Tag key={i} variant="borderless" color="dark">{type}</Tag>), <Tag variant="borderless" color="dark">{`+${item.type_num - 3}`}</Tag>] : []}
|
||||||
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
|
numTag={(num?: number) => <Tag variant="borderless" color="dark">{`+${item.type_num - 3 + (num ? num - 1 : 0)}`}</Tag>}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
const formatSchema = (value: string) => {
|
const formatSchema = (value: string) => {
|
||||||
|
if (!value || value.trim() === '') return
|
||||||
setParseSchemaData({} as ParseSchemaData)
|
setParseSchemaData({} as ParseSchemaData)
|
||||||
parseSchema({ schema_content: value })
|
parseSchema({ schema_content: value })
|
||||||
.then(res => {
|
.then(res => {
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
labelRender={(props) => {
|
labelRender={(props) => {
|
||||||
console.log('props', props)
|
|
||||||
return `${props.value}%`
|
return `${props.value}%`
|
||||||
}}
|
}}
|
||||||
className="rb:w-20 rb:h-4!"
|
className="rb:w-20 rb:h-4!"
|
||||||
|
|||||||
@@ -66,8 +66,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
const [fileList, setFileList] = useState<any[]>([])
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
const [message, setMessage] = useState<string | undefined>(undefined)
|
const [message, setMessage] = useState<string | undefined>(undefined)
|
||||||
|
|
||||||
console.log('abortRef', abortRef, chatList)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Opens the chat drawer and loads workflow variables from the start node
|
* Opens the chat drawer and loads workflow variables from the start node
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
|
|
||||||
// Handle node selection from popover and create new node replacing the add-node placeholder
|
// Handle node selection from popover and create new node replacing the add-node placeholder
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
|
graph.startBatch('add-node');
|
||||||
const parentBBox = node.getBBox();
|
const parentBBox = node.getBBox();
|
||||||
const cycleId = data.cycle;
|
const cycleId = data.cycle;
|
||||||
const horizontalSpacing = 0;
|
const horizontalSpacing = 0;
|
||||||
@@ -43,7 +44,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
if (cycleId) {
|
if (cycleId) {
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
parentNode.addChild(newNode);
|
parentNode.addChild(newNode, { silent: true });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,55 +77,40 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
setTimeout(() => {
|
|
||||||
addedEdges.forEach(e => {
|
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
|
||||||
if (src?.isNode()) src.toFront();
|
|
||||||
if (tgt?.isNode()) tgt.toFront();
|
|
||||||
});
|
|
||||||
}, 50);
|
|
||||||
|
|
||||||
// Automatically adjust loop node size
|
// Automatically adjust loop node size
|
||||||
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
if (loopNode) {
|
if (loopNode) {
|
||||||
const adjustLoopSize = () => {
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
|
||||||
if (childNodes.length > 0) {
|
|
||||||
const bounds = childNodes.reduce((acc, child) => {
|
|
||||||
const bbox = child.getBBox();
|
|
||||||
return {
|
|
||||||
minX: Math.min(acc.minX, bbox.x),
|
|
||||||
minY: Math.min(acc.minY, bbox.y),
|
|
||||||
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
|
||||||
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
|
||||||
};
|
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
|
||||||
|
|
||||||
const padding = 50;
|
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
|
||||||
|
|
||||||
loopNode.prop('size', { width: newWidth, height: newHeight });
|
|
||||||
|
|
||||||
// Update right port x position
|
|
||||||
const ports = loopNode.getPorts();
|
|
||||||
ports.forEach(port => {
|
|
||||||
if (port.group === 'right' && port.args) {
|
|
||||||
loopNode.portProp(port.id!, 'args/x', newWidth);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
adjustLoopSize();
|
|
||||||
|
|
||||||
// Listen to child node movement events
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
||||||
childNodes.forEach((childNode: any) => {
|
if (childNodes.length > 0) {
|
||||||
childNode.on('change:position', adjustLoopSize);
|
const bounds = childNodes.reduce((acc, child) => {
|
||||||
});
|
const bbox = child.getBBox();
|
||||||
|
return {
|
||||||
|
minX: Math.min(acc.minX, bbox.x),
|
||||||
|
minY: Math.min(acc.minY, bbox.y),
|
||||||
|
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
||||||
|
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
||||||
|
};
|
||||||
|
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
||||||
|
const padding = 50;
|
||||||
|
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
||||||
|
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
||||||
|
loopNode.prop('size', { width: newWidth, height: newHeight });
|
||||||
|
loopNode.getPorts().forEach(port => {
|
||||||
|
if (port.group === 'right' && port.args) {
|
||||||
|
loopNode.portProp(port.id!, 'args/x', newWidth);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addedEdges.forEach(e => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.isNode()) src.toFront();
|
||||||
|
if (tgt?.isNode()) tgt.toFront();
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.stopBatch('add-node');
|
||||||
setOpen(false);
|
setOpen(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
{data.type === 'if-else' &&
|
{data.type === 'if-else' &&
|
||||||
<Flex vertical gap={4} className="rb:mt-3!">
|
<Flex vertical gap={4} className="rb:mt-3!">
|
||||||
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
||||||
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
|
<div key={index}>
|
||||||
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
||||||
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
||||||
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
||||||
|
|||||||
@@ -1,134 +1,15 @@
|
|||||||
import { useEffect } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next'
|
|
||||||
import clsx from 'clsx';
|
import clsx from 'clsx';
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import { Flex } from 'antd';
|
import { Flex } from 'antd';
|
||||||
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import { graphNodeLibrary, edgeAttrs } from '../../constant';
|
|
||||||
import NodeTools from './NodeTools'
|
import NodeTools from './NodeTools'
|
||||||
|
|
||||||
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
const LoopNode: ReactShapeConfig['component'] = ({ node }) => {
|
||||||
const data = node.getData() || {};
|
const data = node.getData() || {};
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
// 使用setTimeout确保在所有节点都添加完成后再创建连线
|
|
||||||
const timer = setTimeout(() => {
|
|
||||||
initNodes()
|
|
||||||
checkAndAddAddNode()
|
|
||||||
}, 50)
|
|
||||||
|
|
||||||
return () => clearTimeout(timer)
|
|
||||||
}, [graph])
|
|
||||||
|
|
||||||
const checkAndAddAddNode = () => {
|
|
||||||
if (!graph) return;
|
|
||||||
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
|
|
||||||
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
|
|
||||||
|
|
||||||
// 如果只有一个cycle-start节点且没有其他类型的子节点,则添加add-node
|
|
||||||
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
|
|
||||||
const cycleStartNode = cycleStartNodes[0];
|
|
||||||
const cycleStartBBox = cycleStartNode.getBBox();
|
|
||||||
|
|
||||||
const addNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: cycleStartBBox.x + 84,
|
|
||||||
y: cycleStartBBox.y + 4,
|
|
||||||
data: {
|
|
||||||
type: 'add-node',
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
parentId: node.id,
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
node.addChild(addNode);
|
|
||||||
|
|
||||||
// 连接cycle-start和add-node
|
|
||||||
const sourcePorts = cycleStartNode.getPorts();
|
|
||||||
const targetPorts = addNode.getPorts();
|
|
||||||
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
|
||||||
|
|
||||||
// 然后创建连线
|
|
||||||
graph.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: sourcePort },
|
|
||||||
target: { cell: addNode.id, port: targetPort },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
|
|
||||||
cycleStartNode.toFront()
|
|
||||||
addNode.toFront()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const initNodes = () => {
|
|
||||||
// 检查是否存在cycle为当前节点ID的子节点,若存在则不调用initNodes,避免重复创建
|
|
||||||
const existingCycleNodes = graph.getNodes().filter((n: any) =>
|
|
||||||
n.getData()?.cycle === data.id
|
|
||||||
);
|
|
||||||
if (existingCycleNodes.length > 0) return;
|
|
||||||
// 添加默认子节点
|
|
||||||
const parentBBox = node.getBBox();
|
|
||||||
const centerX = parentBBox.x + 24;
|
|
||||||
const centerY = parentBBox.y + 70;
|
|
||||||
|
|
||||||
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
|
||||||
const cycleStartNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.cycleStart,
|
|
||||||
x: centerX,
|
|
||||||
y: centerY,
|
|
||||||
id: cycleStartNodeId,
|
|
||||||
data: {
|
|
||||||
id: cycleStartNodeId,
|
|
||||||
type: 'cycle-start',
|
|
||||||
parentId: node.id,
|
|
||||||
isDefault: true, // 标记为默认节点,不可删除
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
const addNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: centerX + 84,
|
|
||||||
y: centerY + 4,
|
|
||||||
data: {
|
|
||||||
type: 'add-node',
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
parentId: node.id,
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
node.addChild(cycleStartNode)
|
|
||||||
node.addChild(addNode)
|
|
||||||
const sourcePorts = cycleStartNode.getPorts()
|
|
||||||
const targetPorts = addNode.getPorts()
|
|
||||||
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
|
|
||||||
const edgeConfig = {
|
|
||||||
source: {
|
|
||||||
cell: cycleStartNode.id,
|
|
||||||
port: sourcePort
|
|
||||||
},
|
|
||||||
target: {
|
|
||||||
cell: addNode.id,
|
|
||||||
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
|
|
||||||
},
|
|
||||||
...edgeAttrs
|
|
||||||
}
|
|
||||||
graph.addEdge(edgeConfig)
|
|
||||||
|
|
||||||
setTimeout(() => {
|
|
||||||
|
|
||||||
cycleStartNode.toFront()
|
|
||||||
addNode.toFront()
|
|
||||||
}, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
||||||
|
|||||||
@@ -43,70 +43,52 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Handle node selection from popover menu and create new node with edge connection
|
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
if (!sourceNode || !graph) return;
|
if (!sourceNode || !graph) return;
|
||||||
|
|
||||||
const sourceNodeData = sourceNode.getData();
|
const sourceNodeData = sourceNode.getData();
|
||||||
const sourceNodeType = sourceNodeData?.type;
|
const sourceNodeType = sourceNodeData?.type;
|
||||||
|
const isCycleSubNode = !!sourceNodeData.cycle;
|
||||||
// If it's a cycle-start node, handle the add-node placeholder
|
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
||||||
|
const newNodeType = selectedNodeType.type;
|
||||||
|
|
||||||
|
// Save add-node placeholder position before disabling history
|
||||||
let addNodePosition = null;
|
let addNodePosition = null;
|
||||||
const isCycleSubNode = sourceNodeData.cycle
|
|
||||||
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
||||||
const cycleId = sourceNodeData.cycle;
|
const cycleId = sourceNodeData.cycle;
|
||||||
const addNodes = graph.getNodes().filter((n: any) =>
|
const addNodes = graph.getNodes().filter((n: any) =>
|
||||||
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
||||||
);
|
);
|
||||||
|
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
|
||||||
if (addNodes.length > 0) {
|
|
||||||
const addNode = addNodes[0];
|
|
||||||
addNodePosition = addNode.getBBox();
|
|
||||||
addNode.remove();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate new node position to avoid overlapping
|
// Calculate position
|
||||||
const sourceBBox = sourceNode.getBBox();
|
const sourceBBox = sourceNode.getBBox();
|
||||||
const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
|
const nw = graphNodeLibrary[newNodeType]?.width || 120;
|
||||||
const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
|
const nh = graphNodeLibrary[newNodeType]?.height || 88;
|
||||||
const horizontalSpacing = isCycleSubNode ? 48 : 80;
|
const hSpacing = isCycleSubNode ? 48 : 80;
|
||||||
const verticalSpacing = 10;
|
const vSpacing = 10;
|
||||||
|
|
||||||
// Get source port group information
|
|
||||||
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
||||||
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
||||||
|
|
||||||
// Calculate new node position
|
let newX: number, newY: number;
|
||||||
let newX, newY;
|
|
||||||
if (edgeInsertion) {
|
if (edgeInsertion) {
|
||||||
// Edge insertion: place new node on the same row as target, between source and target
|
|
||||||
const targetBBox = edgeInsertion.targetCell.getBBox();
|
const targetBBox = edgeInsertion.targetCell.getBBox();
|
||||||
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
||||||
const requiredSpace = nodeWidth + horizontalSpacing * 4;
|
const requiredSpace = nw + hSpacing * 4;
|
||||||
|
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
||||||
// New node x: right after source + spacing
|
newY = targetBBox.y + (targetBBox.height - nh) / 2;
|
||||||
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
|
||||||
// Same row as target node
|
|
||||||
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
|
|
||||||
|
|
||||||
// If not enough space, shift target and all downstream nodes to the right
|
|
||||||
if (gap < requiredSpace) {
|
if (gap < requiredSpace) {
|
||||||
const shiftX = requiredSpace - gap;
|
const shiftX = requiredSpace - gap;
|
||||||
const visited = new Set<string>();
|
const visited = new Set<string>();
|
||||||
const shiftDownstream = (cell: any) => {
|
const shiftDownstream = (cell: any) => {
|
||||||
const cellId = cell.id;
|
if (visited.has(cell.id)) return;
|
||||||
if (visited.has(cellId)) return;
|
visited.add(cell.id);
|
||||||
visited.add(cellId);
|
|
||||||
const pos = cell.getPosition();
|
const pos = cell.getPosition();
|
||||||
cell.setPosition(pos.x + shiftX, pos.y);
|
cell.setPosition(pos.x + shiftX, pos.y);
|
||||||
// Recursively shift nodes connected from right ports
|
|
||||||
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
||||||
const tId = e.getTargetCellId();
|
const tCell = graph.getCellById(e.getTargetCellId());
|
||||||
if (tId && !visited.has(tId)) {
|
if (tCell?.isNode()) shiftDownstream(tCell);
|
||||||
const tCell = graph.getCellById(tId);
|
|
||||||
if (tCell?.isNode()) shiftDownstream(tCell);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
shiftDownstream(edgeInsertion.targetCell);
|
shiftDownstream(edgeInsertion.targetCell);
|
||||||
@@ -114,208 +96,170 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
} else if (addNodePosition) {
|
} else if (addNodePosition) {
|
||||||
newX = addNodePosition.x;
|
newX = addNodePosition.x;
|
||||||
newY = addNodePosition.y;
|
newY = addNodePosition.y;
|
||||||
|
} else if (sourcePortGroup === 'left') {
|
||||||
|
newX = sourceBBox.x - nw * 2 - hSpacing;
|
||||||
|
newY = sourceBBox.y;
|
||||||
} else {
|
} else {
|
||||||
// Determine node placement direction based on port position
|
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
||||||
if (sourcePortGroup === 'left') {
|
newY = sourceBBox.y;
|
||||||
// Left port: add node to the left
|
const connectedNodes = new Set<string>();
|
||||||
newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
|
graph.getConnectedEdges(sourceNode).forEach((e: any) => {
|
||||||
newY = sourceBBox.y;
|
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => {
|
||||||
} else {
|
if (cid !== sourceNode.id) connectedNodes.add(cid);
|
||||||
// Right port: add node to the right
|
|
||||||
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
|
||||||
newY = sourceBBox.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if position overlaps with existing nodes (only consider connected nodes)
|
|
||||||
const checkOverlap = (x: number, y: number) => {
|
|
||||||
// Get nodes connected to the source node
|
|
||||||
const connectedNodes = new Set();
|
|
||||||
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
|
|
||||||
const sourceId = edge.getSourceCellId();
|
|
||||||
const targetId = edge.getTargetCellId();
|
|
||||||
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
|
|
||||||
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
|
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return graph.getNodes().some((node: any) => {
|
const checkOverlap = (x: number, y: number) =>
|
||||||
if (node.id === sourceNode.id) return false;
|
graph.getNodes().some((n: any) => {
|
||||||
if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
|
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false;
|
||||||
const bbox = node.getBBox();
|
const b = n.getBBox();
|
||||||
return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
|
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height);
|
||||||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
|
|
||||||
});
|
});
|
||||||
};
|
while (checkOverlap(newX, newY)) newY += nh + vSpacing;
|
||||||
|
|
||||||
// If position is occupied, search downward for empty space
|
|
||||||
while (checkOverlap(newX, newY)) {
|
|
||||||
newY += nodeHeight + verticalSpacing;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new node
|
// Disable history for all graph mutations
|
||||||
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
graph.disableHistory();
|
||||||
|
|
||||||
|
// Remove add-node placeholder
|
||||||
|
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
||||||
|
const cycleId = sourceNodeData.cycle;
|
||||||
|
graph.getNodes()
|
||||||
|
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
|
||||||
|
.forEach((n: any) => n.remove());
|
||||||
|
}
|
||||||
|
|
||||||
|
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
const newNode = graph.addNode({
|
const newNode = graph.addNode({
|
||||||
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default),
|
||||||
x: newX,
|
x: newX,
|
||||||
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
||||||
id,
|
id,
|
||||||
data: {
|
data: {
|
||||||
id,
|
id,
|
||||||
type: selectedNodeType.type,
|
type: newNodeType,
|
||||||
icon: selectedNodeType.icon,
|
icon: selectedNodeType.icon,
|
||||||
name: t(`workflow.${selectedNodeType.type}`),
|
name: t(`workflow.${newNodeType}`),
|
||||||
cycle: sourceNodeData.cycle, // Inherit cycle from source node
|
cycle: sourceNodeData.cycle,
|
||||||
config: selectedNodeType.config || {}
|
config: selectedNodeType.config || {}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add new node as child of parent node
|
|
||||||
if (sourceNodeData.cycle) {
|
if (sourceNodeData.cycle) {
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
||||||
if (parentNode) {
|
if (parentNode) parentNode.addChild(newNode, { silent: true });
|
||||||
parentNode.addChild(newNode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Edge insertion: remove old edge immediately before creating new edges
|
|
||||||
if (edgeInsertion) {
|
if (edgeInsertion) {
|
||||||
const { edge: oldEdge } = edgeInsertion;
|
const { edge: oldEdge } = edgeInsertion;
|
||||||
if (oldEdge.id && graph.getCellById(oldEdge.id)) {
|
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
|
||||||
graph.removeCell(oldEdge.id);
|
else graph.removeEdge(oldEdge);
|
||||||
} else {
|
|
||||||
graph.removeEdge(oldEdge);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create edge connection
|
const newPorts = newNode.getPorts();
|
||||||
setTimeout(() => {
|
const addedCells: any[] = [newNode];
|
||||||
const newPorts = newNode.getPorts();
|
|
||||||
|
|
||||||
const addedEdges: any[] = [];
|
if (edgeInsertion) {
|
||||||
if (edgeInsertion) {
|
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
||||||
// Edge insertion: create source→new and new→target edges
|
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
||||||
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
||||||
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
|
||||||
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
|
||||||
addedEdges.push(graph.addEdge({
|
setEdgeInsertion(null);
|
||||||
source: { cell: sourceNode.id, port: sourcePort },
|
} else if (sourcePortGroup === 'left') {
|
||||||
target: { cell: newNode.id, port: newLeftPort },
|
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
||||||
...edgeAttrs
|
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
|
||||||
}));
|
} else {
|
||||||
addedEdges.push(graph.addEdge({
|
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
||||||
source: { cell: newNode.id, port: newRightPort },
|
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
|
||||||
target: { cell: targetCell.id, port: origTargetPort },
|
}
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
setEdgeInsertion(null);
|
|
||||||
} else if (sourcePortGroup === 'left') {
|
|
||||||
// Connect from left port to new node's right side
|
|
||||||
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: newNode.id, port: targetPort },
|
|
||||||
target: { cell: sourceNode.id, port: sourcePort },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
// Connect from right port to new node's left side
|
|
||||||
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: sourceNode.id, port: sourcePort },
|
|
||||||
target: { cell: newNode.id, port: targetPort },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust loop node size when child node is added via port within loop node
|
|
||||||
const cycleId = sourceNodeData.cycle;
|
|
||||||
if (cycleId) {
|
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
|
||||||
|
|
||||||
if (parentNode) {
|
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
|
||||||
const adjustLoopSize = () => {
|
if (isCycleContainer(newNodeType)) {
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
const parentBBox = newNode.getBBox();
|
||||||
if (childNodes.length > 0) {
|
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
const bounds = childNodes.reduce((acc: any, child: any) => {
|
const cycleStartNode = graph.addNode({
|
||||||
const bbox = child.getBBox();
|
...graphNodeLibrary.cycleStart,
|
||||||
return {
|
x: parentBBox.x + 24,
|
||||||
minX: Math.min(acc.minX, bbox.x),
|
y: parentBBox.y + 70,
|
||||||
minY: Math.min(acc.minY, bbox.y),
|
id: cycleStartId,
|
||||||
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
|
||||||
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
});
|
||||||
};
|
const addNodePlaceholder = graph.addNode({
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
...graphNodeLibrary.addStart,
|
||||||
|
x: parentBBox.x + 24 + 84,
|
||||||
|
y: parentBBox.y + 70 + 4,
|
||||||
|
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
|
||||||
|
});
|
||||||
|
newNode.addChild(cycleStartNode, { silent: true });
|
||||||
|
newNode.addChild(addNodePlaceholder, { silent: true });
|
||||||
|
const innerEdge = graph.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
|
||||||
|
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
|
||||||
|
...edgeAttrs,
|
||||||
|
});
|
||||||
|
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
|
||||||
|
}
|
||||||
|
|
||||||
const padding = 50;
|
// Adjust parent size if adding inside a cycle container
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
const cycleId = sourceNodeData.cycle;
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
if (cycleId) {
|
||||||
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight });
|
if (parentNode) {
|
||||||
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
||||||
// Update right port x position
|
if (childNodes.length > 0) {
|
||||||
const ports = parentNode.getPorts();
|
const bounds = childNodes.reduce((acc: any, child: any) => {
|
||||||
ports.forEach((port: any) => {
|
const b = child.getBBox();
|
||||||
if (port.group === 'right' && port.args) {
|
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
|
||||||
parentNode.portProp(port.id!, 'args/x', newWidth);
|
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
||||||
}
|
const padding = 50;
|
||||||
});
|
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
||||||
}
|
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
||||||
};
|
parentNode.prop('size', { width: newWidth, height: newHeight });
|
||||||
|
parentNode.getPorts().forEach((port: any) => {
|
||||||
adjustLoopSize();
|
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
|
||||||
|
|
||||||
// Listen to child node movement events
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
|
||||||
childNodes.forEach((childNode: any) => {
|
|
||||||
childNode.on('change:position', adjustLoopSize);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
// toFront
|
||||||
const newNodeType = selectedNodeType.type;
|
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
||||||
|
graph.getEdges().forEach((e: any) => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
||||||
|
});
|
||||||
|
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); });
|
||||||
|
};
|
||||||
|
|
||||||
// Helper: bring all child nodes and their edges of a cycle container to front
|
if (isCycleContainer(sourceNodeType)) {
|
||||||
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
|
||||||
|
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
|
||||||
graph.getEdges().forEach((e: any) => {
|
} else if (isCycleContainer(newNodeType)) {
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
} else {
|
||||||
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
|
||||||
});
|
}
|
||||||
graph.getNodes().forEach((n: any) => {
|
|
||||||
if (n.getData()?.cycle === cycleContainerId) n.toFront();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if (isCycleContainer(sourceNodeType)) {
|
// Re-enable history and manually push one batch frame for all added cells
|
||||||
console.log('isCycleContainer(sourceNodeType)')
|
graph.enableHistory();
|
||||||
// Case 4: source is a loop/iteration node — bring new node to front, then its children
|
const history = graph.getPlugin('history') as any;
|
||||||
newNode.toFront();
|
if (history) {
|
||||||
sourceNode.toFront();
|
const batchFrame = addedCells.map((cell: any) => ({
|
||||||
bringCycleChildrenToFront(sourceNodeData.id);
|
batch: true,
|
||||||
} else if (isCycleContainer(newNodeType)) {
|
event: 'cell:added',
|
||||||
console.log('isCycleContainer(newNodeType)')
|
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
||||||
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
|
options: {},
|
||||||
newNode.toFront();
|
}));
|
||||||
sourceNode.toFront()
|
history.undoStack.push(batchFrame);
|
||||||
bringCycleChildrenToFront(id);
|
history.redoStack = [];
|
||||||
} else {
|
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
|
||||||
// Case 2: normal node → normal node
|
}
|
||||||
addedEdges.forEach(e => {
|
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
|
||||||
if (src?.isNode()) src.toFront();
|
|
||||||
if (tgt?.isNode()) tgt.toFront();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, 50);
|
|
||||||
|
|
||||||
// Clean up temporary element
|
|
||||||
if (tempElement) {
|
if (tempElement) {
|
||||||
document.body.removeChild(tempElement);
|
document.body.removeChild(tempElement);
|
||||||
setTempElement(null);
|
setTempElement(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
setPopoverVisible(false);
|
setPopoverVisible(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -391,4 +335,4 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default PortClickHandler;
|
export default PortClickHandler;
|
||||||
|
|||||||
@@ -242,10 +242,11 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
|
|||||||
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
||||||
>
|
>
|
||||||
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
||||||
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
||||||
: parameter.type === 'boolean'
|
: parameter.type === 'boolean'
|
||||||
? <Switch size="small" />
|
? <Switch key={values.tool_id} size="small" />
|
||||||
: <Editor
|
: <Editor
|
||||||
|
key={values.tool_id}
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
type="input"
|
type="input"
|
||||||
size="small"
|
size="small"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:06:18
|
* @Date: 2026-02-03 15:06:18
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-21 18:23:31
|
* @Last Modified time: 2026-04-27 14:07:14
|
||||||
*/
|
*/
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
@@ -948,6 +948,15 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
|||||||
width: nodeWidth,
|
width: nodeWidth,
|
||||||
height: 120,
|
height: 120,
|
||||||
shape: 'notes-node',
|
shape: 'notes-node',
|
||||||
|
},
|
||||||
|
output: {
|
||||||
|
width: nodeWidth,
|
||||||
|
height: 76,
|
||||||
|
shape: 'normal-node',
|
||||||
|
ports: {
|
||||||
|
groups: { left: defaultPortGroup },
|
||||||
|
items: [defaultPortItems[0]],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:17:48
|
* @Date: 2026-02-03 15:17:48
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 17:21:09
|
* @Last Modified time: 2026-04-28 13:49:11
|
||||||
*/
|
*/
|
||||||
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
||||||
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
|
|
||||||
import { register } from '@antv/x6-react-shape';
|
import { register } from '@antv/x6-react-shape';
|
||||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
import { App } from 'antd';
|
import { App } from 'antd';
|
||||||
@@ -17,7 +16,7 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
|
|||||||
import { useUser } from '@/store/user';
|
import { useUser } from '@/store/user';
|
||||||
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
||||||
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
||||||
import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
|
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types';
|
||||||
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
||||||
import { useWorkflowStore } from '@/store/workflow';
|
import { useWorkflowStore } from '@/store/workflow';
|
||||||
|
|
||||||
@@ -86,6 +85,10 @@ export interface UseWorkflowGraphReturn {
|
|||||||
/** Get start node output variable list (user-defined + system variables) */
|
/** Get start node output variable list (user-defined + system variables) */
|
||||||
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
||||||
nodeClick: ({ node }: { node: Node }) => void;
|
nodeClick: ({ node }: { node: Node }) => void;
|
||||||
|
/** All recorded history operations */
|
||||||
|
historyRecords: HistoryRecord[];
|
||||||
|
/** Clear history records */
|
||||||
|
clearHistoryRecords: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -119,14 +122,19 @@ export const useWorkflowGraph = ({
|
|||||||
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
||||||
const [canUndo, setCanUndo] = useState(false)
|
const [canUndo, setCanUndo] = useState(false)
|
||||||
const [canRedo, setCanRedo] = useState(false)
|
const [canRedo, setCanRedo] = useState(false)
|
||||||
|
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
|
||||||
|
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
|
||||||
|
const undoRef = useRef<() => void>(() => {})
|
||||||
|
const redoRef = useRef<() => void>(() => {})
|
||||||
|
const syncChildRelationshipsRef = useRef<() => void>(() => {})
|
||||||
|
const isSyncingRef = useRef(false)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!graphRef.current) return
|
if (!graphRef.current) return
|
||||||
graphRef.current.getNodes().forEach(node => {
|
graphRef.current.getNodes().forEach(node => {
|
||||||
const data = node.getData()
|
const data = node.getData()
|
||||||
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
|
if (data?.type === 'if-else' || data?.type === 'question-classifier') {
|
||||||
console.log('chatVariables', chatVariables)
|
console.log('chatVariables', chatVariables)
|
||||||
node.setData({ ...data, chatVariables }, { silent: true })
|
node.setData({ ...data, chatVariables })
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}, [chatVariables])
|
}, [chatVariables])
|
||||||
@@ -343,7 +351,7 @@ export const useWorkflowGraph = ({
|
|||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
const addedChild = graphRef.current?.addNode(childNode)
|
const addedChild = graphRef.current?.addNode(childNode)
|
||||||
if (addedChild) {
|
if (addedChild) {
|
||||||
parentNode.addChild(addedChild)
|
parentNode.addChild(addedChild, { silent: true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -374,8 +382,6 @@ export const useWorkflowGraph = ({
|
|||||||
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
||||||
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
||||||
|
|
||||||
console.log('newWidth', newHeight, newWidth)
|
|
||||||
|
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight })
|
parentNode.prop('size', { width: newWidth, height: newHeight })
|
||||||
|
|
||||||
// Update x position of right group ports
|
// Update x position of right group ports
|
||||||
@@ -488,8 +494,77 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current.cleanHistory()
|
graphRef.current.cleanHistory()
|
||||||
}
|
}
|
||||||
}, 200)
|
}, 200)
|
||||||
|
} else {
|
||||||
|
graphRef.current.enableHistory()
|
||||||
|
graphRef.current.cleanHistory()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const resizeGroupNodes = (graph: Graph) => {
|
||||||
|
graph.getNodes().forEach(parentNode => {
|
||||||
|
const parentType = parentNode.getData()?.type
|
||||||
|
if (parentType !== 'loop' && parentType !== 'iteration') return
|
||||||
|
const children = graph.getNodes().filter(
|
||||||
|
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
|
||||||
|
)
|
||||||
|
if (!children.length) return
|
||||||
|
const padding = 24
|
||||||
|
const headerHeight = 50
|
||||||
|
const childBounds = children.map(c => c.getBBox())
|
||||||
|
const minX = Math.min(...childBounds.map(b => b.x))
|
||||||
|
const minY = Math.min(...childBounds.map(b => b.y))
|
||||||
|
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
|
||||||
|
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
|
||||||
|
const parentBBox = parentNode.getBBox()
|
||||||
|
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
||||||
|
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
||||||
|
parentNode.prop('size', { width: newWidth, height: newHeight })
|
||||||
|
parentNode.getPorts().forEach(port => {
|
||||||
|
if (port.group === 'right' && port.args) {
|
||||||
|
parentNode.portProp(port.id!, 'args/x', newWidth)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const syncChildRelationships = () => {
|
||||||
|
if (!graphRef.current) return
|
||||||
|
const graph = graphRef.current
|
||||||
|
graph.disableHistory()
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
const cycleId = node.getData()?.cycle
|
||||||
|
if (!cycleId) return
|
||||||
|
const parentNode = graph.getCellById(cycleId) as Node | null
|
||||||
|
if (!parentNode) return
|
||||||
|
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
|
||||||
|
parentNode.addChild(node, { silent: true })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
const children = node.getChildren()
|
||||||
|
if (!children?.length) return
|
||||||
|
children.forEach(child => {
|
||||||
|
if (!child.isNode()) return
|
||||||
|
const childCycleId = (child as Node).getData?.()?.cycle
|
||||||
|
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
|
||||||
|
node.removeChild(child, { silent: true })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
resizeGroupNodes(graph)
|
||||||
|
graph.getEdges().forEach(edge => {
|
||||||
|
const src = graph.getCellById(edge.getSourceCellId())
|
||||||
|
const tgt = graph.getCellById(edge.getTargetCellId())
|
||||||
|
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
|
||||||
|
edge.toFront()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
if (node.getData()?.cycle) node.toFront()
|
||||||
|
})
|
||||||
|
graph.enableHistory()
|
||||||
|
}
|
||||||
|
syncChildRelationshipsRef.current = syncChildRelationships
|
||||||
/**
|
/**
|
||||||
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
||||||
*/
|
*/
|
||||||
@@ -525,18 +600,44 @@ export const useWorkflowGraph = ({
|
|||||||
new History({
|
new History({
|
||||||
enabled: false,
|
enabled: false,
|
||||||
beforeAddCommand(_event, args: any) {
|
beforeAddCommand(_event, args: any) {
|
||||||
const event = args?.key ? `cell:change:${args.key}` : _event;
|
const key = args?.key
|
||||||
if (event.startsWith('cell:change:') &&
|
if (key === 'attrs' || key === 'tools') return false
|
||||||
event !== 'cell:change:position' &&
|
|
||||||
event !== 'cell:change:source' &&
|
|
||||||
event !== 'cell:change:target') return false;
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
|
const MERGE_INTERVAL = 1000
|
||||||
|
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
|
||||||
setCanUndo(graphRef.current?.canUndo() ?? false)
|
setCanUndo(graphRef.current?.canUndo() ?? false)
|
||||||
setCanRedo(graphRef.current?.canRedo() ?? false)
|
setCanRedo(graphRef.current?.canRedo() ?? false)
|
||||||
|
console.log('history:change', cmds, options)
|
||||||
|
const batchName: string | undefined = options?.name
|
||||||
|
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
|
||||||
|
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
|
||||||
|
const now = Date.now()
|
||||||
|
const last = lastHistoryRef.current
|
||||||
|
const canMerge =
|
||||||
|
actionType === 'change' &&
|
||||||
|
last?.type === 'change' &&
|
||||||
|
now - last.timestamp < MERGE_INTERVAL &&
|
||||||
|
cellIds.length > 0 &&
|
||||||
|
cellIds.length === last.cellIds.length &&
|
||||||
|
cellIds.every((id, i) => id === last.cellIds[i])
|
||||||
|
if (canMerge) {
|
||||||
|
lastHistoryRef.current!.timestamp = now
|
||||||
|
setHistoryRecords(prev => {
|
||||||
|
const next = [...prev]
|
||||||
|
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
|
||||||
|
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
|
||||||
|
setHistoryRecords(prev => [...prev, record])
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
||||||
|
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
||||||
};
|
};
|
||||||
// 显示/隐藏连接桩
|
// 显示/隐藏连接桩
|
||||||
// const showPorts = (show: boolean) => {
|
// const showPorts = (show: boolean) => {
|
||||||
@@ -569,13 +670,13 @@ export const useWorkflowGraph = ({
|
|||||||
vo.setData({
|
vo.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
});
|
}, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
node.setData({
|
node.setData({
|
||||||
...nodeData,
|
...nodeData,
|
||||||
isSelected: true,
|
isSelected: true,
|
||||||
});
|
}, { silent: true });
|
||||||
clearEdgeSelect()
|
clearEdgeSelect()
|
||||||
if (nodeData.type !== 'notes') {
|
if (nodeData.type !== 'notes') {
|
||||||
setSelectedNode(node);
|
setSelectedNode(node);
|
||||||
@@ -589,7 +690,7 @@ export const useWorkflowGraph = ({
|
|||||||
const edgeClick = ({ edge }: { edge: Edge }) => {
|
const edgeClick = ({ edge }: { edge: Edge }) => {
|
||||||
clearEdgeSelect();
|
clearEdgeSelect();
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isSelected: true });
|
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true });
|
||||||
clearNodeSelect();
|
clearNodeSelect();
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
@@ -604,7 +705,7 @@ export const useWorkflowGraph = ({
|
|||||||
node.setData({
|
node.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
});
|
}, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
setSelectedNode(null);
|
setSelectedNode(null);
|
||||||
@@ -614,7 +715,7 @@ export const useWorkflowGraph = ({
|
|||||||
*/
|
*/
|
||||||
const clearEdgeSelect = () => {
|
const clearEdgeSelect = () => {
|
||||||
graphRef.current?.getEdges().forEach(e => {
|
graphRef.current?.getEdges().forEach(e => {
|
||||||
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
|
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true });
|
||||||
e.setAttrByPath('line/stroke', edge_color);
|
e.setAttrByPath('line/stroke', edge_color);
|
||||||
e.setAttrByPath('line/strokeWidth', edge_width);
|
e.setAttrByPath('line/strokeWidth', edge_width);
|
||||||
});
|
});
|
||||||
@@ -753,8 +854,6 @@ export const useWorkflowGraph = ({
|
|||||||
// Find corresponding parent node
|
// Find corresponding parent node
|
||||||
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
// Use removeChild method to delete child node
|
|
||||||
parentNode.removeChild(nodeToDelete);
|
|
||||||
parentNodesToUpdate.push(parentNode);
|
parentNodesToUpdate.push(parentNode);
|
||||||
}
|
}
|
||||||
// Add child node to deletion list
|
// Add child node to deletion list
|
||||||
@@ -782,42 +881,51 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
// Delete all collected nodes and edges
|
// Delete all collected nodes and edges
|
||||||
if (cells.length > 0) {
|
if (cells.length > 0) {
|
||||||
|
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
|
||||||
|
const parentsNeedingAddNode = parentNodesToUpdate
|
||||||
|
.filter(parentNode => {
|
||||||
|
const parentShape = parentNode.shape;
|
||||||
|
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
|
||||||
|
const parentData = parentNode.getData();
|
||||||
|
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
|
||||||
|
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
|
||||||
|
// After deletion, only cycle-start will remain
|
||||||
|
const nonCycleStartToDelete = cells.filter(c =>
|
||||||
|
c.isNode() &&
|
||||||
|
(c as Node).getData()?.cycle === parentData.id &&
|
||||||
|
(c as Node).getData()?.type !== 'cycle-start'
|
||||||
|
);
|
||||||
|
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
|
||||||
|
})
|
||||||
|
.map(parentNode => ({
|
||||||
|
parentNode,
|
||||||
|
cycleStartNode: graphRef.current!.getNodes().find(
|
||||||
|
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
|
||||||
|
)!
|
||||||
|
}))
|
||||||
|
.filter(({ cycleStartNode }) => !!cycleStartNode);
|
||||||
|
|
||||||
|
graphRef.current?.startBatch('delete');
|
||||||
graphRef.current?.removeCells(cells);
|
graphRef.current?.removeCells(cells);
|
||||||
|
|
||||||
// If parent is iteration/loop and only cycle-start remains, add add-node connected to it
|
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => {
|
||||||
parentNodesToUpdate.forEach(parentNode => {
|
|
||||||
const parentShape = parentNode.shape;
|
|
||||||
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
|
|
||||||
const parentData = parentNode.getData();
|
const parentData = parentNode.getData();
|
||||||
const remainingChildren = graphRef.current!.getNodes().filter(
|
const bbox = cycleStartNode.getBBox();
|
||||||
n => n.getData()?.cycle === parentData.id
|
const addNode = graphRef.current!.addNode({
|
||||||
);
|
...graphNodeLibrary.addStart,
|
||||||
const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
|
x: bbox.x + 84,
|
||||||
if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
|
y: bbox.y + 4,
|
||||||
const cycleStartNode = cycleStartNodes[0];
|
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' },
|
||||||
const bbox = cycleStartNode.getBBox();
|
});
|
||||||
const addNode = graphRef.current!.addNode({
|
parentNode.addChild(addNode, { silent: true });
|
||||||
...graphNodeLibrary.addStart,
|
graphRef.current!.addEdge({
|
||||||
x: bbox.x + 84,
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
||||||
y: bbox.y + 4,
|
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
||||||
data: {
|
...edgeAttrs,
|
||||||
type: 'add-node',
|
});
|
||||||
parentId: parentNode.id,
|
|
||||||
cycle: parentData.id,
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
parentNode.addChild(addNode);
|
|
||||||
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
|
|
||||||
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
|
|
||||||
graphRef.current!.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: sourcePort },
|
|
||||||
target: { cell: addNode.id, port: targetPort },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
graphRef.current?.stopBatch('delete');
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@@ -1036,7 +1144,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: true });
|
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1044,7 +1152,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_color);
|
edge.setAttrByPath('line/stroke', edge_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: false });
|
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1126,8 +1234,8 @@ export const useWorkflowGraph = ({
|
|||||||
// Delete selected nodes and edges
|
// Delete selected nodes and edges
|
||||||
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
||||||
// Undo / Redo
|
// Undo / Redo
|
||||||
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
|
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; });
|
||||||
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
|
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; });
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1193,13 +1301,51 @@ export const useWorkflowGraph = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
||||||
graphRef.current.addNode({
|
graph.disableHistory()
|
||||||
|
const parentNode = graphRef.current.addNode({
|
||||||
...graphNodeLibrary[dragData.type],
|
...graphNodeLibrary[dragData.type],
|
||||||
x: point.x - 150,
|
x: point.x - 150,
|
||||||
y: point.y - 100,
|
y: point.y - 100,
|
||||||
id: cleanNodeData.id,
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData, isGroup: true },
|
data: { ...cleanNodeData, isGroup: true },
|
||||||
});
|
})
|
||||||
|
const parentBBox = parentNode.getBBox()
|
||||||
|
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
const cycleStartNode = graphRef.current.addNode({
|
||||||
|
...graphNodeLibrary.cycleStart,
|
||||||
|
x: parentBBox.x + 24,
|
||||||
|
y: parentBBox.y + 70,
|
||||||
|
id: cycleStartId,
|
||||||
|
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
|
||||||
|
})
|
||||||
|
const addNode = graphRef.current.addNode({
|
||||||
|
...graphNodeLibrary.addStart,
|
||||||
|
x: parentBBox.x + 24 + 84,
|
||||||
|
y: parentBBox.y + 70 + 4,
|
||||||
|
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
|
||||||
|
})
|
||||||
|
parentNode.addChild(cycleStartNode, { silent: true })
|
||||||
|
parentNode.addChild(addNode, { silent: true })
|
||||||
|
const newEdge = graphRef.current.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
||||||
|
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
||||||
|
...edgeAttrs,
|
||||||
|
})
|
||||||
|
cycleStartNode.toFront()
|
||||||
|
addNode.toFront()
|
||||||
|
graph.enableHistory()
|
||||||
|
// Manually push a single batch frame covering all 4 cells into undoStack
|
||||||
|
const history = graph.getPlugin('history') as History
|
||||||
|
const makeBatchCmd = (cell: any) => ({
|
||||||
|
batch: true,
|
||||||
|
event: 'cell:added',
|
||||||
|
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
||||||
|
options: {},
|
||||||
|
})
|
||||||
|
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
|
||||||
|
;(history as any).undoStack.push(batchFrame)
|
||||||
|
;(history as any).redoStack = []
|
||||||
|
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
|
||||||
} else if (dragData.type === 'if-else') {
|
} else if (dragData.type === 'if-else') {
|
||||||
// Create condition node
|
// Create condition node
|
||||||
graphRef.current.addNode({
|
graphRef.current.addNode({
|
||||||
@@ -1446,8 +1592,80 @@ export const useWorkflowGraph = ({
|
|||||||
return userVars
|
return userVars
|
||||||
}
|
}
|
||||||
|
|
||||||
const undo = () => graphRef.current?.undo()
|
const clearHistoryRecords = () => {
|
||||||
const redo = () => graphRef.current?.redo()
|
setHistoryRecords([])
|
||||||
|
lastHistoryRef.current = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const getStackCellIds = (cmds: any): string[] => {
|
||||||
|
const arr = Array.isArray(cmds) ? cmds : [cmds]
|
||||||
|
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
|
||||||
|
}
|
||||||
|
|
||||||
|
const isSkippableFrame = (frame: any): boolean => {
|
||||||
|
const arr = Array.isArray(frame) ? frame : [frame]
|
||||||
|
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
|
||||||
|
}
|
||||||
|
|
||||||
|
const undo = () => {
|
||||||
|
const history = graphRef.current?.getPlugin('history') as History | undefined
|
||||||
|
if (!history || history.getUndoSize() === 0) return
|
||||||
|
const undoStack = (history as any).undoStack as any[]
|
||||||
|
isSyncingRef.current = true
|
||||||
|
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
}
|
||||||
|
if (undoStack.length === 0) {
|
||||||
|
isSyncingRef.current = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
|
||||||
|
graphRef.current!.undo()
|
||||||
|
while (undoStack.length > 0) {
|
||||||
|
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
|
||||||
|
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isSyncingRef.current = false
|
||||||
|
syncChildRelationships()
|
||||||
|
}
|
||||||
|
|
||||||
|
const redo = () => {
|
||||||
|
const history = graphRef.current?.getPlugin('history') as History | undefined
|
||||||
|
if (!history || history.getRedoSize() === 0) return
|
||||||
|
const redoStack = (history as any).redoStack as any[]
|
||||||
|
isSyncingRef.current = true
|
||||||
|
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
}
|
||||||
|
if (redoStack.length === 0) {
|
||||||
|
isSyncingRef.current = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
|
||||||
|
graphRef.current!.redo()
|
||||||
|
while (redoStack.length > 0) {
|
||||||
|
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
|
||||||
|
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isSyncingRef.current = false
|
||||||
|
syncChildRelationships()
|
||||||
|
}
|
||||||
|
|
||||||
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
||||||
const { statement = '' } = value?.opening_statement || {}
|
const { statement = '' } = value?.opening_statement || {}
|
||||||
@@ -1488,20 +1706,16 @@ export const useWorkflowGraph = ({
|
|||||||
if (!graphRef.current) return;
|
if (!graphRef.current) return;
|
||||||
const nodes = graphRef.current.getNodes();
|
const nodes = graphRef.current.getNodes();
|
||||||
|
|
||||||
const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
|
// Reset all node execution status on every chatHistory change
|
||||||
// Reset all node execution status first
|
|
||||||
nodes.forEach(node => {
|
nodes.forEach(node => {
|
||||||
const data = node.getData();
|
const data = node.getData();
|
||||||
if (typeof data.executionStatus === 'string') {
|
node.setData({ ...data, executionStatus: '' });
|
||||||
node.setData({ ...data, executionStatus: undefined });
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
if (!lastWithSub?.subContent) return;
|
|
||||||
// Build a nodeId -> status map first
|
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant');
|
||||||
const statusMap: Record<string, string> = {};
|
if (!lastAssistant?.subContent?.length) return;
|
||||||
lastWithSub.subContent.forEach(sub => {
|
lastAssistant.subContent.forEach(sub => {
|
||||||
if (typeof sub.status === 'string') {
|
if (typeof sub.status === 'string') {
|
||||||
statusMap[sub.node_id] = sub.status;
|
|
||||||
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
||||||
if (node) {
|
if (node) {
|
||||||
node.setData({ ...node.getData(), executionStatus: sub.status });
|
node.setData({ ...node.getData(), executionStatus: sub.status });
|
||||||
@@ -1537,5 +1751,7 @@ export const useWorkflowGraph = ({
|
|||||||
canRedo,
|
canRedo,
|
||||||
undo,
|
undo,
|
||||||
redo,
|
redo,
|
||||||
|
historyRecords,
|
||||||
|
clearHistoryRecords,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -113,4 +113,13 @@ export interface ChatVariable {
|
|||||||
}
|
}
|
||||||
export interface AddChatVariableRef {
|
export interface AddChatVariableRef {
|
||||||
handleOpen: (value?: ChatVariable) => void;
|
handleOpen: (value?: ChatVariable) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
|
||||||
|
|
||||||
|
export interface HistoryRecord {
|
||||||
|
type: HistoryActionType;
|
||||||
|
timestamp: number;
|
||||||
|
batchName?: string;
|
||||||
|
cellIds?: string[];
|
||||||
}
|
}
|
||||||
@@ -17,6 +17,7 @@ export const isSubExprSet = (sub: any) => {
|
|||||||
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
||||||
*/
|
*/
|
||||||
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
||||||
|
if (!cases?.length) return conditionNodeHeight;
|
||||||
const casesHeight = cases.reduce((acc: number, c: any) => {
|
const casesHeight = cases.reduce((acc: number, c: any) => {
|
||||||
const exprs = c?.expressions ?? [];
|
const exprs = c?.expressions ?? [];
|
||||||
const n = exprs.length;
|
const n = exprs.length;
|
||||||
|
|||||||
Reference in New Issue
Block a user