Compare commits
31 Commits
feature/ra
...
fix/Timebo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f10296969 | ||
|
|
89228825cf | ||
|
|
cab4deb2ff | ||
|
|
4048a10858 | ||
|
|
d6ef0f4923 | ||
|
|
75fbe44839 | ||
|
|
06597c567b | ||
|
|
28694fefb0 | ||
|
|
7a0f08148e | ||
|
|
d3058ce379 | ||
|
|
8d88df391d | ||
|
|
7621321d1b | ||
|
|
0e29b0b2a5 | ||
|
|
2fa4d29548 | ||
|
|
7bb181c1c7 | ||
|
|
a9c87b03ff | ||
|
|
720af8d261 | ||
|
|
09d32ed446 | ||
|
|
9a5ce7f7c6 | ||
|
|
531d785629 | ||
|
|
6d80d74f4a | ||
|
|
3d9882643e | ||
|
|
b4e4be1133 | ||
|
|
16926d9db5 | ||
|
|
f369a63c8d | ||
|
|
1861b0fbc9 | ||
|
|
750d4ca841 | ||
|
|
8baa466b31 | ||
|
|
dd7f9f6cee | ||
|
|
d5d81f0c4f | ||
|
|
610ae27cf9 |
7
.github/workflows/sync-to-gitee.yml
vendored
7
.github/workflows/sync-to-gitee.yml
vendored
@@ -3,12 +3,9 @@ name: Sync to Gitee
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- main # Production
|
- '**' # All branchs
|
||||||
- develop # Integration
|
|
||||||
- 'release/*' # Release preparation
|
|
||||||
- 'hotfix/*' # Urgent fixes
|
|
||||||
tags:
|
tags:
|
||||||
- '*' # All version tags (v1.0.0, etc.)
|
- '**' # All version tags (v1.0.0, etc.)
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
sync:
|
sync:
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import csv
|
|
||||||
import io
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -25,7 +23,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import chunk_schema
|
from app.schemas import chunk_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service, generate_kb_file_key
|
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -85,32 +82,19 @@ async def get_preview_chunks(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Get file content from storage backend
|
# 5. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
if not db_file.file_key:
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Check if the file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File has no storage key (legacy data not migrated)"
|
detail="File not found (possibly deleted)"
|
||||||
)
|
|
||||||
|
|
||||||
from app.services.file_storage_service import FileStorageService
|
|
||||||
import asyncio
|
|
||||||
storage_service = FileStorageService()
|
|
||||||
|
|
||||||
async def _download():
|
|
||||||
return await storage_service.download_file(db_file.file_key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_binary = asyncio.run(_download())
|
|
||||||
except RuntimeError:
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
file_binary = loop.run_until_complete(_download())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"File not found in storage: {e}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7. Document parsing & segmentation
|
# 7. Document parsing & segmentation
|
||||||
@@ -120,12 +104,11 @@ async def get_preview_chunks(
|
|||||||
vision_model = QWenCV(
|
vision_model = QWenCV(
|
||||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||||
lang="Chinese",
|
lang="Chinese", # Default to Chinese
|
||||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||||
)
|
)
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
res = chunk(filename=db_file.file_name,
|
res = chunk(filename=file_path,
|
||||||
binary=file_binary,
|
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=5,
|
to_page=5,
|
||||||
callback=progress_callback,
|
callback=progress_callback,
|
||||||
@@ -274,9 +257,6 @@ async def create_chunk(
|
|||||||
"sort_id": sort_id,
|
"sort_id": sort_id,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
}
|
}
|
||||||
# QA chunk: 注入 chunk_type/question/answer 到 metadata
|
|
||||||
if create_data.is_qa:
|
|
||||||
metadata.update(create_data.qa_metadata)
|
|
||||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||||
# 3. Segmented vector storage
|
# 3. Segmented vector storage
|
||||||
vector_service.add_chunks([chunk])
|
vector_service.add_chunks([chunk])
|
||||||
@@ -288,187 +268,6 @@ async def create_chunk(
|
|||||||
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
return success(data=jsonable_encoder(chunk), msg="Document chunk creation successful")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
|
||||||
async def create_chunks_batch(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
batch_data: chunk_schema.ChunkBatchCreate,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Batch create chunks (max 8)
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Batch create chunks: kb_id={kb_id}, document_id={document_id}, count={len(batch_data.items)}, username: {current_user.username}")
|
|
||||||
|
|
||||||
if len(batch_data.items) > settings.MAX_CHUNK_BATCH_SIZE:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Batch size exceeds limit: max {settings.MAX_CHUNK_BATCH_SIZE}, got {len(batch_data.items)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The knowledge base does not exist or access is denied")
|
|
||||||
|
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
||||||
if not db_document:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The document does not exist or you do not have permission to access it")
|
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
|
|
||||||
# Get current max sort_id
|
|
||||||
sort_id = 0
|
|
||||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
|
||||||
if items:
|
|
||||||
sort_id = items[0].metadata["sort_id"]
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for create_data in batch_data.items:
|
|
||||||
sort_id += 1
|
|
||||||
doc_id = uuid.uuid4().hex
|
|
||||||
metadata = {
|
|
||||||
"doc_id": doc_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": str(document_id),
|
|
||||||
"knowledge_id": str(kb_id),
|
|
||||||
"sort_id": sort_id,
|
|
||||||
"status": 1,
|
|
||||||
}
|
|
||||||
if create_data.is_qa:
|
|
||||||
metadata.update(create_data.qa_metadata)
|
|
||||||
chunks.append(DocumentChunk(page_content=create_data.chunk_content, metadata=metadata))
|
|
||||||
|
|
||||||
vector_service.add_chunks(chunks)
|
|
||||||
|
|
||||||
db_document.chunk_num += len(chunks)
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return success(data=jsonable_encoder(chunks), msg=f"Batch created {len(chunks)} chunks successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/import_qa", response_model=ApiResponse)
|
|
||||||
async def import_qa_new_doc(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user),
|
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
导入 QA 问答对并新建文档(CSV/Excel),异步处理
|
|
||||||
"""
|
|
||||||
from app.schemas import file_schema, document_schema
|
|
||||||
|
|
||||||
api_logger.info(f"Import QA (new doc): kb_id={kb_id}, file={file.filename}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. 校验文件格式
|
|
||||||
filename = file.filename or ""
|
|
||||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
|
||||||
|
|
||||||
# 2. 校验知识库
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
|
||||||
|
|
||||||
# 3. 读取文件
|
|
||||||
contents = await file.read()
|
|
||||||
file_size = len(contents)
|
|
||||||
if file_size == 0:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="文件为空")
|
|
||||||
|
|
||||||
_, file_extension = os.path.splitext(filename)
|
|
||||||
file_ext = file_extension.lower()
|
|
||||||
|
|
||||||
# 4. 创建 File 记录
|
|
||||||
file_data = file_schema.FileCreate(
|
|
||||||
kb_id=kb_id, created_by=current_user.id,
|
|
||||||
parent_id=uuid.UUID("00000000-0000-0000-0000-000000000000"),
|
|
||||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
|
||||||
)
|
|
||||||
db_file = file_service.create_file(db=db, file=file_data, current_user=current_user)
|
|
||||||
|
|
||||||
# 5. 上传文件到存储后端
|
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
|
||||||
try:
|
|
||||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"文件存储失败: {str(e)}")
|
|
||||||
|
|
||||||
db_file.file_key = file_key
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_file)
|
|
||||||
|
|
||||||
# 6. 创建 Document 记录(标记为 QA 类型)
|
|
||||||
doc_data = document_schema.DocumentCreate(
|
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
|
||||||
file_name=filename, file_ext=file_ext, file_size=file_size,
|
|
||||||
file_meta={}, parser_id="qa",
|
|
||||||
parser_config={"doc_type": "qa", "auto_questions": 0}
|
|
||||||
)
|
|
||||||
db_document = document_service.create_document(db=db, document=doc_data, current_user=current_user)
|
|
||||||
|
|
||||||
api_logger.info(f"Created doc for QA import: file_id={db_file.id}, document_id={db_document.id}, file_key={file_key}")
|
|
||||||
|
|
||||||
# 7. 派发异步任务
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
task = celery_app.send_task(
|
|
||||||
"app.core.rag.tasks.import_qa_chunks",
|
|
||||||
args=[str(kb_id), str(db_document.id), filename, contents],
|
|
||||||
queue="qa_import"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data={
|
|
||||||
"task_id": task.id,
|
|
||||||
"document_id": str(db_document.id),
|
|
||||||
"file_id": str(db_file.id),
|
|
||||||
}, msg="QA 导入任务已提交,后台处理中")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/import_qa", response_model=ApiResponse)
|
|
||||||
async def import_qa_chunks(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
file: UploadFile = File(..., description="CSV 或 Excel 文件(第一行标题跳过,第一列问题,第二列答案)"),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
current_user: User = Depends(get_current_user)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
导入 QA 问答对(CSV/Excel),异步处理
|
|
||||||
"""
|
|
||||||
api_logger.info(f"Import QA chunks: kb_id={kb_id}, document_id={document_id}, file={file.filename}, username: {current_user.username}")
|
|
||||||
|
|
||||||
# 1. 校验文件格式
|
|
||||||
filename = file.filename or ""
|
|
||||||
if not (filename.endswith(".csv") or filename.endswith(".xlsx") or filename.endswith(".xls")):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="仅支持 CSV (.csv) 或 Excel (.xlsx) 格式")
|
|
||||||
|
|
||||||
# 2. 校验知识库和文档
|
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if not db_knowledge:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="知识库不存在或无权访问")
|
|
||||||
|
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
|
||||||
if not db_document:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="文档不存在或无权访问")
|
|
||||||
|
|
||||||
# 3. 读取文件内容,派发异步任务
|
|
||||||
contents = await file.read()
|
|
||||||
|
|
||||||
from app.celery_app import celery_app
|
|
||||||
task = celery_app.send_task(
|
|
||||||
"app.core.rag.tasks.import_qa_chunks",
|
|
||||||
args=[str(kb_id), str(document_id), filename, contents],
|
|
||||||
queue="qa_import"
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data={"task_id": task.id}, msg="QA 导入任务已提交,后台处理中")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
@@ -529,9 +328,6 @@ async def update_chunk(
|
|||||||
if total:
|
if total:
|
||||||
chunk = items[0]
|
chunk = items[0]
|
||||||
chunk.page_content = content
|
chunk.page_content = content
|
||||||
# QA chunk: 更新 metadata 中的 question/answer
|
|
||||||
if update_data.is_qa:
|
|
||||||
chunk.metadata.update(update_data.qa_metadata)
|
|
||||||
vector_service.update_by_segment(chunk)
|
vector_service.update_by_segment(chunk)
|
||||||
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
return success(data=jsonable_encoder(chunk), msg="The document chunk has been successfully updated")
|
||||||
else:
|
else:
|
||||||
@@ -546,7 +342,6 @@ async def delete_chunk(
|
|||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
doc_id: str,
|
doc_id: str,
|
||||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
@@ -564,7 +359,7 @@ async def delete_chunk(
|
|||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
if vector_service.text_exists(doc_id):
|
if vector_service.text_exists(doc_id):
|
||||||
vector_service.delete_by_ids([doc_id], refresh=force_refresh)
|
vector_service.delete_by_ids([doc_id])
|
||||||
# 更新 chunk_num
|
# 更新 chunk_num
|
||||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||||
db_document.chunk_num -= 1
|
db_document.chunk_num -= 1
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from app.models.user_model import User
|
|||||||
from app.schemas import document_schema
|
from app.schemas import document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import document_service, file_service, knowledge_service
|
from app.services import document_service, file_service, knowledge_service
|
||||||
from app.services.file_storage_service import FileStorageService, get_file_storage_service
|
|
||||||
|
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
@@ -232,8 +231,7 @@ async def update_document(
|
|||||||
async def delete_document(
|
async def delete_document(
|
||||||
document_id: uuid.UUID,
|
document_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Delete document
|
Delete document
|
||||||
@@ -259,7 +257,7 @@ async def delete_document(
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 3. Delete file
|
# 3. Delete file
|
||||||
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
await file_controller._delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
|
|
||||||
# 4. Delete vector index
|
# 4. Delete vector index
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
@@ -307,25 +305,38 @@ async def parse_documents(
|
|||||||
detail="The file does not exist or you do not have permission to access it"
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Get file_key for storage backend
|
# 3. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
if not db_file.file_key:
|
file_path = os.path.join(
|
||||||
api_logger.error(f"File has no storage key (legacy data not migrated): file_id={db_file.id}")
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Check if the file exists
|
||||||
|
api_logger.debug(f"Constructed file path: {file_path}")
|
||||||
|
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File has no storage key (legacy data not migrated)"
|
detail="File not found (possibly deleted)"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Obtain knowledge base information
|
# 5. Obtain knowledge base information
|
||||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
api_logger.info( f"Obtain details of the knowledge base: knowledge_id={db_document.kb_id}")
|
||||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id, current_user=current_user)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Knowledge base not found")
|
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={db_document.kb_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The knowledge base does not exist or access is denied"
|
||||||
|
)
|
||||||
|
|
||||||
# 5. Dispatch parse task with file_key (not file_path)
|
# 6. Task: Document parsing, vectorization, and storage
|
||||||
task = celery_app.send_task(
|
# from app.tasks import parse_document
|
||||||
"app.core.rag.tasks.parse_document",
|
# parse_document(file_path, document_id)
|
||||||
args=[db_file.file_key, document_id, db_file.file_name]
|
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||||
)
|
|
||||||
result = {
|
result = {
|
||||||
"task_id": task.id
|
"task_id": task.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import FileResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -17,14 +19,10 @@ from app.models.user_model import User
|
|||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import file_schema, document_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import file_service, document_service
|
from app.services import file_service, document_service
|
||||||
from app.services.knowledge_service import get_knowledge_by_id as get_kb_by_id
|
|
||||||
from app.services.file_storage_service import (
|
|
||||||
FileStorageService,
|
|
||||||
generate_kb_file_key,
|
|
||||||
get_file_storage_service,
|
|
||||||
)
|
|
||||||
from app.core.quota_stub import check_knowledge_capacity_quota
|
from app.core.quota_stub import check_knowledge_capacity_quota
|
||||||
|
|
||||||
|
|
||||||
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -37,37 +35,67 @@ router = APIRouter(
|
|||||||
async def get_files(
|
async def get_files(
|
||||||
kb_id: uuid.UUID,
|
kb_id: uuid.UUID,
|
||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
page: int = Query(1, gt=0),
|
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||||
pagesize: int = Query(20, gt=0, le=100),
|
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at"),
|
||||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||||
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
keywords: Optional[str] = Query(None, description="Search keywords (file name)"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Paged query file list"""
|
"""
|
||||||
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}")
|
Paged query file list
|
||||||
|
- Support filtering by kb_id and parent_id
|
||||||
|
- Support keyword search for file names
|
||||||
|
- Support dynamic sorting
|
||||||
|
- Return paging metadata + file list
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Query file list: kb_id={kb_id}, parent_id={parent_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||||
|
# 1. parameter validation
|
||||||
if page < 1 or pagesize < 1:
|
if page < 1 or pagesize < 1:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The paging parameter must be greater than 0")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The paging parameter must be greater than 0"
|
||||||
|
)
|
||||||
|
|
||||||
filters = [file_model.File.kb_id == kb_id]
|
# 2. Construct query conditions
|
||||||
|
filters = [
|
||||||
|
file_model.File.kb_id == kb_id
|
||||||
|
]
|
||||||
if parent_id:
|
if parent_id:
|
||||||
filters.append(file_model.File.parent_id == parent_id)
|
filters.append(file_model.File.parent_id == parent_id)
|
||||||
|
# Keyword search (fuzzy matching of file name)
|
||||||
if keywords:
|
if keywords:
|
||||||
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
filters.append(file_model.File.file_name.ilike(f"%{keywords}%"))
|
||||||
|
|
||||||
|
# 3. Execute paged query
|
||||||
try:
|
try:
|
||||||
|
api_logger.debug("Start executing file paging query")
|
||||||
total, items = file_service.get_files_paginated(
|
total, items = file_service.get_files_paginated(
|
||||||
db=db, filters=filters, page=page, pagesize=pagesize,
|
db=db,
|
||||||
orderby=orderby, desc=desc, current_user=current_user
|
filters=filters,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
orderby=orderby,
|
||||||
|
desc=desc,
|
||||||
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
api_logger.info(f"File query successful: total={total}, returned={len(items)} records")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Query failed: {str(e)}")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Query failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return structured response
|
||||||
result = {
|
result = {
|
||||||
"items": items,
|
"items": items,
|
||||||
"page": {"page": page, "pagesize": pagesize, "total": total, "has_next": page * pagesize < total}
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"has_next": True if page * pagesize < total else False
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
return success(data=jsonable_encoder(result), msg="Query of file list succeeded")
|
||||||
|
|
||||||
@@ -80,14 +108,23 @@ async def create_folder(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""Create a new folder"""
|
"""
|
||||||
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}")
|
Create a new folder
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Create folder request: kb_id={kb_id}, parent_id={parent_id}, folder_name={folder_name}, username: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
create_folder_data = file_schema.FileCreate(
|
api_logger.debug(f"Start creating a folder: {folder_name}")
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
create_folder = file_schema.FileCreate(
|
||||||
file_name=folder_name, file_ext='folder', file_size=0,
|
kb_id=kb_id,
|
||||||
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=folder_name,
|
||||||
|
file_ext='folder',
|
||||||
|
file_size=0,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=create_folder_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=create_folder, current_user=current_user)
|
||||||
|
api_logger.info(f"Folder created successfully: {db_file.file_name} (ID: {db_file.id})")
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="Folder creation successful")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
api_logger.error(f"Folder creation failed: {folder_name} - {str(e)}")
|
||||||
@@ -101,58 +138,76 @@ async def upload_file(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Upload file to storage backend"""
|
"""
|
||||||
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}")
|
upload file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"upload file request: kb_id={kb_id}, parent_id={parent_id}, filename={file.filename}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Read the contents of the file
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
# Check file size
|
||||||
file_size = len(contents)
|
file_size = len(contents)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The file is empty.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The file is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"File size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract the extension using `os.path.splitext`
|
||||||
_, file_extension = os.path.splitext(file.filename)
|
_, file_extension = os.path.splitext(file.filename)
|
||||||
file_ext = file_extension.lower()
|
upload_file = file_schema.FileCreate(
|
||||||
|
kb_id=kb_id,
|
||||||
# Create File record
|
created_by=current_user.id,
|
||||||
upload_file_data = file_schema.FileCreate(
|
parent_id=parent_id,
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
file_name=file.filename,
|
||||||
file_name=file.filename, file_ext=file_ext, file_size=file_size,
|
file_ext=file_extension.lower(),
|
||||||
|
file_size=file_size,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
# Upload to storage backend
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=file_ext)
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
try:
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
await storage_service.storage.upload(file_key=file_key, content=contents, content_type=file.content_type)
|
save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}")
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
|
||||||
|
|
||||||
# Save file_key
|
# Save file
|
||||||
db_file.file_key = file_key
|
with open(save_path, "wb") as f:
|
||||||
db.commit()
|
f.write(contents)
|
||||||
db.refresh(db_file)
|
|
||||||
|
|
||||||
# Create document (inherit parser_config from knowledge base)
|
# Verify whether the file has been saved successfully
|
||||||
default_parser_config = {
|
if not os.path.exists(save_path):
|
||||||
"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
raise HTTPException(
|
||||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
}
|
detail="File save failed"
|
||||||
try:
|
)
|
||||||
db_knowledge = get_kb_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
|
||||||
if db_knowledge and db_knowledge.parser_config:
|
|
||||||
default_parser_config.update(dict(db_knowledge.parser_config))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
# Create a document
|
||||||
create_data = document_schema.DocumentCreate(
|
create_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
kb_id=kb_id,
|
||||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
created_by=current_user.id,
|
||||||
file_meta={}, parser_id="naive", parser_config=default_parser_config
|
file_id=db_file.id,
|
||||||
|
file_name=db_file.file_name,
|
||||||
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_data, current_user=current_user)
|
||||||
|
|
||||||
@@ -166,73 +221,123 @@ async def custom_text(
|
|||||||
parent_id: uuid.UUID,
|
parent_id: uuid.UUID,
|
||||||
create_data: file_schema.CustomTextFileCreate,
|
create_data: file_schema.CustomTextFileCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Custom text upload"""
|
"""
|
||||||
|
custom text
|
||||||
|
"""
|
||||||
|
api_logger.info(f"custom text upload request: kb_id={kb_id}, parent_id={parent_id}, title={create_data.title}, content={create_data.content}, username: {current_user.username}")
|
||||||
|
|
||||||
|
# Check file content size
|
||||||
|
# 将内容编码为字节(UTF-8)
|
||||||
content_bytes = create_data.content.encode('utf-8')
|
content_bytes = create_data.content.encode('utf-8')
|
||||||
file_size = len(content_bytes)
|
file_size = len(content_bytes)
|
||||||
|
print(f"file size: {file_size} byte")
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="The content is empty.")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="The content is empty."
|
||||||
|
)
|
||||||
|
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Content size exceeds {settings.MAX_FILE_SIZE} byte limit")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||||
|
)
|
||||||
|
|
||||||
upload_file_data = file_schema.FileCreate(
|
upload_file = file_schema.FileCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, parent_id=parent_id,
|
kb_id=kb_id,
|
||||||
file_name=f"{create_data.title}.txt", file_ext=".txt", file_size=file_size,
|
created_by=current_user.id,
|
||||||
|
parent_id=parent_id,
|
||||||
|
file_name=f"{create_data.title}.txt",
|
||||||
|
file_ext=".txt",
|
||||||
|
file_size=file_size,
|
||||||
)
|
)
|
||||||
db_file = file_service.create_file(db=db, file=upload_file_data, current_user=current_user)
|
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||||
|
|
||||||
# Upload to storage backend
|
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||||
file_key = generate_kb_file_key(kb_id=kb_id, file_id=db_file.id, file_ext=".txt")
|
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||||
try:
|
Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists
|
||||||
await storage_service.storage.upload(file_key=file_key, content=content_bytes, content_type="text/plain")
|
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"Storage upload failed: {e}")
|
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File storage failed: {str(e)}")
|
|
||||||
|
|
||||||
db_file.file_key = file_key
|
# Save file
|
||||||
db.commit()
|
with open(save_path, "wb") as f:
|
||||||
db.refresh(db_file)
|
f.write(content_bytes)
|
||||||
|
|
||||||
|
# Verify whether the file has been saved successfully
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail="File save failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a document
|
||||||
create_document_data = document_schema.DocumentCreate(
|
create_document_data = document_schema.DocumentCreate(
|
||||||
kb_id=kb_id, created_by=current_user.id, file_id=db_file.id,
|
kb_id=kb_id,
|
||||||
file_name=db_file.file_name, file_ext=db_file.file_ext, file_size=db_file.file_size,
|
created_by=current_user.id,
|
||||||
file_meta={}, parser_id="naive",
|
file_id=db_file.id,
|
||||||
parser_config={"layout_recognize": "DeepDOC", "chunk_token_num": 128, "delimiter": "\n",
|
file_name=db_file.file_name,
|
||||||
"auto_keywords": 0, "auto_questions": 0, "html4excel": "false"}
|
file_ext=db_file.file_ext,
|
||||||
|
file_size=db_file.file_size,
|
||||||
|
file_meta={},
|
||||||
|
parser_id="naive",
|
||||||
|
parser_config={
|
||||||
|
"layout_recognize": "DeepDOC",
|
||||||
|
"chunk_token_num": 128,
|
||||||
|
"delimiter": "\n",
|
||||||
|
"auto_keywords": 0,
|
||||||
|
"auto_questions": 0,
|
||||||
|
"html4excel": "false"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||||
|
|
||||||
|
api_logger.info(f"custom text upload successfully: {create_data.title} (file_id: {db_file.id}, document_id: {db_document.id})")
|
||||||
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
return success(data=jsonable_encoder(document_schema.Document.model_validate(db_document)), msg="custom text upload successful")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{file_id}", response_model=Any)
|
@router.get("/{file_id}", response_model=Any)
|
||||||
async def get_file(
|
async def get_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Download file by file_id"""
|
"""
|
||||||
|
Download the file based on the file_id
|
||||||
|
- Query file information from the database
|
||||||
|
- Construct the file path and check if it exists
|
||||||
|
- Return a FileResponse to download the file
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Download the file based on the file_id: file_id={file_id}")
|
||||||
|
|
||||||
|
# 1. Query file information from the database
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
if not db_file:
|
if not db_file:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
if not db_file.file_key:
|
# 2. Construct file path:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File has no storage key (legacy data not migrated)")
|
file_path = os.path.join(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# 3. Check if the file exists
|
||||||
content = await storage_service.download_file(db_file.file_key)
|
if not os.path.exists(file_path):
|
||||||
except Exception as e:
|
raise HTTPException(
|
||||||
api_logger.error(f"Storage download failed: {e}")
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found in storage")
|
detail="File not found (possibly deleted)"
|
||||||
|
)
|
||||||
|
|
||||||
import mimetypes
|
# 4.Return FileResponse (automatically handle download)
|
||||||
media_type = mimetypes.guess_type(db_file.file_name)[0] or "application/octet-stream"
|
return FileResponse(
|
||||||
return Response(
|
path=file_path,
|
||||||
content=content,
|
filename=db_file.file_name, # Use original file name
|
||||||
media_type=media_type,
|
media_type="application/octet-stream" # Universal binary stream type
|
||||||
headers={"Content-Disposition": f'attachment; filename="{db_file.file_name}"'}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -243,22 +348,50 @@ async def update_file(
|
|||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Update file information (such as file name)"""
|
"""
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
Update file information (such as file name)
|
||||||
if not db_file:
|
- Only specified fields such as file_name are allowed to be modified
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
"""
|
||||||
|
api_logger.debug(f"Query the file to be updated: {file_id}")
|
||||||
|
|
||||||
|
# 1. Check if the file exists
|
||||||
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
|
if not db_file:
|
||||||
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Update fields (only update non-null fields)
|
||||||
|
api_logger.debug(f"Start updating the file fields: {file_id}")
|
||||||
|
updated_fields = []
|
||||||
for field, value in update_data.dict(exclude_unset=True).items():
|
for field, value in update_data.dict(exclude_unset=True).items():
|
||||||
if hasattr(db_file, field):
|
if hasattr(db_file, field):
|
||||||
setattr(db_file, field, value)
|
old_value = getattr(db_file, field)
|
||||||
|
if old_value != value:
|
||||||
|
# update value
|
||||||
|
setattr(db_file, field, value)
|
||||||
|
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||||
|
|
||||||
|
if updated_fields:
|
||||||
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
|
# 3. Save to database
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_file)
|
db.refresh(db_file)
|
||||||
|
api_logger.info(f"The file has been successfully updated: {db_file.file_name} (ID: {db_file.id})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"File update failed: {str(e)}")
|
api_logger.error(f"File update failed: file_id={file_id} - {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"File update failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Return the updated file
|
||||||
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
return success(data=jsonable_encoder(file_schema.File.model_validate(db_file)), msg="File information updated successfully")
|
||||||
|
|
||||||
|
|
||||||
@@ -266,43 +399,60 @@ async def update_file(
|
|||||||
async def delete_file(
|
async def delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
|
||||||
):
|
):
|
||||||
"""Delete a file or folder"""
|
"""
|
||||||
api_logger.info(f"Request to delete file: file_id={file_id}")
|
Delete a file or folder
|
||||||
await _delete_file(db=db, file_id=file_id, current_user=current_user, storage_service=storage_service)
|
"""
|
||||||
|
api_logger.info(f"Request to delete file: file_id={file_id}, username: {current_user.username}")
|
||||||
|
await _delete_file(db=db, file_id=file_id, current_user=current_user)
|
||||||
return success(msg="File deleted successfully")
|
return success(msg="File deleted successfully")
|
||||||
|
|
||||||
|
|
||||||
async def _delete_file(
|
async def _delete_file(
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session,
|
db: Session = Depends(get_db),
|
||||||
current_user: User,
|
current_user: User = Depends(get_current_user)
|
||||||
storage_service: FileStorageService,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete a file or folder from storage and database"""
|
"""
|
||||||
|
Delete a file or folder
|
||||||
|
"""
|
||||||
|
# 1. Check if the file exists
|
||||||
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
db_file = file_service.get_file_by_id(db, file_id=file_id)
|
||||||
|
|
||||||
if not db_file:
|
if not db_file:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
|
api_logger.warning(f"The file does not exist or you do not have permission to access it: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist or you do not have permission to access it"
|
||||||
|
)
|
||||||
|
|
||||||
# Delete from storage backend
|
# 2. Construct physical path
|
||||||
|
file_path = Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.id)
|
||||||
|
) if db_file.file_ext == 'folder' else Path(
|
||||||
|
settings.FILE_PATH,
|
||||||
|
str(db_file.kb_id),
|
||||||
|
str(db_file.parent_id),
|
||||||
|
f"{db_file.id}{db_file.file_ext}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Delete physical files/folders
|
||||||
|
try:
|
||||||
|
if file_path.exists():
|
||||||
|
if db_file.file_ext == 'folder':
|
||||||
|
shutil.rmtree(file_path) # Recursively delete folders
|
||||||
|
else:
|
||||||
|
file_path.unlink() # Delete a single file
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete physical file/folder: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4.Delete db_file
|
||||||
if db_file.file_ext == 'folder':
|
if db_file.file_ext == 'folder':
|
||||||
# For folders, delete all child files from storage first
|
|
||||||
child_files = db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).all()
|
|
||||||
for child in child_files:
|
|
||||||
if child.file_key:
|
|
||||||
try:
|
|
||||||
await storage_service.delete_file(child.file_key)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Failed to delete child file from storage: {child.file_key} - {e}")
|
|
||||||
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
db.query(file_model.File).filter(file_model.File.parent_id == db_file.id).delete()
|
||||||
else:
|
|
||||||
if db_file.file_key:
|
|
||||||
try:
|
|
||||||
await storage_service.delete_file(db_file.file_key)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Failed to delete file from storage: {db_file.file_key} - {e}")
|
|
||||||
|
|
||||||
db.delete(db_file)
|
db.delete(db_file)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ async def chat(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 多 Agent 非流式返回
|
# workflow 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
|
|||||||
@@ -113,33 +113,6 @@ async def create_chunk(
|
|||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{kb_id}/{document_id}/chunk/batch", response_model=ApiResponse)
|
|
||||||
@require_api_key(scopes=["rag"])
|
|
||||||
async def create_chunks_batch(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
document_id: uuid.UUID,
|
|
||||||
request: Request,
|
|
||||||
api_key_auth: ApiKeyAuth = None,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
items: list = Body(..., description="chunk items list"),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Batch create chunks (max 8)
|
|
||||||
"""
|
|
||||||
body = await request.json()
|
|
||||||
batch_data = chunk_schema.ChunkBatchCreate(**body)
|
|
||||||
# 0. Obtain the creator of the api key
|
|
||||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
|
||||||
current_user = api_key.creator
|
|
||||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
|
||||||
|
|
||||||
return await chunk_controller.create_chunks_batch(kb_id=kb_id,
|
|
||||||
document_id=document_id,
|
|
||||||
batch_data=batch_data,
|
|
||||||
db=db,
|
|
||||||
current_user=current_user)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
@router.get("/{kb_id}/{document_id}/{doc_id}", response_model=ApiResponse)
|
||||||
@require_api_key(scopes=["rag"])
|
@require_api_key(scopes=["rag"])
|
||||||
async def get_chunk(
|
async def get_chunk(
|
||||||
@@ -203,7 +176,6 @@ async def delete_chunk(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
force_refresh: bool = Query(False, description="Force Elasticsearch refresh after deletion"),
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
delete document chunk
|
delete document chunk
|
||||||
@@ -216,7 +188,6 @@ async def delete_chunk(
|
|||||||
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
return await chunk_controller.delete_chunk(kb_id=kb_id,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
doc_id=doc_id,
|
doc_id=doc_id,
|
||||||
force_refresh=force_refresh,
|
|
||||||
db=db,
|
db=db,
|
||||||
current_user=current_user)
|
current_user=current_user)
|
||||||
|
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ def update_workspace_members(
|
|||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def delete_workspace_member(
|
async def delete_workspace_member(
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -230,7 +230,7 @@ def delete_workspace_member(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
workspace_service.delete_workspace_member(
|
await workspace_service.delete_workspace_member(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
member_id=member_id,
|
member_id=member_id,
|
||||||
|
|||||||
@@ -98,7 +98,6 @@ class Settings:
|
|||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
MAX_CHUNK_BATCH_SIZE: int = int(os.getenv("MAX_CHUNK_BATCH_SIZE", "8"))
|
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
|
|||||||
@@ -216,7 +216,7 @@ class RedBearModelFactory:
|
|||||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||||
if config.deep_thinking:
|
if config.deep_thinking:
|
||||||
budget = config.thinking_budget_tokens or 10000
|
budget = config.thinking_budget_tokens or 1024
|
||||||
params["additional_model_request_fields"] = {
|
params["additional_model_request_fields"] = {
|
||||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,10 +46,7 @@ async def run_graphrag(
|
|||||||
start = trio.current_time()
|
start = trio.current_time()
|
||||||
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
workspace_id, kb_id, document_id = row["workspace_id"], str(row["kb_id"]), row["document_id"]
|
||||||
chunks = []
|
chunks = []
|
||||||
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id", "chunk_type"], sort_by_position=True):
|
for d in settings.retriever.chunk_list(document_id, workspace_id, [kb_id], fields=["page_content", "document_id"], sort_by_position=True):
|
||||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
|
||||||
if d.get("chunk_type") == "qa":
|
|
||||||
continue
|
|
||||||
chunks.append(d["page_content"])
|
chunks.append(d["page_content"])
|
||||||
|
|
||||||
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
with trio.fail_after(max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000):
|
||||||
@@ -153,9 +150,6 @@ async def run_graphrag_for_kb(
|
|||||||
|
|
||||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
total, items = vector_service.search_by_segment(document_id=str(document_id), query=None, pagesize=9999, page=1, asc=True)
|
||||||
for doc in items:
|
for doc in items:
|
||||||
# 跳过 QA chunks,只用原文 chunks 构建图谱
|
|
||||||
if (doc.metadata or {}).get("chunk_type") == "qa":
|
|
||||||
continue
|
|
||||||
content = doc.page_content
|
content = doc.page_content
|
||||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||||
current_chunk += content
|
current_chunk += content
|
||||||
|
|||||||
@@ -131,52 +131,18 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||||||
|
|
||||||
|
|
||||||
def question_proposal(chat_mdl, content, topn=3):
|
def question_proposal(chat_mdl, content, topn=3):
|
||||||
"""生成问题(向后兼容,返回纯文本问题列表)"""
|
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||||
pairs = qa_proposal(chat_mdl, content, topn)
|
rendered_prompt = template.render(content=content, topn=topn)
|
||||||
if not pairs:
|
|
||||||
return ""
|
|
||||||
return "\n".join([p["question"] for p in pairs])
|
|
||||||
|
|
||||||
|
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||||
def qa_proposal(chat_mdl, content, topn=3, custom_prompt=None):
|
|
||||||
"""生成 QA 对,返回 [{"question": ..., "answer": ...}, ...]
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_mdl: LLM 模型
|
|
||||||
content: 文本内容
|
|
||||||
topn: 生成 QA 对数量
|
|
||||||
custom_prompt: 自定义 prompt 模板(支持 Jinja2,可用变量: content, topn)
|
|
||||||
"""
|
|
||||||
if custom_prompt:
|
|
||||||
template = PROMPT_JINJA_ENV.from_string(custom_prompt)
|
|
||||||
sys_prompt = template.render(topn=topn)
|
|
||||||
else:
|
|
||||||
sys_prompt = QUESTION_PROMPT_TEMPLATE
|
|
||||||
msg = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": content}]
|
|
||||||
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
_, msg = message_fit_in(msg, getattr(chat_mdl, 'max_length', 8096))
|
||||||
raw = chat_mdl.chat(sys_prompt, msg[1:], {"temperature": 0.2})
|
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(raw, tuple):
|
if isinstance(kwd, tuple):
|
||||||
raw = raw[0]
|
kwd = kwd[0]
|
||||||
raw = re.sub(r"^.*</think>", "", raw, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
if raw.find("**ERROR**") >= 0:
|
if kwd.find("**ERROR**") >= 0:
|
||||||
return []
|
return ""
|
||||||
return parse_qa_pairs(raw)
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
def parse_qa_pairs(text: str) -> list:
|
|
||||||
"""解析 LLM 返回的 QA 对文本,格式: Q: xxx A: xxx"""
|
|
||||||
pairs = []
|
|
||||||
for line in text.strip().split("\n"):
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
# 匹配 Q: ... A: ... 格式
|
|
||||||
match = re.match(r'^Q:\s*(.+?)\s+A:\s*(.+)$', line, re.IGNORECASE)
|
|
||||||
if match:
|
|
||||||
q, a = match.group(1).strip(), match.group(2).strip()
|
|
||||||
if q and a:
|
|
||||||
pairs.append({"question": q, "answer": a})
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
|
|
||||||
def graph_entity_types(chat_mdl, scenario):
|
def graph_entity_types(chat_mdl, scenario):
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
## Role
|
## Role
|
||||||
You are a text analyzer and knowledge extraction expert.
|
You are a text analyzer.
|
||||||
|
|
||||||
## Task
|
## Task
|
||||||
Generate question-answer pairs from the given text content.
|
Propose {{ topn }} questions about a given piece of text content.
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
- Understand and summarize the text content, then generate up to {{ topn }} important question-answer pairs.
|
- Understand and summarize the text content, and propose the top {{ topn }} important questions.
|
||||||
- Each question-answer pair MUST be on a single line, formatted as: Q: <question> A: <answer>
|
|
||||||
- The questions SHOULD NOT have overlapping meanings.
|
- The questions SHOULD NOT have overlapping meanings.
|
||||||
- The questions SHOULD cover the main content of the text as much as possible.
|
- The questions SHOULD cover the main content of the text as much as possible.
|
||||||
- The answers MUST be concise, accurate, and directly derived from the text content.
|
- The questions MUST be in the same language as the given piece of text content.
|
||||||
- The answers SHOULD be self-contained and understandable without additional context.
|
- One question per line.
|
||||||
- Both questions and answers MUST be in the same language as the given text content.
|
- Output questions ONLY.
|
||||||
- If the text is too short or lacks substantive content, generate fewer pairs rather than padding.
|
|
||||||
- Output question-answer pairs ONLY, no extra explanation or commentary.
|
---
|
||||||
|
|
||||||
|
## Text Content
|
||||||
|
{{ content }}
|
||||||
|
|
||||||
## Example Output
|
|
||||||
Q: What is the capital of France? A: The capital of France is Paris.
|
|
||||||
Q: When was the Eiffel Tower built? A: The Eiffel Tower was built in 1889.
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ Transcribe the content from the provided PDF page image into clean Markdown form
|
|||||||
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
6. Do NOT wrap the output in ```markdown or ``` blocks.
|
||||||
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
|
||||||
8. Preserve the original language, information, and order exactly as shown in the image.
|
8. Preserve the original language, information, and order exactly as shown in the image.
|
||||||
9. Your output language MUST match the language of the content in the image. If the image contains Chinese text, output in Chinese. If English, output in English. Never translate.
|
|
||||||
|
|
||||||
{% if page %}
|
{% if page %}
|
||||||
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
At the end of the transcription, add the page divider: `--- Page {{ page }} ---`.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from elasticsearch import Elasticsearch, helpers, NotFoundError
|
from elasticsearch import Elasticsearch, helpers
|
||||||
from elasticsearch.helpers import BulkIndexError
|
from elasticsearch.helpers import BulkIndexError
|
||||||
from packaging.version import parse as parse_version
|
from packaging.version import parse as parse_version
|
||||||
# langchain-community
|
# langchain-community
|
||||||
@@ -53,30 +53,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
return "elasticsearch"
|
return "elasticsearch"
|
||||||
|
|
||||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||||
# QA chunks: embedding 只对 question 字段做;source chunks: 不做 embedding
|
# 实现 Elasticsearch 保存向量
|
||||||
texts_for_embedding = []
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
for chunk in chunks:
|
|
||||||
chunk_type = (chunk.metadata or {}).get("chunk_type", "chunk")
|
|
||||||
if chunk_type == "source":
|
|
||||||
# source chunk 不需要向量索引
|
|
||||||
texts_for_embedding.append("")
|
|
||||||
elif chunk_type == "qa":
|
|
||||||
# QA chunk: 用 question 字段做 embedding
|
|
||||||
texts_for_embedding.append((chunk.metadata or {}).get("question", chunk.page_content))
|
|
||||||
else:
|
|
||||||
# 普通 chunk: 用 page_content 做 embedding
|
|
||||||
texts_for_embedding.append(chunk.page_content)
|
|
||||||
|
|
||||||
if self.is_multimodal_embedding:
|
if self.is_multimodal_embedding:
|
||||||
embeddings = self.embeddings.embed_batch(texts_for_embedding)
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = self.embeddings.embed_batch(texts)
|
||||||
else:
|
else:
|
||||||
embeddings = self.embeddings.embed_documents(texts_for_embedding)
|
embeddings = self.embeddings.embed_documents(list(texts))
|
||||||
|
|
||||||
# source chunk 的向量置空
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
if (chunk.metadata or {}).get("chunk_type") == "source":
|
|
||||||
embeddings[i] = None
|
|
||||||
|
|
||||||
self.create(chunks, embeddings, **kwargs)
|
self.create(chunks, embeddings, **kwargs)
|
||||||
|
|
||||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||||
@@ -89,25 +72,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
uuids = self._get_uuids(chunks)
|
uuids = self._get_uuids(chunks)
|
||||||
actions = []
|
actions = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
source = {
|
|
||||||
Field.CONTENT_KEY.value: chunk.page_content,
|
|
||||||
Field.METADATA_KEY.value: chunk.metadata or {},
|
|
||||||
Field.VECTOR.value: embeddings[i] or None
|
|
||||||
}
|
|
||||||
# 写入 QA 相关字段
|
|
||||||
meta = chunk.metadata or {}
|
|
||||||
if meta.get("chunk_type"):
|
|
||||||
source[Field.CHUNK_TYPE.value] = meta["chunk_type"]
|
|
||||||
if meta.get("question"):
|
|
||||||
source[Field.QUESTION.value] = meta["question"]
|
|
||||||
if meta.get("answer"):
|
|
||||||
source[Field.ANSWER.value] = meta["answer"]
|
|
||||||
if meta.get("source_chunk_id"):
|
|
||||||
source[Field.SOURCE_CHUNK_ID.value] = meta["source_chunk_id"]
|
|
||||||
|
|
||||||
action = {
|
action = {
|
||||||
"_index": self._collection_name,
|
"_index": self._collection_name,
|
||||||
"_source": source
|
"_source": {
|
||||||
|
Field.CONTENT_KEY.value: chunk.page_content,
|
||||||
|
Field.METADATA_KEY.value: chunk.metadata or {},
|
||||||
|
Field.VECTOR.value: embeddings[i] or None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
actions.append(action)
|
actions.append(action)
|
||||||
# using bulk mode
|
# using bulk mode
|
||||||
@@ -142,7 +113,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
def delete_by_ids(self, ids: list[str]):
|
||||||
if not ids:
|
if not ids:
|
||||||
return
|
return
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
@@ -163,8 +134,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
if refresh:
|
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -184,7 +153,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
if not self._client.indices.exists(index=self._collection_name):
|
if not self._client.indices.exists(index=self._collection_name):
|
||||||
return False
|
return False
|
||||||
actual_ids = self.get_ids_by_metadata_field(key, value)
|
actual_ids = self.get_ids_by_metadata_field(key, value)
|
||||||
@@ -193,8 +162,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
actions = [{"_op_type": "delete", "_index": self._collection_name, "_id": es_id} for es_id in actual_ids]
|
||||||
try:
|
try:
|
||||||
helpers.bulk(self._client, actions)
|
helpers.bulk(self._client, actions)
|
||||||
if refresh:
|
|
||||||
self._client.indices.refresh(index=self._collection_name)
|
|
||||||
except BulkIndexError as e:
|
except BulkIndexError as e:
|
||||||
for error in e.errors:
|
for error in e.errors:
|
||||||
delete_error = error.get('delete', {})
|
delete_error = error.get('delete', {})
|
||||||
@@ -225,8 +192,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multiple indexes are also supported, such as "index1, index2, index3"
|
||||||
if not self._client.indices.exists(index=indices):
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
# Calculate the start position for the current page
|
# Calculate the start position for the current page
|
||||||
from_ = pagesize * (page-1)
|
from_ = pagesize * (page-1)
|
||||||
@@ -261,15 +226,12 @@ class ElasticSearchVector(BaseVector):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
# For simplicity, we use from/size here which has a limit (usually up to 10,000).
|
||||||
try:
|
result = self._client.search(
|
||||||
result = self._client.search(
|
index=indices,
|
||||||
index=indices,
|
from_=from_, # Only use from_ for the first page (simplified)
|
||||||
from_=from_, # Only use from_ for the first page (simplified)
|
size=pagesize,
|
||||||
size=pagesize,
|
body=query_str,
|
||||||
body=query_str,
|
)
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -279,19 +241,10 @@ class ElasticSearchVector(BaseVector):
|
|||||||
for res in result["hits"]["hits"]:
|
for res in result["hits"]["hits"]:
|
||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
|
# vector = source.get(Field.VECTOR.value)
|
||||||
vector = None
|
vector = None
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
|
|
||||||
# 将 QA 字段注入 metadata 供前端展示
|
|
||||||
if chunk_type:
|
|
||||||
metadata["chunk_type"] = chunk_type
|
|
||||||
if chunk_type == "qa":
|
|
||||||
metadata["question"] = source.get(Field.QUESTION.value, "")
|
|
||||||
metadata["answer"] = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {metadata['question']}\nA: {metadata['answer']}"
|
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, vector=vector, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -314,18 +267,13 @@ class ElasticSearchVector(BaseVector):
|
|||||||
List of DocumentChunk objects that match the query.
|
List of DocumentChunk objects that match the query.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
if not self._client.indices.exists(index=indices):
|
|
||||||
return 0, []
|
|
||||||
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
query_str = {"query": {"term": {f"{Field.DOC_ID.value}": doc_id}}}
|
||||||
try:
|
result = self._client.search(
|
||||||
result = self._client.search(
|
index=indices,
|
||||||
index=indices,
|
from_=0, # Only use from_ for the first page (simplified)
|
||||||
from_=0, # Only use from_ for the first page (simplified)
|
size=1,
|
||||||
size=1,
|
body=query_str,
|
||||||
body=query_str,
|
)
|
||||||
)
|
|
||||||
except NotFoundError:
|
|
||||||
return 0, []
|
|
||||||
# print(result)
|
# print(result)
|
||||||
if "errors" in result:
|
if "errors" in result:
|
||||||
raise ValueError(f"Error during query: {result['errors']}")
|
raise ValueError(f"Error during query: {result['errors']}")
|
||||||
@@ -360,43 +308,27 @@ class ElasticSearchVector(BaseVector):
|
|||||||
Returns:
|
Returns:
|
||||||
updated count.
|
updated count.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name)
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
chunk_type = (chunk.metadata or {}).get("chunk_type")
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
# QA chunk: embedding 基于 question;source chunk: 不更新向量
|
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||||
if chunk_type == "source":
|
|
||||||
embed_text = ""
|
|
||||||
elif chunk_type == "qa":
|
|
||||||
embed_text = (chunk.metadata or {}).get("question", chunk.page_content)
|
|
||||||
else:
|
else:
|
||||||
embed_text = chunk.page_content
|
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||||
|
|
||||||
if chunk_type != "source":
|
|
||||||
if self.is_multimodal_embedding:
|
|
||||||
chunk.vector = self.embeddings.embed_text(embed_text)
|
|
||||||
else:
|
|
||||||
chunk.vector = self.embeddings.embed_query(embed_text)
|
|
||||||
|
|
||||||
script_source = "ctx._source.page_content = params.new_content; ctx._source.vector = params.new_vector;"
|
|
||||||
params = {
|
|
||||||
"new_content": chunk.page_content,
|
|
||||||
"new_vector": chunk.vector if chunk_type != "source" else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# QA chunk: 同时更新 question/answer 字段
|
|
||||||
if chunk_type == "qa":
|
|
||||||
script_source += " ctx._source.question = params.new_question; ctx._source.answer = params.new_answer;"
|
|
||||||
params["new_question"] = (chunk.metadata or {}).get("question", "")
|
|
||||||
params["new_answer"] = (chunk.metadata or {}).get("answer", "")
|
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"script": {
|
"script": {
|
||||||
"source": script_source,
|
"source": """
|
||||||
"params": params
|
ctx._source.page_content = params.new_content;
|
||||||
|
ctx._source.vector = params.new_vector;
|
||||||
|
""",
|
||||||
|
"params": {
|
||||||
|
"new_content": chunk.page_content,
|
||||||
|
"new_vector": chunk.vector
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"query": {
|
"query": {
|
||||||
"term": {
|
"term": {
|
||||||
Field.DOC_ID.value: chunk.metadata["doc_id"]
|
Field.DOC_ID.value: chunk.metadata["doc_id"] # exact match doc_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -404,6 +336,9 @@ class ElasticSearchVector(BaseVector):
|
|||||||
index=indices,
|
index=indices,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
|
# Remove debug printing and use logging instead
|
||||||
|
# print(result)
|
||||||
|
# print(f"Update successful, number of affected documents: {result['updated']}")
|
||||||
return result['updated']
|
return result['updated']
|
||||||
|
|
||||||
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
def change_status_by_document_id(self, document_id: str, status: int, **kwargs) -> str:
|
||||||
@@ -462,11 +397,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": { # Add the filter condition of status=1
|
||||||
{"term": {"metadata.status": 1}},
|
"term": {
|
||||||
# 排除 source chunk(仅供 GraphRAG 使用,不参与检索)
|
"metadata.status": 1
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
}
|
||||||
]
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# If file_names_filter is passed in, merge the filtering conditions
|
# If file_names_filter is passed in, merge the filtering conditions
|
||||||
@@ -480,14 +415,22 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
"script": {
|
"script": {
|
||||||
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
"source": f"cosineSimilarity(params.query_vector, '{Field.VECTOR.value}') + 1.0",
|
||||||
|
# The script_score query calculates the cosine similarity between the embedding field of each document and the query vector. The addition of +1.0 is to ensure that the scores returned by the script are non-negative, as the range of cosine similarity is [-1, 1]
|
||||||
"params": {"query_vector": query_vector}
|
"params": {"query_vector": query_vector}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{"term": {"metadata.status": 1}},
|
{
|
||||||
{"terms": {"metadata.file_name": file_names_filter}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"terms": {
|
||||||
|
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -508,19 +451,8 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
score = res["_score"]
|
score = res["_score"]
|
||||||
score = score / 2 # Normalized [0-1]
|
score = score / 2 # Normalized [0-1]
|
||||||
|
|
||||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
|
||||||
if chunk_type == "qa":
|
|
||||||
question = source.get(Field.QUESTION.value, "")
|
|
||||||
answer = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {question}\nA: {answer}"
|
|
||||||
metadata["chunk_type"] = "qa"
|
|
||||||
metadata["question"] = question
|
|
||||||
metadata["answer"] = answer
|
|
||||||
|
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), score))
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
@@ -559,10 +491,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": { # Add the filter condition of status=1
|
||||||
{"term": {"metadata.status": 1}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
]
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -579,9 +512,16 @@ class ElasticSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"filter": [
|
"filter": [
|
||||||
{"term": {"metadata.status": 1}},
|
{
|
||||||
{"terms": {"metadata.file_name": file_names_filter}},
|
"term": {
|
||||||
{"bool": {"must_not": {"term": {Field.CHUNK_TYPE.value: "source"}}}}
|
"metadata.status": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"terms": {
|
||||||
|
"metadata.file_name": file_names_filter # Additional file_name filtering
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -603,17 +543,6 @@ class ElasticSearchVector(BaseVector):
|
|||||||
source = res["_source"]
|
source = res["_source"]
|
||||||
page_content = source.get(Field.CONTENT_KEY.value)
|
page_content = source.get(Field.CONTENT_KEY.value)
|
||||||
metadata = source.get(Field.METADATA_KEY.value, {})
|
metadata = source.get(Field.METADATA_KEY.value, {})
|
||||||
chunk_type = source.get(Field.CHUNK_TYPE.value)
|
|
||||||
|
|
||||||
# QA chunk: 返回 Q+A 拼接作为上下文
|
|
||||||
if chunk_type == "qa":
|
|
||||||
question = source.get(Field.QUESTION.value, "")
|
|
||||||
answer = source.get(Field.ANSWER.value, "")
|
|
||||||
page_content = f"Q: {question}\nA: {answer}"
|
|
||||||
metadata["chunk_type"] = "qa"
|
|
||||||
metadata["question"] = question
|
|
||||||
metadata["answer"] = answer
|
|
||||||
|
|
||||||
# Normalize the score to the [0,1] interval
|
# Normalize the score to the [0,1] interval
|
||||||
normalized_score = res["_score"] / max_score
|
normalized_score = res["_score"] / max_score
|
||||||
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
docs_and_scores.append((DocumentChunk(page_content=page_content, metadata=metadata), normalized_score))
|
||||||
@@ -723,7 +652,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
},
|
},
|
||||||
Field.VECTOR.value: {
|
Field.VECTOR.value: {
|
||||||
"type": "dense_vector",
|
"type": "dense_vector",
|
||||||
"dims": len(next((e for e in embeddings if e is not None), [0]*768)), # 跳过 None 获取向量维度,fallback 768
|
"dims": len(embeddings[0]), # Make sure the dimension is correct here,The dimension size of the vector. When index is true, it cannot exceed 1024; when index is false or not specified, it cannot exceed 2048, which can improve retrieval efficiency
|
||||||
"index": True,
|
"index": True,
|
||||||
"similarity": "cosine"
|
"similarity": "cosine"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,3 @@ class Field(StrEnum):
|
|||||||
DOCUMENT_ID = "metadata.document_id"
|
DOCUMENT_ID = "metadata.document_id"
|
||||||
KNOWLEDGE_ID = "metadata.knowledge_id"
|
KNOWLEDGE_ID = "metadata.knowledge_id"
|
||||||
SORT_ID = "metadata.sort_id"
|
SORT_ID = "metadata.sort_id"
|
||||||
# QA fields
|
|
||||||
CHUNK_TYPE = "chunk_type" # "chunk" | "source" | "qa"
|
|
||||||
QUESTION = "question"
|
|
||||||
ANSWER = "answer"
|
|
||||||
SOURCE_CHUNK_ID = "source_chunk_id"
|
|
||||||
|
|||||||
@@ -27,14 +27,14 @@ class BaseVector(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_ids(self, ids: list[str], *, refresh: bool = False):
|
def delete_by_ids(self, ids: list[str]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete_by_metadata_field(self, key: str, value: str, *, refresh: bool = False):
|
def delete_by_metadata_field(self, key: str, value: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class DocExtractorNode(BaseNode):
|
|||||||
mime_type=f"image/{ext}",
|
mime_type=f"image/{ext}",
|
||||||
is_file=True,
|
is_file=True,
|
||||||
).model_dump())
|
).model_dump())
|
||||||
text = text + f"\n{placeholder}: {url}"
|
text = text + f"\n{placeholder}: <img src=\"{url}\" data-url=\"{url}\">"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
logger.error(f"Node {self.node_id}: failed to save image {placeholder}: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -15,5 +15,4 @@ class File(Base):
|
|||||||
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
file_ext = Column(String, index=True, nullable=False, comment="file extension:folder|pdf")
|
||||||
file_size = Column(Integer, default=0, comment="file size(byte)")
|
file_size = Column(Integer, default=0, comment="file size(byte)")
|
||||||
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
file_url = Column(String, index=True, nullable=True, comment="file comes from a website url")
|
||||||
file_key = Column(String(512), nullable=True, index=True, comment="storage file key for FileStorageService")
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||||
@@ -250,7 +250,7 @@ class ModelParameters(BaseModel):
|
|||||||
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
n: int = Field(default=1, ge=1, le=10, description="生成的回复数量")
|
||||||
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
stop: Optional[List[str]] = Field(default=None, description="停止序列")
|
||||||
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
deep_thinking: bool = Field(default=False, description="是否启用深度思考模式(需模型支持,如 DeepSeek-R1、QwQ 等)")
|
||||||
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1024, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
thinking_budget_tokens: Optional[int] = Field(default=None, ge=1, le=131072, description="深度思考 token 预算(仅部分模型支持)")
|
||||||
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
json_output: bool = Field(default=False, description="是否强制 JSON 格式输出(需模型支持 json_output 能力)")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -20,26 +20,13 @@ class ChunkCreate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""Get the actual content string regardless of input type"""
|
"""
|
||||||
|
Get the actual content string regardless of input type
|
||||||
|
"""
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return self.content.question # QA 模式下 page_content 存 question
|
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
@property
|
|
||||||
def is_qa(self) -> bool:
|
|
||||||
return isinstance(self.content, QAChunk)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def qa_metadata(self) -> dict:
|
|
||||||
"""返回 QA 相关的 metadata 字段"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
|
||||||
return {
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": self.content.question,
|
|
||||||
"answer": self.content.answer,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkUpdate(BaseModel):
|
class ChunkUpdate(BaseModel):
|
||||||
content: Union[str, QAChunk] = Field(
|
content: Union[str, QAChunk] = Field(
|
||||||
@@ -48,26 +35,13 @@ class ChunkUpdate(BaseModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_content(self) -> str:
|
def chunk_content(self) -> str:
|
||||||
"""Get the actual content string regardless of input type"""
|
"""
|
||||||
|
Get the actual content string regardless of input type
|
||||||
|
"""
|
||||||
if isinstance(self.content, QAChunk):
|
if isinstance(self.content, QAChunk):
|
||||||
return self.content.question # QA 模式下 page_content 存 question
|
return f"question: {self.content.question} answer: {self.content.answer}"
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
@property
|
|
||||||
def is_qa(self) -> bool:
|
|
||||||
return isinstance(self.content, QAChunk)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def qa_metadata(self) -> dict:
|
|
||||||
"""返回 QA 相关的 metadata 字段"""
|
|
||||||
if isinstance(self.content, QAChunk):
|
|
||||||
return {
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": self.content.question,
|
|
||||||
"answer": self.content.answer,
|
|
||||||
}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkRetrieve(BaseModel):
|
class ChunkRetrieve(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
@@ -77,8 +51,3 @@ class ChunkRetrieve(BaseModel):
|
|||||||
vector_similarity_weight: float | None = Field(None)
|
vector_similarity_weight: float | None = Field(None)
|
||||||
top_k: int | None = Field(None)
|
top_k: int | None = Field(None)
|
||||||
retrieve_type: RetrieveType | None = Field(None)
|
retrieve_type: RetrieveType | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ChunkBatchCreate(BaseModel):
|
|
||||||
"""批量创建 chunk"""
|
|
||||||
items: list[ChunkCreate] = Field(..., min_length=1, description="chunk 列表")
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ class FileBase(BaseModel):
|
|||||||
file_ext: str
|
file_ext: str
|
||||||
file_size: int
|
file_size: int
|
||||||
file_url: str | None = None
|
file_url: str | None = None
|
||||||
file_key: str | None = None
|
|
||||||
created_at: datetime.datetime | None = None
|
created_at: datetime.datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -161,7 +161,10 @@ class AppChatService:
|
|||||||
f.type == FileType.DOCUMENT for f in files
|
f.type == FileType.DOCUMENT for f in files
|
||||||
):
|
):
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
@@ -448,7 +451,10 @@ class AppChatService:
|
|||||||
):
|
):
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
|
|||||||
@@ -650,7 +650,10 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -924,7 +927,10 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
if has_doc_with_images:
|
if has_doc_with_images:
|
||||||
system_prompt += (
|
system_prompt += (
|
||||||
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: http://...,请在回答中用 Markdown 格式  展示对应图片。"
|
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>,"
|
||||||
|
"请在回答中用 Markdown 格式  展示对应图片。"
|
||||||
|
"重要:图片 URL 中包含 UUID(如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx),"
|
||||||
|
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
|
|||||||
@@ -34,7 +34,26 @@ def generate_file_key(
|
|||||||
Generate a unique file key for storage.
|
Generate a unique file key for storage.
|
||||||
|
|
||||||
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
The file key follows the format: {tenant_id}/{workspace_id}/{file_id}{file_ext}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: The tenant UUID.
|
||||||
|
workspace_id: The workspace UUID.
|
||||||
|
file_id: The file UUID.
|
||||||
|
file_ext: The file extension (e.g., '.pdf', '.txt').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A unique file key string.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> generate_file_key(
|
||||||
|
... uuid.UUID('550e8400-e29b-41d4-a716-446655440000'),
|
||||||
|
... uuid.UUID('660e8400-e29b-41d4-a716-446655440001'),
|
||||||
|
... uuid.UUID('770e8400-e29b-41d4-a716-446655440002'),
|
||||||
|
... '.pdf'
|
||||||
|
... )
|
||||||
|
'550e8400-e29b-41d4-a716-446655440000/660e8400-e29b-41d4-a716-446655440001/770e8400-e29b-41d4-a716-446655440002.pdf'
|
||||||
"""
|
"""
|
||||||
|
# Ensure file_ext starts with a dot
|
||||||
if file_ext and not file_ext.startswith('.'):
|
if file_ext and not file_ext.startswith('.'):
|
||||||
file_ext = f'.{file_ext}'
|
file_ext = f'.{file_ext}'
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
@@ -42,21 +61,6 @@ def generate_file_key(
|
|||||||
return f"{tenant_id}/{file_id}{file_ext}"
|
return f"{tenant_id}/{file_id}{file_ext}"
|
||||||
|
|
||||||
|
|
||||||
def generate_kb_file_key(
|
|
||||||
kb_id: uuid.UUID,
|
|
||||||
file_id: uuid.UUID,
|
|
||||||
file_ext: str,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Generate a file key for knowledge base files.
|
|
||||||
|
|
||||||
Format: kb/{kb_id}/{file_id}{file_ext}
|
|
||||||
"""
|
|
||||||
if file_ext and not file_ext.startswith('.'):
|
|
||||||
file_ext = f'.{file_ext}'
|
|
||||||
return f"kb/{kb_id}/{file_id}{file_ext}"
|
|
||||||
|
|
||||||
|
|
||||||
class FileStorageService:
|
class FileStorageService:
|
||||||
"""
|
"""
|
||||||
High-level service for file storage operations.
|
High-level service for file storage operations.
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ class MultimodalService:
|
|||||||
# 在文本内容中追加图片位置标记
|
# 在文本内容中追加图片位置标记
|
||||||
if result and result[-1].get("type") in ("text", "document"):
|
if result and result[-1].get("type") in ("text", "document"):
|
||||||
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
key = "text" if "text" in result[-1] else list(result[-1].keys())[-1]
|
||||||
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: {img_url}"
|
result[-1][key] = result[-1].get(key, "") + f"\n[图片 {placeholder}]: <img src=\"{img_url}\" data-url=\"{img_url}\">"
|
||||||
# 将图片以视觉格式追加到消息内容中
|
# 将图片以视觉格式追加到消息内容中
|
||||||
img_file = FileInput(
|
img_file = FileInput(
|
||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
|
|||||||
@@ -554,13 +554,16 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "workflow_end":
|
case "workflow_end":
|
||||||
|
data = {
|
||||||
|
"elapsed_time": payload.get("elapsed_time"),
|
||||||
|
"message_length": len(payload.get("output", "")),
|
||||||
|
"error": payload.get("error", "")
|
||||||
|
}
|
||||||
|
if "citations" in payload and payload["citations"]:
|
||||||
|
data["citations"] = payload["citations"]
|
||||||
return {
|
return {
|
||||||
"event": "end",
|
"event": "end",
|
||||||
"data": {
|
"data": data
|
||||||
"elapsed_time": payload.get("elapsed_time"),
|
|
||||||
"message_length": len(payload.get("output", "")),
|
|
||||||
"error": payload.get("error", "")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.models.workspace_model import (
|
|||||||
)
|
)
|
||||||
from app.repositories import workspace_repository
|
from app.repositories import workspace_repository
|
||||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||||
|
from app.services.session_service import SessionService
|
||||||
from app.schemas.workspace_schema import (
|
from app.schemas.workspace_schema import (
|
||||||
InviteAcceptRequest,
|
InviteAcceptRequest,
|
||||||
InviteValidateResponse,
|
InviteValidateResponse,
|
||||||
@@ -58,7 +59,7 @@ def switch_workspace(
|
|||||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
def delete_workspace_member(
|
async def delete_workspace_member(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
@@ -76,10 +77,29 @@ def delete_workspace_member(
|
|||||||
BizCode.WORKSPACE_NOT_FOUND)
|
BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
deleted_user = workspace_member.user
|
||||||
workspace_member.is_active = False
|
workspace_member.is_active = False
|
||||||
workspace_member.user.current_workspace_id = None
|
deleted_user.current_workspace_id = None
|
||||||
|
|
||||||
|
# 若被删除成员不是超级管理员且没有其他可用工作空间,则禁用该用户
|
||||||
|
if not deleted_user.is_superuser:
|
||||||
|
remaining = (
|
||||||
|
db.query(WorkspaceMember)
|
||||||
|
.filter(
|
||||||
|
WorkspaceMember.user_id == deleted_user.id,
|
||||||
|
WorkspaceMember.workspace_id != workspace_id,
|
||||||
|
WorkspaceMember.is_active.is_(True),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
if remaining == 0:
|
||||||
|
deleted_user.is_active = False
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
|
|
||||||
|
# 使被删除成员的所有 token 立即失效
|
||||||
|
await SessionService.invalidate_all_user_tokens(str(workspace_member.user_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||||
|
|||||||
325
api/app/tasks.py
325
api/app/tasks.py
@@ -30,7 +30,7 @@ from app.core.rag.llm.cv_model import QWenCV
|
|||||||
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
from app.core.rag.llm.embedding_model import OpenAIEmbed
|
||||||
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
from app.core.rag.llm.sequence2txt_model import QWenSeq2txt
|
||||||
from app.core.rag.models.chunk import DocumentChunk
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
from app.core.rag.prompts.generator import question_proposal, qa_proposal
|
from app.core.rag.prompts.generator import question_proposal
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||||
ElasticSearchVectorFactory,
|
ElasticSearchVectorFactory,
|
||||||
)
|
)
|
||||||
@@ -210,14 +210,9 @@ def _build_vision_model(file_path: str, db_knowledge):
|
|||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
@celery_app.task(name="app.core.rag.tasks.parse_document")
|
||||||
def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
def parse_document(file_path: str, document_id: uuid.UUID):
|
||||||
"""
|
"""
|
||||||
Document parsing, vectorization, and storage.
|
Document parsing, vectorization, and storage
|
||||||
|
|
||||||
Args:
|
|
||||||
file_key: Storage key for FileStorageService (e.g. "kb/{kb_id}/{file_id}.docx")
|
|
||||||
document_id: Document UUID
|
|
||||||
file_name: Original file name (used for extension detection in chunk())
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_document = None
|
db_document = None
|
||||||
@@ -228,6 +223,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
|
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
try:
|
try:
|
||||||
|
# Celery JSON 序列化会将 UUID 转为字符串,需要确保类型正确
|
||||||
if not isinstance(document_id, uuid.UUID):
|
if not isinstance(document_id, uuid.UUID):
|
||||||
document_id = uuid.UUID(str(document_id))
|
document_id = uuid.UUID(str(document_id))
|
||||||
|
|
||||||
@@ -238,11 +234,7 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
if db_knowledge is None:
|
if db_knowledge is None:
|
||||||
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
raise ValueError(f"Knowledge {db_document.kb_id} not found")
|
||||||
|
|
||||||
# Use file_name from argument or fall back to document record
|
# 1. Document parsing & segmentation
|
||||||
if not file_name:
|
|
||||||
file_name = db_document.file_name
|
|
||||||
|
|
||||||
# 1. Download file from storage backend
|
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} Start to parse.")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
db_document.progress = 0.0
|
db_document.progress = 0.0
|
||||||
@@ -253,36 +245,45 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_document)
|
db.refresh(db_document)
|
||||||
|
|
||||||
# Read file content from storage backend (no NFS dependency)
|
|
||||||
from app.services.file_storage_service import FileStorageService
|
|
||||||
import asyncio
|
|
||||||
storage_service = FileStorageService()
|
|
||||||
|
|
||||||
async def _download():
|
|
||||||
return await storage_service.download_file(file_key)
|
|
||||||
|
|
||||||
try:
|
|
||||||
file_binary = asyncio.run(_download())
|
|
||||||
except RuntimeError:
|
|
||||||
# If there's already a running loop (e.g. in some worker configurations)
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
try:
|
|
||||||
file_binary = loop.run_until_complete(_download())
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
if not file_binary:
|
|
||||||
raise IOError(f"Downloaded empty file from storage: {file_key}")
|
|
||||||
logger.info(f"[ParseDoc] Downloaded {len(file_binary)} bytes from storage key: {file_key}")
|
|
||||||
|
|
||||||
def progress_callback(prog=None, msg=None):
|
def progress_callback(prog=None, msg=None):
|
||||||
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
progress_lines.append(f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.")
|
||||||
|
|
||||||
# Prepare vision_model for parsing
|
# Prepare vision_model for parsing
|
||||||
vision_model = _build_vision_model(file_name, db_knowledge)
|
vision_model = _build_vision_model(file_path, db_knowledge)
|
||||||
|
|
||||||
|
# 先将文件读入内存,避免解析过程中依赖 NFS 文件持续可访问
|
||||||
|
# python-docx 等库在 binary=None 时会用路径直接打开文件,
|
||||||
|
# 在 NFS/共享存储上可能因缓存失效导致 "Package not found"
|
||||||
|
max_wait_seconds = 30
|
||||||
|
wait_interval = 2
|
||||||
|
waited = 0
|
||||||
|
file_binary = None
|
||||||
|
while waited <= max_wait_seconds:
|
||||||
|
# os.listdir 强制 NFS 客户端刷新目录缓存
|
||||||
|
parent_dir = os.path.dirname(file_path)
|
||||||
|
try:
|
||||||
|
os.listdir(parent_dir)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
file_binary = f.read()
|
||||||
|
if not file_binary:
|
||||||
|
# NFS 上文件存在但内容为空(可能还在同步中)
|
||||||
|
raise IOError(f"File is empty (0 bytes), NFS may still be syncing: {file_path}")
|
||||||
|
break
|
||||||
|
except (FileNotFoundError, IOError) as e:
|
||||||
|
if waited >= max_wait_seconds:
|
||||||
|
raise type(e)(
|
||||||
|
f"File not accessible at '{file_path}' after waiting {max_wait_seconds}s: {e}"
|
||||||
|
)
|
||||||
|
logger.warning(f"File not ready on this node, retrying in {wait_interval}s: {file_path} ({e})")
|
||||||
|
time.sleep(wait_interval)
|
||||||
|
waited += wait_interval
|
||||||
|
|
||||||
from app.core.rag.app.naive import chunk
|
from app.core.rag.app.naive import chunk
|
||||||
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
logger.info(f"[ParseDoc] file_binary size={len(file_binary)} bytes, type={type(file_binary).__name__}, bool={bool(file_binary)}")
|
||||||
res = chunk(filename=file_name,
|
res = chunk(filename=file_path,
|
||||||
binary=file_binary,
|
binary=file_binary,
|
||||||
from_page=0,
|
from_page=0,
|
||||||
to_page=DEFAULT_PARSE_TO_PAGE,
|
to_page=DEFAULT_PARSE_TO_PAGE,
|
||||||
@@ -311,7 +312,6 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||||
# 2.2 Vectorize and import batch documents
|
# 2.2 Vectorize and import batch documents
|
||||||
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
auto_questions_topn = db_document.parser_config.get("auto_questions", 0)
|
||||||
qa_prompt = db_document.parser_config.get("qa_prompt", None)
|
|
||||||
chat_model = None
|
chat_model = None
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
chat_model = Base(
|
chat_model = Base(
|
||||||
@@ -319,123 +319,62 @@ def parse_document(file_key: str, document_id: uuid.UUID, file_name: str = ""):
|
|||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
)
|
)
|
||||||
logger.info(f"[QA] LLM model: {db_knowledge.llm.api_keys[0].model_name}, base_url: {db_knowledge.llm.api_keys[0].api_base}")
|
|
||||||
if qa_prompt:
|
|
||||||
logger.info(f"[QA] Using custom prompt ({len(qa_prompt)} chars)")
|
|
||||||
|
|
||||||
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
# 预先构建所有 batch 的 chunks,保证 sort_id 全局有序
|
||||||
all_batch_chunks: list[list[DocumentChunk]] = []
|
all_batch_chunks: list[list[DocumentChunk]] = []
|
||||||
|
|
||||||
if auto_questions_topn:
|
if auto_questions_topn:
|
||||||
# QA 模式(FastGPT 方案):
|
# auto_questions 开启:先并发生成所有 chunk 的问题,再按 batch 分组
|
||||||
# 1. 原 chunk 标记为 source(保留供 GraphRAG 使用,不参与检索)
|
# 构建 (global_idx, item) 列表
|
||||||
# 2. LLM 生成 QA 对,每个 QA 对独立存储为 qa chunk
|
|
||||||
indexed_items = list(enumerate(res))
|
indexed_items = list(enumerate(res))
|
||||||
|
|
||||||
def _generate_qa(idx_item: tuple[int, dict]) -> tuple[int, list]:
|
def _generate_question(idx_item: tuple[int, dict]) -> tuple[int, str]:
|
||||||
"""为单个 chunk 生成 QA 对(带缓存),返回 (global_idx, qa_pairs)"""
|
"""为单个 chunk 生成问题(带缓存),返回 (global_idx, question_text)"""
|
||||||
global_idx, item = idx_item
|
global_idx, item = idx_item
|
||||||
content = item["content_with_weight"]
|
content = item["content_with_weight"]
|
||||||
cache_params = {"topn": auto_questions_topn}
|
cached = get_llm_cache(chat_model.model_name, content, "question",
|
||||||
if qa_prompt:
|
{"topn": auto_questions_topn})
|
||||||
import hashlib
|
|
||||||
cache_params["prompt_hash"] = hashlib.md5(qa_prompt.encode()).hexdigest()[:8]
|
|
||||||
cached = get_llm_cache(chat_model.model_name, content, "qa", cache_params)
|
|
||||||
if not cached:
|
if not cached:
|
||||||
logger.info(f"[QA] Cache miss for chunk {global_idx}, calling LLM. cache_params={cache_params}")
|
cached = question_proposal(chat_model, content, auto_questions_topn)
|
||||||
try:
|
set_llm_cache(chat_model.model_name, content, cached, "question",
|
||||||
pairs = qa_proposal(chat_model, content, auto_questions_topn, custom_prompt=qa_prompt)
|
{"topn": auto_questions_topn})
|
||||||
except Exception as e:
|
return global_idx, cached
|
||||||
logger.error(f"[QA] LLM call failed: model={chat_model.model_name}, base_url={getattr(chat_model, 'base_url', 'N/A')}, error={e}")
|
|
||||||
return global_idx, []
|
|
||||||
logger.info(f"[QA] Chunk {global_idx} generated {len(pairs)} QA pairs")
|
|
||||||
# 缓存存 JSON 字符串
|
|
||||||
set_llm_cache(chat_model.model_name, content, json.dumps(pairs, ensure_ascii=False), "qa",
|
|
||||||
cache_params)
|
|
||||||
return global_idx, pairs
|
|
||||||
logger.info(f"[QA] Cache hit for chunk {global_idx}, cache_params={cache_params}, cached_type={type(cached).__name__}")
|
|
||||||
# 从缓存读取:可能是 JSON 字符串或旧格式纯文本
|
|
||||||
if isinstance(cached, str):
|
|
||||||
try:
|
|
||||||
parsed = json.loads(cached)
|
|
||||||
if isinstance(parsed, list):
|
|
||||||
logger.info(f"[QA] Chunk {global_idx} loaded {len(parsed)} QA pairs from cache")
|
|
||||||
return global_idx, parsed
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
# 旧缓存格式(纯文本问题),尝试解析
|
|
||||||
from app.core.rag.prompts.generator import parse_qa_pairs
|
|
||||||
return global_idx, parse_qa_pairs(cached) if cached else []
|
|
||||||
return global_idx, cached if isinstance(cached, list) else []
|
|
||||||
|
|
||||||
# 并发调用 LLM 生成 QA 对
|
# 并发调用 LLM 生成问题
|
||||||
qa_map: dict[int, list] = {}
|
question_map: dict[int, str] = {}
|
||||||
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
with ThreadPoolExecutor(max_workers=AUTO_QUESTIONS_MAX_WORKERS) as q_executor:
|
||||||
futures = {q_executor.submit(_generate_qa, item): item[0]
|
futures = {q_executor.submit(_generate_question, item): item[0]
|
||||||
for item in indexed_items}
|
for item in indexed_items}
|
||||||
for future in futures:
|
for future in futures:
|
||||||
global_idx, pairs = future.result()
|
global_idx, cached = future.result()
|
||||||
qa_map[global_idx] = pairs
|
question_map[global_idx] = cached
|
||||||
|
|
||||||
progress_lines.append(
|
progress_lines.append(
|
||||||
f"{datetime.now().strftime('%H:%M:%S')} QA pairs generated for {total_chunks} chunks "
|
f"{datetime.now().strftime('%H:%M:%S')} Auto questions generated for {total_chunks} chunks "
|
||||||
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
f"(workers={AUTO_QUESTIONS_MAX_WORKERS}).")
|
||||||
|
|
||||||
# 组装 chunks:source chunks + qa chunks
|
# 按 batch 分组组装 DocumentChunk
|
||||||
source_chunks = []
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
qa_chunks = []
|
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, total_chunks)
|
||||||
qa_sort_id = 0
|
chunks = []
|
||||||
|
for global_idx in range(batch_start, batch_end):
|
||||||
for global_idx in range(total_chunks):
|
item = res[global_idx]
|
||||||
item = res[global_idx]
|
metadata = {
|
||||||
source_chunk_id = uuid.uuid4().hex
|
|
||||||
|
|
||||||
# source chunk:保留原文,供 GraphRAG 使用,不参与向量检索
|
|
||||||
source_meta = {
|
|
||||||
"doc_id": source_chunk_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": str(db_document.id),
|
|
||||||
"knowledge_id": str(db_document.kb_id),
|
|
||||||
"sort_id": global_idx,
|
|
||||||
"status": 1,
|
|
||||||
"chunk_type": "source",
|
|
||||||
}
|
|
||||||
source_chunks.append(
|
|
||||||
DocumentChunk(page_content=item["content_with_weight"], metadata=source_meta))
|
|
||||||
|
|
||||||
# qa chunks:每个 QA 对独立存储
|
|
||||||
pairs = qa_map.get(global_idx, [])
|
|
||||||
for pair in pairs:
|
|
||||||
qa_meta = {
|
|
||||||
"doc_id": uuid.uuid4().hex,
|
"doc_id": uuid.uuid4().hex,
|
||||||
"file_id": str(db_document.file_id),
|
"file_id": str(db_document.file_id),
|
||||||
"file_name": db_document.file_name,
|
"file_name": db_document.file_name,
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||||
"document_id": str(db_document.id),
|
"document_id": str(db_document.id),
|
||||||
"knowledge_id": str(db_document.kb_id),
|
"knowledge_id": str(db_document.kb_id),
|
||||||
"sort_id": qa_sort_id,
|
"sort_id": global_idx,
|
||||||
"status": 1,
|
"status": 1,
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": pair["question"],
|
|
||||||
"answer": pair["answer"],
|
|
||||||
"source_chunk_id": source_chunk_id,
|
|
||||||
}
|
}
|
||||||
# page_content 存 question,用于向量索引
|
cached = question_map[global_idx]
|
||||||
qa_chunks.append(
|
chunks.append(
|
||||||
DocumentChunk(page_content=pair["question"], metadata=qa_meta))
|
DocumentChunk(
|
||||||
qa_sort_id += 1
|
page_content=f"question: {cached} answer: {item['content_with_weight']}",
|
||||||
|
metadata=metadata))
|
||||||
# 按 batch 分组(source + qa 一起)
|
all_batch_chunks.append(chunks)
|
||||||
all_chunks = source_chunks + qa_chunks
|
|
||||||
for batch_start in range(0, len(all_chunks), EMBEDDING_BATCH_SIZE):
|
|
||||||
batch_end = min(batch_start + EMBEDDING_BATCH_SIZE, len(all_chunks))
|
|
||||||
all_batch_chunks.append(all_chunks[batch_start:batch_end])
|
|
||||||
|
|
||||||
progress_lines.append(
|
|
||||||
f"{datetime.now().strftime('%H:%M:%S')} QA mode: {len(source_chunks)} source chunks + "
|
|
||||||
f"{len(qa_chunks)} QA chunks prepared.")
|
|
||||||
else:
|
else:
|
||||||
# 无 auto_questions:直接构建 chunks
|
# 无 auto_questions:直接构建 chunks
|
||||||
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
for batch_start in range(0, total_chunks, EMBEDDING_BATCH_SIZE):
|
||||||
@@ -697,136 +636,6 @@ def build_graphrag_for_document(document_id: str, knowledge_id: str):
|
|||||||
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
return f"build_graphrag_for_document '{document_id}' failed: {e}"
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.import_qa_chunks", queue="qa_import")
|
|
||||||
def import_qa_chunks(kb_id: str, document_id: str, filename: str, contents: bytes):
|
|
||||||
"""
|
|
||||||
异步导入 QA 问答对(CSV/Excel)
|
|
||||||
|
|
||||||
文件格式:第一行标题(跳过),第一列问题,第二列答案
|
|
||||||
"""
|
|
||||||
import csv as csv_module
|
|
||||||
import io
|
|
||||||
|
|
||||||
db = None
|
|
||||||
try:
|
|
||||||
from app.db import get_db_context
|
|
||||||
with get_db_context() as db:
|
|
||||||
db_document = db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
|
||||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == uuid.UUID(kb_id)).first()
|
|
||||||
if not db_document or not db_knowledge:
|
|
||||||
logger.error(f"[ImportQA] document={document_id} or knowledge={kb_id} not found")
|
|
||||||
return {"error": "document or knowledge not found", "imported": 0}
|
|
||||||
|
|
||||||
# 1. 解析文件
|
|
||||||
qa_pairs = []
|
|
||||||
failed_rows = []
|
|
||||||
|
|
||||||
if filename.endswith(".csv"):
|
|
||||||
try:
|
|
||||||
text = contents.decode("utf-8-sig")
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
text = contents.decode("gbk", errors="ignore")
|
|
||||||
|
|
||||||
sniffer = csv_module.Sniffer()
|
|
||||||
try:
|
|
||||||
dialect = sniffer.sniff(text[:2048])
|
|
||||||
delimiter = dialect.delimiter
|
|
||||||
except csv_module.Error:
|
|
||||||
delimiter = "," if "," in text[:500] else "\t"
|
|
||||||
|
|
||||||
reader = csv_module.reader(io.StringIO(text), delimiter=delimiter)
|
|
||||||
for i, row in enumerate(reader):
|
|
||||||
if i == 0:
|
|
||||||
continue
|
|
||||||
if len(row) >= 2 and row[0].strip() and row[1].strip():
|
|
||||||
qa_pairs.append({"question": row[0].strip(), "answer": row[1].strip()})
|
|
||||||
elif len(row) >= 1 and row[0].strip():
|
|
||||||
failed_rows.append(i + 1)
|
|
||||||
|
|
||||||
elif filename.endswith(".xlsx") or filename.endswith(".xls"):
|
|
||||||
try:
|
|
||||||
import openpyxl
|
|
||||||
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
|
||||||
for sheet in wb.worksheets:
|
|
||||||
for i, row in enumerate(sheet.iter_rows(values_only=True)):
|
|
||||||
if i == 0:
|
|
||||||
continue
|
|
||||||
if len(row) >= 2 and row[0] and row[1]:
|
|
||||||
q = str(row[0]).strip()
|
|
||||||
a = str(row[1]).strip()
|
|
||||||
if q and a:
|
|
||||||
qa_pairs.append({"question": q, "answer": a})
|
|
||||||
elif len(row) >= 1 and row[0]:
|
|
||||||
failed_rows.append(i + 1)
|
|
||||||
wb.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ImportQA] Excel parse failed: {e}")
|
|
||||||
return {"error": f"Excel parse failed: {e}", "imported": 0}
|
|
||||||
|
|
||||||
if not qa_pairs:
|
|
||||||
logger.warning(f"[ImportQA] No valid QA pairs found in {filename}")
|
|
||||||
return {"error": "No valid QA pairs found", "imported": 0}
|
|
||||||
|
|
||||||
logger.info(f"[ImportQA] Parsed {len(qa_pairs)} QA pairs from {filename}, failed_rows={failed_rows}")
|
|
||||||
|
|
||||||
# 2. 写入 ES
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
|
|
||||||
sort_id = 0
|
|
||||||
total, items = vector_service.search_by_segment(document_id=document_id, pagesize=1, page=1, asc=False)
|
|
||||||
if items:
|
|
||||||
sort_id = items[0].metadata["sort_id"]
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
for pair in qa_pairs:
|
|
||||||
sort_id += 1
|
|
||||||
doc_id = uuid.uuid4().hex
|
|
||||||
metadata = {
|
|
||||||
"doc_id": doc_id,
|
|
||||||
"file_id": str(db_document.file_id),
|
|
||||||
"file_name": db_document.file_name,
|
|
||||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
|
||||||
"document_id": document_id,
|
|
||||||
"knowledge_id": kb_id,
|
|
||||||
"sort_id": sort_id,
|
|
||||||
"status": 1,
|
|
||||||
"chunk_type": "qa",
|
|
||||||
"question": pair["question"],
|
|
||||||
"answer": pair["answer"],
|
|
||||||
}
|
|
||||||
chunks.append(DocumentChunk(page_content=pair["question"], metadata=metadata))
|
|
||||||
|
|
||||||
batch_size = 50
|
|
||||||
for i in range(0, len(chunks), batch_size):
|
|
||||||
batch = chunks[i:i + batch_size]
|
|
||||||
vector_service.add_chunks(batch)
|
|
||||||
|
|
||||||
# 3. 更新 chunk_num 和 progress
|
|
||||||
db_document.chunk_num += len(chunks)
|
|
||||||
db_document.progress = 1.0
|
|
||||||
db_document.progress_msg = f"QA 导入完成: {len(chunks)} 条"
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
result = {"imported": len(chunks), "failed_rows": failed_rows}
|
|
||||||
logger.info(f"[ImportQA] Done: imported={len(chunks)}, failed={len(failed_rows)}")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[ImportQA] Failed: {e}", exc_info=True)
|
|
||||||
# 尝试更新文档状态为失败
|
|
||||||
try:
|
|
||||||
from app.db import get_db_context
|
|
||||||
with get_db_context() as err_db:
|
|
||||||
doc = err_db.query(Document).filter(Document.id == uuid.UUID(document_id)).first()
|
|
||||||
if doc:
|
|
||||||
doc.progress = -1.0
|
|
||||||
doc.progress_msg = f"QA 导入失败: {str(e)[:200]}"
|
|
||||||
err_db.commit()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {"error": str(e), "imported": 0}
|
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||||
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
def sync_knowledge_for_kb(kb_id: uuid.UUID):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -8,12 +8,11 @@ import { type FC, useRef, useEffect, useState } from 'react'
|
|||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import Markdown from '@/components/Markdown'
|
import Markdown from '@/components/Markdown'
|
||||||
import type { ChatContentProps } from './types'
|
import type { ChatContentProps } from './types'
|
||||||
import { Spin, Image, Flex, Button } from 'antd'
|
import { Spin, Flex, Button } from 'antd'
|
||||||
import { SoundOutlined } from '@ant-design/icons'
|
import { SoundOutlined } from '@ant-design/icons'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import AudioPlayer from './AudioPlayer'
|
import MessageFiles from './MessageFiles'
|
||||||
import VideoPlayer from './VideoPlayer'
|
|
||||||
|
|
||||||
const getFileUrl = (file: any) => {
|
const getFileUrl = (file: any) => {
|
||||||
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
@@ -149,72 +148,7 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
{labelFormat(item)}
|
{labelFormat(item)}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
{item?.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end" className="rb:mb-2!">
|
<MessageFiles files={item.meta_data?.files ?? []} contentClassNames={contentClassNames} onDownload={handleDownload} />
|
||||||
{item.meta_data?.files?.map((file) => {
|
|
||||||
if (file.type.includes('image')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
|
|
||||||
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('video')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className="rb:w-50">
|
|
||||||
{/* <video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" /> */}
|
|
||||||
<VideoPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if (file.type.includes('audio')) {
|
|
||||||
return (
|
|
||||||
<div key={file.url || file.uid} className="rb:w-50">
|
|
||||||
<AudioPlayer key={file.url || file.uid} src={getFileUrl(file)} />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const documentType = (file.file_type || file.type)?.split('/')
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
key={file.url || file.uid}
|
|
||||||
align="center"
|
|
||||||
gap={10}
|
|
||||||
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
|
||||||
onClick={() => handleDownload(file)}
|
|
||||||
>
|
|
||||||
<div
|
|
||||||
className={clsx(
|
|
||||||
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
|
||||||
file.type?.includes('pdf')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/pdf.svg')]"
|
|
||||||
: (file.type?.includes('excel') || file.type?.includes('spreadsheetml.sheet')) || file.type?.includes('xls') || file.type?.includes('xlsx')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/excel.svg')]"
|
|
||||||
: file.type?.includes('csv')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/csv.svg')]"
|
|
||||||
: file.type?.includes('html')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/html.svg')]"
|
|
||||||
: file.type?.includes('json')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/json.svg')]"
|
|
||||||
: file.type?.includes('ppt')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/ppt.svg')]"
|
|
||||||
: file.type?.includes('markdown')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/md.svg')]"
|
|
||||||
: file.type?.includes('text')
|
|
||||||
? "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
|
||||||
: (file.type?.includes('doc') || file.type?.includes('docx') || file.type?.includes('word') || file.type?.includes('wordprocessingml.document'))
|
|
||||||
? "rb:bg-[url('@/assets/images/file/word.svg')]"
|
|
||||||
: "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
|
||||||
)}
|
|
||||||
></div>
|
|
||||||
<div className="rb:flex-1 rb:w-32.5">
|
|
||||||
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
|
||||||
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{documentType?.[documentType.length - 1]} · {file.size}</div>
|
|
||||||
</div>
|
|
||||||
</Flex>
|
|
||||||
)
|
|
||||||
})}
|
|
||||||
</Flex>}
|
|
||||||
{/* Message bubble */}
|
{/* Message bubble */}
|
||||||
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
<div className={clsx('rb:text-left rb:leading-5 rb:inline-block rb:wrap-break-word rb:relative', item.role === 'user' ? contentClassNames : '', {
|
||||||
// Error message style (content is null and not assistant message)
|
// Error message style (content is null and not assistant message)
|
||||||
|
|||||||
87
web/src/components/Chat/MessageFiles.tsx
Normal file
87
web/src/components/Chat/MessageFiles.tsx
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import { Image, Flex } from 'antd'
|
||||||
|
import clsx from 'clsx'
|
||||||
|
import AudioPlayer from './AudioPlayer'
|
||||||
|
import VideoPlayer from './VideoPlayer'
|
||||||
|
|
||||||
|
const getFileUrl = (file: any) =>
|
||||||
|
file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
|
|
||||||
|
const DOC_ICONS: [string[], string][] = [
|
||||||
|
[['pdf'], "rb:bg-[url('@/assets/images/file/pdf.svg')]"],
|
||||||
|
[['excel', 'spreadsheetml.sheet', 'xls', 'xlsx'], "rb:bg-[url('@/assets/images/file/excel.svg')]"],
|
||||||
|
[['csv'], "rb:bg-[url('@/assets/images/file/csv.svg')]"],
|
||||||
|
[['html'], "rb:bg-[url('@/assets/images/file/html.svg')]"],
|
||||||
|
[['json'], "rb:bg-[url('@/assets/images/file/json.svg')]"],
|
||||||
|
[['ppt'], "rb:bg-[url('@/assets/images/file/ppt.svg')]"],
|
||||||
|
[['markdown'], "rb:bg-[url('@/assets/images/file/md.svg')]"],
|
||||||
|
[['text'], "rb:bg-[url('@/assets/images/file/txt.svg')]"],
|
||||||
|
[['doc', 'docx', 'word', 'wordprocessingml.document'], "rb:bg-[url('@/assets/images/file/word.svg')]"],
|
||||||
|
]
|
||||||
|
|
||||||
|
const getDocIcon = (parts: string[]) => {
|
||||||
|
const match = DOC_ICONS.find(([keys]) => keys.some(k => parts.includes(k)))
|
||||||
|
return match ? match[1] : "rb:bg-[url('@/assets/images/file/txt.svg')]"
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MessageFilesProps {
|
||||||
|
files: any[]
|
||||||
|
contentClassNames?: string | Record<string, boolean>
|
||||||
|
onDownload: (file: any) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
const MessageFiles = ({ files, contentClassNames, onDownload }: MessageFilesProps) => {
|
||||||
|
if (!files?.length) return null
|
||||||
|
return (
|
||||||
|
<Flex gap={8} vertical align="end" className="rb:mb-2!">
|
||||||
|
{files.map((file) => {
|
||||||
|
const key = file.url || file.uid
|
||||||
|
if (file.type.includes('image')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className={clsx('rb:inline-block rb:group rb:relative rb:rounded-lg', contentClassNames)}>
|
||||||
|
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('video')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className="rb:w-50">
|
||||||
|
<VideoPlayer src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('audio')) {
|
||||||
|
return (
|
||||||
|
<div key={key} className="rb:w-50">
|
||||||
|
<AudioPlayer src={getFileUrl(file)} />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
const documentType = (file.file_type || file.type)?.split('/') ?? []
|
||||||
|
return (
|
||||||
|
<Flex
|
||||||
|
key={key}
|
||||||
|
align="center"
|
||||||
|
gap={10}
|
||||||
|
className="rb:text-left rb:w-45 rb:text-[12px] rb:group rb:relative rb:rounded-lg rb-border rb:py-2! rb:px-2.5! rb:border rb:border-[#F6F6F6]"
|
||||||
|
onClick={() => onDownload(file)}
|
||||||
|
>
|
||||||
|
<div
|
||||||
|
className={clsx(
|
||||||
|
"rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf_disabled.svg')]",
|
||||||
|
getDocIcon(documentType)
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<div className="rb:flex-1 rb:w-32.5">
|
||||||
|
<div className="rb:leading-4 rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">{file.name}</div>
|
||||||
|
<div className="rb:leading-3.5 rb:mt-0.5 rb:text-[#5B6167] rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap">
|
||||||
|
{documentType?.[documentType.length - 1]} · {file.size}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Flex>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</Flex>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MessageFiles
|
||||||
@@ -399,7 +399,7 @@ const Menu: FC<{
|
|||||||
className="rb:overflow-y-auto rb:flex-1!"
|
className="rb:overflow-y-auto rb:flex-1!"
|
||||||
/>
|
/>
|
||||||
{/* Return to space button for superusers */}
|
{/* Return to space button for superusers */}
|
||||||
{user?.is_superuser && source === 'space' &&
|
{source === 'space' &&
|
||||||
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
<Flex gap={4} vertical className="rb:my-3! rb:mx-3!">
|
||||||
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
<Divider className="rb:mb-2.5! rb:mt-0! rb:border-[#DFE4ED]! rb:mx-2! rb:min-w-[calc(100%-20px)]! rb:w-[calc(100%-20px)]!" />
|
||||||
<Flex
|
<Flex
|
||||||
@@ -412,16 +412,18 @@ const Menu: FC<{
|
|||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/switch.svg')]"></div>
|
||||||
{collapsed ? null : t('common.switchSpace')}
|
{collapsed ? null : t('common.switchSpace')}
|
||||||
</Flex>
|
</Flex>
|
||||||
<Flex
|
{user?.is_superuser &&
|
||||||
gap={8}
|
<Flex
|
||||||
align="center"
|
gap={8}
|
||||||
justify="start"
|
align="center"
|
||||||
onClick={goToSpace}
|
justify="start"
|
||||||
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
onClick={goToSpace}
|
||||||
>
|
className="rb:p-2.5! rb:text-[13px] rb:hover:bg-[rgba(223,228,237,0.5)] rb:rounded-lg rb:leading-3.5 rb:font-regular rb:text-center rb:cursor-pointer"
|
||||||
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
>
|
||||||
{collapsed ? null : t('common.returnToSpace')}
|
<div className="rb:cursor-pointer rb:size-4 rb:bg-cover rb:bg-[url('@/assets/images/menuNew/return.svg')]"></div>
|
||||||
</Flex>
|
{collapsed ? null : t('common.returnToSpace')}
|
||||||
|
</Flex>
|
||||||
|
}
|
||||||
</Flex>
|
</Flex>
|
||||||
}
|
}
|
||||||
{source === 'manage' && subscription && !collapsed &&
|
{source === 'manage' && subscription && !collapsed &&
|
||||||
|
|||||||
@@ -1538,6 +1538,7 @@ export const en = {
|
|||||||
json_output: 'Support JSON formatted output',
|
json_output: 'Support JSON formatted output',
|
||||||
thinking_budget_tokens: 'thinking budget tokens',
|
thinking_budget_tokens: 'thinking budget tokens',
|
||||||
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
thinking_budget_tokens_max_error: "Cannot exceed the max tokens limit ({{max}})",
|
||||||
|
thinking_budget_tokens_min_error: "Cannot be less than {{min}}",
|
||||||
logSearchPlaceholder: 'Search log content',
|
logSearchPlaceholder: 'Search log content',
|
||||||
},
|
},
|
||||||
userMemory: {
|
userMemory: {
|
||||||
|
|||||||
@@ -868,6 +868,7 @@ export const zh = {
|
|||||||
json_output: '支持JSON格式化输出',
|
json_output: '支持JSON格式化输出',
|
||||||
thinking_budget_tokens: '深度思考预算Token数',
|
thinking_budget_tokens: '深度思考预算Token数',
|
||||||
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
thinking_budget_tokens_max_error: "不能超过 最大令牌数 ({{max}})",
|
||||||
|
thinking_budget_tokens_min_error: "不能小于 {{min}}",
|
||||||
logSearchPlaceholder: '搜索日志内容',
|
logSearchPlaceholder: '搜索日志内容',
|
||||||
},
|
},
|
||||||
table: {
|
table: {
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ const configFields = [
|
|||||||
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
{ key: 'n', max: 10, min: 1, step: 1, defaultValue: 1 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
const minThinkingBudgetTokens = 128;
|
||||||
|
const defaultThinkingBudgetTokens = 1000;
|
||||||
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(({
|
||||||
refresh,
|
refresh,
|
||||||
data,
|
data,
|
||||||
@@ -108,7 +110,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
const newValues: ModelConfig = {
|
const newValues: ModelConfig = {
|
||||||
capability: (option as Model).capability,
|
capability: (option as Model).capability,
|
||||||
deep_thinking: false,
|
deep_thinking: false,
|
||||||
thinking_budget_tokens: undefined,
|
thinking_budget_tokens: defaultThinkingBudgetTokens,
|
||||||
json_output: false,
|
json_output: false,
|
||||||
}
|
}
|
||||||
if (source === 'chat') {
|
if (source === 'chat') {
|
||||||
@@ -128,6 +130,12 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
form.setFieldsValue({ ...rest })
|
form.setFieldsValue({ ...rest })
|
||||||
}, [data?.default_model_config_id])
|
}, [data?.default_model_config_id])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (values?.deep_thinking && !values?.thinking_budget_tokens) {
|
||||||
|
form.setFieldValue('thinking_budget_tokens', defaultThinkingBudgetTokens)
|
||||||
|
}
|
||||||
|
}, [values?.deep_thinking])
|
||||||
|
|
||||||
const handleReset = () => {
|
const handleReset = () => {
|
||||||
if (!id) return
|
if (!id) return
|
||||||
resetAppModelConfig(id).then((res) => {
|
resetAppModelConfig(id).then((res) => {
|
||||||
@@ -178,15 +186,20 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
name="thinking_budget_tokens"
|
name="thinking_budget_tokens"
|
||||||
label={t('application.thinking_budget_tokens')}
|
label={t('application.thinking_budget_tokens')}
|
||||||
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
hidden={!['model', 'chat'].includes(source) || !(values?.deep_thinking || values?.capability?.includes('thinking'))}
|
||||||
extra={<>{t('application.range')}: [{0}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
extra={<>{t('application.range')}: [{minThinkingBudgetTokens}, {t(`application.max_tokens`)}: {values?.max_tokens}]</>}
|
||||||
rules={[
|
rules={[
|
||||||
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
{ required: values?.deep_thinking, message: t('common.pleaseEnter') },
|
||||||
{
|
{
|
||||||
validator: (_, value) => {
|
validator: (_, value) => {
|
||||||
const maxTokens = values?.max_tokens
|
const maxTokens = values?.max_tokens
|
||||||
const deep_thinking = values?.deep_thinking;
|
const deep_thinking = values?.deep_thinking;
|
||||||
if (deep_thinking && value !== undefined && maxTokens !== undefined && value > maxTokens) {
|
if (deep_thinking && value !== undefined) {
|
||||||
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
if (value < minThinkingBudgetTokens) {
|
||||||
|
return Promise.reject(t('application.thinking_budget_tokens_min_error', { min: minThinkingBudgetTokens }))
|
||||||
|
}
|
||||||
|
if (maxTokens !== undefined && value > maxTokens) {
|
||||||
|
return Promise.reject(t('application.thinking_budget_tokens_max_error', { max: maxTokens }))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Promise.resolve()
|
return Promise.resolve()
|
||||||
}
|
}
|
||||||
@@ -195,7 +208,7 @@ const ModelConfigModal = forwardRef<ModelConfigModalRef, ModelConfigModalProps>(
|
|||||||
>
|
>
|
||||||
<RbSlider
|
<RbSlider
|
||||||
step={1}
|
step={1}
|
||||||
min={0}
|
min={minThinkingBudgetTokens}
|
||||||
max={32000}
|
max={32000}
|
||||||
isInput={true}
|
isInput={true}
|
||||||
disabled={!values?.deep_thinking}
|
disabled={!values?.deep_thinking}
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
|
|||||||
});
|
});
|
||||||
};
|
};
|
||||||
const formatSchema = (value: string) => {
|
const formatSchema = (value: string) => {
|
||||||
|
if (!value || value.trim() === '') return
|
||||||
setParseSchemaData({} as ParseSchemaData)
|
setParseSchemaData({} as ParseSchemaData)
|
||||||
parseSchema({ schema_content: value })
|
parseSchema({ schema_content: value })
|
||||||
.then(res => {
|
.then(res => {
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ const CanvasToolbar: FC<CanvasToolbarProps> = ({
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
labelRender={(props) => {
|
labelRender={(props) => {
|
||||||
console.log('props', props)
|
|
||||||
return `${props.value}%`
|
return `${props.value}%`
|
||||||
}}
|
}}
|
||||||
className="rb:w-20 rb:h-4!"
|
className="rb:w-20 rb:h-4!"
|
||||||
|
|||||||
@@ -66,8 +66,6 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
const [fileList, setFileList] = useState<any[]>([])
|
const [fileList, setFileList] = useState<any[]>([])
|
||||||
const [message, setMessage] = useState<string | undefined>(undefined)
|
const [message, setMessage] = useState<string | undefined>(undefined)
|
||||||
|
|
||||||
console.log('abortRef', abortRef, chatList)
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Opens the chat drawer and loads workflow variables from the start node
|
* Opens the chat drawer and loads workflow variables from the start node
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
|
|
||||||
// Handle node selection from popover and create new node replacing the add-node placeholder
|
// Handle node selection from popover and create new node replacing the add-node placeholder
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
|
graph.startBatch('add-node');
|
||||||
const parentBBox = node.getBBox();
|
const parentBBox = node.getBBox();
|
||||||
const cycleId = data.cycle;
|
const cycleId = data.cycle;
|
||||||
const horizontalSpacing = 0;
|
const horizontalSpacing = 0;
|
||||||
@@ -43,7 +44,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
if (cycleId) {
|
if (cycleId) {
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
parentNode.addChild(newNode);
|
parentNode.addChild(newNode, { silent: true });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,55 +77,40 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
setTimeout(() => {
|
|
||||||
addedEdges.forEach(e => {
|
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
|
||||||
if (src?.isNode()) src.toFront();
|
|
||||||
if (tgt?.isNode()) tgt.toFront();
|
|
||||||
});
|
|
||||||
}, 50);
|
|
||||||
|
|
||||||
// Automatically adjust loop node size
|
// Automatically adjust loop node size
|
||||||
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
const loopNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
if (loopNode) {
|
if (loopNode) {
|
||||||
const adjustLoopSize = () => {
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
|
||||||
if (childNodes.length > 0) {
|
|
||||||
const bounds = childNodes.reduce((acc, child) => {
|
|
||||||
const bbox = child.getBBox();
|
|
||||||
return {
|
|
||||||
minX: Math.min(acc.minX, bbox.x),
|
|
||||||
minY: Math.min(acc.minY, bbox.y),
|
|
||||||
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
|
||||||
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
|
||||||
};
|
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
|
||||||
|
|
||||||
const padding = 50;
|
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
|
||||||
|
|
||||||
loopNode.prop('size', { width: newWidth, height: newHeight });
|
|
||||||
|
|
||||||
// Update right port x position
|
|
||||||
const ports = loopNode.getPorts();
|
|
||||||
ports.forEach(port => {
|
|
||||||
if (port.group === 'right' && port.args) {
|
|
||||||
loopNode.portProp(port.id!, 'args/x', newWidth);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
adjustLoopSize();
|
|
||||||
|
|
||||||
// Listen to child node movement events
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
||||||
childNodes.forEach((childNode: any) => {
|
if (childNodes.length > 0) {
|
||||||
childNode.on('change:position', adjustLoopSize);
|
const bounds = childNodes.reduce((acc, child) => {
|
||||||
});
|
const bbox = child.getBBox();
|
||||||
|
return {
|
||||||
|
minX: Math.min(acc.minX, bbox.x),
|
||||||
|
minY: Math.min(acc.minY, bbox.y),
|
||||||
|
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
||||||
|
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
||||||
|
};
|
||||||
|
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
||||||
|
const padding = 50;
|
||||||
|
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
||||||
|
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
||||||
|
loopNode.prop('size', { width: newWidth, height: newHeight });
|
||||||
|
loopNode.getPorts().forEach(port => {
|
||||||
|
if (port.group === 'right' && port.args) {
|
||||||
|
loopNode.portProp(port.id!, 'args/x', newWidth);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addedEdges.forEach(e => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.isNode()) src.toFront();
|
||||||
|
if (tgt?.isNode()) tgt.toFront();
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.stopBatch('add-node');
|
||||||
setOpen(false);
|
setOpen(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => {
|
|||||||
{data.type === 'if-else' &&
|
{data.type === 'if-else' &&
|
||||||
<Flex vertical gap={4} className="rb:mt-3!">
|
<Flex vertical gap={4} className="rb:mt-3!">
|
||||||
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
{data.config?.cases?.defaultValue.map((item: any, index: number) => (
|
||||||
<div key={index} className={item.expressions.length > 0 ? '' : 'rb:mb-1'}>
|
<div key={index}>
|
||||||
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
<Flex justify={item.expressions.length > 0 ? "space-between" : 'end'} className="rb:mb-1! rb:leading-4">
|
||||||
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
{item.expressions.length > 0 && <span className="rb:text-[#5B6167] rb:text-[10px] rb:pl-1">CASE{index + 1}</span>}
|
||||||
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
<span className="rb:text-[#212332] rb:font-medium rb:text-[12px]">{index === 0 ? 'IF' : `ELIF`}</span>
|
||||||
|
|||||||
@@ -1,134 +1,15 @@
|
|||||||
import { useEffect } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next'
|
|
||||||
import clsx from 'clsx';
|
import clsx from 'clsx';
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import { Flex } from 'antd';
|
import { Flex } from 'antd';
|
||||||
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons';
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import { graphNodeLibrary, edgeAttrs } from '../../constant';
|
|
||||||
import NodeTools from './NodeTools'
|
import NodeTools from './NodeTools'
|
||||||
|
|
||||||
const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => {
|
const LoopNode: ReactShapeConfig['component'] = ({ node }) => {
|
||||||
const data = node.getData() || {};
|
const data = node.getData() || {};
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
// 使用setTimeout确保在所有节点都添加完成后再创建连线
|
|
||||||
const timer = setTimeout(() => {
|
|
||||||
initNodes()
|
|
||||||
checkAndAddAddNode()
|
|
||||||
}, 50)
|
|
||||||
|
|
||||||
return () => clearTimeout(timer)
|
|
||||||
}, [graph])
|
|
||||||
|
|
||||||
const checkAndAddAddNode = () => {
|
|
||||||
if (!graph) return;
|
|
||||||
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === data.id);
|
|
||||||
const cycleStartNodes = childNodes.filter((n: any) => n.getData()?.type === 'cycle-start');
|
|
||||||
|
|
||||||
// 如果只有一个cycle-start节点且没有其他类型的子节点,则添加add-node
|
|
||||||
if (cycleStartNodes.length === 1 && childNodes.length === 1) {
|
|
||||||
const cycleStartNode = cycleStartNodes[0];
|
|
||||||
const cycleStartBBox = cycleStartNode.getBBox();
|
|
||||||
|
|
||||||
const addNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: cycleStartBBox.x + 84,
|
|
||||||
y: cycleStartBBox.y + 4,
|
|
||||||
data: {
|
|
||||||
type: 'add-node',
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
parentId: node.id,
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
node.addChild(addNode);
|
|
||||||
|
|
||||||
// 连接cycle-start和add-node
|
|
||||||
const sourcePorts = cycleStartNode.getPorts();
|
|
||||||
const targetPorts = addNode.getPorts();
|
|
||||||
const sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
const targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
|
||||||
|
|
||||||
// 然后创建连线
|
|
||||||
graph.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: sourcePort },
|
|
||||||
target: { cell: addNode.id, port: targetPort },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
|
|
||||||
cycleStartNode.toFront()
|
|
||||||
addNode.toFront()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const initNodes = () => {
|
|
||||||
// 检查是否存在cycle为当前节点ID的子节点,若存在则不调用initNodes,避免重复创建
|
|
||||||
const existingCycleNodes = graph.getNodes().filter((n: any) =>
|
|
||||||
n.getData()?.cycle === data.id
|
|
||||||
);
|
|
||||||
if (existingCycleNodes.length > 0) return;
|
|
||||||
// 添加默认子节点
|
|
||||||
const parentBBox = node.getBBox();
|
|
||||||
const centerX = parentBBox.x + 24;
|
|
||||||
const centerY = parentBBox.y + 70;
|
|
||||||
|
|
||||||
const cycleStartNodeId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
|
||||||
const cycleStartNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.cycleStart,
|
|
||||||
x: centerX,
|
|
||||||
y: centerY,
|
|
||||||
id: cycleStartNodeId,
|
|
||||||
data: {
|
|
||||||
id: cycleStartNodeId,
|
|
||||||
type: 'cycle-start',
|
|
||||||
parentId: node.id,
|
|
||||||
isDefault: true, // 标记为默认节点,不可删除
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
const addNode = graph.addNode({
|
|
||||||
...graphNodeLibrary.addStart,
|
|
||||||
x: centerX + 84,
|
|
||||||
y: centerY + 4,
|
|
||||||
data: {
|
|
||||||
type: 'add-node',
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
parentId: node.id,
|
|
||||||
cycle: data.id,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
node.addChild(cycleStartNode)
|
|
||||||
node.addChild(addNode)
|
|
||||||
const sourcePorts = cycleStartNode.getPorts()
|
|
||||||
const targetPorts = addNode.getPorts()
|
|
||||||
let sourcePort = sourcePorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
|
|
||||||
const edgeConfig = {
|
|
||||||
source: {
|
|
||||||
cell: cycleStartNode.id,
|
|
||||||
port: sourcePort
|
|
||||||
},
|
|
||||||
target: {
|
|
||||||
cell: addNode.id,
|
|
||||||
port: targetPorts.find((port: any) => port.group === 'left')?.id || 'left'
|
|
||||||
},
|
|
||||||
...edgeAttrs
|
|
||||||
}
|
|
||||||
graph.addEdge(edgeConfig)
|
|
||||||
|
|
||||||
setTimeout(() => {
|
|
||||||
|
|
||||||
cycleStartNode.toFront()
|
|
||||||
addNode.toFront()
|
|
||||||
}, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
<div className={clsx('rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:p-3 rb:border rb:rounded-2xl rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)]', {
|
||||||
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
'rb:border-[#171719]!': data.isSelected && !data.executionStatus,
|
||||||
|
|||||||
@@ -43,70 +43,52 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// Handle node selection from popover menu and create new node with edge connection
|
|
||||||
const handleNodeSelect = (selectedNodeType: any) => {
|
const handleNodeSelect = (selectedNodeType: any) => {
|
||||||
if (!sourceNode || !graph) return;
|
if (!sourceNode || !graph) return;
|
||||||
|
|
||||||
const sourceNodeData = sourceNode.getData();
|
const sourceNodeData = sourceNode.getData();
|
||||||
const sourceNodeType = sourceNodeData?.type;
|
const sourceNodeType = sourceNodeData?.type;
|
||||||
|
const isCycleSubNode = !!sourceNodeData.cycle;
|
||||||
// If it's a cycle-start node, handle the add-node placeholder
|
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
||||||
|
const newNodeType = selectedNodeType.type;
|
||||||
|
|
||||||
|
// Save add-node placeholder position before disabling history
|
||||||
let addNodePosition = null;
|
let addNodePosition = null;
|
||||||
const isCycleSubNode = sourceNodeData.cycle
|
|
||||||
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
||||||
const cycleId = sourceNodeData.cycle;
|
const cycleId = sourceNodeData.cycle;
|
||||||
const addNodes = graph.getNodes().filter((n: any) =>
|
const addNodes = graph.getNodes().filter((n: any) =>
|
||||||
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId
|
||||||
);
|
);
|
||||||
|
if (addNodes.length > 0) addNodePosition = addNodes[0].getBBox();
|
||||||
if (addNodes.length > 0) {
|
|
||||||
const addNode = addNodes[0];
|
|
||||||
addNodePosition = addNode.getBBox();
|
|
||||||
addNode.remove();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate new node position to avoid overlapping
|
// Calculate position
|
||||||
const sourceBBox = sourceNode.getBBox();
|
const sourceBBox = sourceNode.getBBox();
|
||||||
const nodeWidth = graphNodeLibrary[selectedNodeType.type]?.width || 120;
|
const nw = graphNodeLibrary[newNodeType]?.width || 120;
|
||||||
const nodeHeight = graphNodeLibrary[selectedNodeType.type]?.height || 88;
|
const nh = graphNodeLibrary[newNodeType]?.height || 88;
|
||||||
const horizontalSpacing = isCycleSubNode ? 48 : 80;
|
const hSpacing = isCycleSubNode ? 48 : 80;
|
||||||
const verticalSpacing = 10;
|
const vSpacing = 10;
|
||||||
|
|
||||||
// Get source port group information
|
|
||||||
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort);
|
||||||
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
const sourcePortGroup = sourcePortInfo?.group || sourcePort;
|
||||||
|
|
||||||
// Calculate new node position
|
let newX: number, newY: number;
|
||||||
let newX, newY;
|
|
||||||
if (edgeInsertion) {
|
if (edgeInsertion) {
|
||||||
// Edge insertion: place new node on the same row as target, between source and target
|
|
||||||
const targetBBox = edgeInsertion.targetCell.getBBox();
|
const targetBBox = edgeInsertion.targetCell.getBBox();
|
||||||
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width);
|
||||||
const requiredSpace = nodeWidth + horizontalSpacing * 4;
|
const requiredSpace = nw + hSpacing * 4;
|
||||||
|
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
||||||
// New node x: right after source + spacing
|
newY = targetBBox.y + (targetBBox.height - nh) / 2;
|
||||||
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
|
||||||
// Same row as target node
|
|
||||||
newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2;
|
|
||||||
|
|
||||||
// If not enough space, shift target and all downstream nodes to the right
|
|
||||||
if (gap < requiredSpace) {
|
if (gap < requiredSpace) {
|
||||||
const shiftX = requiredSpace - gap;
|
const shiftX = requiredSpace - gap;
|
||||||
const visited = new Set<string>();
|
const visited = new Set<string>();
|
||||||
const shiftDownstream = (cell: any) => {
|
const shiftDownstream = (cell: any) => {
|
||||||
const cellId = cell.id;
|
if (visited.has(cell.id)) return;
|
||||||
if (visited.has(cellId)) return;
|
visited.add(cell.id);
|
||||||
visited.add(cellId);
|
|
||||||
const pos = cell.getPosition();
|
const pos = cell.getPosition();
|
||||||
cell.setPosition(pos.x + shiftX, pos.y);
|
cell.setPosition(pos.x + shiftX, pos.y);
|
||||||
// Recursively shift nodes connected from right ports
|
|
||||||
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => {
|
||||||
const tId = e.getTargetCellId();
|
const tCell = graph.getCellById(e.getTargetCellId());
|
||||||
if (tId && !visited.has(tId)) {
|
if (tCell?.isNode()) shiftDownstream(tCell);
|
||||||
const tCell = graph.getCellById(tId);
|
|
||||||
if (tCell?.isNode()) shiftDownstream(tCell);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
shiftDownstream(edgeInsertion.targetCell);
|
shiftDownstream(edgeInsertion.targetCell);
|
||||||
@@ -114,208 +96,170 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
} else if (addNodePosition) {
|
} else if (addNodePosition) {
|
||||||
newX = addNodePosition.x;
|
newX = addNodePosition.x;
|
||||||
newY = addNodePosition.y;
|
newY = addNodePosition.y;
|
||||||
|
} else if (sourcePortGroup === 'left') {
|
||||||
|
newX = sourceBBox.x - nw * 2 - hSpacing;
|
||||||
|
newY = sourceBBox.y;
|
||||||
} else {
|
} else {
|
||||||
// Determine node placement direction based on port position
|
newX = sourceBBox.x + sourceBBox.width + hSpacing;
|
||||||
if (sourcePortGroup === 'left') {
|
newY = sourceBBox.y;
|
||||||
// Left port: add node to the left
|
const connectedNodes = new Set<string>();
|
||||||
newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing;
|
graph.getConnectedEdges(sourceNode).forEach((e: any) => {
|
||||||
newY = sourceBBox.y;
|
[e.getSourceCellId(), e.getTargetCellId()].forEach((cid: string) => {
|
||||||
} else {
|
if (cid !== sourceNode.id) connectedNodes.add(cid);
|
||||||
// Right port: add node to the right
|
|
||||||
newX = sourceBBox.x + sourceBBox.width + horizontalSpacing;
|
|
||||||
newY = sourceBBox.y;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if position overlaps with existing nodes (only consider connected nodes)
|
|
||||||
const checkOverlap = (x: number, y: number) => {
|
|
||||||
// Get nodes connected to the source node
|
|
||||||
const connectedNodes = new Set();
|
|
||||||
graph.getConnectedEdges(sourceNode).forEach((edge: any) => {
|
|
||||||
const sourceId = edge.getSourceCellId();
|
|
||||||
const targetId = edge.getTargetCellId();
|
|
||||||
if (sourceId !== sourceNode.id) connectedNodes.add(sourceId);
|
|
||||||
if (targetId !== sourceNode.id) connectedNodes.add(targetId);
|
|
||||||
});
|
});
|
||||||
|
});
|
||||||
return graph.getNodes().some((node: any) => {
|
const checkOverlap = (x: number, y: number) =>
|
||||||
if (node.id === sourceNode.id) return false;
|
graph.getNodes().some((n: any) => {
|
||||||
if (!connectedNodes.has(node.id)) return false; // Only consider connected nodes
|
if (n.id === sourceNode.id || !connectedNodes.has(n.id)) return false;
|
||||||
const bbox = node.getBBox();
|
const b = n.getBBox();
|
||||||
return !(x + nodeWidth < bbox.x || x > bbox.x + bbox.width ||
|
return !(x + nw < b.x || x > b.x + b.width || y + nh < b.y || y > b.y + b.height);
|
||||||
y + nodeHeight < bbox.y || y > bbox.y + bbox.height);
|
|
||||||
});
|
});
|
||||||
};
|
while (checkOverlap(newX, newY)) newY += nh + vSpacing;
|
||||||
|
|
||||||
// If position is occupied, search downward for empty space
|
|
||||||
while (checkOverlap(newX, newY)) {
|
|
||||||
newY += nodeHeight + verticalSpacing;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new node
|
// Disable history for all graph mutations
|
||||||
const id = `${selectedNodeType.type.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
graph.disableHistory();
|
||||||
|
|
||||||
|
// Remove add-node placeholder
|
||||||
|
if (isCycleSubNode && sourceNodeType === 'cycle-start') {
|
||||||
|
const cycleId = sourceNodeData.cycle;
|
||||||
|
graph.getNodes()
|
||||||
|
.filter((n: any) => n.getData()?.type === 'add-node' && n.getData()?.cycle === cycleId)
|
||||||
|
.forEach((n: any) => n.remove());
|
||||||
|
}
|
||||||
|
|
||||||
|
const id = `${newNodeType.replace(/-/g, '_')}_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
const newNode = graph.addNode({
|
const newNode = graph.addNode({
|
||||||
...(graphNodeLibrary[selectedNodeType.type] || graphNodeLibrary.default),
|
...(graphNodeLibrary[newNodeType] || graphNodeLibrary.default),
|
||||||
x: newX,
|
x: newX,
|
||||||
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
y: newY - (isCycleSubNode && sourceNodeType === 'cycle-start' ? 12 : 0),
|
||||||
id,
|
id,
|
||||||
data: {
|
data: {
|
||||||
id,
|
id,
|
||||||
type: selectedNodeType.type,
|
type: newNodeType,
|
||||||
icon: selectedNodeType.icon,
|
icon: selectedNodeType.icon,
|
||||||
name: t(`workflow.${selectedNodeType.type}`),
|
name: t(`workflow.${newNodeType}`),
|
||||||
cycle: sourceNodeData.cycle, // Inherit cycle from source node
|
cycle: sourceNodeData.cycle,
|
||||||
config: selectedNodeType.config || {}
|
config: selectedNodeType.config || {}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add new node as child of parent node
|
|
||||||
if (sourceNodeData.cycle) {
|
if (sourceNodeData.cycle) {
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === sourceNodeData.cycle);
|
||||||
if (parentNode) {
|
if (parentNode) parentNode.addChild(newNode, { silent: true });
|
||||||
parentNode.addChild(newNode);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Edge insertion: remove old edge immediately before creating new edges
|
|
||||||
if (edgeInsertion) {
|
if (edgeInsertion) {
|
||||||
const { edge: oldEdge } = edgeInsertion;
|
const { edge: oldEdge } = edgeInsertion;
|
||||||
if (oldEdge.id && graph.getCellById(oldEdge.id)) {
|
if (oldEdge.id && graph.getCellById(oldEdge.id)) graph.removeCell(oldEdge.id);
|
||||||
graph.removeCell(oldEdge.id);
|
else graph.removeEdge(oldEdge);
|
||||||
} else {
|
|
||||||
graph.removeEdge(oldEdge);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create edge connection
|
const newPorts = newNode.getPorts();
|
||||||
setTimeout(() => {
|
const addedCells: any[] = [newNode];
|
||||||
const newPorts = newNode.getPorts();
|
|
||||||
|
|
||||||
const addedEdges: any[] = [];
|
if (edgeInsertion) {
|
||||||
if (edgeInsertion) {
|
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
||||||
// Edge insertion: create source→new and new→target edges
|
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
||||||
const { targetCell, targetPort: origTargetPort } = edgeInsertion;
|
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
||||||
const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: newLeftPort }, ...edgeAttrs }));
|
||||||
const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: newRightPort }, target: { cell: targetCell.id, port: origTargetPort }, ...edgeAttrs }));
|
||||||
addedEdges.push(graph.addEdge({
|
setEdgeInsertion(null);
|
||||||
source: { cell: sourceNode.id, port: sourcePort },
|
} else if (sourcePortGroup === 'left') {
|
||||||
target: { cell: newNode.id, port: newLeftPort },
|
const tp = newPorts.find((p: any) => p.group === 'right')?.id || 'right';
|
||||||
...edgeAttrs
|
addedCells.push(graph.addEdge({ source: { cell: newNode.id, port: tp }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs }));
|
||||||
}));
|
} else {
|
||||||
addedEdges.push(graph.addEdge({
|
const tp = newPorts.find((p: any) => p.group === 'left')?.id || 'left';
|
||||||
source: { cell: newNode.id, port: newRightPort },
|
addedCells.push(graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: tp }, ...edgeAttrs }));
|
||||||
target: { cell: targetCell.id, port: origTargetPort },
|
}
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
setEdgeInsertion(null);
|
|
||||||
} else if (sourcePortGroup === 'left') {
|
|
||||||
// Connect from left port to new node's right side
|
|
||||||
const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right';
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: newNode.id, port: targetPort },
|
|
||||||
target: { cell: sourceNode.id, port: sourcePort },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
} else {
|
|
||||||
// Connect from right port to new node's left side
|
|
||||||
const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left';
|
|
||||||
addedEdges.push(graph.addEdge({
|
|
||||||
source: { cell: sourceNode.id, port: sourcePort },
|
|
||||||
target: { cell: newNode.id, port: targetPort },
|
|
||||||
...edgeAttrs
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust loop node size when child node is added via port within loop node
|
|
||||||
const cycleId = sourceNodeData.cycle;
|
|
||||||
if (cycleId) {
|
|
||||||
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
|
||||||
|
|
||||||
if (parentNode) {
|
// If adding a loop/iteration node, create cycle-start, add-node and inner edge regardless of source type
|
||||||
const adjustLoopSize = () => {
|
if (isCycleContainer(newNodeType)) {
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
const parentBBox = newNode.getBBox();
|
||||||
if (childNodes.length > 0) {
|
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
|
||||||
const bounds = childNodes.reduce((acc: any, child: any) => {
|
const cycleStartNode = graph.addNode({
|
||||||
const bbox = child.getBBox();
|
...graphNodeLibrary.cycleStart,
|
||||||
return {
|
x: parentBBox.x + 24,
|
||||||
minX: Math.min(acc.minX, bbox.x),
|
y: parentBBox.y + 70,
|
||||||
minY: Math.min(acc.minY, bbox.y),
|
id: cycleStartId,
|
||||||
maxX: Math.max(acc.maxX, bbox.x + bbox.width),
|
data: { id: cycleStartId, type: 'cycle-start', parentId: id, isDefault: true, cycle: id },
|
||||||
maxY: Math.max(acc.maxY, bbox.y + bbox.height)
|
});
|
||||||
};
|
const addNodePlaceholder = graph.addNode({
|
||||||
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
...graphNodeLibrary.addStart,
|
||||||
|
x: parentBBox.x + 24 + 84,
|
||||||
|
y: parentBBox.y + 70 + 4,
|
||||||
|
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: id, cycle: id },
|
||||||
|
});
|
||||||
|
newNode.addChild(cycleStartNode, { silent: true });
|
||||||
|
newNode.addChild(addNodePlaceholder, { silent: true });
|
||||||
|
const innerEdge = graph.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find((p: any) => p.group === 'right')?.id || 'right' },
|
||||||
|
target: { cell: addNodePlaceholder.id, port: addNodePlaceholder.getPorts().find((p: any) => p.group === 'left')?.id || 'left' },
|
||||||
|
...edgeAttrs,
|
||||||
|
});
|
||||||
|
addedCells.push(cycleStartNode, addNodePlaceholder, innerEdge);
|
||||||
|
}
|
||||||
|
|
||||||
const padding = 50;
|
// Adjust parent size if adding inside a cycle container
|
||||||
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
const cycleId = sourceNodeData.cycle;
|
||||||
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
if (cycleId) {
|
||||||
|
const parentNode = graph.getNodes().find((n: any) => n.getData()?.id === cycleId);
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight });
|
if (parentNode) {
|
||||||
|
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
||||||
// Update right port x position
|
if (childNodes.length > 0) {
|
||||||
const ports = parentNode.getPorts();
|
const bounds = childNodes.reduce((acc: any, child: any) => {
|
||||||
ports.forEach((port: any) => {
|
const b = child.getBBox();
|
||||||
if (port.group === 'right' && port.args) {
|
return { minX: Math.min(acc.minX, b.x), minY: Math.min(acc.minY, b.y), maxX: Math.max(acc.maxX, b.x + b.width), maxY: Math.max(acc.maxY, b.y + b.height) };
|
||||||
parentNode.portProp(port.id!, 'args/x', newWidth);
|
}, { minX: Infinity, minY: Infinity, maxX: -Infinity, maxY: -Infinity });
|
||||||
}
|
const padding = 50;
|
||||||
});
|
const newWidth = Math.max(nodeWidth, bounds.maxX - bounds.minX + padding * 2);
|
||||||
}
|
const newHeight = Math.max(120, bounds.maxY - bounds.minY + padding * 2);
|
||||||
};
|
parentNode.prop('size', { width: newWidth, height: newHeight });
|
||||||
|
parentNode.getPorts().forEach((port: any) => {
|
||||||
adjustLoopSize();
|
if (port.group === 'right' && port.args) parentNode.portProp(port.id!, 'args/x', newWidth);
|
||||||
|
|
||||||
// Listen to child node movement events
|
|
||||||
const childNodes = graph.getNodes().filter((n: any) => n.getData()?.cycle === cycleId);
|
|
||||||
childNodes.forEach((childNode: any) => {
|
|
||||||
childNode.on('change:position', adjustLoopSize);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const isCycleContainer = (type: string) => type === 'loop' || type === 'iteration';
|
// toFront
|
||||||
const newNodeType = selectedNodeType.type;
|
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
||||||
|
graph.getEdges().forEach((e: any) => {
|
||||||
|
const src = graph.getCellById(e.getSourceCellId());
|
||||||
|
const tgt = graph.getCellById(e.getTargetCellId());
|
||||||
|
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
||||||
|
});
|
||||||
|
graph.getNodes().forEach((n: any) => { if (n.getData()?.cycle === cycleContainerId) n.toFront(); });
|
||||||
|
};
|
||||||
|
|
||||||
// Helper: bring all child nodes and their edges of a cycle container to front
|
if (isCycleContainer(sourceNodeType)) {
|
||||||
const bringCycleChildrenToFront = (cycleContainerId: string) => {
|
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(sourceNodeData.id);
|
||||||
|
if (isCycleContainer(newNodeType)) bringCycleChildrenToFront(id);
|
||||||
graph.getEdges().forEach((e: any) => {
|
} else if (isCycleContainer(newNodeType)) {
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
newNode.toFront(); sourceNode.toFront(); bringCycleChildrenToFront(id);
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
} else {
|
||||||
if (src?.getData()?.cycle === cycleContainerId || tgt?.getData()?.cycle === cycleContainerId) e.toFront();
|
addedCells.forEach(c => { if (c.isNode?.()) c.toFront(); });
|
||||||
});
|
}
|
||||||
graph.getNodes().forEach((n: any) => {
|
|
||||||
if (n.getData()?.cycle === cycleContainerId) n.toFront();
|
|
||||||
});
|
|
||||||
};
|
|
||||||
|
|
||||||
if (isCycleContainer(sourceNodeType)) {
|
// Re-enable history and manually push one batch frame for all added cells
|
||||||
console.log('isCycleContainer(sourceNodeType)')
|
graph.enableHistory();
|
||||||
// Case 4: source is a loop/iteration node — bring new node to front, then its children
|
const history = graph.getPlugin('history') as any;
|
||||||
newNode.toFront();
|
if (history) {
|
||||||
sourceNode.toFront();
|
const batchFrame = addedCells.map((cell: any) => ({
|
||||||
bringCycleChildrenToFront(sourceNodeData.id);
|
batch: true,
|
||||||
} else if (isCycleContainer(newNodeType)) {
|
event: 'cell:added',
|
||||||
console.log('isCycleContainer(newNodeType)')
|
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
||||||
// Case 3: adding a loop/iteration node from a normal node — bring new node to front, then its children
|
options: {},
|
||||||
newNode.toFront();
|
}));
|
||||||
sourceNode.toFront()
|
history.undoStack.push(batchFrame);
|
||||||
bringCycleChildrenToFront(id);
|
history.redoStack = [];
|
||||||
} else {
|
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-node' } });
|
||||||
// Case 2: normal node → normal node
|
}
|
||||||
addedEdges.forEach(e => {
|
|
||||||
const src = graph.getCellById(e.getSourceCellId());
|
|
||||||
const tgt = graph.getCellById(e.getTargetCellId());
|
|
||||||
if (src?.isNode()) src.toFront();
|
|
||||||
if (tgt?.isNode()) tgt.toFront();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}, 50);
|
|
||||||
|
|
||||||
// Clean up temporary element
|
|
||||||
if (tempElement) {
|
if (tempElement) {
|
||||||
document.body.removeChild(tempElement);
|
document.body.removeChild(tempElement);
|
||||||
setTempElement(null);
|
setTempElement(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
setPopoverVisible(false);
|
setPopoverVisible(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -391,4 +335,4 @@ const PortClickHandler: React.FC<PortClickHandlerProps> = ({ graph }) => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default PortClickHandler;
|
export default PortClickHandler;
|
||||||
|
|||||||
@@ -242,10 +242,11 @@ const ToolConfig: FC<{ options: Suggestion[]; }> = ({
|
|||||||
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
className={parameter.type === 'boolean' ? 'rb:mb-0!' : ''}
|
||||||
>
|
>
|
||||||
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
{parameter.type === 'string' && parameter.enum && parameter.enum.length > 0
|
||||||
? <Select size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
? <Select key={values.tool_id} size="small" options={parameter.enum.map(vo => ({ value: vo, label: vo }))} placeholder={t('common.pleaseSelect')} />
|
||||||
: parameter.type === 'boolean'
|
: parameter.type === 'boolean'
|
||||||
? <Switch size="small" />
|
? <Switch key={values.tool_id} size="small" />
|
||||||
: <Editor
|
: <Editor
|
||||||
|
key={values.tool_id}
|
||||||
variant="outlined"
|
variant="outlined"
|
||||||
type="input"
|
type="input"
|
||||||
size="small"
|
size="small"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:06:18
|
* @Date: 2026-02-03 15:06:18
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-21 18:23:31
|
* @Last Modified time: 2026-04-27 14:07:14
|
||||||
*/
|
*/
|
||||||
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
import type { ReactShapeConfig } from '@antv/x6-react-shape';
|
||||||
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
import type { GroupMetadata, PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
@@ -948,6 +948,15 @@ export const graphNodeLibrary: Record<string, NodeConfig> = {
|
|||||||
width: nodeWidth,
|
width: nodeWidth,
|
||||||
height: 120,
|
height: 120,
|
||||||
shape: 'notes-node',
|
shape: 'notes-node',
|
||||||
|
},
|
||||||
|
output: {
|
||||||
|
width: nodeWidth,
|
||||||
|
height: 76,
|
||||||
|
shape: 'normal-node',
|
||||||
|
ports: {
|
||||||
|
groups: { left: defaultPortGroup },
|
||||||
|
items: [defaultPortItems[0]],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,9 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 15:17:48
|
* @Date: 2026-02-03 15:17:48
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-04-24 17:21:09
|
* @Last Modified time: 2026-04-28 13:49:11
|
||||||
*/
|
*/
|
||||||
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
import { Clipboard, Graph, Keyboard, MiniMap, Node, Snapline, History, type Edge } from '@antv/x6';
|
||||||
import type { HistoryCommand as Command } from '@antv/x6/lib/plugin/history/type';
|
|
||||||
import { register } from '@antv/x6-react-shape';
|
import { register } from '@antv/x6-react-shape';
|
||||||
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
import type { PortMetadata } from '@antv/x6/lib/model/port';
|
||||||
import { App } from 'antd';
|
import { App } from 'antd';
|
||||||
@@ -17,7 +16,7 @@ import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application';
|
|||||||
import { useUser } from '@/store/user';
|
import { useUser } from '@/store/user';
|
||||||
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
import type { FeaturesConfigForm } from '@/views/ApplicationConfig/types';
|
||||||
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
import { conditionNodeHeight, conditionNodeItemHeight, conditionNodePortItemArgsY, defaultAbsolutePortGroups, defaultPortItems, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, edge_width, graphNodeLibrary, nodeLibrary, nodeRegisterLibrary, nodeWidth, notesConfig, portAttrs, portItemArgsY, portMarkup, portTextAttrs, unknownNode } from '../constant';
|
||||||
import type { ChatVariable, NodeProperties, WorkflowConfig } from '../types';
|
import type { ChatVariable, HistoryRecord, NodeProperties, WorkflowConfig } from '../types';
|
||||||
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
import { calcConditionNodeTotalHeight, getConditionNodeCasePortY } from '../utils';
|
||||||
import { useWorkflowStore } from '@/store/workflow';
|
import { useWorkflowStore } from '@/store/workflow';
|
||||||
|
|
||||||
@@ -86,6 +85,10 @@ export interface UseWorkflowGraphReturn {
|
|||||||
/** Get start node output variable list (user-defined + system variables) */
|
/** Get start node output variable list (user-defined + system variables) */
|
||||||
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
getStartNodeVariables: () => Array<{ name: string; type: string; readonly?: boolean }>;
|
||||||
nodeClick: ({ node }: { node: Node }) => void;
|
nodeClick: ({ node }: { node: Node }) => void;
|
||||||
|
/** All recorded history operations */
|
||||||
|
historyRecords: HistoryRecord[];
|
||||||
|
/** Clear history records */
|
||||||
|
clearHistoryRecords: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -119,7 +122,12 @@ export const useWorkflowGraph = ({
|
|||||||
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
const featuresRef = useRef<FeaturesConfigForm | undefined>(undefined)
|
||||||
const [canUndo, setCanUndo] = useState(false)
|
const [canUndo, setCanUndo] = useState(false)
|
||||||
const [canRedo, setCanRedo] = useState(false)
|
const [canRedo, setCanRedo] = useState(false)
|
||||||
|
const [historyRecords, setHistoryRecords] = useState<HistoryRecord[]>([])
|
||||||
|
const lastHistoryRef = useRef<{ cellIds: string[]; timestamp: number; type: string } | null>(null)
|
||||||
|
const undoRef = useRef<() => void>(() => {})
|
||||||
|
const redoRef = useRef<() => void>(() => {})
|
||||||
|
const syncChildRelationshipsRef = useRef<() => void>(() => {})
|
||||||
|
const isSyncingRef = useRef(false)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!graphRef.current) return
|
if (!graphRef.current) return
|
||||||
graphRef.current.getNodes().forEach(node => {
|
graphRef.current.getNodes().forEach(node => {
|
||||||
@@ -343,7 +351,7 @@ export const useWorkflowGraph = ({
|
|||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
const addedChild = graphRef.current?.addNode(childNode)
|
const addedChild = graphRef.current?.addNode(childNode)
|
||||||
if (addedChild) {
|
if (addedChild) {
|
||||||
parentNode.addChild(addedChild)
|
parentNode.addChild(addedChild, { silent: true })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -374,8 +382,6 @@ export const useWorkflowGraph = ({
|
|||||||
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
||||||
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
||||||
|
|
||||||
console.log('newWidth', newHeight, newWidth)
|
|
||||||
|
|
||||||
parentNode.prop('size', { width: newWidth, height: newHeight })
|
parentNode.prop('size', { width: newWidth, height: newHeight })
|
||||||
|
|
||||||
// Update x position of right group ports
|
// Update x position of right group ports
|
||||||
@@ -488,8 +494,77 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current.cleanHistory()
|
graphRef.current.cleanHistory()
|
||||||
}
|
}
|
||||||
}, 200)
|
}, 200)
|
||||||
|
} else {
|
||||||
|
graphRef.current.enableHistory()
|
||||||
|
graphRef.current.cleanHistory()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const resizeGroupNodes = (graph: Graph) => {
|
||||||
|
graph.getNodes().forEach(parentNode => {
|
||||||
|
const parentType = parentNode.getData()?.type
|
||||||
|
if (parentType !== 'loop' && parentType !== 'iteration') return
|
||||||
|
const children = graph.getNodes().filter(
|
||||||
|
n => n.getData()?.cycle === parentNode.getData()?.id && n.getData()?.type !== 'add-node'
|
||||||
|
)
|
||||||
|
if (!children.length) return
|
||||||
|
const padding = 24
|
||||||
|
const headerHeight = 50
|
||||||
|
const childBounds = children.map(c => c.getBBox())
|
||||||
|
const minX = Math.min(...childBounds.map(b => b.x))
|
||||||
|
const minY = Math.min(...childBounds.map(b => b.y))
|
||||||
|
const maxX = Math.max(...childBounds.map(b => b.x + b.width))
|
||||||
|
const maxY = Math.max(...childBounds.map(b => b.y + b.height))
|
||||||
|
const parentBBox = parentNode.getBBox()
|
||||||
|
const newWidth = Math.max(parentBBox.width, maxX - minX + padding * 2)
|
||||||
|
const newHeight = Math.max(parentBBox.height, maxY - minY + padding * 2 + headerHeight)
|
||||||
|
parentNode.prop('size', { width: newWidth, height: newHeight })
|
||||||
|
parentNode.getPorts().forEach(port => {
|
||||||
|
if (port.group === 'right' && port.args) {
|
||||||
|
parentNode.portProp(port.id!, 'args/x', newWidth)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const syncChildRelationships = () => {
|
||||||
|
if (!graphRef.current) return
|
||||||
|
const graph = graphRef.current
|
||||||
|
graph.disableHistory()
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
const cycleId = node.getData()?.cycle
|
||||||
|
if (!cycleId) return
|
||||||
|
const parentNode = graph.getCellById(cycleId) as Node | null
|
||||||
|
if (!parentNode) return
|
||||||
|
if (!parentNode.getChildren()?.some(c => c.id === node.id)) {
|
||||||
|
parentNode.addChild(node, { silent: true })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
const children = node.getChildren()
|
||||||
|
if (!children?.length) return
|
||||||
|
children.forEach(child => {
|
||||||
|
if (!child.isNode()) return
|
||||||
|
const childCycleId = (child as Node).getData?.()?.cycle
|
||||||
|
if (childCycleId !== node.id && childCycleId !== node.getData?.()?.id) {
|
||||||
|
node.removeChild(child, { silent: true })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
resizeGroupNodes(graph)
|
||||||
|
graph.getEdges().forEach(edge => {
|
||||||
|
const src = graph.getCellById(edge.getSourceCellId())
|
||||||
|
const tgt = graph.getCellById(edge.getTargetCellId())
|
||||||
|
if (src?.getData()?.cycle || tgt?.getData()?.cycle) {
|
||||||
|
edge.toFront()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
graph.getNodes().forEach(node => {
|
||||||
|
if (node.getData()?.cycle) node.toFront()
|
||||||
|
})
|
||||||
|
graph.enableHistory()
|
||||||
|
}
|
||||||
|
syncChildRelationshipsRef.current = syncChildRelationships
|
||||||
/**
|
/**
|
||||||
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
* Setup X6 graph plugins (MiniMap, Snapline, Clipboard, Keyboard)
|
||||||
*/
|
*/
|
||||||
@@ -525,18 +600,44 @@ export const useWorkflowGraph = ({
|
|||||||
new History({
|
new History({
|
||||||
enabled: false,
|
enabled: false,
|
||||||
beforeAddCommand(_event, args: any) {
|
beforeAddCommand(_event, args: any) {
|
||||||
const event = args?.key ? `cell:change:${args.key}` : _event;
|
const key = args?.key
|
||||||
if (event.startsWith('cell:change:') &&
|
if (key === 'attrs' || key === 'tools') return false
|
||||||
event !== 'cell:change:position' &&
|
|
||||||
event !== 'cell:change:source' &&
|
|
||||||
event !== 'cell:change:target') return false;
|
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
graphRef.current.on('history:change', ({ cmds }: { cmds: Command[] }) => {
|
const MERGE_INTERVAL = 1000
|
||||||
|
graphRef.current.on('history:change', ({ cmds, options }: { cmds: any[]; options: any }) => {
|
||||||
setCanUndo(graphRef.current?.canUndo() ?? false)
|
setCanUndo(graphRef.current?.canUndo() ?? false)
|
||||||
setCanRedo(graphRef.current?.canRedo() ?? false)
|
setCanRedo(graphRef.current?.canRedo() ?? false)
|
||||||
|
console.log('history:change', cmds, options)
|
||||||
|
const batchName: string | undefined = options?.name
|
||||||
|
const actionType = batchName === 'undo' ? 'undo' : batchName === 'redo' ? 'redo' : batchName ? 'batch' : 'change'
|
||||||
|
const cellIds = [...new Set(cmds?.map((cmd: any) => cmd.data?.id).filter(Boolean))]
|
||||||
|
const now = Date.now()
|
||||||
|
const last = lastHistoryRef.current
|
||||||
|
const canMerge =
|
||||||
|
actionType === 'change' &&
|
||||||
|
last?.type === 'change' &&
|
||||||
|
now - last.timestamp < MERGE_INTERVAL &&
|
||||||
|
cellIds.length > 0 &&
|
||||||
|
cellIds.length === last.cellIds.length &&
|
||||||
|
cellIds.every((id, i) => id === last.cellIds[i])
|
||||||
|
if (canMerge) {
|
||||||
|
lastHistoryRef.current!.timestamp = now
|
||||||
|
setHistoryRecords(prev => {
|
||||||
|
const next = [...prev]
|
||||||
|
next[next.length - 1] = { ...next[next.length - 1], timestamp: now }
|
||||||
|
return next
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const record: HistoryRecord = { type: actionType, timestamp: now, batchName, cellIds }
|
||||||
|
lastHistoryRef.current = { cellIds, timestamp: now, type: actionType }
|
||||||
|
setHistoryRecords(prev => [...prev, record])
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
graphRef.current.on('history:undo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
||||||
|
graphRef.current.on('history:redo', () => { if (!isSyncingRef.current) syncChildRelationshipsRef.current() })
|
||||||
};
|
};
|
||||||
// 显示/隐藏连接桩
|
// 显示/隐藏连接桩
|
||||||
// const showPorts = (show: boolean) => {
|
// const showPorts = (show: boolean) => {
|
||||||
@@ -569,13 +670,13 @@ export const useWorkflowGraph = ({
|
|||||||
vo.setData({
|
vo.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
});
|
}, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
node.setData({
|
node.setData({
|
||||||
...nodeData,
|
...nodeData,
|
||||||
isSelected: true,
|
isSelected: true,
|
||||||
});
|
}, { silent: true });
|
||||||
clearEdgeSelect()
|
clearEdgeSelect()
|
||||||
if (nodeData.type !== 'notes') {
|
if (nodeData.type !== 'notes') {
|
||||||
setSelectedNode(node);
|
setSelectedNode(node);
|
||||||
@@ -589,7 +690,7 @@ export const useWorkflowGraph = ({
|
|||||||
const edgeClick = ({ edge }: { edge: Edge }) => {
|
const edgeClick = ({ edge }: { edge: Edge }) => {
|
||||||
clearEdgeSelect();
|
clearEdgeSelect();
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isSelected: true });
|
edge.setData({ ...edge.getData(), isSelected: true }, { silent: true });
|
||||||
clearNodeSelect();
|
clearNodeSelect();
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
@@ -604,7 +705,7 @@ export const useWorkflowGraph = ({
|
|||||||
node.setData({
|
node.setData({
|
||||||
...data,
|
...data,
|
||||||
isSelected: false,
|
isSelected: false,
|
||||||
});
|
}, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
setSelectedNode(null);
|
setSelectedNode(null);
|
||||||
@@ -614,7 +715,7 @@ export const useWorkflowGraph = ({
|
|||||||
*/
|
*/
|
||||||
const clearEdgeSelect = () => {
|
const clearEdgeSelect = () => {
|
||||||
graphRef.current?.getEdges().forEach(e => {
|
graphRef.current?.getEdges().forEach(e => {
|
||||||
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false });
|
e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }, { silent: true });
|
||||||
e.setAttrByPath('line/stroke', edge_color);
|
e.setAttrByPath('line/stroke', edge_color);
|
||||||
e.setAttrByPath('line/strokeWidth', edge_width);
|
e.setAttrByPath('line/strokeWidth', edge_width);
|
||||||
});
|
});
|
||||||
@@ -753,8 +854,6 @@ export const useWorkflowGraph = ({
|
|||||||
// Find corresponding parent node
|
// Find corresponding parent node
|
||||||
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
const parentNode = nodes?.find(n => n.id === nodeData.cycle);
|
||||||
if (parentNode) {
|
if (parentNode) {
|
||||||
// Use removeChild method to delete child node
|
|
||||||
parentNode.removeChild(nodeToDelete);
|
|
||||||
parentNodesToUpdate.push(parentNode);
|
parentNodesToUpdate.push(parentNode);
|
||||||
}
|
}
|
||||||
// Add child node to deletion list
|
// Add child node to deletion list
|
||||||
@@ -782,42 +881,51 @@ export const useWorkflowGraph = ({
|
|||||||
|
|
||||||
// Delete all collected nodes and edges
|
// Delete all collected nodes and edges
|
||||||
if (cells.length > 0) {
|
if (cells.length > 0) {
|
||||||
|
// Pre-calculate which parents need an add-node restored (before removal changes the graph)
|
||||||
|
const parentsNeedingAddNode = parentNodesToUpdate
|
||||||
|
.filter(parentNode => {
|
||||||
|
const parentShape = parentNode.shape;
|
||||||
|
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return false;
|
||||||
|
const parentData = parentNode.getData();
|
||||||
|
const allChildren = graphRef.current!.getNodes().filter(n => n.getData()?.cycle === parentData.id);
|
||||||
|
const cycleStartNodes = allChildren.filter(n => n.getData()?.type === 'cycle-start');
|
||||||
|
// After deletion, only cycle-start will remain
|
||||||
|
const nonCycleStartToDelete = cells.filter(c =>
|
||||||
|
c.isNode() &&
|
||||||
|
(c as Node).getData()?.cycle === parentData.id &&
|
||||||
|
(c as Node).getData()?.type !== 'cycle-start'
|
||||||
|
);
|
||||||
|
return cycleStartNodes.length === 1 && (allChildren.length - nonCycleStartToDelete.length) === 1;
|
||||||
|
})
|
||||||
|
.map(parentNode => ({
|
||||||
|
parentNode,
|
||||||
|
cycleStartNode: graphRef.current!.getNodes().find(
|
||||||
|
n => n.getData()?.cycle === parentNode.getData().id && n.getData()?.type === 'cycle-start'
|
||||||
|
)!
|
||||||
|
}))
|
||||||
|
.filter(({ cycleStartNode }) => !!cycleStartNode);
|
||||||
|
|
||||||
|
graphRef.current?.startBatch('delete');
|
||||||
graphRef.current?.removeCells(cells);
|
graphRef.current?.removeCells(cells);
|
||||||
|
|
||||||
// If parent is iteration/loop and only cycle-start remains, add add-node connected to it
|
parentsNeedingAddNode.forEach(({ parentNode, cycleStartNode }) => {
|
||||||
parentNodesToUpdate.forEach(parentNode => {
|
|
||||||
const parentShape = parentNode.shape;
|
|
||||||
if (parentShape !== 'loop-node' && parentShape !== 'iteration-node') return;
|
|
||||||
const parentData = parentNode.getData();
|
const parentData = parentNode.getData();
|
||||||
const remainingChildren = graphRef.current!.getNodes().filter(
|
const bbox = cycleStartNode.getBBox();
|
||||||
n => n.getData()?.cycle === parentData.id
|
const addNode = graphRef.current!.addNode({
|
||||||
);
|
...graphNodeLibrary.addStart,
|
||||||
const cycleStartNodes = remainingChildren.filter(n => n.getData()?.type === 'cycle-start');
|
x: bbox.x + 84,
|
||||||
if (cycleStartNodes.length === 1 && remainingChildren.length === 1) {
|
y: bbox.y + 4,
|
||||||
const cycleStartNode = cycleStartNodes[0];
|
data: { type: 'add-node', parentId: parentNode.id, cycle: parentData.id, label: t('workflow.addNode'), icon: '+' },
|
||||||
const bbox = cycleStartNode.getBBox();
|
});
|
||||||
const addNode = graphRef.current!.addNode({
|
parentNode.addChild(addNode, { silent: true });
|
||||||
...graphNodeLibrary.addStart,
|
graphRef.current!.addEdge({
|
||||||
x: bbox.x + 84,
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
||||||
y: bbox.y + 4,
|
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
||||||
data: {
|
...edgeAttrs,
|
||||||
type: 'add-node',
|
});
|
||||||
parentId: parentNode.id,
|
|
||||||
cycle: parentData.id,
|
|
||||||
label: t('workflow.addNode'),
|
|
||||||
icon: '+',
|
|
||||||
},
|
|
||||||
});
|
|
||||||
parentNode.addChild(addNode);
|
|
||||||
const sourcePort = cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right';
|
|
||||||
const targetPort = addNode.getPorts().find(p => p.group === 'left')?.id || 'left';
|
|
||||||
graphRef.current!.addEdge({
|
|
||||||
source: { cell: cycleStartNode.id, port: sourcePort },
|
|
||||||
target: { cell: addNode.id, port: targetPort },
|
|
||||||
...edgeAttrs,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
graphRef.current?.stopBatch('delete');
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
@@ -1036,7 +1144,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_selected_color);
|
edge.setAttrByPath('line/stroke', edge_selected_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: true });
|
edge.setData({ ...edge.getData(), isNodeHover: true }, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1044,7 +1152,7 @@ export const useWorkflowGraph = ({
|
|||||||
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
graphRef.current?.getConnectedEdges(node).forEach(edge => {
|
||||||
if (!edge.getData()?.isSelected) {
|
if (!edge.getData()?.isSelected) {
|
||||||
edge.setAttrByPath('line/stroke', edge_color);
|
edge.setAttrByPath('line/stroke', edge_color);
|
||||||
edge.setData({ ...edge.getData(), isNodeHover: false });
|
edge.setData({ ...edge.getData(), isNodeHover: false }, { silent: true });
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -1126,8 +1234,8 @@ export const useWorkflowGraph = ({
|
|||||||
// Delete selected nodes and edges
|
// Delete selected nodes and edges
|
||||||
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
graphRef.current.bindKey(['ctrl+d', 'cmd+d', 'delete', 'backspace'], deleteEvent);
|
||||||
// Undo / Redo
|
// Undo / Redo
|
||||||
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { graphRef.current?.undo(); return false; });
|
graphRef.current.bindKey(['ctrl+z', 'cmd+z'], () => { undo(); return false; });
|
||||||
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { graphRef.current?.redo(); return false; });
|
graphRef.current.bindKey(['ctrl+y', 'cmd+y', 'ctrl+shift+z', 'cmd+shift+z'], () => { redo(); return false; });
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -1193,13 +1301,51 @@ export const useWorkflowGraph = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
if (dragData.type === 'loop' || dragData.type === 'iteration') {
|
||||||
graphRef.current.addNode({
|
graph.disableHistory()
|
||||||
|
const parentNode = graphRef.current.addNode({
|
||||||
...graphNodeLibrary[dragData.type],
|
...graphNodeLibrary[dragData.type],
|
||||||
x: point.x - 150,
|
x: point.x - 150,
|
||||||
y: point.y - 100,
|
y: point.y - 100,
|
||||||
id: cleanNodeData.id,
|
id: cleanNodeData.id,
|
||||||
data: { ...cleanNodeData, isGroup: true },
|
data: { ...cleanNodeData, isGroup: true },
|
||||||
});
|
})
|
||||||
|
const parentBBox = parentNode.getBBox()
|
||||||
|
const cycleStartId = `cycle_start_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||||
|
const cycleStartNode = graphRef.current.addNode({
|
||||||
|
...graphNodeLibrary.cycleStart,
|
||||||
|
x: parentBBox.x + 24,
|
||||||
|
y: parentBBox.y + 70,
|
||||||
|
id: cycleStartId,
|
||||||
|
data: { id: cycleStartId, type: 'cycle-start', parentId: cleanNodeData.id, isDefault: true, cycle: cleanNodeData.id },
|
||||||
|
})
|
||||||
|
const addNode = graphRef.current.addNode({
|
||||||
|
...graphNodeLibrary.addStart,
|
||||||
|
x: parentBBox.x + 24 + 84,
|
||||||
|
y: parentBBox.y + 70 + 4,
|
||||||
|
data: { type: 'add-node', label: t('workflow.addNode'), icon: '+', parentId: cleanNodeData.id, cycle: cleanNodeData.id },
|
||||||
|
})
|
||||||
|
parentNode.addChild(cycleStartNode, { silent: true })
|
||||||
|
parentNode.addChild(addNode, { silent: true })
|
||||||
|
const newEdge = graphRef.current.addEdge({
|
||||||
|
source: { cell: cycleStartNode.id, port: cycleStartNode.getPorts().find(p => p.group === 'right')?.id || 'right' },
|
||||||
|
target: { cell: addNode.id, port: addNode.getPorts().find(p => p.group === 'left')?.id || 'left' },
|
||||||
|
...edgeAttrs,
|
||||||
|
})
|
||||||
|
cycleStartNode.toFront()
|
||||||
|
addNode.toFront()
|
||||||
|
graph.enableHistory()
|
||||||
|
// Manually push a single batch frame covering all 4 cells into undoStack
|
||||||
|
const history = graph.getPlugin('history') as History
|
||||||
|
const makeBatchCmd = (cell: any) => ({
|
||||||
|
batch: true,
|
||||||
|
event: 'cell:added',
|
||||||
|
data: { id: cell.id, node: cell.isNode(), edge: cell.isEdge(), props: cell.toJSON() },
|
||||||
|
options: {},
|
||||||
|
})
|
||||||
|
const batchFrame = [parentNode, cycleStartNode, addNode, newEdge].map(makeBatchCmd)
|
||||||
|
;(history as any).undoStack.push(batchFrame)
|
||||||
|
;(history as any).redoStack = []
|
||||||
|
graph.trigger('history:change', { cmds: batchFrame, options: { name: 'add-group' } })
|
||||||
} else if (dragData.type === 'if-else') {
|
} else if (dragData.type === 'if-else') {
|
||||||
// Create condition node
|
// Create condition node
|
||||||
graphRef.current.addNode({
|
graphRef.current.addNode({
|
||||||
@@ -1446,8 +1592,80 @@ export const useWorkflowGraph = ({
|
|||||||
return userVars
|
return userVars
|
||||||
}
|
}
|
||||||
|
|
||||||
const undo = () => graphRef.current?.undo()
|
const clearHistoryRecords = () => {
|
||||||
const redo = () => graphRef.current?.redo()
|
setHistoryRecords([])
|
||||||
|
lastHistoryRef.current = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const getStackCellIds = (cmds: any): string[] => {
|
||||||
|
const arr = Array.isArray(cmds) ? cmds : [cmds]
|
||||||
|
return [...new Set(arr.map((c: any) => c.data?.id).filter(Boolean))]
|
||||||
|
}
|
||||||
|
|
||||||
|
const isSkippableFrame = (frame: any): boolean => {
|
||||||
|
const arr = Array.isArray(frame) ? frame : [frame]
|
||||||
|
return arr.every((c: any) => ['zIndex', 'attrs', 'tools'].includes(c.data?.key))
|
||||||
|
}
|
||||||
|
|
||||||
|
const undo = () => {
|
||||||
|
const history = graphRef.current?.getPlugin('history') as History | undefined
|
||||||
|
if (!history || history.getUndoSize() === 0) return
|
||||||
|
const undoStack = (history as any).undoStack as any[]
|
||||||
|
isSyncingRef.current = true
|
||||||
|
while (undoStack.length > 0 && isSkippableFrame(undoStack[undoStack.length - 1])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
}
|
||||||
|
if (undoStack.length === 0) {
|
||||||
|
isSyncingRef.current = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const topIds = getStackCellIds(undoStack[undoStack.length - 1])
|
||||||
|
graphRef.current!.undo()
|
||||||
|
while (undoStack.length > 0) {
|
||||||
|
if (isSkippableFrame(undoStack[undoStack.length - 1])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const nextIds = getStackCellIds(undoStack[undoStack.length - 1])
|
||||||
|
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
||||||
|
graphRef.current!.undo()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isSyncingRef.current = false
|
||||||
|
syncChildRelationships()
|
||||||
|
}
|
||||||
|
|
||||||
|
const redo = () => {
|
||||||
|
const history = graphRef.current?.getPlugin('history') as History | undefined
|
||||||
|
if (!history || history.getRedoSize() === 0) return
|
||||||
|
const redoStack = (history as any).redoStack as any[]
|
||||||
|
isSyncingRef.current = true
|
||||||
|
while (redoStack.length > 0 && isSkippableFrame(redoStack[redoStack.length - 1])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
}
|
||||||
|
if (redoStack.length === 0) {
|
||||||
|
isSyncingRef.current = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
const topIds = getStackCellIds(redoStack[redoStack.length - 1])
|
||||||
|
graphRef.current!.redo()
|
||||||
|
while (redoStack.length > 0) {
|
||||||
|
if (isSkippableFrame(redoStack[redoStack.length - 1])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const nextIds = getStackCellIds(redoStack[redoStack.length - 1])
|
||||||
|
if (nextIds.length === topIds.length && nextIds.every((id, i) => id === topIds[i])) {
|
||||||
|
graphRef.current!.redo()
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
isSyncingRef.current = false
|
||||||
|
syncChildRelationships()
|
||||||
|
}
|
||||||
|
|
||||||
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
const handleSaveFeaturesConfig = (value?: FeaturesConfigForm) => {
|
||||||
const { statement = '' } = value?.opening_statement || {}
|
const { statement = '' } = value?.opening_statement || {}
|
||||||
@@ -1488,23 +1706,19 @@ export const useWorkflowGraph = ({
|
|||||||
if (!graphRef.current) return;
|
if (!graphRef.current) return;
|
||||||
const nodes = graphRef.current.getNodes();
|
const nodes = graphRef.current.getNodes();
|
||||||
|
|
||||||
const lastWithSub = [...chatHistory].reverse().find(item => item.subContent?.length);
|
// Reset all node execution status on every chatHistory change
|
||||||
// Reset all node execution status first
|
|
||||||
nodes.forEach(node => {
|
nodes.forEach(node => {
|
||||||
const data = node.getData();
|
const data = node.getData();
|
||||||
if (typeof data.executionStatus === 'string') {
|
node.setData({ ...data, executionStatus: '' }, { silent: true });
|
||||||
node.setData({ ...data, executionStatus: undefined });
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
if (!lastWithSub?.subContent) return;
|
|
||||||
// Build a nodeId -> status map first
|
const lastAssistant = [...chatHistory].reverse().find(item => item.role === 'assistant');
|
||||||
const statusMap: Record<string, string> = {};
|
if (!lastAssistant?.subContent?.length) return;
|
||||||
lastWithSub.subContent.forEach(sub => {
|
lastAssistant.subContent.forEach(sub => {
|
||||||
if (typeof sub.status === 'string') {
|
if (typeof sub.status === 'string') {
|
||||||
statusMap[sub.node_id] = sub.status;
|
|
||||||
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
const node = nodes.find(n => n.getData()?.id === sub.node_id);
|
||||||
if (node) {
|
if (node) {
|
||||||
node.setData({ ...node.getData(), executionStatus: sub.status });
|
node.setData({ ...node.getData(), executionStatus: sub.status }, { silent: true });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -1537,5 +1751,7 @@ export const useWorkflowGraph = ({
|
|||||||
canRedo,
|
canRedo,
|
||||||
undo,
|
undo,
|
||||||
redo,
|
redo,
|
||||||
|
historyRecords,
|
||||||
|
clearHistoryRecords,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -113,4 +113,13 @@ export interface ChatVariable {
|
|||||||
}
|
}
|
||||||
export interface AddChatVariableRef {
|
export interface AddChatVariableRef {
|
||||||
handleOpen: (value?: ChatVariable) => void;
|
handleOpen: (value?: ChatVariable) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type HistoryActionType = 'add' | 'remove' | 'change' | 'undo' | 'redo' | 'batch'
|
||||||
|
|
||||||
|
export interface HistoryRecord {
|
||||||
|
type: HistoryActionType;
|
||||||
|
timestamp: number;
|
||||||
|
batchName?: string;
|
||||||
|
cellIds?: string[];
|
||||||
}
|
}
|
||||||
@@ -17,6 +17,7 @@ export const isSubExprSet = (sub: any) => {
|
|||||||
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
* Uses the same per-expression height logic as getConditionNodeCasePortY.
|
||||||
*/
|
*/
|
||||||
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
export const calcConditionNodeTotalHeight = (cases: any[]) => {
|
||||||
|
if (!cases?.length) return conditionNodeHeight;
|
||||||
const casesHeight = cases.reduce((acc: number, c: any) => {
|
const casesHeight = cases.reduce((acc: number, c: any) => {
|
||||||
const exprs = c?.expressions ?? [];
|
const exprs = c?.expressions ?? [];
|
||||||
const n = exprs.length;
|
const n = exprs.length;
|
||||||
|
|||||||
Reference in New Issue
Block a user