Merge branch 'develop' into feature/ui_upgrade_zy
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -25,6 +25,8 @@ examples/
|
|||||||
time.log
|
time.log
|
||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
|
redbear-mem-metrics/
|
||||||
|
pitch-deck/
|
||||||
|
|
||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
tmp
|
tmp
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from . import (
|
|||||||
document_controller,
|
document_controller,
|
||||||
emotion_config_controller,
|
emotion_config_controller,
|
||||||
emotion_controller,
|
emotion_controller,
|
||||||
|
end_user_controller,
|
||||||
file_controller,
|
file_controller,
|
||||||
file_storage_controller,
|
file_storage_controller,
|
||||||
home_page_controller,
|
home_page_controller,
|
||||||
@@ -96,5 +97,6 @@ manager_router.include_router(file_storage_controller.router)
|
|||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
manager_router.include_router(skill_controller.router)
|
||||||
manager_router.include_router(i18n_controller.router)
|
manager_router.include_router(i18n_controller.router)
|
||||||
|
manager_router.include_router(end_user_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
48
api/app/controllers/end_user_controller.py
Normal file
48
api/app/controllers/end_user_controller.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""End User 管理接口 - 无需认证"""
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.schemas.memory_api_schema import (
|
||||||
|
CreateEndUserRequest,
|
||||||
|
CreateEndUserResponse,
|
||||||
|
)
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def create_end_user(
|
||||||
|
data: CreateEndUserRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an end user.
|
||||||
|
|
||||||
|
Creates a new end user for the given workspace.
|
||||||
|
If an end user with the same other_id already exists in the workspace,
|
||||||
|
returns the existing one.
|
||||||
|
"""
|
||||||
|
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||||
|
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=None,
|
||||||
|
workspace_id=data.workspace_id,
|
||||||
|
other_id=data.other_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"End user ready: {end_user.id}")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": str(end_user.id),
|
||||||
|
"other_id": end_user.other_id or "",
|
||||||
|
"other_name": end_user.other_name or "",
|
||||||
|
"workspace_id": str(end_user.workspace_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||||
@@ -91,7 +91,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +172,6 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
# Get share and release info from share_token
|
# Get share and release info from share_token
|
||||||
service = ReleaseShareService(db)
|
service = ReleaseShareService(db)
|
||||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
# Get share object to access app_id
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
share = service.repo.get_by_share_token(share_data.share_token)
|
||||||
@@ -499,6 +498,51 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||||
|
async def get_permanent_file_url(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文件的永久公开 URL(无过期时间)。
|
||||||
|
|
||||||
|
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||||
|
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||||
|
"""
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||||
|
|
||||||
|
if file_metadata.status != "completed":
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||||
|
|
||||||
|
file_key = file_metadata.file_key
|
||||||
|
storage = storage_service.storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(storage, LocalStorage):
|
||||||
|
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||||
|
else:
|
||||||
|
url = await storage.get_permanent_url(file_key)
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Permanent URL not supported for current storage backend")
|
||||||
|
|
||||||
|
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||||
|
return success(
|
||||||
|
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||||
|
msg="Permanent file URL generated successfully"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|||||||
@@ -195,10 +195,9 @@ async def get_workspace_end_users(
|
|||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
|
||||||
try:
|
try:
|
||||||
from app.tasks import init_community_clustering_for_users
|
from app.tasks import init_community_clustering_for_users
|
||||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
|
|||||||
@@ -33,35 +33,47 @@ def get_memory_count(
|
|||||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
end_user_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve all conversations for the current user in a specific group.
|
Retrieve conversations for the current user in a specific group with pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
pagesize (int): Number of items per page. Defaults to 20.
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Contains a list of conversation IDs.
|
ApiResponse: Contains a paginated list of conversations.
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Initializes the ConversationService with the current DB session.
|
|
||||||
- Returns only conversation IDs for lightweight response.
|
|
||||||
- Logs can be added to trace requests in production.
|
|
||||||
"""
|
"""
|
||||||
|
page = max(1, page)
|
||||||
|
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations = conversation_service.get_user_conversations(
|
conversations, total = conversation_service.get_user_conversations(
|
||||||
end_user_id
|
end_user_id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return success(data=[
|
return success(data={
|
||||||
{
|
"items": [
|
||||||
"id": conversation.id,
|
{
|
||||||
"title": conversation.title
|
"id": conversation.id,
|
||||||
} for conversation in conversations
|
"title": conversation.title
|
||||||
], msg="get conversations success")
|
} for conversation in conversations
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": page_size,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * page_size) < total
|
||||||
|
},
|
||||||
|
}, msg="get conversations success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.schemas.memory_api_schema import (
|
from app.schemas.memory_api_schema import (
|
||||||
|
ListConfigsResponse,
|
||||||
MemoryReadRequest,
|
MemoryReadRequest,
|
||||||
MemoryReadResponse,
|
MemoryReadResponse,
|
||||||
MemoryWriteRequest,
|
MemoryWriteRequest,
|
||||||
@@ -31,14 +32,15 @@ async def write_memory_api_service(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
message: str = Body(..., description="Message content"),
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write memory to storage.
|
Write memory to storage.
|
||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryWriteRequest(**body)
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
@@ -62,13 +64,15 @@ async def read_memory_api_service(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryReadRequest = Body(..., embed=False),
|
message: str = Body(..., description="Query message"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read memory from storage.
|
Read memory from storage.
|
||||||
|
|
||||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryReadRequest(**body)
|
||||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
@@ -85,3 +89,27 @@ async def read_memory_api_service(
|
|||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/configs")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def list_memory_configs(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all memory configs for the workspace.
|
||||||
|
|
||||||
|
Returns all available memory configurations associated with the authorized workspace.
|
||||||
|
"""
|
||||||
|
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
result = memory_api_service.list_memory_configs(
|
||||||
|
workspace_id=api_key_auth.workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||||
|
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ async def get_tool_methods(
|
|||||||
if methods is None:
|
if methods is None:
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
return success(data=methods, msg="获取工具方法成功")
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -121,6 +123,8 @@ async def create_tool(
|
|||||||
raise HTTPException(status_code=400, detail=e.message)
|
raise HTTPException(status_code=400, detail=e.message)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -149,6 +153,8 @@ async def update_tool(
|
|||||||
return success(msg="工具更新成功")
|
return success(msg="工具更新成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -191,6 +197,8 @@ async def set_tool_active(
|
|||||||
return success(msg=f"工具已{action}")
|
return success(msg=f"工具已{action}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -223,6 +231,8 @@ async def execute_tool(
|
|||||||
},
|
},
|
||||||
msg="工具执行完成"
|
msg="工具执行完成"
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class Settings:
|
|||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
|
|||||||
@@ -166,15 +166,12 @@ async def write(
|
|||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
config_id=config_id,
|
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||||
schedule_clustering_after_write(
|
schedule_clustering_after_write(
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
config_id=config_id,
|
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -69,15 +69,15 @@ class LabelPropagationEngine:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.config_id = config_id
|
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
self.embedding_model_id = embedding_model_id
|
self.embedding_model_id = embedding_model_id
|
||||||
|
self.embedding_model_id = embedding_model_id
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -439,15 +439,17 @@ class LabelPropagationEngine:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||||
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
|
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
|
||||||
lines = []
|
lines = []
|
||||||
for m in members:
|
for m in members:
|
||||||
m_name = m.get("name", "")
|
m_name = m.get("name", "")
|
||||||
aliases = m.get("aliases") or []
|
aliases = m.get("aliases") or []
|
||||||
description = m.get("description") or ""
|
description = m.get("description") or ""
|
||||||
|
example = m.get("example") or ""
|
||||||
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||||
desc_str = f":{description}" if description else ""
|
desc_str = f":{description}" if description else ""
|
||||||
lines.append(f"- {m_name}{aliases_str}{desc_str}")
|
example_str = f"(示例:{example})" if example else ""
|
||||||
|
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
async def _generate_community_metadata(
|
async def _generate_community_metadata(
|
||||||
@@ -481,11 +483,24 @@ class LabelPropagationEngine:
|
|||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
|
||||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
|
|
||||||
|
# 方案四:注入社区内实体间关系三元组
|
||||||
|
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||||
|
rel_lines = [
|
||||||
|
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||||
|
for r in relationships
|
||||||
|
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||||
|
]
|
||||||
|
rel_section = (
|
||||||
|
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||||
|
if rel_lines else ""
|
||||||
|
)
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
|
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||||
f"请为这组实体所代表的主题:\n"
|
f"请为这组实体所代表的主题:\n"
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -121,3 +121,18 @@ class StorageBackend(ABC):
|
|||||||
URL for accessing the file.
|
URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (no expiration).
|
||||||
|
|
||||||
|
Returns None by default; remote storage backends should override this
|
||||||
|
if the bucket is configured for public read access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: Unique identifier for the file in the storage system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent public URL, or None if not supported.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|||||||
@@ -261,3 +261,13 @@ class OSSStorage(StorageBackend):
|
|||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
# Return a basic URL format as fallback
|
||||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.{endpoint}/{file_key}
|
||||||
|
"""
|
||||||
|
host = self.endpoint.replace("https://", "").replace("http://", "")
|
||||||
|
return f"https://{self.bucket_name}.{host}/{file_key}"
|
||||||
|
|||||||
@@ -378,3 +378,12 @@ class S3Storage(StorageBackend):
|
|||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
# Return a basic URL format as fallback
|
||||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key}
|
||||||
|
"""
|
||||||
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|||||||
@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import NodeFactory
|
from app.core.workflow.nodes import NodeFactory
|
||||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Regex to split output into:
|
||||||
|
# - variable placeholders: {{ ... }}
|
||||||
|
# - normal literal text
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||||
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -37,13 +49,13 @@ class GraphBuilder:
|
|||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.subgraph = subgraph
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id: str | None = None
|
||||||
self.end_node_ids = []
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_activation_dep = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_activation_dep)
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
@@ -51,10 +63,19 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
self._analyze_end_node_output()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
self._build_reverse_adj()
|
||||||
|
self._analyze_end_node_output()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
return self.workflow_config.get("nodes", [])
|
return self.workflow_config.get("nodes", [])
|
||||||
@@ -87,60 +108,50 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
def _build_reverse_adj(self):
|
||||||
"""
|
for edge in self.edges:
|
||||||
Recursively find all upstream branch (control) nodes that influence the execution
|
if edge["source"] not in self.reachable_nodes:
|
||||||
of the given target node.
|
continue
|
||||||
|
self._reverse_adj[edge.get("target")].append({
|
||||||
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
|
})
|
||||||
|
|
||||||
This method walks upstream along the workflow graph starting from `target_node`.
|
def _find_upstream_activation_dep(
|
||||||
It distinguishes between:
|
self,
|
||||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
target_node: str
|
||||||
- non-branch nodes (ordinary processing nodes)
|
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||||
|
"""Find upstream dependencies that affect the activation of a target node.
|
||||||
|
|
||||||
Traversal rules:
|
Walks upstream along the workflow graph from the target node, collecting
|
||||||
1. For each immediate upstream node:
|
two types of dependencies:
|
||||||
- If it is a branch node, it is recorded as an affecting control node.
|
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
routing outcome determines whether the target node executes.
|
||||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
- Output nodes: upstream END nodes that must complete their output
|
||||||
a branch node, the traversal is considered invalid:
|
before the target node can activate.
|
||||||
- `has_branch` will be False
|
|
||||||
- no branch nodes are returned.
|
|
||||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
|
||||||
branch node will `has_branch` be True.
|
|
||||||
|
|
||||||
Special case:
|
The traversal terminates early and returns empty tuples if any upstream
|
||||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
path reaches START/CYCLE_START without encountering a branch or output
|
||||||
it is considered directly reachable from the workflow entry, and therefore
|
node, indicating the target node is directly reachable and should be
|
||||||
has no controlling branch nodes.
|
activated immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str):
|
target_node: The ID of the node whose upstream activation
|
||||||
The identifier of the node whose upstream control branches
|
dependencies are to be resolved.
|
||||||
are to be resolved.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[tuple[str, str]]]:
|
A tuple of two elements:
|
||||||
- has_branch (bool):
|
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||||
True if every upstream path from `target_node` encounters
|
representing upstream branch control dependencies. Empty if
|
||||||
at least one branch node.
|
any clean path to START exists.
|
||||||
False if any path reaches a start node without a branch.
|
- A deduplicated tuple of upstream output node IDs that must
|
||||||
- branch_nodes (tuple[tuple[str, str]]):
|
complete before this node activates.
|
||||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
|
||||||
representing all branch nodes that can influence `target_node`.
|
|
||||||
Returns an empty tuple if `has_branch` is False.
|
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = self._reverse_adj[target_node]
|
||||||
{
|
|
||||||
"id": edge.get("source"),
|
|
||||||
"branch": edge.get("label")
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
|
||||||
if edge.get("target") == target_node
|
|
||||||
]
|
|
||||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
return False, tuple()
|
return tuple(), tuple()
|
||||||
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
|
output_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
@@ -149,19 +160,23 @@ class GraphBuilder:
|
|||||||
(node_info["id"], node_info["branch"])
|
(node_info["id"], node_info["branch"])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||||
|
output_nodes.append(node_info["id"])
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||||
has_branch = has_branch and node_has_branch
|
if not upstream_control_nodes:
|
||||||
if not has_branch:
|
if not upstream_output_nodes and node_id not in output_nodes:
|
||||||
break
|
return tuple(), tuple()
|
||||||
branch_nodes.extend(nodes)
|
branch_nodes = []
|
||||||
if not has_branch:
|
has_branch = False
|
||||||
branch_nodes = []
|
if has_branch:
|
||||||
|
branch_nodes.extend(upstream_control_nodes)
|
||||||
|
output_nodes.extend(upstream_output_nodes)
|
||||||
|
|
||||||
return has_branch, tuple(set(branch_nodes))
|
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||||
|
|
||||||
def _analyze_end_node_output(self):
|
def _analyze_end_node_output(self):
|
||||||
"""
|
"""
|
||||||
@@ -182,11 +197,10 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Collect all End nodes in the workflow
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
|
||||||
|
|
||||||
# Iterate through each End node to analyze its output
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in self.end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
config = end_node.get("config", {})
|
config = end_node.get("config", {})
|
||||||
output = config.get("output")
|
output = config.get("output")
|
||||||
@@ -195,42 +209,33 @@ class GraphBuilder:
|
|||||||
if not output:
|
if not output:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Regex to split output into:
|
|
||||||
# - variable placeholders: {{ ... }}
|
|
||||||
# - normal literal text
|
|
||||||
#
|
|
||||||
# Example:
|
|
||||||
# "Hello {{user.name}}!" ->
|
|
||||||
# ["Hello ", "{{user.name}}", "!"]
|
|
||||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
|
||||||
|
|
||||||
# Strict variable format: {{ node_id.field_name }}
|
|
||||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
|
||||||
variable_pattern = re.compile(variable_pattern_string)
|
|
||||||
|
|
||||||
# Split output into ordered segments
|
# Split output into ordered segments
|
||||||
output_template = list(re.findall(pattern, output))
|
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||||
|
|
||||||
# Determine whether each segment is literal text
|
# Determine whether each segment is literal text
|
||||||
# True -> literal (can be directly output)
|
# True -> literal (can be directly output)
|
||||||
# False -> variable placeholder (needs runtime value)
|
# False -> variable placeholder (needs runtime value)
|
||||||
output_flag = [
|
output_flag = [
|
||||||
not bool(variable_pattern.match(item))
|
not bool(_VARIABLE_PATTERN.match(item))
|
||||||
for item in output_template
|
for item in output_template
|
||||||
]
|
]
|
||||||
|
|
||||||
# Stream mode: output activation depends on upstream branch nodes
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Find upstream branch nodes that can control this End node
|
# Find upstream branch nodes that can control this End node
|
||||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||||
|
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||||
# Build StreamOutputConfig for this End node
|
# Build StreamOutputConfig for this End node
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
# If there is no upstream branch, output is active immediately
|
# If there is no upstream branch, output is active immediately
|
||||||
activate=not has_branch,
|
activate=activate,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=self._merge_control_nodes(control_nodes),
|
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||||
|
upstream_output_nodes=list(upstream_output_nodes),
|
||||||
|
control_resolved=not bool(upstream_control_nodes),
|
||||||
|
output_resolved=not bool(upstream_output_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -249,14 +254,16 @@ class GraphBuilder:
|
|||||||
cursor=0
|
cursor=0
|
||||||
)
|
)
|
||||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
f"activate: {not has_branch}, "
|
f"activate: {activate}, "
|
||||||
f"control_nodes: {control_nodes},"
|
f"control_nodes: {upstream_control_nodes},"
|
||||||
|
f"ref_outputs: {upstream_output_nodes},"
|
||||||
f"output: {output_template},"
|
f"output: {output_template},"
|
||||||
f"output_activate: {output_flag}")
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
# Non-stream mode: all outputs are activated by default
|
# Non-stream mode: all outputs are activated by default
|
||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes={},
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -269,7 +276,10 @@ class GraphBuilder:
|
|||||||
for output_string, activate in zip(output_template, output_flag)
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
cursor=0
|
cursor=0,
|
||||||
|
upstream_output_nodes=[],
|
||||||
|
control_resolved=True,
|
||||||
|
output_resolved=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
@@ -304,8 +314,6 @@ class GraphBuilder:
|
|||||||
# Record start and end node IDs
|
# Record start and end node IDs
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
self.start_node_id = node_id
|
self.start_node_id = node_id
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
@@ -448,7 +456,7 @@ class GraphBuilder:
|
|||||||
branch_activate = []
|
branch_activate = []
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
|
||||||
for label, branch in unique_branch.items():
|
for label, branch in unique_branch.items():
|
||||||
if node_output and evaluate_condition(
|
if node_output and evaluate_condition(
|
||||||
branch["condition"],
|
branch["condition"],
|
||||||
@@ -494,9 +502,11 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node_id in self.end_node_ids:
|
for end_node in self.end_nodes:
|
||||||
self.graph.add_edge(end_node_id, END)
|
end_node_id = end_node.get("id")
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
if end_node_id:
|
||||||
|
self.graph.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
final_output: str,
|
final_output: str,
|
||||||
|
success: bool
|
||||||
):
|
):
|
||||||
"""Construct the final standardized output of the workflow execution.
|
"""Construct the final standardized output of the workflow execution.
|
||||||
|
|
||||||
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
|
|||||||
elapsed_time (float): Total execution time in seconds.
|
elapsed_time (float): Total execution time in seconds.
|
||||||
final_output (Any): The aggregated or final output content of the workflow
|
final_output (Any): The aggregated or final output content of the workflow
|
||||||
(e.g., combined messages from all End nodes).
|
(e.g., combined messages from all End nodes).
|
||||||
|
success (bool): Whether the execution was successful.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the final workflow execution result with keys:
|
dict: A dictionary containing the final workflow execution result with keys:
|
||||||
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
|
|||||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
"variables": {
|
"variables": {
|
||||||
"conv": variable_pool.get_all_conversation_vars(),
|
"conv": variable_pool.get_all_conversation_vars(),
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 15:11
|
# @Time : 2026/2/9 15:11
|
||||||
import re
|
import re
|
||||||
|
from queue import Queue
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
|||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this output segment is currently active.\n"
|
"Whether this output segment is currently active."
|
||||||
"- True: allowed to be emitted/output\n"
|
"- True: allowed to be emitted/output"
|
||||||
"- False: blocked until activated by branch control"
|
"- False: blocked until activated by branch control"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
|
|||||||
is_variable: bool = Field(
|
is_variable: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this segment represents a variable placeholder.\n"
|
"Whether this segment represents a variable placeholder."
|
||||||
"True -> variable (e.g. {{ node.field }})\n"
|
"True -> variable (e.g. {{ node.field }})"
|
||||||
"False -> literal text"
|
"False -> literal text"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
|||||||
- which upstream branch/control nodes gate the activation
|
- which upstream branch/control nodes gate the activation
|
||||||
- how each parsed output segment is streamed and activated
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="ID of the End node this configuration belongs to."
|
||||||
|
)
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation flag for the End node output.\n"
|
"Global activation flag for the End node output."
|
||||||
"When False, output segments should not be emitted even if available.\n"
|
"When False, output segments should not be emitted even if available."
|
||||||
"This flag typically becomes True once required control branch conditions "
|
"This flag typically becomes True once required control branch conditions "
|
||||||
"are satisfied."
|
"are satisfied."
|
||||||
)
|
)
|
||||||
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
|||||||
control_nodes: dict[str, list[str]] = Field(
|
control_nodes: dict[str, list[str]] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Control branch conditions for this End node output.\n"
|
"Control branch conditions for this End node output."
|
||||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||||
"The End node output becomes globally active when a controlling branch node "
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
"reports a matching completion status."
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upstream_output_nodes: list[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Upstream output node dependencies (data flow)."
|
||||||
|
"Represents END/output nodes that this output depends on."
|
||||||
|
"These nodes provide data sources required before this output can be activated "
|
||||||
|
"or streamed."
|
||||||
|
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream branch control dependencies have been satisfied."
|
||||||
|
"True if no upstream branch nodes exist or the required branch "
|
||||||
|
"conditions have been met."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream output node dependencies have been completed."
|
||||||
|
"True if no upstream output nodes exist or all upstream output "
|
||||||
|
"nodes have finished their output."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Ordered list of output segments parsed from the output template.\n"
|
"Ordered list of output segments parsed from the output template."
|
||||||
"Each segment represents either a literal text block or a variable placeholder "
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
"that may be activated independently."
|
"that may be activated independently."
|
||||||
)
|
)
|
||||||
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
|||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index."
|
||||||
"Indicates the next output segment index to be emitted.\n"
|
"Indicates the next output segment index to be emitted."
|
||||||
"Segments with index < cursor are considered already streamed."
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
force: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Force flag for output emission."
|
||||||
|
"When True, all output segments are emitted regardless of activation state."
|
||||||
|
"Triggered when this output node has finished execution."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def update_activate(self, scope: str, status=None):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update streaming activation state based on an upstream node or special variable.
|
Update streaming activation state based on upstream events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str):
|
scope (str):
|
||||||
Identifier of the completed upstream entity.
|
Identifier of the completed upstream entity.
|
||||||
- If a control branch node, it should match a key in `control_nodes`.
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||||
|
it may appear in output segments.
|
||||||
|
|
||||||
status (optional):
|
status (optional):
|
||||||
Completion status of the control branch node.
|
Completion status of the control branch node.
|
||||||
Required when `scope` refers to a control node.
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. Control branch nodes:
|
1. Force activation:
|
||||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
- If `self.force` is True, the method returns immediately.
|
||||||
branch label, the End node output becomes globally active (`activate = True`).
|
- If `scope == self.id`, the node marks itself as completed:
|
||||||
|
- `activate = True`
|
||||||
|
- `force = True`
|
||||||
|
This is typically used for final flushing when the node finishes execution.
|
||||||
|
|
||||||
2. Variable output segments:
|
2. Control dependency resolution:
|
||||||
- For each segment that is a variable (`is_variable=True`):
|
- If `scope` matches a key in `control_nodes`:
|
||||||
- If the segment literal references `scope`, mark the segment as active.
|
- `status` must be provided.
|
||||||
- This applies both to regular node variables (e.g., "node_id.field")
|
- If `status` matches expected branch labels, mark control as resolved
|
||||||
and special system variables (e.g., "sys.xxx").
|
(`control_resolved = True`).
|
||||||
|
|
||||||
|
3. Upstream output dependency resolution:
|
||||||
|
- If `scope` is in `upstream_output_nodes`,
|
||||||
|
mark data dependency as resolved (`output_resolved = True`).
|
||||||
|
|
||||||
|
4. Global activation condition:
|
||||||
|
- The node becomes active when BOTH conditions are satisfied:
|
||||||
|
- control_resolved == True
|
||||||
|
- output_resolved == True
|
||||||
|
- Once activated, `activate` remains True.
|
||||||
|
|
||||||
|
5. Variable segment activation:
|
||||||
|
- For each output segment that is a variable (`is_variable=True`):
|
||||||
|
- If the segment depends on the given `scope`,
|
||||||
|
mark the segment as active.
|
||||||
|
- This applies to both node variables (e.g., "node_id.field")
|
||||||
|
and system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This method does not emit output or advance the streaming cursor.
|
- This method does NOT emit output or advance the streaming cursor.
|
||||||
- It only updates activation flags based on upstream events or special variables.
|
- It only updates activation and dependency resolution states.
|
||||||
|
- Activation is driven by both control flow (branch nodes) and
|
||||||
|
data flow (upstream output nodes).
|
||||||
"""
|
"""
|
||||||
|
if self.force:
|
||||||
|
return
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
if scope == self.id:
|
||||||
|
self.activate = True
|
||||||
|
self.force = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# resolve control branch dependency
|
||||||
if scope in self.control_nodes:
|
if scope in self.control_nodes:
|
||||||
if status is None:
|
if status is None:
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
if status in self.control_nodes[scope]:
|
if status in self.control_nodes[scope]:
|
||||||
self.activate = True
|
self.control_resolved = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
if scope in self.upstream_output_nodes:
|
||||||
|
self.upstream_output_nodes.remove(scope)
|
||||||
|
if not self.upstream_output_nodes:
|
||||||
|
self.output_resolved = True
|
||||||
|
|
||||||
|
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||||
|
|
||||||
|
# activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
@@ -174,12 +256,17 @@ class StreamOutputCoordinator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
self.activate_end: str | None = None
|
self.activate_end: str | None = None
|
||||||
|
self.output_queue: Queue = Queue()
|
||||||
|
self.processed_outputs = []
|
||||||
|
|
||||||
def initialize_end_outputs(
|
def initialize_end_outputs(
|
||||||
self,
|
self,
|
||||||
end_node_map: dict[str, StreamOutputConfig]
|
end_node_map: dict[str, StreamOutputConfig]
|
||||||
):
|
):
|
||||||
self.end_outputs = end_node_map
|
self.end_outputs = end_node_map
|
||||||
|
self.processed_outputs = []
|
||||||
|
self.activate_end = None
|
||||||
|
self.output_queue = Queue()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_activate_end_info(self):
|
def current_activate_end_info(self):
|
||||||
@@ -211,8 +298,11 @@ class StreamOutputCoordinator:
|
|||||||
"""
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||||
self.activate_end = node
|
self.output_queue.put(node)
|
||||||
|
self.processed_outputs.append(node)
|
||||||
|
if self.activate_end is None and not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
|
|
||||||
async def emit_activate_chunk(
|
async def emit_activate_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -256,7 +346,7 @@ class StreamOutputCoordinator:
|
|||||||
final_chunk = ''
|
final_chunk = ''
|
||||||
current_segment = end_info.outputs[end_info.cursor]
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
if not current_segment.activate and not force:
|
if not current_segment.activate and not force and not end_info.force:
|
||||||
# Stop processing until this segment becomes active
|
# Stop processing until this segment becomes active
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -273,7 +363,7 @@ class StreamOutputCoordinator:
|
|||||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||||
|
|
||||||
if final_chunk:
|
if final_chunk:
|
||||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
@@ -285,8 +375,7 @@ class StreamOutputCoordinator:
|
|||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
self.end_outputs.pop(self.activate_end)
|
self.pop_current_activate_end()
|
||||||
self.activate_end = None
|
|
||||||
|
|
||||||
async def flush_remaining_chunk(
|
async def flush_remaining_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -325,6 +414,8 @@ class StreamOutputCoordinator:
|
|||||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
|
if not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
# Move to next active End node if current one is done
|
# Move to next active End node if current one is done
|
||||||
if not self.activate_end and self.end_outputs:
|
if not self.activate_end and self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|||||||
@@ -351,12 +351,12 @@ class VariablePool:
|
|||||||
}
|
}
|
||||||
return runtime_vars
|
return runtime_vars
|
||||||
|
|
||||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||||
"""获取指定节点的输出(运行时变量)
|
"""获取指定节点的输出(运行时变量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点 ID
|
node_id: 节点 ID
|
||||||
defalut: 默认值
|
default: 默认值
|
||||||
strict: 是否严格模式
|
strict: 是否严格模式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -368,7 +368,7 @@ class VariablePool:
|
|||||||
if strict:
|
if strict:
|
||||||
raise KeyError(f"node {node_id} output not exist")
|
raise KeyError(f"node {node_id} output not exist")
|
||||||
else:
|
else:
|
||||||
return defalut
|
return default
|
||||||
|
|
||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|||||||
@@ -128,89 +128,100 @@ class WorkflowExecutor:
|
|||||||
- token_usage: aggregated token usage if available
|
- token_usage: aggregated token usage if available
|
||||||
- error: error message if any
|
- error: error message if any
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
start = datetime.datetime.now()
|
||||||
|
async for event in self.execute_stream(input_data):
|
||||||
start_time = datetime.datetime.now()
|
if event.get("event") == "workflow_end":
|
||||||
|
return event.get("data")
|
||||||
# Execute the workflow
|
return self.result_builder.build_final_output(
|
||||||
try:
|
{"error": "Workflow execution did not end as expected"},
|
||||||
# Build the workflow graph
|
self.variable_pool,
|
||||||
graph = self.build_graph()
|
(datetime.datetime.now() - start).total_seconds(),
|
||||||
|
"",
|
||||||
# Initialize the variable pool with input data
|
success=False
|
||||||
await self.variable_initializer.initialize(
|
)
|
||||||
variable_pool=self.variable_pool,
|
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||||
input_data=input_data,
|
#
|
||||||
execution_context=self.execution_context
|
# start_time = datetime.datetime.now()
|
||||||
)
|
#
|
||||||
initial_state = self.state_manager.create_initial_state(
|
# # Execute the workflow
|
||||||
workflow_config=self.workflow_config,
|
# try:
|
||||||
input_data=input_data,
|
# # Build the workflow graph
|
||||||
execution_context=self.execution_context,
|
# graph = self.build_graph()
|
||||||
start_node_id=self.start_node_id
|
#
|
||||||
)
|
# # Initialize the variable pool with input data
|
||||||
|
# await self.variable_initializer.initialize(
|
||||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
# variable_pool=self.variable_pool,
|
||||||
|
# input_data=input_data,
|
||||||
# Aggregate output from all End nodes
|
# execution_context=self.execution_context
|
||||||
full_content = ''
|
# )
|
||||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
# initial_state = self.state_manager.create_initial_state(
|
||||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
# workflow_config=self.workflow_config,
|
||||||
|
# input_data=input_data,
|
||||||
# Append messages for user and assistant
|
# execution_context=self.execution_context,
|
||||||
if input_data.get("files"):
|
# start_node_id=self.start_node_id
|
||||||
result["messages"].extend(
|
# )
|
||||||
[
|
#
|
||||||
{
|
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||||
"role": "user",
|
#
|
||||||
"content": input_data.get("message", '')
|
# # Aggregate output from all End nodes
|
||||||
},
|
# full_content = ''
|
||||||
{
|
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||||
"role": "user",
|
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||||
"content": input_data.get("files")
|
#
|
||||||
},
|
# # Append messages for user and assistant
|
||||||
{
|
# if input_data.get("files"):
|
||||||
"role": "assistant",
|
# result["messages"].extend(
|
||||||
"content": full_content
|
# [
|
||||||
}
|
# {
|
||||||
]
|
# "role": "user",
|
||||||
)
|
# "content": input_data.get("message", '')
|
||||||
else:
|
# },
|
||||||
result["messages"].extend(
|
# {
|
||||||
[
|
# "role": "user",
|
||||||
{
|
# "content": input_data.get("files")
|
||||||
"role": "user",
|
# },
|
||||||
"content": input_data.get("message", '')
|
# {
|
||||||
},
|
# "role": "assistant",
|
||||||
{
|
# "content": full_content
|
||||||
"role": "assistant",
|
# }
|
||||||
"content": full_content
|
# ]
|
||||||
}
|
# )
|
||||||
]
|
# else:
|
||||||
)
|
# result["messages"].extend(
|
||||||
# Calculate elapsed time
|
# [
|
||||||
end_time = datetime.datetime.now()
|
# {
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# "role": "user",
|
||||||
|
# "content": input_data.get("message", '')
|
||||||
logger.info(
|
# },
|
||||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
# {
|
||||||
|
# "role": "assistant",
|
||||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
# "content": full_content
|
||||||
|
# }
|
||||||
except Exception as e:
|
# ]
|
||||||
end_time = datetime.datetime.now()
|
# )
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# # Calculate elapsed time
|
||||||
|
# end_time = datetime.datetime.now()
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
exc_info=True)
|
#
|
||||||
return {
|
# logger.info(
|
||||||
"status": "failed",
|
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||||
"error": str(e),
|
#
|
||||||
"output": None,
|
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||||
"node_outputs": {},
|
#
|
||||||
"elapsed_time": elapsed_time,
|
# except Exception as e:
|
||||||
"token_usage": None
|
# end_time = datetime.datetime.now()
|
||||||
}
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
#
|
||||||
|
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
|
# exc_info=True)
|
||||||
|
# return {
|
||||||
|
# "status": "failed",
|
||||||
|
# "error": str(e),
|
||||||
|
# "output": None,
|
||||||
|
# "node_outputs": {},
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "token_usage": None
|
||||||
|
# }
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
@@ -248,7 +259,8 @@ class WorkflowExecutor:
|
|||||||
"timestamp": int(start_time.timestamp() * 1000)
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
result = None
|
||||||
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
# Build the workflow graph in streaming mode
|
# Build the workflow graph in streaming mode
|
||||||
graph = self.build_graph(stream=True)
|
graph = self.build_graph(stream=True)
|
||||||
@@ -266,7 +278,6 @@ class WorkflowExecutor:
|
|||||||
start_node_id=self.start_node_id
|
start_node_id=self.start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
full_content = ''
|
|
||||||
self.stream_coordinator.update_scope_activation("sys")
|
self.stream_coordinator.update_scope_activation("sys")
|
||||||
|
|
||||||
# Execute the workflow with streaming
|
# Execute the workflow with streaming
|
||||||
@@ -363,7 +374,12 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
"data": self.result_builder.build_final_output(
|
||||||
|
result,
|
||||||
|
self.variable_pool,
|
||||||
|
elapsed_time,
|
||||||
|
full_content,
|
||||||
|
success=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -372,16 +388,19 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
exc_info=True)
|
exc_info=True)
|
||||||
|
if result is None:
|
||||||
|
result = {"error": str(e)}
|
||||||
|
else:
|
||||||
|
result["error"] = str(e)
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": {
|
"data": self.result_builder.build_final_output(
|
||||||
"execution_id": self.execution_context.execution_id,
|
result,
|
||||||
"status": "failed",
|
self.variable_pool,
|
||||||
"error": str(e),
|
elapsed_time,
|
||||||
"elapsed_time": elapsed_time,
|
full_content,
|
||||||
"timestamp": end_time.isoformat()
|
success=False
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"http://sandbox:8194/v1/sandbox/run",
|
"http://sandbox:8194/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Right-hand operand of the comparison expression"
|
description="Right-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class LoopRuntime:
|
|||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.variables["conv"]
|
self.child_variable_pool.variables["conv"]
|
||||||
)
|
)
|
||||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
|
||||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||||
|
|
||||||
def evaluate_conditional(self) -> bool:
|
def evaluate_conditional(self) -> bool:
|
||||||
@@ -261,4 +261,4 @@ class LoopRuntime:
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
logger.info(f"loop node {self.node_id}: execution completed")
|
logger.info(f"loop node {self.node_id}: execution completed")
|
||||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Value to compare with"
|
description="Value to compare with"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,13 +31,13 @@ class IfElseNode(BaseNode):
|
|||||||
expressions.append({
|
expressions.append({
|
||||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||||
"right": expression.right
|
"right": expression.right
|
||||||
if expression.input_type == ValueInputType.CONSTANT
|
if expression.input_type == ValueInputType.CONSTANT or expression.right is None
|
||||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
"operator": expression.operator,
|
"operator": str(expression.operator),
|
||||||
})
|
})
|
||||||
result.append({
|
result.append({
|
||||||
"expressions": expressions,
|
"expressions": expressions,
|
||||||
"logical_operator": case.logical_operator,
|
"logical_operator": str(case.logical_operator),
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
"cases": result
|
"cases": result
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
@@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
return reranker
|
return reranker
|
||||||
|
|
||||||
|
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
kb_config.kb_id = child.id
|
||||||
|
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||||
|
return
|
||||||
|
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||||
|
match kb_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold))
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight))
|
||||||
|
case RetrieveType.HYBRID:
|
||||||
|
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight)
|
||||||
|
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold)
|
||||||
|
|
||||||
|
# Deduplicate hybrid retrieval results
|
||||||
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
|
if not unique_rs:
|
||||||
|
return
|
||||||
|
if self.typed_config.reranker_id:
|
||||||
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
|
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
|
else:
|
||||||
|
rs.extend(sorted(
|
||||||
|
unique_rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:kb_config.top_k])
|
||||||
|
case _:
|
||||||
|
raise RuntimeError("Unknown retrieval type")
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the knowledge retrieval workflow node.
|
Execute the knowledge retrieval workflow node.
|
||||||
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
query = self._render_template(self.typed_config.query, variable_pool)
|
query = self._render_template(self.typed_config.query, variable_pool)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
|
||||||
|
|
||||||
if not existing_ids:
|
|
||||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
|
||||||
|
|
||||||
rs = []
|
rs = []
|
||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
|
||||||
match kb_config.retrieve_type:
|
|
||||||
case RetrieveType.PARTICIPLE:
|
|
||||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold))
|
|
||||||
case RetrieveType.SEMANTIC:
|
|
||||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
|
||||||
case RetrieveType.HYBRID:
|
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight)
|
|
||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold)
|
|
||||||
|
|
||||||
# Deduplicate hy brid retrieval results
|
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
|
||||||
if not unique_rs:
|
|
||||||
continue
|
|
||||||
if self.typed_config.reranker_id:
|
|
||||||
vector_service.reranker = self.get_reranker_model()
|
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
|
||||||
else:
|
|
||||||
rs.extend(sorted(
|
|
||||||
unique_rs,
|
|
||||||
key=lambda d: d.metadata.get("score", 0),
|
|
||||||
reverse=True
|
|
||||||
)[:kb_config.top_k])
|
|
||||||
case _:
|
|
||||||
raise RuntimeError("Unknown retrieval type")
|
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
vector_service.reranker = self.get_reranker_model()
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
else:
|
else:
|
||||||
final_rs = sorted(
|
final_rs = sorted(
|
||||||
rs,
|
rs,
|
||||||
|
|||||||
@@ -250,6 +250,8 @@ class ConditionBase(ABC):
|
|||||||
self.type_limit = getattr(self, "type_limit", None)
|
self.type_limit = getattr(self, "type_limit", None)
|
||||||
|
|
||||||
def resolve_right_literal_value(self):
|
def resolve_right_literal_value(self):
|
||||||
|
if self.right_selector is None:
|
||||||
|
return None
|
||||||
if self.input_type == ValueInputType.VARIABLE:
|
if self.input_type == ValueInputType.VARIABLE:
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
|||||||
# 仅在发布时验证所有节点可达
|
# 仅在发布时验证所有节点可达
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
reachable = WorkflowValidator.get_reachable_nodes(
|
||||||
start_nodes[0]["id"],
|
start_nodes[0]["id"],
|
||||||
edges
|
edges
|
||||||
)
|
)
|
||||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
|||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
"""获取从 start 节点可达的所有节点
|
"""获取从 start 节点可达的所有节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
|||||||
"""
|
"""
|
||||||
if isinstance(var, str):
|
if isinstance(var, str):
|
||||||
return cls.STRING
|
return cls.STRING
|
||||||
elif isinstance(var, (int, float)):
|
|
||||||
return cls.NUMBER
|
|
||||||
elif isinstance(var, bool):
|
elif isinstance(var, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
|
elif isinstance(var, (int, float)):
|
||||||
|
return cls.NUMBER
|
||||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
|
|||||||
content_cache: dict = Field(default_factory=dict)
|
content_cache: dict = Field(default_factory=dict)
|
||||||
is_file: bool
|
is_file: bool
|
||||||
|
|
||||||
_byte_content: bytes | None = None
|
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_content(self):
|
def get_content(self):
|
||||||
return self._byte_content
|
return self._byte_content
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
|||||||
|
|
||||||
|
|
||||||
class StringVariable(BaseVariable):
|
class StringVariable(BaseVariable):
|
||||||
|
value: str
|
||||||
type = 'str'
|
type = 'str'
|
||||||
|
|
||||||
def valid_value(self, value) -> str:
|
def valid_value(self, value) -> str:
|
||||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class NumberVariable(BaseVariable):
|
class NumberVariable(BaseVariable):
|
||||||
|
value: int | float
|
||||||
type = 'number'
|
type = 'number'
|
||||||
|
|
||||||
def valid_value(self, value) -> int | float:
|
def valid_value(self, value) -> int | float:
|
||||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class BooleanVariable(BaseVariable):
|
class BooleanVariable(BaseVariable):
|
||||||
|
value: bool
|
||||||
type = 'boolean'
|
type = 'boolean'
|
||||||
|
|
||||||
def valid_value(self, value) -> bool:
|
def valid_value(self, value) -> bool:
|
||||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class DictVariable(BaseVariable):
|
class DictVariable(BaseVariable):
|
||||||
|
value: dict
|
||||||
type = 'object'
|
type = 'object'
|
||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class FileVariable(BaseVariable):
|
class FileVariable(BaseVariable):
|
||||||
|
value: FileObject
|
||||||
type = 'file'
|
type = 'file'
|
||||||
|
|
||||||
def valid_value(self, value) -> FileObject:
|
def valid_value(self, value) -> FileObject:
|
||||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class ArrayVariable(BaseVariable, Generic[T]):
|
class ArrayVariable(BaseVariable, Generic[T]):
|
||||||
|
value: list[T]
|
||||||
type = 'array'
|
type = 'array'
|
||||||
|
|
||||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class NestedArrayVariable(BaseVariable):
|
class NestedArrayVariable(BaseVariable):
|
||||||
|
value: list[ArrayVariable]
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
|||||||
category=RuntimeWarning
|
category=RuntimeWarning
|
||||||
)
|
)
|
||||||
class AnyVariable(BaseVariable):
|
class AnyVariable(BaseVariable):
|
||||||
|
value: Any
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
def valid_value(self, value: Any) -> Any:
|
def valid_value(self, value: Any) -> Any:
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
|
|||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.rollback() # 只读任务无需 commit
|
db.rollback() # 只读任务无需 commit
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def get_pool_status():
|
def get_pool_status():
|
||||||
|
|||||||
@@ -506,10 +506,13 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
404: "errors.common.not_found",
|
404: "errors.common.not_found",
|
||||||
405: "errors.common.method_not_allowed",
|
405: "errors.common.method_not_allowed",
|
||||||
409: "errors.common.conflict",
|
409: "errors.common.conflict",
|
||||||
|
413: "errors.common.payload_too_large",
|
||||||
422: "errors.common.validation_failed",
|
422: "errors.common.validation_failed",
|
||||||
429: "errors.common.too_many_requests",
|
429: "errors.common.too_many_requests",
|
||||||
500: "errors.common.internal_error",
|
500: "errors.common.internal_error",
|
||||||
|
502: "errors.common.bad_gateway",
|
||||||
503: "errors.common.service_unavailable",
|
503: "errors.common.service_unavailable",
|
||||||
|
504: "errors.common.gateway_timeout",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果有对应的翻译键,使用翻译
|
# 如果有对应的翻译键,使用翻译
|
||||||
@@ -534,7 +537,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
|
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,27 +90,27 @@ class ConversationRepository:
|
|||||||
self,
|
self,
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID = None,
|
workspace_id: uuid.UUID = None,
|
||||||
limit: int = 10,
|
is_activate: bool = True,
|
||||||
is_activate: bool = True
|
page: int = 1,
|
||||||
) -> list[Conversation]:
|
page_size: int = 20
|
||||||
|
) -> tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
Retrieve recent conversations for a specific user.
|
Retrieve recent conversations for a specific user with pagination.
|
||||||
|
|
||||||
This method queries conversations associated with the given user ID,
|
This method queries conversations associated with the given user ID,
|
||||||
optionally scoped to a specific workspace. Results are ordered by the
|
optionally scoped to a specific workspace. Results are ordered by the
|
||||||
most recently updated conversations and limited to a fixed number.
|
most recently updated conversations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (uuid.UUID): Unique identifier of the user.
|
user_id (uuid.UUID): Unique identifier of the user.
|
||||||
workspace_id (uuid.UUID, optional): Workspace scope for the query.
|
workspace_id (uuid.UUID, optional): Workspace scope for the query.
|
||||||
If provided, only conversations under this workspace will be returned.
|
If provided, only conversations under this workspace will be returned.
|
||||||
limit (int): Maximum number of conversations to return.
|
is_activate (bool): Conversation State limit.
|
||||||
Defaults to 10.
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
is_activate (bool): Convsersation State limit
|
page_size (int): Number of items per page. Defaults to 20.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Conversation]: A list of conversation entities ordered by
|
tuple[list[Conversation], int]: A list of conversation entities and total count.
|
||||||
last updated time (descending).
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"Fetching conversation by user_id: {user_id}")
|
logger.info(f"Fetching conversation by user_id: {user_id}")
|
||||||
|
|
||||||
@@ -122,18 +122,25 @@ class ConversationRepository:
|
|||||||
if workspace_id:
|
if workspace_id:
|
||||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||||
|
|
||||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
# Calculate total count
|
||||||
stmt = stmt.limit(limit)
|
total = int(self.db.execute(
|
||||||
|
select(func.count()).select_from(stmt.subquery())
|
||||||
|
).scalar_one())
|
||||||
|
|
||||||
convsersations = list(self.db.scalars(stmt).all())
|
# Apply ordering and pagination
|
||||||
|
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||||
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||||
|
|
||||||
|
conversations = list(self.db.scalars(stmt).all())
|
||||||
logger.info(
|
logger.info(
|
||||||
"Conversation fetched successfully",
|
"Conversation fetched successfully",
|
||||||
extra={
|
extra={
|
||||||
"user_id": str(user_id),
|
"user_id": str(user_id),
|
||||||
"workspace_id": str(workspace_id),
|
"workspace_id": str(workspace_id),
|
||||||
|
"total": total,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return convsersations
|
return conversations, total
|
||||||
|
|
||||||
def list_conversations(
|
def list_conversations(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -17,12 +17,17 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
GET_ALL_ENTITY_IDS_FOR_USER,
|
GET_ALL_ENTITY_IDS_FOR_USER,
|
||||||
GET_ENTITIES_PAGE,
|
GET_ENTITIES_PAGE,
|
||||||
GET_COMMUNITY_MEMBERS,
|
GET_COMMUNITY_MEMBERS,
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS,
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||||
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||||
CHECK_USER_HAS_COMMUNITIES,
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
GET_INCOMPLETE_COMMUNITIES,
|
||||||
|
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING,
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE,
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING,
|
||||||
BATCH_UPDATE_COMMUNITY_METADATA,
|
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -177,7 +182,7 @@ class CommunityRepository:
|
|||||||
async def get_community_members(
|
async def get_community_members(
|
||||||
self, community_id: str, end_user_id: str
|
self, community_id: str, end_user_id: str
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""查询社区成员列表。"""
|
"""查询社区成员列表(含 example 字段)。"""
|
||||||
try:
|
try:
|
||||||
return await self.connector.execute_query(
|
return await self.connector.execute_query(
|
||||||
GET_COMMUNITY_MEMBERS,
|
GET_COMMUNITY_MEMBERS,
|
||||||
@@ -188,6 +193,20 @@ class CommunityRepository:
|
|||||||
logger.error(f"get_community_members failed: {e}")
|
logger.error(f"get_community_members failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_community_relationships(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""查询社区内实体间的关系三元组(subject, predicate, object)。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_community_relationships failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def get_all_community_members_batch(
|
async def get_all_community_members_batch(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> Dict[str, List[Dict]]:
|
||||||
@@ -234,6 +253,31 @@ class CommunityRepository:
|
|||||||
logger.error(f"refresh_member_count failed: {e}")
|
logger.error(f"refresh_member_count failed: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]:
|
||||||
|
"""查询该用户下属性不完整的 Community 节点 ID 列表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户 ID
|
||||||
|
check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES
|
||||||
|
result = await self.connector.execute_query(query, end_user_id=end_user_id)
|
||||||
|
return [row["community_id"] for row in result]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_incomplete_communities failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def is_community_complete(self, community_id: str, end_user_id: str, check_embedding: bool = False) -> bool:
|
||||||
|
"""检查单个社区节点的属性是否完整。"""
|
||||||
|
try:
|
||||||
|
query = CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING if check_embedding else CHECK_COMMUNITY_IS_COMPLETE
|
||||||
|
result = await self.connector.execute_query(query, community_id=community_id, end_user_id=end_user_id)
|
||||||
|
return result[0]["is_complete"] if result else False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"is_community_complete failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def update_community_metadata(
|
async def update_community_metadata(
|
||||||
self,
|
self,
|
||||||
community_id: str,
|
community_id: str,
|
||||||
@@ -243,7 +287,7 @@ class CommunityRepository:
|
|||||||
core_entities: List[str],
|
core_entities: List[str],
|
||||||
summary_embedding: Optional[List[float]] = None,
|
summary_embedding: Optional[List[float]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
|
"""更新社区的名称、摘要、核心实体列表及 summary_embedding。"""
|
||||||
try:
|
try:
|
||||||
result = await self.connector.execute_query(
|
result = await self.connector.execute_query(
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
|||||||
@@ -1137,10 +1137,20 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(
|
|||||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
e.aliases AS aliases, e.description AS description
|
e.aliases AS aliases, e.description AS description,
|
||||||
|
e.example AS example
|
||||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS = """
|
||||||
|
MATCH (e1:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
|
MATCH (e2:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
MATCH (e1)-[r:EXTRACTED_RELATIONSHIP]->(e2)
|
||||||
|
RETURN e1.name AS subject, r.predicate AS predicate, e2.name AS object
|
||||||
|
ORDER BY e1.name, r.predicate, e2.name
|
||||||
|
LIMIT 20
|
||||||
|
"""
|
||||||
|
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
RETURN c.community_id AS community_id,
|
RETURN c.community_id AS community_id,
|
||||||
@@ -1316,3 +1326,38 @@ RETURN s.statement AS statement,
|
|||||||
ORDER BY COALESCE(s.activation_value, 0) DESC
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
RETURN (
|
||||||
|
c.name IS NOT NULL AND c.name <> '' AND
|
||||||
|
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||||
|
c.core_entities IS NOT NULL
|
||||||
|
) AS is_complete
|
||||||
|
"""
|
||||||
|
|
||||||
|
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
RETURN (
|
||||||
|
c.name IS NOT NULL AND c.name <> '' AND
|
||||||
|
c.summary IS NOT NULL AND c.summary <> '' AND
|
||||||
|
c.core_entities IS NOT NULL AND
|
||||||
|
c.summary_embedding IS NOT NULL
|
||||||
|
) AS is_complete
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_INCOMPLETE_COMMUNITIES = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||||
|
OR c.name = '' OR c.summary = ''
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
WHERE c.name IS NULL OR c.name = ''
|
||||||
|
OR c.summary IS NULL OR c.summary = ''
|
||||||
|
OR c.core_entities IS NULL
|
||||||
|
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
@@ -304,7 +305,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
|
|
||||||
def schedule_clustering_after_write(
|
def schedule_clustering_after_write(
|
||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -325,13 +325,12 @@ def schedule_clustering_after_write(
|
|||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
new_entity_ids: List[str],
|
new_entity_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
embedding_model_id: Optional[str] = None,
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -343,7 +342,7 @@ async def _trigger_clustering(
|
|||||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class WorkflowConfigRepository:
|
|||||||
edges: list[dict[str, Any]],
|
edges: list[dict[str, Any]],
|
||||||
variables: list[dict[str, Any]] | None = None,
|
variables: list[dict[str, Any]] | None = None,
|
||||||
execution_config: dict[str, Any] | None = None,
|
execution_config: dict[str, Any] | None = None,
|
||||||
|
features: dict[str, Any] | None = None,
|
||||||
triggers: list[dict[str, Any]] | None = None
|
triggers: list[dict[str, Any]] | None = None
|
||||||
) -> WorkflowConfig:
|
) -> WorkflowConfig:
|
||||||
"""创建或更新工作流配置
|
"""创建或更新工作流配置
|
||||||
@@ -53,6 +54,7 @@ class WorkflowConfigRepository:
|
|||||||
edges: 边列表
|
edges: 边列表
|
||||||
variables: 变量列表
|
variables: 变量列表
|
||||||
execution_config: 执行配置
|
execution_config: 执行配置
|
||||||
|
features: 功能特性
|
||||||
triggers: 触发器列表
|
triggers: 触发器列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -82,6 +84,7 @@ class WorkflowConfigRepository:
|
|||||||
edges=edges,
|
edges=edges,
|
||||||
variables=variables or [],
|
variables=variables or [],
|
||||||
execution_config=execution_config or {},
|
execution_config=execution_config or {},
|
||||||
|
features=features or {},
|
||||||
triggers=triggers or []
|
triggers=triggers or []
|
||||||
)
|
)
|
||||||
self.db.add(config)
|
self.db.add(config)
|
||||||
|
|||||||
@@ -149,18 +149,26 @@ class FileUploadConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
||||||
document_enabled: bool = Field(default=False)
|
document_enabled: bool = Field(default=False)
|
||||||
document_max_size_mb: int = Field(default=100)
|
document_max_size_mb: int = Field(default=50)
|
||||||
document_allowed_extensions: List[str] = Field(
|
document_allowed_extensions: List[str] = Field(
|
||||||
default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"]
|
default=["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"]
|
||||||
)
|
)
|
||||||
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
||||||
video_enabled: bool = Field(default=False)
|
video_enabled: bool = Field(default=False)
|
||||||
video_max_size_mb: int = Field(default=500)
|
video_max_size_mb: int = Field(default=50)
|
||||||
video_allowed_extensions: List[str] = Field(
|
video_allowed_extensions: List[str] = Field(
|
||||||
default=["mp4", "mov"]
|
default=["mp4"]
|
||||||
)
|
)
|
||||||
# 最大文件数量
|
# 最大文件数量
|
||||||
max_file_count: int = Field(default=5, ge=1, le=20)
|
max_file_count: int = Field(default=5, ge=1)
|
||||||
|
|
||||||
|
@field_validator("max_file_count")
|
||||||
|
@classmethod
|
||||||
|
def validate_max_file_count(cls, v: int) -> int:
|
||||||
|
from app.core.config import settings
|
||||||
|
if v > settings.MAX_FILE_COUNT:
|
||||||
|
raise ValueError(f"max_file_count 不能超过 {settings.MAX_FILE_COUNT}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class OpeningStatementConfig(BaseModel):
|
class OpeningStatementConfig(BaseModel):
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class MemoryWriteRequest(BaseModel):
|
|||||||
"""
|
"""
|
||||||
end_user_id: str = Field(..., description="End user ID (required)")
|
end_user_id: str = Field(..., description="End user ID (required)")
|
||||||
message: str = Field(..., description="Message content to store")
|
message: str = Field(..., description="Message content to store")
|
||||||
config_id: Optional[str] = Field(None, description="Memory configuration ID")
|
config_id: str = Field(..., description="Memory configuration ID (required)")
|
||||||
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
||||||
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class MemoryReadRequest(BaseModel):
|
|||||||
"0",
|
"0",
|
||||||
description="Search mode: 0=verify, 1=direct, 2=context"
|
description="Search mode: 0=verify, 1=direct, 2=context"
|
||||||
)
|
)
|
||||||
config_id: Optional[str] = Field(None, description="Memory configuration ID")
|
config_id: str = Field(..., description="Memory configuration ID (required)")
|
||||||
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
||||||
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
||||||
|
|
||||||
@@ -132,3 +132,79 @@ class MemoryReadResponse(BaseModel):
|
|||||||
description="Intermediate retrieval outputs"
|
description="Intermediate retrieval outputs"
|
||||||
)
|
)
|
||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEndUserRequest(BaseModel):
|
||||||
|
"""Request schema for creating an end user.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
workspace_id: Workspace ID (required)
|
||||||
|
other_id: External user identifier (required)
|
||||||
|
other_name: Display name for the end user
|
||||||
|
"""
|
||||||
|
workspace_id: str = Field(..., description="Workspace ID (required)")
|
||||||
|
other_id: str = Field(..., description="External user identifier (required)")
|
||||||
|
other_name: Optional[str] = Field("", description="Display name")
|
||||||
|
|
||||||
|
@field_validator("workspace_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_workspace_id(cls, v: str) -> str:
|
||||||
|
"""Validate that workspace_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("workspace_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("other_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_other_id(cls, v: str) -> str:
|
||||||
|
"""Validate that other_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("other_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEndUserResponse(BaseModel):
|
||||||
|
"""Response schema for end user creation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Created end user UUID
|
||||||
|
other_id: External user identifier
|
||||||
|
other_name: Display name
|
||||||
|
workspace_id: Workspace the user belongs to
|
||||||
|
"""
|
||||||
|
id: str = Field(..., description="End user UUID")
|
||||||
|
other_id: str = Field(..., description="External user identifier")
|
||||||
|
other_name: str = Field("", description="Display name")
|
||||||
|
workspace_id: str = Field(..., description="Workspace ID")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConfigItem(BaseModel):
|
||||||
|
"""Schema for a single memory config in the list response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID
|
||||||
|
config_name: Configuration name
|
||||||
|
config_desc: Configuration description
|
||||||
|
is_default: Whether this is the workspace default config
|
||||||
|
scene_name: Associated ontology scene name
|
||||||
|
created_at: Creation timestamp
|
||||||
|
updated_at: Last update timestamp
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID")
|
||||||
|
config_name: str = Field(..., description="Configuration name")
|
||||||
|
config_desc: Optional[str] = Field(None, description="Configuration description")
|
||||||
|
is_default: bool = Field(False, description="Whether this is the workspace default")
|
||||||
|
scene_name: Optional[str] = Field(None, description="Associated ontology scene name")
|
||||||
|
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||||
|
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ListConfigsResponse(BaseModel):
|
||||||
|
"""Response schema for listing memory configs.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
configs: List of memory config items
|
||||||
|
total: Total number of configs
|
||||||
|
"""
|
||||||
|
configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs")
|
||||||
|
total: int = Field(0, description="Total number of configs")
|
||||||
|
|||||||
@@ -118,28 +118,54 @@ class AppChatService:
|
|||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_info = ModelInfo(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
capability=api_key_obj.capability,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
messages = self.conversation_service.get_messages(
|
messages = self.conversation_service.get_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
history = [
|
history = []
|
||||||
{"role": msg.role, "content": msg.content}
|
for msg in messages:
|
||||||
for msg in messages
|
content = [{"type": "text", "text": msg.content}]
|
||||||
]
|
|
||||||
|
# 处理 meta_data 中的 files
|
||||||
|
if msg.meta_data and msg.meta_data.get("files"):
|
||||||
|
files = msg.meta_data.get("files", [])
|
||||||
|
# 使用 MultimodalService 处理文件
|
||||||
|
multimodal_service = MultimodalService(self.db, api_config=model_info)
|
||||||
|
|
||||||
|
# 将 files 转换为 FileInput 格式
|
||||||
|
file_inputs = []
|
||||||
|
for file in files:
|
||||||
|
from app.schemas.app_schema import FileInput, TransferMethod
|
||||||
|
file_input = FileInput(
|
||||||
|
type=file.get("type"),
|
||||||
|
transfer_method=TransferMethod.REMOTE_URL,
|
||||||
|
url=file.get("url")
|
||||||
|
)
|
||||||
|
file_inputs.append(file_input)
|
||||||
|
|
||||||
|
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||||
|
|
||||||
|
content.extend(history_processed_files)
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
|
||||||
# 处理多模态文件
|
# 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
model_info = ModelInfo(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
capability=api_key_obj.capability,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(user_id, files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
@@ -313,31 +339,54 @@ class AppChatService:
|
|||||||
streaming=True
|
streaming=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_info = ModelInfo(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
capability=api_key_obj.capability,
|
||||||
|
is_omni=api_key_obj.is_omni,
|
||||||
|
model_type=ModelType.LLM
|
||||||
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
|
messages = self.conversation_service.get_messages(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
history = []
|
history = []
|
||||||
memory_config = {"enabled": True, 'max_history': 10}
|
for msg in messages:
|
||||||
if memory_config.get("enabled"):
|
content = [{"type": "text", "text": msg.content}]
|
||||||
messages = self.conversation_service.get_messages(
|
|
||||||
conversation_id=conversation_id,
|
# 处理 meta_data 中的 files
|
||||||
limit=memory_config.get("max_history", 10)
|
if msg.meta_data and msg.meta_data.get("files"):
|
||||||
)
|
history_files = msg.meta_data.get("files", [])
|
||||||
history = [
|
# 使用 MultimodalService 处理文件
|
||||||
{"role": msg.role, "content": msg.content}
|
multimodal_service = MultimodalService(self.db, api_config=model_info)
|
||||||
for msg in messages
|
|
||||||
]
|
# 将 files 转换为 FileInput 格式
|
||||||
|
file_inputs = []
|
||||||
|
for file in history_files:
|
||||||
|
from app.schemas.app_schema import FileInput, TransferMethod
|
||||||
|
file_input = FileInput(
|
||||||
|
type=file.get("type"),
|
||||||
|
transfer_method=TransferMethod.REMOTE_URL,
|
||||||
|
url=file.get("url")
|
||||||
|
)
|
||||||
|
file_inputs.append(file_input)
|
||||||
|
|
||||||
|
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||||
|
|
||||||
|
content.extend(history_processed_files)
|
||||||
|
|
||||||
|
history.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": content
|
||||||
|
})
|
||||||
|
|
||||||
# 处理多模态文件
|
# 处理多模态文件
|
||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
model_info = ModelInfo(
|
|
||||||
model_name=api_key_obj.model_name,
|
|
||||||
provider=api_key_obj.provider,
|
|
||||||
api_key=api_key_obj.api_key,
|
|
||||||
api_base=api_key_obj.api_base,
|
|
||||||
capability=api_key_obj.capability,
|
|
||||||
is_omni=api_key_obj.is_omni,
|
|
||||||
model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(user_id, files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
@@ -347,8 +396,14 @@ class AppChatService:
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
text_queue: asyncio.Queue = asyncio.Queue()
|
text_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
api_key_config = {
|
||||||
|
"model_name": api_key_obj.model_name,
|
||||||
|
"api_key": api_key_obj.api_key,
|
||||||
|
"api_base": api_key_obj.api_base,
|
||||||
|
"provider": api_key_obj.provider,
|
||||||
|
}
|
||||||
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||||
features_config, api_key_obj,
|
features_config, api_key_config,
|
||||||
text_queue=text_queue,
|
text_queue=text_queue,
|
||||||
tenant_id=tenant_id, workspace_id=workspace_id
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from app.models.app_release_model import AppRelease
|
|||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
from app.models.models_model import ModelConfig
|
from app.models.models_model import ModelConfig
|
||||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||||
|
from app.models.skill_model import Skill
|
||||||
from app.models.workflow_model import WorkflowConfig
|
from app.models.workflow_model import WorkflowConfig
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||||
@@ -84,7 +85,9 @@ class AppDslService:
|
|||||||
if "knowledge_retrieval" in cfg:
|
if "knowledge_retrieval" in cfg:
|
||||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||||
if "tools" in cfg:
|
if "tools" in cfg:
|
||||||
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
enriched["tools"] = self._enrich_tools(cfg.get("tools"))
|
||||||
|
if "skills" in cfg:
|
||||||
|
enriched["skills"] = self._enrich_skills(cfg.get("skills"))
|
||||||
return enriched
|
return enriched
|
||||||
if app_type == AppType.MULTI_AGENT:
|
if app_type == AppType.MULTI_AGENT:
|
||||||
enriched = {**cfg}
|
enriched = {**cfg}
|
||||||
@@ -108,6 +111,7 @@ class AppDslService:
|
|||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"edges": config.edges if config else [],
|
"edges": config.edges if config else [],
|
||||||
"nodes": config.nodes if config else [],
|
"nodes": config.nodes if config else [],
|
||||||
|
"features": config.features if config else {},
|
||||||
"execution_config": config.execution_config if config else {},
|
"execution_config": config.execution_config if config else {},
|
||||||
"triggers": config.triggers if config else [],
|
"triggers": config.triggers if config else [],
|
||||||
} if config else {}
|
} if config else {}
|
||||||
@@ -123,7 +127,8 @@ class AppDslService:
|
|||||||
"memory": config.memory if config else None,
|
"memory": config.memory if config else None,
|
||||||
"variables": config.variables if config else [],
|
"variables": config.variables if config else [],
|
||||||
"tools": self._enrich_tools(config.tools) if config else [],
|
"tools": self._enrich_tools(config.tools) if config else [],
|
||||||
"skills": config.skills if config else {},
|
"skills": self._enrich_skills(config.skills) if config else {},
|
||||||
|
"features": config.features if config else {}
|
||||||
} if config else {}
|
} if config else {}
|
||||||
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||||
|
|
||||||
@@ -185,6 +190,22 @@ class AppDslService:
|
|||||||
def _enrich_tools(self, tools: list) -> list:
|
def _enrich_tools(self, tools: list) -> list:
|
||||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||||
|
|
||||||
|
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||||
|
if not skill_id:
|
||||||
|
return None
|
||||||
|
s = self.db.query(Skill).filter(Skill.id == skill_id).first()
|
||||||
|
return {"id": str(skill_id), "name": s.name} if s else {"id": str(skill_id)}
|
||||||
|
|
||||||
|
def _enrich_skills(self, skills: Optional[dict]) -> Optional[dict]:
|
||||||
|
if not skills:
|
||||||
|
return skills
|
||||||
|
skill_ids = skills.get("skill_ids", [])
|
||||||
|
enriched_ids = [
|
||||||
|
{"id": sid, "_ref": self._skill_ref(sid)}
|
||||||
|
for sid in (skill_ids or [])
|
||||||
|
]
|
||||||
|
return {**skills, "skill_ids": enriched_ids}
|
||||||
|
|
||||||
def _agent_ref(self, agent_id) -> Optional[dict]:
|
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||||
if not agent_id:
|
if not agent_id:
|
||||||
return None
|
return None
|
||||||
@@ -249,7 +270,8 @@ class AppDslService:
|
|||||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||||
variables=cfg.get("variables", []),
|
variables=cfg.get("variables", []),
|
||||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||||
skills=cfg.get("skills", {}),
|
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
|
||||||
|
features=cfg.get("features", {}),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -290,6 +312,7 @@ class AppDslService:
|
|||||||
edges=[e.model_dump() for e in result.edges],
|
edges=[e.model_dump() for e in result.edges],
|
||||||
variables=[v.model_dump() for v in result.variables],
|
variables=[v.model_dump() for v in result.variables],
|
||||||
execution_config=wf.get("execution_config", {}),
|
execution_config=wf.get("execution_config", {}),
|
||||||
|
features=wf.get("features", {}),
|
||||||
triggers=wf.get("triggers", []),
|
triggers=wf.get("triggers", []),
|
||||||
validate=False,
|
validate=False,
|
||||||
)
|
)
|
||||||
@@ -444,6 +467,46 @@ class AppDslService:
|
|||||||
return {**memory, "memory_config_id": None, "enabled": False}
|
return {**memory, "memory_config_id": None, "enabled": False}
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
def _resolve_skills(self, skills: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> dict:
|
||||||
|
if not skills:
|
||||||
|
return skills or {}
|
||||||
|
resolved_ids = []
|
||||||
|
for entry in (skills.get("skill_ids") or []):
|
||||||
|
# entry 可能是 {"id": "...", "_ref": {...}} 或直接是字符串
|
||||||
|
if isinstance(entry, dict):
|
||||||
|
ref = entry.get("_ref") or ({"name": None, "id": entry.get("id")} if entry.get("id") else None)
|
||||||
|
skill_id = self._resolve_skill(ref, tenant_id, warnings)
|
||||||
|
else:
|
||||||
|
skill_id = self._resolve_skill({"id": str(entry)}, tenant_id, warnings)
|
||||||
|
if skill_id:
|
||||||
|
resolved_ids.append(str(skill_id))
|
||||||
|
return {**{k: v for k, v in skills.items() if k != "skill_ids"}, "skill_ids": resolved_ids}
|
||||||
|
|
||||||
|
def _resolve_skill(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||||
|
if not ref:
|
||||||
|
return None
|
||||||
|
# 先按 id 匹配
|
||||||
|
if ref.get("id"):
|
||||||
|
try:
|
||||||
|
s = self.db.query(Skill).filter(
|
||||||
|
Skill.id == uuid.UUID(str(ref["id"])),
|
||||||
|
Skill.tenant_id == tenant_id
|
||||||
|
).first()
|
||||||
|
if s:
|
||||||
|
return str(s.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# 再按名称匹配
|
||||||
|
if ref.get("name"):
|
||||||
|
s = self.db.query(Skill).filter(
|
||||||
|
Skill.name == ref["name"],
|
||||||
|
Skill.tenant_id == tenant_id
|
||||||
|
).first()
|
||||||
|
if s:
|
||||||
|
return str(s.id)
|
||||||
|
warnings.append(f"未找到技能: {ref}")
|
||||||
|
return None
|
||||||
|
|
||||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||||
result = []
|
result = []
|
||||||
for t in (tools or []):
|
for t in (tools or []):
|
||||||
|
|||||||
@@ -833,8 +833,6 @@ class AppService:
|
|||||||
|
|
||||||
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
|
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
|
||||||
target_tenant_id = None
|
target_tenant_id = None
|
||||||
available_model_ids: set = set()
|
|
||||||
available_kb_ids: set = set()
|
|
||||||
if is_cross_workspace:
|
if is_cross_workspace:
|
||||||
target_ws = self.db.get(Workspace, target_workspace_id)
|
target_ws = self.db.get(Workspace, target_workspace_id)
|
||||||
if not target_ws:
|
if not target_ws:
|
||||||
@@ -849,28 +847,29 @@ class AppService:
|
|||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
if is_cross_workspace:
|
if is_cross_workspace:
|
||||||
# Batch-collect and preload all referenced resources
|
# 跨工作空间:model/tools/skills 属于 tenant 级别直接保留,
|
||||||
model_ids, kb_ids = self._collect_resource_ids_from_config(
|
# knowledge_bases 属于 workspace 级别需过滤,memory_config 需清空
|
||||||
source_config.default_model_config_id,
|
_, kb_ids = self._collect_resource_ids_from_config(
|
||||||
source_config.knowledge_retrieval,
|
None, source_config.knowledge_retrieval
|
||||||
source_config.tools
|
|
||||||
)
|
)
|
||||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
_, available_kb_ids = self._preload_cross_workspace_resources(
|
||||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
target_tenant_id, target_workspace_id, set(), kb_ids
|
||||||
)
|
|
||||||
new_model_config_id = self._is_model_available(
|
|
||||||
source_config.default_model_config_id, available_model_ids
|
|
||||||
)
|
)
|
||||||
|
new_model_config_id = source_config.default_model_config_id
|
||||||
new_knowledge_retrieval = self._clean_knowledge_retrieval(
|
new_knowledge_retrieval = self._clean_knowledge_retrieval(
|
||||||
source_config.knowledge_retrieval, available_kb_ids
|
source_config.knowledge_retrieval, available_kb_ids
|
||||||
)
|
)
|
||||||
new_tools = self._clean_tools(
|
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||||
source_config.tools, available_kb_ids
|
new_memory = self._clean_memory_cross_workspace(
|
||||||
|
source_config.memory, target_workspace_id
|
||||||
)
|
)
|
||||||
|
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||||
else:
|
else:
|
||||||
new_model_config_id = source_config.default_model_config_id
|
new_model_config_id = source_config.default_model_config_id
|
||||||
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
|
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
|
||||||
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||||
|
new_memory = copy.deepcopy(source_config.memory) if source_config.memory else None
|
||||||
|
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||||
|
|
||||||
new_config = AgentConfig(
|
new_config = AgentConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
@@ -879,9 +878,11 @@ class AppService:
|
|||||||
default_model_config_id=new_model_config_id,
|
default_model_config_id=new_model_config_id,
|
||||||
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||||
knowledge_retrieval=new_knowledge_retrieval,
|
knowledge_retrieval=new_knowledge_retrieval,
|
||||||
memory=copy.deepcopy(source_config.memory) if source_config.memory else None,
|
memory=new_memory,
|
||||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||||
tools=new_tools,
|
tools=new_tools,
|
||||||
|
skills=new_skills,
|
||||||
|
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -894,28 +895,14 @@ class AppService:
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
if is_cross_workspace:
|
|
||||||
model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes(
|
|
||||||
source_config.nodes
|
|
||||||
)
|
|
||||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
|
||||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
|
||||||
)
|
|
||||||
new_nodes = self._clean_workflow_nodes_for_cross_workspace(
|
|
||||||
source_config.nodes or [],
|
|
||||||
available_model_ids,
|
|
||||||
available_kb_ids
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else []
|
|
||||||
|
|
||||||
new_config = WorkflowConfig(
|
new_config = WorkflowConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
app_id=new_app.id,
|
app_id=new_app.id,
|
||||||
nodes=new_nodes,
|
nodes=copy.deepcopy(source_config.nodes) if source_config.nodes else [],
|
||||||
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
|
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
|
||||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||||
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
||||||
|
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||||
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
|
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -929,24 +916,15 @@ class AppService:
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
if is_cross_workspace:
|
# multi_agent 的 model_config_id/sub_agents/routing_rules 均属于 tenant 级别直接保留
|
||||||
model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set()
|
# 跨空间时 master_agent_id(AppRelease)属于源空间,需清空
|
||||||
available_model_ids, _ = self._preload_cross_workspace_resources(
|
|
||||||
target_tenant_id, target_workspace_id, model_ids, set()
|
|
||||||
)
|
|
||||||
new_model_config_id = self._is_model_available(
|
|
||||||
source_config.default_model_config_id, available_model_ids
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_model_config_id = source_config.default_model_config_id
|
|
||||||
|
|
||||||
new_config = MultiAgentConfig(
|
new_config = MultiAgentConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
app_id=new_app.id,
|
app_id=new_app.id,
|
||||||
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
|
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
|
||||||
master_agent_name=source_config.master_agent_name,
|
master_agent_name=source_config.master_agent_name,
|
||||||
default_model_config_id=new_model_config_id,
|
default_model_config_id=source_config.default_model_config_id,
|
||||||
model_parameters=source_config.model_parameters,
|
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||||
orchestration_mode=source_config.orchestration_mode,
|
orchestration_mode=source_config.orchestration_mode,
|
||||||
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
|
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
|
||||||
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
|
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
|
||||||
@@ -1037,8 +1015,7 @@ class AppService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _collect_resource_ids_from_config(
|
def _collect_resource_ids_from_config(
|
||||||
model_config_id: Optional[uuid.UUID],
|
model_config_id: Optional[uuid.UUID],
|
||||||
knowledge_retrieval: Optional[dict],
|
knowledge_retrieval: Optional[dict]
|
||||||
tools: Optional[list]
|
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""Extract all model config IDs and knowledge base IDs from an app config."""
|
"""Extract all model config IDs and knowledge base IDs from an app config."""
|
||||||
model_ids: set = set()
|
model_ids: set = set()
|
||||||
@@ -1048,62 +1025,12 @@ class AppService:
|
|||||||
model_ids.add(model_config_id)
|
model_ids.add(model_config_id)
|
||||||
|
|
||||||
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
|
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
|
||||||
if "kb_ids" in knowledge_retrieval:
|
if "knowledge_bases" in knowledge_retrieval:
|
||||||
for kid in knowledge_retrieval.get("kb_ids", []):
|
for kid in knowledge_retrieval.get("knowledge_bases", []):
|
||||||
if kid:
|
kb_ids.add(str(kid.get("kb_id")))
|
||||||
kb_ids.add(str(kid))
|
|
||||||
if knowledge_retrieval.get("knowledge_id"):
|
|
||||||
kb_ids.add(str(knowledge_retrieval["knowledge_id"]))
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
for tool in tools:
|
|
||||||
if isinstance(tool, dict):
|
|
||||||
kid = tool.get("knowledge_id") or tool.get("kb_id")
|
|
||||||
if kid:
|
|
||||||
kb_ids.add(str(kid))
|
|
||||||
|
|
||||||
return model_ids, kb_ids
|
return model_ids, kb_ids
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple:
|
|
||||||
"""Extract all model config IDs and knowledge base IDs from workflow nodes."""
|
|
||||||
model_ids: set = set()
|
|
||||||
kb_ids: set = set()
|
|
||||||
|
|
||||||
for node in (nodes or []):
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
continue
|
|
||||||
data = node.get("data", {})
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
for key in ("model_config_id", "default_model_config_id"):
|
|
||||||
val = data.get(key)
|
|
||||||
if val:
|
|
||||||
try:
|
|
||||||
model_ids.add(uuid.UUID(str(val)))
|
|
||||||
except (ValueError, AttributeError):
|
|
||||||
pass
|
|
||||||
kr = data.get("knowledge_retrieval")
|
|
||||||
if isinstance(kr, dict):
|
|
||||||
for kid in kr.get("kb_ids", []):
|
|
||||||
if kid:
|
|
||||||
kb_ids.add(str(kid))
|
|
||||||
if kr.get("knowledge_id"):
|
|
||||||
kb_ids.add(str(kr["knowledge_id"]))
|
|
||||||
if data.get("knowledge_id"):
|
|
||||||
kb_ids.add(str(data["knowledge_id"]))
|
|
||||||
for kid in data.get("kb_ids", []):
|
|
||||||
if kid:
|
|
||||||
kb_ids.add(str(kid))
|
|
||||||
|
|
||||||
return model_ids, kb_ids
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]:
|
|
||||||
if not model_config_id:
|
|
||||||
return None
|
|
||||||
return model_config_id if model_config_id in available_model_ids else None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
|
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
|
||||||
if not kb_id:
|
if not kb_id:
|
||||||
@@ -1124,95 +1051,53 @@ class AppService:
|
|||||||
|
|
||||||
cleaned = copy.deepcopy(knowledge_retrieval)
|
cleaned = copy.deepcopy(knowledge_retrieval)
|
||||||
|
|
||||||
if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list):
|
if "knowledge_bases" in cleaned and isinstance(cleaned["knowledge_bases"], list):
|
||||||
cleaned["kb_ids"] = [
|
cleaned["knowledge_bases"] = [
|
||||||
kid for kid in cleaned["kb_ids"]
|
kb for kb in cleaned["knowledge_bases"]
|
||||||
if self._is_kb_available(kid, available_kb_ids)
|
if self._is_kb_available(kb.get("kb_id"), available_kb_ids)
|
||||||
]
|
]
|
||||||
|
|
||||||
if "knowledge_id" in cleaned:
|
|
||||||
cleaned["knowledge_id"] = self._is_kb_available(
|
|
||||||
cleaned.get("knowledge_id"), available_kb_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
def _clean_tools(
|
def _clean_memory_cross_workspace(
|
||||||
self,
|
self,
|
||||||
tools: Optional[list],
|
memory: Optional[dict],
|
||||||
available_kb_ids: set
|
target_workspace_id: uuid.UUID
|
||||||
) -> list:
|
) -> Optional[dict]:
|
||||||
"""Clean tools config, keeping built-in tools and tools with available KBs."""
|
"""Clear memory_config_id/memory_content if it doesn't belong to target workspace."""
|
||||||
if not tools:
|
if not memory:
|
||||||
return []
|
return None
|
||||||
|
|
||||||
cleaned = []
|
from app.models.memory_config_model import MemoryConfig
|
||||||
for tool in tools:
|
|
||||||
if not isinstance(tool, dict):
|
|
||||||
cleaned.append(tool)
|
|
||||||
continue
|
|
||||||
|
|
||||||
tool_type = tool.get("type", "")
|
cleaned = copy.deepcopy(memory)
|
||||||
if tool_type in ("builtin", "built_in", "system"):
|
# 兼容旧字段 memory_content 和新字段 memory_config_id
|
||||||
cleaned.append(copy.deepcopy(tool))
|
mid = cleaned.get("memory_config_id") or cleaned.get("memory_content")
|
||||||
continue
|
if mid:
|
||||||
|
try:
|
||||||
|
mid_uuid = uuid.UUID(str(mid))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
exists = self.db.query(MemoryConfig).filter(
|
||||||
|
MemoryConfig.config_id_old == int(mid),
|
||||||
|
MemoryConfig.workspace_id == target_workspace_id
|
||||||
|
).first()
|
||||||
|
if not exists:
|
||||||
|
cleaned["memory_config_id"] = None
|
||||||
|
cleaned.pop("memory_content", None)
|
||||||
|
cleaned["enabled"] = False
|
||||||
|
return cleaned
|
||||||
|
|
||||||
kb_id = tool.get("knowledge_id") or tool.get("kb_id")
|
exists = self.db.query(
|
||||||
if kb_id:
|
self.db.query(MemoryConfig).filter(
|
||||||
if self._is_kb_available(kb_id, available_kb_ids):
|
MemoryConfig.config_id == mid_uuid,
|
||||||
cleaned.append(copy.deepcopy(tool))
|
MemoryConfig.workspace_id == target_workspace_id
|
||||||
continue
|
).exists()
|
||||||
|
).scalar()
|
||||||
|
if not exists:
|
||||||
|
cleaned["memory_config_id"] = None
|
||||||
|
cleaned.pop("memory_content", None)
|
||||||
|
cleaned["enabled"] = False
|
||||||
|
|
||||||
cleaned.append(copy.deepcopy(tool))
|
|
||||||
|
|
||||||
return cleaned
|
|
||||||
|
|
||||||
def _clean_workflow_nodes_for_cross_workspace(
|
|
||||||
self,
|
|
||||||
nodes: list,
|
|
||||||
available_model_ids: set,
|
|
||||||
available_kb_ids: set
|
|
||||||
) -> list:
|
|
||||||
"""Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source."""
|
|
||||||
if not nodes:
|
|
||||||
return []
|
|
||||||
|
|
||||||
cleaned = []
|
|
||||||
for node in nodes:
|
|
||||||
if not isinstance(node, dict):
|
|
||||||
cleaned.append(node)
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_copy = copy.deepcopy(node)
|
|
||||||
data = node_copy.get("data")
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
cleaned.append(node_copy)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for key in ("model_config_id", "default_model_config_id"):
|
|
||||||
if key in data and data[key]:
|
|
||||||
try:
|
|
||||||
mid = uuid.UUID(str(data[key]))
|
|
||||||
except (ValueError, AttributeError):
|
|
||||||
data[key] = None
|
|
||||||
continue
|
|
||||||
data[key] = str(mid) if mid in available_model_ids else None
|
|
||||||
|
|
||||||
if "knowledge_retrieval" in data and data["knowledge_retrieval"]:
|
|
||||||
data["knowledge_retrieval"] = self._clean_knowledge_retrieval(
|
|
||||||
data["knowledge_retrieval"], available_kb_ids
|
|
||||||
)
|
|
||||||
if "knowledge_id" in data:
|
|
||||||
data["knowledge_id"] = self._is_kb_available(
|
|
||||||
data.get("knowledge_id"), available_kb_ids
|
|
||||||
)
|
|
||||||
if "kb_ids" in data and isinstance(data["kb_ids"], list):
|
|
||||||
data["kb_ids"] = [
|
|
||||||
kid for kid in data["kb_ids"]
|
|
||||||
if self._is_kb_available(kid, available_kb_ids)
|
|
||||||
]
|
|
||||||
|
|
||||||
cleaned.append(node_copy)
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
def list_apps(
|
def list_apps(
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.models.conversation_model import ConversationDetail
|
|||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||||
from app.schemas.conversation_schema import ConversationOut
|
from app.schemas.conversation_schema import ConversationOut
|
||||||
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
@@ -119,25 +120,27 @@ class ConversationService:
|
|||||||
|
|
||||||
def get_user_conversations(
|
def get_user_conversations(
|
||||||
self,
|
self,
|
||||||
user_id: uuid.UUID
|
user_id: uuid.UUID,
|
||||||
) -> list[Conversation]:
|
page: int = 1,
|
||||||
|
page_size: int = 20
|
||||||
|
) -> tuple[list[Conversation], int]:
|
||||||
"""
|
"""
|
||||||
Retrieve recent conversations for a specific user
|
Retrieve recent conversations for a specific user with pagination.
|
||||||
|
|
||||||
This method delegates persistence logic to the repository layer and
|
|
||||||
applies service-level defaults (e.g. recent conversation limit).
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (uuid.UUID): Unique identifier of the user.
|
user_id (uuid.UUID): Unique identifier of the user.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
page_size (int): Number of items per page. Defaults to 20.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Conversation]: A list of recent conversation entities.
|
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
|
||||||
"""
|
"""
|
||||||
conversations = self.conversation_repo.get_conversation_by_user_id(
|
conversations, total = self.conversation_repo.get_conversation_by_user_id(
|
||||||
user_id,
|
user_id,
|
||||||
limit=10
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return conversations
|
return conversations, total
|
||||||
|
|
||||||
def list_conversations(
|
def list_conversations(
|
||||||
self,
|
self,
|
||||||
@@ -267,10 +270,11 @@ class ConversationService:
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_conversation_history(
|
async def get_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
max_history: Optional[int] = None
|
max_history: Optional[int] = None,
|
||||||
|
api_config: Optional[ModelInfo] = None
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
Retrieve historical conversation messages formatted as dictionaries.
|
Retrieve historical conversation messages formatted as dictionaries.
|
||||||
@@ -278,6 +282,7 @@ class ConversationService:
|
|||||||
Args:
|
Args:
|
||||||
conversation_id (uuid.UUID): Conversation UUID.
|
conversation_id (uuid.UUID): Conversation UUID.
|
||||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||||
|
api_config (Optional[ModelInfo]): Model API configuration for multimodal processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||||
@@ -288,13 +293,37 @@ class ConversationService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
history = [
|
history = []
|
||||||
{
|
for msg in messages:
|
||||||
|
content = [{"type": "text", "text": msg.content}]
|
||||||
|
|
||||||
|
# 处理 meta_data 中的 files
|
||||||
|
if msg.meta_data and msg.meta_data.get("files"):
|
||||||
|
files = msg.meta_data.get("files", [])
|
||||||
|
if api_config:
|
||||||
|
# 使用 MultimodalService 处理文件
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
multimodal_service = MultimodalService(self.db, api_config=api_config)
|
||||||
|
|
||||||
|
# 将 files 转换为 FileInput 格式
|
||||||
|
file_inputs = []
|
||||||
|
for file in files:
|
||||||
|
from app.schemas.app_schema import FileInput, TransferMethod
|
||||||
|
file_input = FileInput(
|
||||||
|
type=file.get("type"),
|
||||||
|
transfer_method=TransferMethod.REMOTE_URL,
|
||||||
|
url=file.get("url")
|
||||||
|
)
|
||||||
|
file_inputs.append(file_input)
|
||||||
|
|
||||||
|
processed_files = await multimodal_service.history_process_files(files=file_inputs)
|
||||||
|
|
||||||
|
content.extend(processed_files)
|
||||||
|
|
||||||
|
history.append({
|
||||||
"role": msg.role,
|
"role": msg.role,
|
||||||
"content": msg.content
|
"content": content
|
||||||
}
|
})
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
@@ -522,9 +551,18 @@ class ConversationService:
|
|||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_messages = self.get_conversation_history(
|
conversation_messages = await self.get_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=20
|
max_history=20,
|
||||||
|
api_config=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
capability=api_config.capability,
|
||||||
|
is_omni=api_config.is_omni,
|
||||||
|
model_type=model_type
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(conversation_messages) == 0:
|
if len(conversation_messages) == 0:
|
||||||
return ConversationOut(
|
return ConversationOut(
|
||||||
|
|||||||
@@ -579,9 +579,20 @@ class AgentRunService:
|
|||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_info = ModelInfo(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
provider=api_key_config["provider"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
api_base=api_key_config["api_base"],
|
||||||
|
capability=api_key_config["capability"],
|
||||||
|
is_omni=api_key_config["is_omni"],
|
||||||
|
model_type=model_config.type
|
||||||
|
)
|
||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
api_config=model_info,
|
||||||
max_history=10
|
max_history=10
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -589,15 +600,6 @@ class AgentRunService:
|
|||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
# 获取 provider 信息
|
# 获取 provider 信息
|
||||||
model_info = ModelInfo(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
provider=api_key_config["provider"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
api_base=api_key_config["api_base"],
|
|
||||||
capability=api_key_config["capability"],
|
|
||||||
is_omni=api_key_config["is_omni"],
|
|
||||||
model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(user_id, files)
|
||||||
@@ -815,9 +817,20 @@ class AgentRunService:
|
|||||||
sub_agent=sub_agent
|
sub_agent=sub_agent
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_info = ModelInfo(
|
||||||
|
model_name=api_key_config["model_name"],
|
||||||
|
provider=api_key_config["provider"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
api_base=api_key_config["api_base"],
|
||||||
|
capability=api_key_config["capability"],
|
||||||
|
is_omni=api_key_config["is_omni"],
|
||||||
|
model_type=model_config.type
|
||||||
|
)
|
||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
api_config=model_info,
|
||||||
max_history=memory_config.get("max_history", 10)
|
max_history=memory_config.get("max_history", 10)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -825,15 +838,6 @@ class AgentRunService:
|
|||||||
processed_files = None
|
processed_files = None
|
||||||
if files:
|
if files:
|
||||||
# 获取 provider 信息
|
# 获取 provider 信息
|
||||||
model_info = ModelInfo(
|
|
||||||
model_name=api_key_config["model_name"],
|
|
||||||
provider=api_key_config["provider"],
|
|
||||||
api_key=api_key_config["api_key"],
|
|
||||||
api_base=api_key_config["api_base"],
|
|
||||||
capability=api_key_config["capability"],
|
|
||||||
is_omni=api_key_config["is_omni"],
|
|
||||||
model_type=ModelType.LLM
|
|
||||||
)
|
|
||||||
provider = api_key_config.get("provider", "openai")
|
provider = api_key_config.get("provider", "openai")
|
||||||
multimodal_service = MultimodalService(self.db, model_info)
|
multimodal_service = MultimodalService(self.db, model_info)
|
||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(user_id, files)
|
||||||
@@ -1115,6 +1119,7 @@ class AgentRunService:
|
|||||||
async def _load_conversation_history(
|
async def _load_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
|
api_config: ModelInfo | None = None,
|
||||||
max_history: int = 10
|
max_history: int = 10
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""加载会话历史消息
|
"""加载会话历史消息
|
||||||
@@ -1129,9 +1134,11 @@ class AgentRunService:
|
|||||||
try:
|
try:
|
||||||
|
|
||||||
conversation_service = ConversationService(self.db)
|
conversation_service = ConversationService(self.db)
|
||||||
history = conversation_service.get_conversation_history(
|
# 获取 API 配置用于多模态处理
|
||||||
|
history = await conversation_service.get_conversation_history(
|
||||||
conversation_id=uuid.UUID(conversation_id),
|
conversation_id=uuid.UUID(conversation_id),
|
||||||
max_history=max_history
|
max_history=max_history,
|
||||||
|
api_config=api_config
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
@@ -1179,7 +1179,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
app = db.query(App).filter(App.id == app_id).first()
|
app = db.query(App).filter(App.id == app_id).first()
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found: {app_id}")
|
logger.warning(f"App not found: {app_id}")
|
||||||
raise ValueError(f"应用不存在: {app_id}")
|
# raise ValueError(f"应用不存在: {app_id}")
|
||||||
# TODO: temp fix for draft run
|
# TODO: temp fix for draft run
|
||||||
# if not app.current_release_id:
|
# if not app.current_release_id:
|
||||||
# logger.warning(f"No current release for app: {app_id}")
|
# logger.warning(f"No current release for app: {app_id}")
|
||||||
@@ -1252,17 +1252,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
memory_config_service = MemoryConfigService(db)
|
memory_config_service = MemoryConfigService(db)
|
||||||
memory_config = memory_config_service.get_config_with_fallback(
|
memory_config = memory_config_service.get_config_with_fallback(
|
||||||
memory_config_id=memory_config_id_to_use,
|
memory_config_id=memory_config_id_to_use,
|
||||||
workspace_id=app.workspace_id
|
workspace_id=end_user.workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config_id = str(memory_config.config_id) if memory_config else None
|
memory_config_id = str(memory_config.config_id) if memory_config else None
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"end_user_id": str(end_user_id),
|
"end_user_id": str(end_user_id),
|
||||||
"app_id": str(app_id),
|
|
||||||
"release_id": str(app.current_release_id) if app.current_release_id else None,
|
|
||||||
"memory_config_id": memory_config_id,
|
"memory_config_id": memory_config_id,
|
||||||
"workspace_id": str(app.workspace_id)
|
"workspace_id": str(end_user.workspace_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -84,43 +84,65 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found for end_user: {end_user_id}")
|
logger.warning(f"App not found for end_user: {end_user_id}")
|
||||||
raise ResourceNotFoundException(
|
# raise ResourceNotFoundException(
|
||||||
resource_type="App",
|
# resource_type="App",
|
||||||
resource_id=str(end_user.app_id)
|
# resource_id=str(end_user.app_id)
|
||||||
)
|
# )
|
||||||
|
# temporally allow any workspace to access
|
||||||
if app.workspace_id != workspace_id:
|
# if end_user.workspace_id != workspace_id:
|
||||||
logger.warning(
|
# print(f"[DEBUG] end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}")
|
||||||
f"End user {end_user_id} belongs to workspace {app.workspace_id}, "
|
# logger.warning(
|
||||||
f"not authorized workspace {workspace_id}"
|
# f"End user {end_user_id} belongs to workspace {end_user.workspace_id}, "
|
||||||
)
|
# f"not authorized workspace {workspace_id}"
|
||||||
raise BusinessException(
|
# )
|
||||||
message="End user does not belong to authorized workspace",
|
# raise BusinessException(
|
||||||
code=BizCode.FORBIDDEN
|
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
||||||
)
|
# code=BizCode.FORBIDDEN
|
||||||
|
# )
|
||||||
|
|
||||||
logger.info(f"End user {end_user_id} validated successfully")
|
logger.info(f"End user {end_user_id} validated successfully")
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
|
def _update_end_user_config(self, end_user_id: str, config_id: str) -> None:
|
||||||
|
"""Update the end user's memory_config_id.
|
||||||
|
|
||||||
|
Silently updates the config association. Logs warnings on failure
|
||||||
|
but does not raise, so it won't block the main read/write operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: End user identifier
|
||||||
|
config_id: Memory configuration ID to assign
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config_uuid = uuid.UUID(config_id)
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
end_user_repo = EndUserRepository(self.db)
|
||||||
|
end_user_repo.update_memory_config_id(
|
||||||
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
|
memory_config_id=config_uuid,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: Optional[str] = None,
|
config_id: str,
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Write memory with validation.
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, then delegates
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
to MemoryAgentService.write_memory.
|
memory_config_id, then delegates to MemoryAgentService.write_memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Message content to store
|
message: Message content to store
|
||||||
config_id: Optional memory configuration ID
|
config_id: Memory configuration ID (required)
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
user_rag_memory_id: Optional RAG memory ID
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
@@ -136,7 +158,8 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as end_user_id for memory operations
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
@@ -188,21 +211,21 @@ class MemoryAPIService:
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
search_switch: str = "0",
|
search_switch: str = "0",
|
||||||
config_id: Optional[str] = None,
|
config_id: str = "",
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory with validation.
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, then delegates
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
to MemoryAgentService.read_memory.
|
memory_config_id, then delegates to MemoryAgentService.read_memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Query message
|
message: Query message
|
||||||
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||||
config_id: Optional memory configuration ID
|
config_id: Memory configuration ID (required)
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
user_rag_memory_id: Optional RAG memory ID
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
@@ -218,7 +241,8 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as end_user_id for memory operations
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -256,3 +280,50 @@ class MemoryAPIService:
|
|||||||
message=f"Memory read failed: {str(e)}",
|
message=f"Memory read failed: {str(e)}",
|
||||||
code=BizCode.MEMORY_READ_FAILED
|
code=BizCode.MEMORY_READ_FAILED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_memory_configs(
|
||||||
|
self,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""List all memory configs for a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace ID from API key authorization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with configs list and total count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If listing fails
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing memory configs for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
|
||||||
|
results = MemoryConfigRepository.get_all(self.db, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
for config, scene_name in results:
|
||||||
|
configs.append({
|
||||||
|
"config_id": str(config.config_id),
|
||||||
|
"config_name": config.config_name,
|
||||||
|
"config_desc": config.config_desc,
|
||||||
|
"is_default": config.is_default or False,
|
||||||
|
"scene_name": scene_name,
|
||||||
|
"created_at": config.created_at.isoformat() if config.created_at else None,
|
||||||
|
"updated_at": config.updated_at.isoformat() if config.updated_at else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Found {len(configs)} memory configs for workspace {workspace_id}")
|
||||||
|
return {
|
||||||
|
"configs": configs,
|
||||||
|
"total": len(configs),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list memory configs for workspace {workspace_id}: {e}")
|
||||||
|
raise BusinessException(
|
||||||
|
message=f"Failed to list memory configs: {str(e)}",
|
||||||
|
code=BizCode.MEMORY_READ_FAILED
|
||||||
|
)
|
||||||
|
|||||||
@@ -619,7 +619,7 @@ class MemoryForgetService:
|
|||||||
recent_trends.append({
|
recent_trends.append({
|
||||||
'date': date_str,
|
'date': date_str,
|
||||||
'merged_count': record.merged_count,
|
'merged_count': record.merged_count,
|
||||||
'average_activation': record.average_activation_value,
|
'average_activation': round(record.average_activation_value, 2) if record.average_activation_value is not None else None,
|
||||||
'total_nodes': record.total_nodes,
|
'total_nodes': record.total_nodes,
|
||||||
'execution_time': int(record.execution_time.timestamp() * 1000)
|
'execution_time': int(record.execution_time.timestamp() * 1000)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,6 +11,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import uuid
|
import uuid
|
||||||
|
import zipfile
|
||||||
|
import chardet
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
@@ -42,12 +44,10 @@ PDF_MIME = ['application/pdf']
|
|||||||
DOC_MIME = [
|
DOC_MIME = [
|
||||||
'application/msword',
|
'application/msword',
|
||||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
'application/zip'
|
|
||||||
]
|
]
|
||||||
XLSX_MIME = [
|
XLSX_MIME = [
|
||||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||||
'application/vnd.ms-excel',
|
'application/vnd.ms-excel',
|
||||||
'application/zip'
|
|
||||||
]
|
]
|
||||||
CSV_MIME = ['text/csv', 'application/csv']
|
CSV_MIME = ['text/csv', 'application/csv']
|
||||||
JSON_MIME = ['application/json']
|
JSON_MIME = ['application/json']
|
||||||
@@ -418,6 +418,71 @@ class MultimodalService:
|
|||||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def history_process_files(
|
||||||
|
self,
|
||||||
|
files: Optional[List[FileInput]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理文件列表,返回 LLM 可用的格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 文件输入列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取对应的策略
|
||||||
|
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||||
|
if self.provider == "dashscope" and self.is_omni:
|
||||||
|
strategy_class = OpenAIFormatStrategy
|
||||||
|
else:
|
||||||
|
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||||
|
if not strategy_class:
|
||||||
|
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||||
|
strategy_class = DashScopeFormatStrategy
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for idx, file in enumerate(files):
|
||||||
|
strategy = strategy_class(file)
|
||||||
|
if not file.url:
|
||||||
|
file.url = await self.get_file_url(file)
|
||||||
|
try:
|
||||||
|
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||||
|
is_support, content = await self._process_image(file, strategy)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.DOCUMENT:
|
||||||
|
is_support, content = await self._process_document(file, strategy)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
|
result.append(content)
|
||||||
|
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||||
|
is_support, content = await self._process_video(file, strategy)
|
||||||
|
result.append(content)
|
||||||
|
else:
|
||||||
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"处理文件失败",
|
||||||
|
extra={
|
||||||
|
"file_index": idx,
|
||||||
|
"file_type": file.type,
|
||||||
|
"error": str(e)
|
||||||
|
},
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
# 继续处理其他文件,不中断整个流程
|
||||||
|
result.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[文件处理失败: {str(e)}]"
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||||
|
return result
|
||||||
|
|
||||||
def write_perceptual_memory(
|
def write_perceptual_memory(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -588,12 +653,12 @@ class MultimodalService:
|
|||||||
file.set_content(file_content)
|
file.set_content(file_content)
|
||||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||||
if file_mime_type in TEXT_MIME:
|
if file_mime_type in TEXT_MIME:
|
||||||
return file_content.decode("utf-8")
|
return self._decode_text_safe(file_content)
|
||||||
elif file_mime_type in PDF_MIME:
|
elif file_mime_type in PDF_MIME:
|
||||||
return await self._extract_pdf_text(file_content)
|
return await self._extract_pdf_text(file_content)
|
||||||
elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')):
|
elif self._is_word_file(file_content, file_mime_type):
|
||||||
return await self._extract_word_text(file_content)
|
return await self._extract_word_text(file_content)
|
||||||
elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")):
|
elif self._is_excel_file(file_content, file_mime_type):
|
||||||
return await self._extract_xlsx_text(file_content)
|
return await self._extract_xlsx_text(file_content)
|
||||||
elif file_mime_type in CSV_MIME:
|
elif file_mime_type in CSV_MIME:
|
||||||
return await self._extract_csv_text(file_content)
|
return await self._extract_csv_text(file_content)
|
||||||
@@ -622,52 +687,156 @@ class MultimodalService:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _extract_word_text(file_content: bytes) -> str:
|
async def _extract_word_text(file_content: bytes) -> str:
|
||||||
"""提取 Word 文档文本"""
|
"""提取 Word 文档文本(支持 .docx 和旧版 .doc)"""
|
||||||
|
# 先尝试 docx(ZIP 格式)
|
||||||
|
if file_content[:2] == b'PK':
|
||||||
|
try:
|
||||||
|
word_file = io.BytesIO(file_content)
|
||||||
|
doc = Document(word_file)
|
||||||
|
return '\n'.join(p.text for p in doc.paragraphs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 docx 文本失败: {e}")
|
||||||
|
return f"[docx 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
# 旧版 .doc(OLE2 格式)
|
||||||
try:
|
try:
|
||||||
word_file = io.BytesIO(file_content)
|
import olefile
|
||||||
doc = Document(word_file)
|
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
if not ole.exists('WordDocument'):
|
||||||
return '\n'.join(text_parts)
|
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||||
|
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||||
|
stream = ole.openstream('WordDocument').read()
|
||||||
|
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||||
|
# 简单提取:过滤出可打印字符段
|
||||||
|
try:
|
||||||
|
text = stream.decode('utf-16-le', errors='ignore')
|
||||||
|
except Exception:
|
||||||
|
text = stream.decode('latin-1', errors='ignore')
|
||||||
|
# 过滤控制字符,保留可打印内容
|
||||||
|
import re
|
||||||
|
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||||
|
text = re.sub(r' +', ' ', text).strip()
|
||||||
|
ole.close()
|
||||||
|
return text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 Word 文本失败: {e}")
|
logger.error(f"提取 doc 文本失败: {e}")
|
||||||
return f"[Word 提取失败: {str(e)}]"
|
return f"[doc 提取失败: {str(e)}]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _extract_xlsx_text(file_content: bytes) -> str:
|
async def _extract_xlsx_text(file_content: bytes) -> str:
|
||||||
"""提取 Excel 文本"""
|
"""提取 Excel 文本(支持 .xlsx 和旧版 .xls)"""
|
||||||
|
# xlsx(ZIP 格式)
|
||||||
|
if file_content[:2] == b'PK':
|
||||||
|
try:
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||||
|
parts = []
|
||||||
|
for sheet in wb.worksheets:
|
||||||
|
parts.append(f"[Sheet: {sheet.title}]")
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||||
|
return '\n'.join(parts)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 xlsx 文本失败: {e}")
|
||||||
|
return f"[xlsx 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
# xls(OLE2/BIFF 格式)
|
||||||
try:
|
try:
|
||||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
import xlrd
|
||||||
|
wb = xlrd.open_workbook(file_contents=file_content)
|
||||||
parts = []
|
parts = []
|
||||||
for sheet in wb.worksheets:
|
for sheet in wb.sheets():
|
||||||
parts.append(f"[Sheet: {sheet.title}]")
|
parts.append(f"[Sheet: {sheet.name}]")
|
||||||
for row in sheet.iter_rows(values_only=True):
|
for row_idx in range(sheet.nrows):
|
||||||
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
parts.append('\t'.join(str(sheet.cell_value(row_idx, col)) for col in range(sheet.ncols)))
|
||||||
return '\n'.join(parts)
|
return '\n'.join(parts)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 Excel 文本失败: {e}")
|
logger.error(f"提取 xls 文本失败: {e}")
|
||||||
return f"[Excel 提取失败: {str(e)}]"
|
return f"[xls 提取失败: {str(e)}]"
|
||||||
|
|
||||||
@staticmethod
|
async def _extract_csv_text(self, file_content: bytes) -> str:
|
||||||
async def _extract_csv_text(file_content: bytes) -> str:
|
|
||||||
"""提取 CSV 文本"""
|
"""提取 CSV 文本"""
|
||||||
try:
|
try:
|
||||||
text = file_content.decode('utf-8-sig')
|
text = self._decode_text_safe(file_content)
|
||||||
reader = csv.reader(io.StringIO(text))
|
reader = csv.reader(io.StringIO(text))
|
||||||
return '\n'.join('\t'.join(row) for row in reader)
|
return '\n'.join('\t'.join(row) for row in reader)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 CSV 文本失败: {e}")
|
logger.error(f"提取 CSV 文本失败: {e}")
|
||||||
return f"[CSV 提取失败: {str(e)}]"
|
return f"[CSV 提取失败: {str(e)}]"
|
||||||
|
|
||||||
@staticmethod
|
async def _extract_json_text(self, file_content: bytes) -> str:
|
||||||
async def _extract_json_text(file_content: bytes) -> str:
|
|
||||||
"""提取 JSON 文本"""
|
"""提取 JSON 文本"""
|
||||||
try:
|
try:
|
||||||
data = json.loads(file_content.decode('utf-8'))
|
text = self._decode_text_safe(file_content)
|
||||||
|
data = json.loads(text)
|
||||||
return json.dumps(data, ensure_ascii=False, indent=2)
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 JSON 文本失败: {e}")
|
logger.error(f"提取 JSON 文本失败: {e}")
|
||||||
return f"[JSON 提取失败: {str(e)}]"
|
return f"[JSON 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
def _is_word_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||||
|
"""判断是不是 Word 文件(doc / docx),不依赖后缀"""
|
||||||
|
# 旧版 .doc
|
||||||
|
if mime_type == 'application/msword':
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 新版 .docx(ZIP 内部包含 word/document.xml)
|
||||||
|
header = file_content[:4]
|
||||||
|
if header == b'PK\x03\x04':
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||||
|
return "word/document.xml" in zf.namelist()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_excel_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||||
|
"""判断是不是 Excel 文件(xls / xlsx),不依赖后缀"""
|
||||||
|
# 旧版 .xls
|
||||||
|
if mime_type == 'application/vnd.ms-excel':
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 新版 .xlsx(ZIP 内部包含 xl/workbook.xml)
|
||||||
|
header = file_content[:4]
|
||||||
|
if header == b'PK\x03\x04':
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||||
|
return "xl/workbook.xml" in zf.namelist()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _decode_text_safe(file_content: bytes) -> str:
|
||||||
|
"""
|
||||||
|
【万能文本解码】
|
||||||
|
自动检测编码,支持 utf-8 / gbk / gb2312 / utf-8-sig / ascii 等
|
||||||
|
永远不报错,永远不乱码
|
||||||
|
"""
|
||||||
|
if not file_content:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 1. 自动检测文件编码
|
||||||
|
detect = chardet.detect(file_content)
|
||||||
|
encoding = detect.get("encoding") or "utf-8"
|
||||||
|
encoding = encoding.lower()
|
||||||
|
|
||||||
|
# 2. 兼容常见中文编码
|
||||||
|
compatible_encodings = ["utf-8", "gbk", "gb18030", "gb2312", "ascii", "latin-1"]
|
||||||
|
|
||||||
|
# 3. 按优先级尝试解码
|
||||||
|
for enc in [encoding] + compatible_encodings:
|
||||||
|
if not enc:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
return file_content.decode(enc.strip())
|
||||||
|
except (UnicodeDecodeError, LookupError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 终极兜底
|
||||||
|
return file_content.decode("utf-8", errors="replace")
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||||
"""获取多模态服务实例(依赖注入)"""
|
"""获取多模态服务实例(依赖注入)"""
|
||||||
|
|||||||
@@ -1408,12 +1408,11 @@ async def analytics_memory_types(
|
|||||||
if end_user_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
conversation_repo = ConversationRepository(db)
|
conversation_repo = ConversationRepository(db)
|
||||||
conversations = conversation_repo.get_conversation_by_user_id(
|
conversations, total = conversation_repo.get_conversation_by_user_id(
|
||||||
user_id=uuid.UUID(end_user_id),
|
user_id=uuid.UUID(end_user_id),
|
||||||
limit=100, # 获取更多会话以准确统计
|
|
||||||
is_activate=True
|
is_activate=True
|
||||||
)
|
)
|
||||||
work_count = len(conversations)
|
work_count = total
|
||||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
from app.schemas import DraftRunRequest, FileInput
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
@@ -55,6 +55,7 @@ class WorkflowService:
|
|||||||
edges: list[dict[str, Any]],
|
edges: list[dict[str, Any]],
|
||||||
variables: list[dict[str, Any]] | None = None,
|
variables: list[dict[str, Any]] | None = None,
|
||||||
execution_config: dict[str, Any] | None = None,
|
execution_config: dict[str, Any] | None = None,
|
||||||
|
features: dict[str, Any] | None = None,
|
||||||
triggers: list[dict[str, Any]] | None = None,
|
triggers: list[dict[str, Any]] | None = None,
|
||||||
validate: bool = True
|
validate: bool = True
|
||||||
) -> WorkflowConfig:
|
) -> WorkflowConfig:
|
||||||
@@ -66,6 +67,7 @@ class WorkflowService:
|
|||||||
edges: 边列表
|
edges: 边列表
|
||||||
variables: 变量列表
|
variables: 变量列表
|
||||||
execution_config: 执行配置
|
execution_config: 执行配置
|
||||||
|
features: 功能特性
|
||||||
triggers: 触发器列表
|
triggers: 触发器列表
|
||||||
validate: 是否验证配置
|
validate: 是否验证配置
|
||||||
|
|
||||||
@@ -81,6 +83,7 @@ class WorkflowService:
|
|||||||
"edges": edges,
|
"edges": edges,
|
||||||
"variables": variables or [],
|
"variables": variables or [],
|
||||||
"execution_config": execution_config or {},
|
"execution_config": execution_config or {},
|
||||||
|
"features": features or {},
|
||||||
"triggers": triggers or []
|
"triggers": triggers or []
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,6 +104,7 @@ class WorkflowService:
|
|||||||
edges=edges,
|
edges=edges,
|
||||||
variables=variables,
|
variables=variables,
|
||||||
execution_config=execution_config,
|
execution_config=execution_config,
|
||||||
|
features=features,
|
||||||
triggers=triggers
|
triggers=triggers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2675,13 +2675,15 @@ def write_perceptual_memory(
|
|||||||
time_limit=7200, # 2小时硬超时
|
time_limit=7200, # 2小时硬超时
|
||||||
soft_time_limit=6900,
|
soft_time_limit=6900,
|
||||||
)
|
)
|
||||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||||
|
|
||||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||||
|
任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_ids: 需要检查的用户 ID 列表
|
end_user_ids: 需要检查的用户 ID 列表
|
||||||
|
workspace_id: 工作空间 ID,用于完成标记
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含任务执行结果的字典
|
包含任务执行结果的字典
|
||||||
@@ -2707,6 +2709,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
|||||||
|
|
||||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||||
user_llm_map: Dict[str, Optional[str]] = {}
|
user_llm_map: Dict[str, Optional[str]] = {}
|
||||||
|
user_embedding_map: Dict[str, Optional[str]] = {}
|
||||||
try:
|
try:
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||||
@@ -2718,21 +2721,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
|||||||
try:
|
try:
|
||||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||||
|
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}")
|
||||||
user_llm_map[uid] = None
|
user_llm_map[uid] = None
|
||||||
|
user_embedding_map[uid] = None
|
||||||
else:
|
else:
|
||||||
user_llm_map[uid] = None
|
user_llm_map[uid] = None
|
||||||
|
user_embedding_map[uid] = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}")
|
||||||
|
|
||||||
for end_user_id in end_user_ids:
|
for end_user_id in end_user_ids:
|
||||||
try:
|
try:
|
||||||
# 已有社区节点则跳过
|
# 已有社区节点时,检查是否存在属性不完整的节点
|
||||||
has_communities = await repo.has_communities(end_user_id)
|
has_communities = await repo.has_communities(end_user_id)
|
||||||
if has_communities:
|
if has_communities:
|
||||||
skipped += 1
|
llm_model_id = user_llm_map.get(end_user_id)
|
||||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||||
|
incomplete_ids = await repo.get_incomplete_communities(
|
||||||
|
end_user_id, check_embedding=bool(embedding_model_id)
|
||||||
|
)
|
||||||
|
if not incomplete_ids:
|
||||||
|
skipped += 1
|
||||||
|
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 对不完整的社区节点逐一补全元数据
|
||||||
|
engine = LabelPropagationEngine(
|
||||||
|
connector=connector,
|
||||||
|
llm_model_id=llm_model_id,
|
||||||
|
embedding_model_id=embedding_model_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全"
|
||||||
|
)
|
||||||
|
patch_ok = 0
|
||||||
|
patch_fail = 0
|
||||||
|
for cid in incomplete_ids:
|
||||||
|
try:
|
||||||
|
await engine._generate_community_metadata(cid, end_user_id)
|
||||||
|
patch_ok += 1
|
||||||
|
except Exception as patch_err:
|
||||||
|
patch_fail += 1
|
||||||
|
logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}")
|
||||||
|
logger.info(
|
||||||
|
f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}"
|
||||||
|
)
|
||||||
|
initialized += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否有 ExtractedEntity 节点
|
# 检查是否有 ExtractedEntity 节点
|
||||||
@@ -2742,11 +2778,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
|||||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 每个用户使用自己的 llm_model_id
|
# 每个用户使用自己的 llm_model_id / embedding_model_id
|
||||||
llm_model_id = user_llm_map.get(end_user_id)
|
llm_model_id = user_llm_map.get(end_user_id)
|
||||||
|
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||||
engine = LabelPropagationEngine(
|
engine = LabelPropagationEngine(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
llm_model_id=llm_model_id,
|
llm_model_id=llm_model_id,
|
||||||
|
embedding_model_id=embedding_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||||
|
|||||||
@@ -1,4 +1,38 @@
|
|||||||
{
|
{
|
||||||
|
"v0.2.8": {
|
||||||
|
"introduction": {
|
||||||
|
"codeName": "景玉",
|
||||||
|
"releaseDate": "2026-3-20",
|
||||||
|
"upgradePosition": "🐻 MemoryBear v0.2.8 社区版全面升级应用共享、多模态交互与平台基础设施,引入语音交互、感知记忆和云端存储,打造更强大的开放 AI 记忆平台",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. 应用共享与发布<br>* 应用共享(Agent、工作流、Agent 集群):全类型应用共享至其他空间<br>* 分享应用默认开启记忆功能:发布分享后记忆默认开启,关闭时提醒<br>* 工作流记忆分享规则:按记忆配置自动控制分享页记忆开关<br>* 分享会话联网搜索修复:恢复分享应用的联网搜索能力",
|
||||||
|
"2. 多模态与交互 💬<br>* 语音输入:模型接口和应用支持语音输入<br>* 语音回复:应用支持语音回复模态<br>* 多模态感知记忆:记忆系统支持视觉、音频、图片和文件的感知记忆<br>* 对话框文件展示:试运行和体验分享中正确展示上传文件",
|
||||||
|
"3. 平台与基础设施 ⚙️<br>* i18n 国际化:全面多语言多地区支持<br>* 云端文件存储(OSS + S3):支持阿里云 OSS 和 S3 云端上传<br>* Flower 容器监控:Celery 异步任务监控与管理",
|
||||||
|
"4. EndUser 身份迁移 🔐<br>* EndUser 从 app_id 迁移至 workspace_id:身份从应用级迁移至工作空间级",
|
||||||
|
"5. 情景记忆 🧠<br>* 情景记忆聚类算法:基于社区图谱的聚类算法,支持老用户图谱生成",
|
||||||
|
"6. 稳健性与缺陷修复 🔧<br>* MCP 服务删除后工具 404:修复删除 MCP 服务后接口报错<br>* 应用导出配置不一致:导出已保存配置而非画布状态<br>* 工作流节点 ID 重复:修复复制节点后 ID 冲突<br>* 条件分支连线错误:修复保存刷新后连线错乱<br>* 回复节点内容丢失:修复点击画布后内容消失<br>* 连接桩规则优化:禁止非法连接方向<br>* 知识库状态列宽度:锁定或自适应宽度<br>* 等待中文档预览:支持未完成解析文档预览<br>* 知识库关联修复:统一修复关联问题<br>* 多模态对话连续性:修复多模态内容后无法继续对话<br>* 时区统一:环境变量统一控制存储和任务时区<br>* 遗忘强度精度:修复小数显示过长",
|
||||||
|
"<br>",
|
||||||
|
"v0.2.8 社区版在应用共享和多模态交互方面实现重大升级,感知记忆扩展了平台的认知维度。后续将深化多智能体协作、情景记忆聚类,并持续优化平台稳定性与开放生态。",
|
||||||
|
"MemoryBear —— 让 AI 拥有记忆 🐻✨"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"introduction_en": {
|
||||||
|
"codeName": "JingYu",
|
||||||
|
"releaseDate": "2026-3-20",
|
||||||
|
"upgradePosition": "🐻 MemoryBear v0.2.8 Community delivers multimodal interaction, perceptual memory, cloud storage, and workspace-level identity for a more capable open AI memory platform",
|
||||||
|
"coreUpgrades": [
|
||||||
|
"1. Application Sharing & Publishing<br>* Application Sharing (Agent, Workflow, Agent Cluster): Full sharing across all app types<br>* Memory Enabled by Default: Memory auto-enabled on shared apps with disable reminder<br>* Workflow Memory Sharing Rules: Auto-controlled based on memory configuration<br>* Shared Session Web Search Fix: Restored web search for shared apps",
|
||||||
|
"2. Multimodal & Interaction 💬<br>* Voice Input: Model interfaces and apps support voice input<br>* Voice Reply: Apps support voice reply modality<br>* Multimodal Perceptual Memory: Memory system supports visual, audio, image, and file perception<br>* File Display in Chat: Uploaded files display correctly in dry-run and sharing",
|
||||||
|
"3. Platform & Infrastructure ⚙️<br>* i18n Internationalization: Full multi-language multi-region support<br>* Cloud File Storage (OSS + S3): Alibaba Cloud OSS and S3 cloud uploads<br>* Flower Container Monitoring: Celery async task monitoring and management",
|
||||||
|
"4. EndUser Identity Migration 🔐<br>* EndUser Migration from app_id to workspace_id: Identity migrated to workspace level",
|
||||||
|
"5. Episodic Memory 🧠<br>* Episodic Memory Clustering: Community-graph-based clustering with legacy user support",
|
||||||
|
"6. Robustness & Bug Fixes 🔧<br>* MCP Service Deletion 404: Fixed tool endpoint error after MCP removal<br>* App Export Config Mismatch: Exports saved config instead of canvas state<br>* Workflow Duplicate Node ID: Fixed ID conflict on node duplication<br>* Conditional Branch Wiring: Fixed wiring reset after save/refresh<br>* Reply Node Content Loss: Fixed content disappearing on canvas click<br>* Port Connection Rules: Prohibited invalid connection directions<br>* Knowledge Base Status Width: Locked or adaptive column width<br>* Pending Document Preview: Preview support for unparsed documents<br>* Knowledge Base Association Fixes: Consolidated association fixes<br>* Multimodal Conversation Continuity: Fixed single-round limit after multimodal input<br>* Timezone Unification: Env-var controlled unified timezone<br>* Forgetting Strength Precision: Fixed excessive decimal display",
|
||||||
|
"<br>",
|
||||||
|
"v0.2.8 Community delivers major upgrades in application sharing and multimodal interaction, with perceptual memory expanding the platform's cognitive dimensions. Multi-agent collaboration, episodic clustering, and continued platform stability improvements are ahead.",
|
||||||
|
"MemoryBear — Give AI Memory 🐻✨"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
"v0.2.7": {
|
"v0.2.7": {
|
||||||
"introduction": {
|
"introduction": {
|
||||||
"codeName": "武陵",
|
"codeName": "武陵",
|
||||||
|
|||||||
@@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default():
|
|||||||
"""测试获取不存在的节点输出(使用默认值)"""
|
"""测试获取不存在的节点输出(使用默认值)"""
|
||||||
pool = VariablePool()
|
pool = VariablePool()
|
||||||
|
|
||||||
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
|
result = pool.get_node_output("nonexistent_node", default=None, strict=False)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ export const getKnowledgeBaseTypeList = async (): Promise<string[]> => {
|
|||||||
// 如果不是数组,返回空数组
|
// 如果不是数组,返回空数组
|
||||||
return [];
|
return [];
|
||||||
};
|
};
|
||||||
|
// 获取文件地址
|
||||||
|
export const getFileUrl = (fileId: string) => {
|
||||||
|
return `${apiPrefix}/files/${fileId}`;
|
||||||
|
};
|
||||||
// 知识库文档解析类型
|
// 知识库文档解析类型
|
||||||
export const getKnowledgeBaseDocumentParseTypeList = async () => {
|
export const getKnowledgeBaseDocumentParseTypeList = async () => {
|
||||||
const response = await request.get(`${apiPrefix}/knowledges/parsertype`);
|
const response = await request.get(`${apiPrefix}/knowledges/parsertype`);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:00:06
|
* @Date: 2026-02-03 14:00:06
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-13 10:48:41
|
* @Last Modified time: 2026-03-19 18:35:10
|
||||||
*/
|
*/
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { AxiosRequestConfig } from 'axios'
|
import type { AxiosRequestConfig } from 'axios'
|
||||||
@@ -218,8 +218,8 @@ export const getExplicitMemory = (end_user_id: string) => {
|
|||||||
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
|
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
|
||||||
return request.post(`/memory/explicit-memory/details`, data)
|
return request.post(`/memory/explicit-memory/details`, data)
|
||||||
}
|
}
|
||||||
export const getConversations = (end_user_id: string) => {
|
export const getConversations = (end_user_id: string, page = 1, pagesize = 20) => {
|
||||||
return request.get(`/memory/work/${end_user_id}/conversations`)
|
return request.get(`/memory/work/${end_user_id}/conversations`, { page, pagesize })
|
||||||
}
|
}
|
||||||
export const getConversationMessages = (end_user_id: string, conversation_id: string) => {
|
export const getConversationMessages = (end_user_id: string, conversation_id: string) => {
|
||||||
return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id })
|
return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id })
|
||||||
|
|||||||
@@ -143,15 +143,20 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<div key={file.url || file.uid} className="rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:p-1! rb:cursor-pointer" onClick={() => handleDownload(file)}>
|
<div key={file.url || file.uid} className="rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:p-1! rb:cursor-pointer" onClick={() => handleDownload(file)}>
|
||||||
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv'))
|
||||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
|
? <div
|
||||||
></div>}
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||||
{(file.type.includes('pdf')) && <div
|
></div>
|
||||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
:(file.type.includes('pdf'))
|
||||||
></div>}
|
? <div
|
||||||
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) && <div
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||||
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
|
></div>
|
||||||
></div>}
|
: (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document'))
|
||||||
|
? <div
|
||||||
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||||
|
></div>
|
||||||
|
: null
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
})}
|
})}
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ interface FormValues {
|
|||||||
memory?: boolean;
|
memory?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const max_file_count = 1;
|
||||||
const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
||||||
features,
|
features,
|
||||||
leftExtra,
|
leftExtra,
|
||||||
@@ -86,10 +87,16 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
|||||||
|
|
||||||
// Append newly uploaded file to the file list when upload is complete
|
// Append newly uploaded file to the file list when upload is complete
|
||||||
const fileChange = (file?: any) => {
|
const fileChange = (file?: any) => {
|
||||||
if (file?.status !== 'done') return
|
console.log('file', file)
|
||||||
const files = [...(queryValues?.files || []), file]
|
const lastFiles = form.getFieldValue('files') || [];
|
||||||
form.setFieldValue('files', files)
|
const index = lastFiles.findIndex((item: any) => item.uid === file.uid)
|
||||||
onFilesChange?.(files)
|
if (index > -1) {
|
||||||
|
lastFiles[index] = file
|
||||||
|
} else {
|
||||||
|
lastFiles.push(file)
|
||||||
|
}
|
||||||
|
form.setFieldValue('files', [...lastFiles])
|
||||||
|
onFilesChange?.([...lastFiles])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append recorded audio file to the file list and notify parent
|
// Append recorded audio file to the file list and notify parent
|
||||||
@@ -129,8 +136,8 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
|||||||
key: 'url',
|
key: 'url',
|
||||||
label: t('memoryConversation.addRemoteFile'),
|
label: t('memoryConversation.addRemoteFile'),
|
||||||
onClick: () => {
|
onClick: () => {
|
||||||
if ((queryValues?.files?.length || 0) >= file_upload.max_file_count) {
|
if ((queryValues?.files?.length || 0) >= max_file_count) {
|
||||||
messageApi.warning(t('common.fileNumTip', { num: file_upload.max_file_count }))
|
messageApi.warning(t('common.fileNumTip', { num: max_file_count }))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
uploadFileListModalRef.current?.handleOpen()
|
uploadFileListModalRef.current?.handleOpen()
|
||||||
@@ -146,7 +153,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
|||||||
onChange={fileChange}
|
onChange={fileChange}
|
||||||
requestConfig={uploadRequestConfig}
|
requestConfig={uploadRequestConfig}
|
||||||
featureConfig={file_upload}
|
featureConfig={file_upload}
|
||||||
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
|
disabled={(queryValues?.files?.length || 0) >= max_file_count}
|
||||||
/>
|
/>
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@@ -184,7 +191,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
|
|||||||
{rightExtra}
|
{rightExtra}
|
||||||
{file_upload?.audio_enabled && file_upload?.allowed_transfer_methods?.includes('local_file') &&
|
{file_upload?.audio_enabled && file_upload?.allowed_transfer_methods?.includes('local_file') &&
|
||||||
<AudioRecorder
|
<AudioRecorder
|
||||||
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
|
disabled={(queryValues?.files?.length || 0) >= max_file_count}
|
||||||
action={uploadAction}
|
action={uploadAction}
|
||||||
requestConfig={uploadRequestConfig}
|
requestConfig={uploadRequestConfig}
|
||||||
onRecordingComplete={handleRecordingComplete}
|
onRecordingComplete={handleRecordingComplete}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* @Author: yujiangping
|
* @Author: yujiangping
|
||||||
* @Date: 2026-03-16 19:01:12
|
* @Date: 2026-03-16 19:01:12
|
||||||
* @LastEditors: yujiangping
|
* @LastEditors: yujiangping
|
||||||
* @LastEditTime: 2026-03-18 18:35:53
|
* @LastEditTime: 2026-03-20 12:12:20
|
||||||
*/
|
*/
|
||||||
import { useState, useEffect, useRef, useCallback, type FC } from 'react';
|
import { useState, useEffect, useRef, useCallback, type FC } from 'react';
|
||||||
import { Spin, Alert, Button, Table, InputNumber, Image } from 'antd';
|
import { Spin, Alert, Button, Table, InputNumber, Image } from 'antd';
|
||||||
@@ -309,23 +309,64 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const [csvTruncated, setCsvTruncated] = useState(false);
|
||||||
|
|
||||||
const isCsvFile = () => getFileExtension() === '.csv';
|
const isCsvFile = () => getFileExtension() === '.csv';
|
||||||
|
|
||||||
|
// CSV 预览大小限制:1MB
|
||||||
|
const CSV_PREVIEW_SIZE = 1 * 1024 * 1024;
|
||||||
|
// 最大预览行数
|
||||||
|
const MAX_PREVIEW_ROWS = 500;
|
||||||
|
|
||||||
|
const fetchFileBufferWithLimit = async (url: string, maxBytes?: number): Promise<ArrayBuffer> => {
|
||||||
|
const requestUrl = getRequestUrl(url);
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
|
||||||
|
};
|
||||||
|
if (maxBytes) {
|
||||||
|
headers['Range'] = `bytes=0-${maxBytes - 1}`;
|
||||||
|
}
|
||||||
|
const response = await fetch(requestUrl, {
|
||||||
|
credentials: 'include',
|
||||||
|
headers,
|
||||||
|
});
|
||||||
|
if (!response.ok && response.status !== 206) {
|
||||||
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
return response.arrayBuffer();
|
||||||
|
};
|
||||||
|
|
||||||
const loadExcelFile = async () => {
|
const loadExcelFile = async () => {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
setError(false);
|
setError(false);
|
||||||
setErrorMessage('');
|
setErrorMessage('');
|
||||||
|
setCsvTruncated(false);
|
||||||
try {
|
try {
|
||||||
const arrayBuffer = await fetchFileBuffer(fileUrl);
|
// CSV 文件需要处理编码问题(可能是 GBK/GB2312),且大文件只取前 1MB
|
||||||
|
|
||||||
// CSV 文件需要处理编码问题(可能是 GBK/GB2312)
|
|
||||||
if (isCsvFile()) {
|
if (isCsvFile()) {
|
||||||
|
let arrayBuffer: ArrayBuffer;
|
||||||
|
let truncated = false;
|
||||||
|
try {
|
||||||
|
// 先尝试 Range 请求只取前 1MB
|
||||||
|
arrayBuffer = await fetchFileBufferWithLimit(fileUrl, CSV_PREVIEW_SIZE);
|
||||||
|
// 如果返回的数据刚好等于限制大小,说明可能被截断了
|
||||||
|
if (arrayBuffer.byteLength >= CSV_PREVIEW_SIZE) {
|
||||||
|
truncated = true;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Range 请求不支持时,全量获取后截断
|
||||||
|
const fullBuffer = await fetchFileBuffer(fileUrl);
|
||||||
|
if (fullBuffer.byteLength > CSV_PREVIEW_SIZE) {
|
||||||
|
arrayBuffer = fullBuffer.slice(0, CSV_PREVIEW_SIZE);
|
||||||
|
truncated = true;
|
||||||
|
} else {
|
||||||
|
arrayBuffer = fullBuffer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let csvText: string;
|
let csvText: string;
|
||||||
// 先尝试 UTF-8 解码
|
|
||||||
const utf8Text = new TextDecoder('utf-8').decode(arrayBuffer);
|
const utf8Text = new TextDecoder('utf-8').decode(arrayBuffer);
|
||||||
// 检测是否有乱码特征(常见的 GBK 被错误解析为 UTF-8 的替换字符)
|
|
||||||
if (utf8Text.includes('\uFFFD') || /[\x80-\xff]/.test(utf8Text.slice(0, 200))) {
|
if (utf8Text.includes('\uFFFD') || /[\x80-\xff]/.test(utf8Text.slice(0, 200))) {
|
||||||
// 尝试 GBK 解码
|
|
||||||
try {
|
try {
|
||||||
csvText = new TextDecoder('gbk').decode(arrayBuffer);
|
csvText = new TextDecoder('gbk').decode(arrayBuffer);
|
||||||
} catch {
|
} catch {
|
||||||
@@ -334,19 +375,35 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
|||||||
} else {
|
} else {
|
||||||
csvText = utf8Text;
|
csvText = utf8Text;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 如果被截断,去掉最后一行不完整的数据
|
||||||
|
if (truncated) {
|
||||||
|
const lastNewline = csvText.lastIndexOf('\n');
|
||||||
|
if (lastNewline > 0) {
|
||||||
|
csvText = csvText.substring(0, lastNewline);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const workbook = XLSX.read(csvText, { type: 'string' });
|
const workbook = XLSX.read(csvText, { type: 'string' });
|
||||||
const sheets = workbook.SheetNames.map(sheetName => {
|
const sheets = workbook.SheetNames.map(sheetName => {
|
||||||
const worksheet = workbook.Sheets[sheetName];
|
const worksheet = workbook.Sheets[sheetName];
|
||||||
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
let data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||||
|
// 限制最大行数
|
||||||
|
if (data.length > MAX_PREVIEW_ROWS + 1) {
|
||||||
|
data = data.slice(0, MAX_PREVIEW_ROWS + 1); // +1 保留表头
|
||||||
|
truncated = true;
|
||||||
|
}
|
||||||
return { sheetName, data };
|
return { sheetName, data };
|
||||||
});
|
});
|
||||||
|
setCsvTruncated(truncated);
|
||||||
setExcelData(sheets);
|
setExcelData(sheets);
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const arrayBuffer = await fetchFileBuffer(fileUrl);
|
||||||
const workbook = XLSX.read(arrayBuffer, { type: 'array' });
|
const workbook = XLSX.read(arrayBuffer, { type: 'array' });
|
||||||
const sheets = workbook.SheetNames.map(sheetName => {
|
const sheets = workbook.SheetNames.map((sheetName: string) => {
|
||||||
const worksheet = workbook.Sheets[sheetName];
|
const worksheet = workbook.Sheets[sheetName];
|
||||||
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
|
||||||
return { sheetName, data };
|
return { sheetName, data };
|
||||||
@@ -522,9 +579,14 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
|||||||
)
|
)
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* Excel 预览 */}
|
{/* Excel/CSV 预览 */}
|
||||||
{isExcelFile() && !error && !loading && (
|
{isExcelFile() && !error && !loading && (
|
||||||
<div className="rb:w-full rb:flex-1 rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
|
<div className="rb:w-full rb:flex-1 rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
|
||||||
|
{csvTruncated && (
|
||||||
|
<div className="rb:mb-3 rb:px-3 rb:py-2 rb:bg-yellow-50 rb:border rb:border-yellow-200 rb:rounded rb:text-sm rb:text-yellow-700">
|
||||||
|
文件较大,仅预览前 {MAX_PREVIEW_ROWS} 行数据
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
{excelData.map((sheet, index) => (
|
{excelData.map((sheet, index) => (
|
||||||
<div key={index} className="rb:mb-6">
|
<div key={index} className="rb:mb-6">
|
||||||
<h3 className="rb:text-lg rb:font-semibold rb:mb-3">{sheet.sheetName}</h3>
|
<h3 className="rb:text-lg rb:font-semibold rb:mb-3">{sheet.sheetName}</h3>
|
||||||
@@ -541,6 +603,7 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
|
|||||||
scroll={{ x: 'max-content' }}
|
scroll={{ x: 'max-content' }}
|
||||||
size="small"
|
size="small"
|
||||||
bordered
|
bordered
|
||||||
|
virtual
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -469,6 +469,7 @@ export const en = {
|
|||||||
download: 'Download',
|
download: 'Download',
|
||||||
view: 'View',
|
view: 'View',
|
||||||
updated_at: 'Updated At',
|
updated_at: 'Updated At',
|
||||||
|
callbackUrlInvalid: 'Please enter a valid URL',
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
searchPlaceholder: 'search model…',
|
searchPlaceholder: 'search model…',
|
||||||
|
|||||||
@@ -1106,6 +1106,7 @@ export const zh = {
|
|||||||
download: '下载',
|
download: '下载',
|
||||||
view: '查看',
|
view: '查看',
|
||||||
updated_at: '更新时间',
|
updated_at: '更新时间',
|
||||||
|
callbackUrlInvalid: '请输入有效的 URL',
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
searchPlaceholder: '搜索模型…',
|
searchPlaceholder: '搜索模型…',
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
|
|
||||||
const handleSend = () => {
|
const handleSend = () => {
|
||||||
if (loading || !application || !message || !message?.trim()) return
|
if (loading || !application || !message || !message?.trim()) return
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
const variables = toolbarRef.current?.getVariables() || []
|
const variables = toolbarRef.current?.getVariables() || []
|
||||||
const { isCanSend, params } = buildVariableParams(variables)
|
const { isCanSend, params } = buildVariableParams(variables)
|
||||||
if (!isCanSend) return
|
if (!isCanSend) return
|
||||||
@@ -235,7 +235,7 @@ const TestChat: FC<TestChatProps> = ({
|
|||||||
|
|
||||||
const handleWorkflowSend = () => {
|
const handleWorkflowSend = () => {
|
||||||
if (loading || !application || !message || !message?.trim()) return
|
if (loading || !application || !message || !message?.trim()) return
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
const variables = toolbarRef.current?.getVariables() || []
|
const variables = toolbarRef.current?.getVariables() || []
|
||||||
const { isCanSend, params } = buildVariableParams(variables)
|
const { isCanSend, params } = buildVariableParams(variables)
|
||||||
if (!isCanSend) return
|
if (!isCanSend) return
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ const Chat: FC<ChatProps> = ({
|
|||||||
.then(() => {
|
.then(() => {
|
||||||
const message = msg
|
const message = msg
|
||||||
if (!message?.trim()) return
|
if (!message?.trim()) return
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
// Validate required variables before sending
|
// Validate required variables before sending
|
||||||
let isCanSend = true
|
let isCanSend = true
|
||||||
const params: Record<string, any> = {}
|
const params: Record<string, any> = {}
|
||||||
@@ -350,7 +350,7 @@ const Chat: FC<ChatProps> = ({
|
|||||||
.then(() => {
|
.then(() => {
|
||||||
const message = msg
|
const message = msg
|
||||||
if (!message || message.trim() === '') return
|
if (!message || message.trim() === '') return
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
addUserMessage(message, files)
|
addUserMessage(message, files)
|
||||||
setMessage(undefined)
|
setMessage(undefined)
|
||||||
toolbarRef.current?.setFiles([])
|
toolbarRef.current?.setFiles([])
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ interface FeaturesConfigModalProps {
|
|||||||
refresh: (value: FeaturesConfigForm) => void;
|
refresh: (value: FeaturesConfigForm) => void;
|
||||||
source?: Application['type'];
|
source?: Application['type'];
|
||||||
}
|
}
|
||||||
|
const max_file_count = 1;
|
||||||
/**
|
/**
|
||||||
* Modal for copying applications
|
* Modal for copying applications
|
||||||
*/
|
*/
|
||||||
@@ -133,7 +133,7 @@ const FeaturesConfigModal = forwardRef<FeaturesConfigModalRef, FeaturesConfigMod
|
|||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:py-1">{t('application.maxCount')}</div>
|
<div className="rb:text-[12px] rb:text-[#5B6167] rb:py-1">{t('application.maxCount')}</div>
|
||||||
{fu.max_file_count} {t('application.unix')}
|
{max_file_count} {t('application.unix')}
|
||||||
</div>
|
</div>
|
||||||
</Flex>
|
</Flex>
|
||||||
<Button block onClick={handleOpenSettings}>{t('application.setting')}</Button>
|
<Button block onClick={handleOpenSettings}>{t('application.setting')}</Button>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-03-05
|
* @Date: 2026-03-05
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-19 15:18:20
|
* @Last Modified time: 2026-03-19 20:19:14
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||||
import { Form, InputNumber, Flex, Switch, Row, Col, Radio } from 'antd';
|
import { Form, InputNumber, Flex, Switch, Row, Col, Radio } from 'antd';
|
||||||
@@ -82,28 +82,27 @@ const defaultValues: FileUpload = {
|
|||||||
"mp3",
|
"mp3",
|
||||||
"wav",
|
"wav",
|
||||||
"m4a",
|
"m4a",
|
||||||
"ogg",
|
|
||||||
"flac"
|
|
||||||
],
|
],
|
||||||
document_enabled: false,
|
document_enabled: false,
|
||||||
document_max_size_mb: 100,
|
document_max_size_mb: 100,
|
||||||
document_allowed_extensions: [
|
document_allowed_extensions: [
|
||||||
"pdf",
|
"pdf",
|
||||||
"docx",
|
"docx",
|
||||||
|
"doc",
|
||||||
"xlsx",
|
"xlsx",
|
||||||
|
"xls",
|
||||||
"txt",
|
"txt",
|
||||||
"csv",
|
"csv",
|
||||||
"json"
|
"json",
|
||||||
|
"md",
|
||||||
],
|
],
|
||||||
video_enabled: false,
|
video_enabled: false,
|
||||||
video_max_size_mb: 100,
|
video_max_size_mb: 100,
|
||||||
video_allowed_extensions: [
|
video_allowed_extensions: [
|
||||||
"mp4",
|
"mp4",
|
||||||
"mov",
|
"mov",
|
||||||
"avi",
|
|
||||||
"webm"
|
|
||||||
],
|
],
|
||||||
max_file_count: 5,
|
max_file_count: 1,
|
||||||
allowed_transfer_methods: 'both'
|
allowed_transfer_methods: 'both'
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,8 +167,8 @@ const FileUploadSettingModal = forwardRef<FileUploadSettingModalRef, FileUploadS
|
|||||||
</Radio.Group>
|
</Radio.Group>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
<div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div>
|
{/* <div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div> */}
|
||||||
<Form.Item label={t('application.maxCount')} name="max_file_count">
|
<Form.Item label={t('application.maxCount')} name="max_file_count" hidden>
|
||||||
<InputNumber min={1} max={20} precision={0} className="rb:w-full!" placeholder={t('common.pleaseEnter')} />
|
<InputNumber min={1} max={20} precision={0} className="rb:w-full!" placeholder={t('common.pleaseEnter')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@
|
|||||||
import { useState, useEffect, forwardRef, useImperativeHandle, useMemo } from 'react';
|
import { useState, useEffect, forwardRef, useImperativeHandle, useMemo } from 'react';
|
||||||
import { Upload, Progress, App, Flex } from 'antd';
|
import { Upload, Progress, App, Flex } from 'antd';
|
||||||
import type { UploadProps, UploadFile } from 'antd';
|
import type { UploadProps, UploadFile } from 'antd';
|
||||||
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
|
import type { UploadProps as RcUploadProps, RcFile, UploadFileStatus } from 'antd/es/upload/interface';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
@@ -221,17 +221,29 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
|||||||
*/
|
*/
|
||||||
const handleCustomRequest: RcUploadProps['customRequest'] = async (options) => {
|
const handleCustomRequest: RcUploadProps['customRequest'] = async (options) => {
|
||||||
const { file, onSuccess, onError } = options;
|
const { file, onSuccess, onError } = options;
|
||||||
|
if (typeof file === 'string') return;
|
||||||
try {
|
const rcFile = file as RcFile;
|
||||||
const formData = new FormData();
|
const formData = new FormData();
|
||||||
formData.append('file', file);
|
formData.append('file', rcFile);
|
||||||
|
const fileVo: UploadFile = {
|
||||||
const response = await request.uploadFile(action, formData, requestConfig);
|
uid: rcFile.uid,
|
||||||
|
name: rcFile.name,
|
||||||
onSuccess?.({data: response});
|
status: 'uploading' as UploadFileStatus,
|
||||||
} catch (error) {
|
percent: 0,
|
||||||
onError?.(error as Error);
|
type: rcFile.type,
|
||||||
|
originFileObj: rcFile,
|
||||||
|
thumbUrl: URL.createObjectURL(rcFile)
|
||||||
}
|
}
|
||||||
|
onChange?.(fileVo)
|
||||||
|
request.uploadFile(action, formData, requestConfig)
|
||||||
|
.then(res => {
|
||||||
|
onSuccess?.({ data: res });
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
onError?.(error as Error);
|
||||||
|
fileVo.status = 'error'
|
||||||
|
onChange?.(fileVo)
|
||||||
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:09:47
|
* @Date: 2026-02-06 21:09:47
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-18 21:10:01
|
* @Last Modified time: 2026-03-19 20:32:32
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Upload File List Modal Component
|
* Upload File List Modal Component
|
||||||
@@ -19,7 +19,10 @@
|
|||||||
* @component
|
* @component
|
||||||
*/
|
*/
|
||||||
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
|
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
|
||||||
import { Form, Input, Select, Button, Flex } from 'antd';
|
import { Form, Input, Select,
|
||||||
|
// Button,
|
||||||
|
Flex
|
||||||
|
} from 'antd';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import type { UploadFileListModalRef } from '../types'
|
import type { UploadFileListModalRef } from '../types'
|
||||||
@@ -105,9 +108,11 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
|||||||
onOk={handleSave}
|
onOk={handleSave}
|
||||||
confirmLoading={loading}
|
confirmLoading={loading}
|
||||||
>
|
>
|
||||||
<Form form={form} layout="vertical">
|
<Form form={form} layout="vertical" initialValues={{ files: [{ type: undefined, url: undefined }] }}>
|
||||||
<Form.List name="files">
|
<Form.List name="files">
|
||||||
{(fields, { add, remove }) => (
|
{(fields,
|
||||||
|
// { add, remove }
|
||||||
|
) => (
|
||||||
<>
|
<>
|
||||||
{/* Render each file entry with type selector and URL input */}
|
{/* Render each file entry with type selector and URL input */}
|
||||||
{fields.map(({ key, name, ...restField }) => (
|
{fields.map(({ key, name, ...restField }) => (
|
||||||
@@ -116,6 +121,9 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
|||||||
{...restField}
|
{...restField}
|
||||||
name={[name, 'type']}
|
name={[name, 'type']}
|
||||||
className="rb:mb-0!"
|
className="rb:mb-0!"
|
||||||
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseSelect') }
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Select
|
<Select
|
||||||
placeholder={t('memoryConversation.fileType')}
|
placeholder={t('memoryConversation.fileType')}
|
||||||
@@ -126,22 +134,25 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
|||||||
<FormItem
|
<FormItem
|
||||||
{...restField}
|
{...restField}
|
||||||
name={[name, 'url']}
|
name={[name, 'url']}
|
||||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
|
{ type: 'url', message: t('common.callbackUrlInvalid') },
|
||||||
|
]}
|
||||||
className="rb:mb-0! rb:flex-1!"
|
className="rb:mb-0! rb:flex-1!"
|
||||||
>
|
>
|
||||||
<Input placeholder={t('memoryConversation.fileUrl')} />
|
<Input placeholder={t('memoryConversation.fileUrl')} />
|
||||||
</FormItem>
|
</FormItem>
|
||||||
<div
|
{/* <div
|
||||||
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
|
||||||
onClick={() => remove(name)}
|
onClick={() => remove(name)}
|
||||||
></div>
|
></div> */}
|
||||||
</Flex>
|
</Flex>
|
||||||
))}
|
))}
|
||||||
<Form.Item noStyle>
|
{/* <Form.Item noStyle>
|
||||||
<Button type="dashed" onClick={() => add()} block>
|
<Button type="dashed" onClick={() => add()} block>
|
||||||
+ {t('common.add')}
|
+ {t('common.add')}
|
||||||
</Button>
|
</Button>
|
||||||
</Form.Item>
|
</Form.Item> */}
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Form.List>
|
</Form.List>
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ const Conversation: FC = () => {
|
|||||||
/** Send message and handle streaming response */
|
/** Send message and handle streaming response */
|
||||||
const handleSend = () => {
|
const handleSend = () => {
|
||||||
if (!token || !shareToken) return
|
if (!token || !shareToken) return
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
const variables = toolbarRef.current?.getVariables() || []
|
const variables = toolbarRef.current?.getVariables() || []
|
||||||
let isCanSend = true
|
let isCanSend = true
|
||||||
const params: Record<string, any> = {}
|
const params: Record<string, any> = {}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import { useNavigate, useParams, useLocation } from 'react-router-dom';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
|
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
|
||||||
import { Button, Spin, message, Switch } from 'antd';
|
import { Button, Spin, message, Switch } from 'antd';
|
||||||
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase';
|
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk, getFileUrl } from '@/api/knowledgeBase';
|
||||||
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
|
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
|
||||||
import { formatDateTime } from '@/utils/format';
|
import { formatDateTime } from '@/utils/format';
|
||||||
import InfoPanel, { type InfoItem } from '../components/InfoPanel';
|
import InfoPanel, { type InfoItem } from '../components/InfoPanel';
|
||||||
@@ -138,7 +138,7 @@ const DocumentDetails: FC = () => {
|
|||||||
const response = await getDocumentDetail(documentId);
|
const response = await getDocumentDetail(documentId);
|
||||||
setDocument(response);
|
setDocument(response);
|
||||||
setInfoItems(formatDocumentInfo(response));
|
setInfoItems(formatDocumentInfo(response));
|
||||||
const url = `${imagePath}/api/files/${response.file_id}`
|
const url = `${window.location.origin}/api/files/${response.file_id}`;
|
||||||
setFileUrl(url);
|
setFileUrl(url);
|
||||||
setParserMode(response?.parser_config?.auto_questions || 0)
|
setParserMode(response?.parser_config?.auto_questions || 0)
|
||||||
// ChunkList will be called automatically in useEffect based on document.progress
|
// ChunkList will be called automatically in useEffect based on document.progress
|
||||||
|
|||||||
@@ -191,24 +191,28 @@ const RelationshipNetwork: FC = () => {
|
|||||||
})}>
|
})}>
|
||||||
{(selectedNode as RawCommunityNode).properties.community_id
|
{(selectedNode as RawCommunityNode).properties.community_id
|
||||||
? <div>
|
? <div>
|
||||||
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
|
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
|
||||||
{(selectedNode as RawCommunityNode).properties.name}
|
{(selectedNode as RawCommunityNode).properties.name || selectedNode.id}
|
||||||
</div>
|
</div>
|
||||||
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
|
{(selectedNode as RawCommunityNode).properties.summary && <>
|
||||||
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
|
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
|
||||||
{(selectedNode as RawCommunityNode).properties.summary}
|
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
|
||||||
</div>
|
{(selectedNode as RawCommunityNode).properties.summary}
|
||||||
<Flex align="center" justify="space-between" className="rb:mt-5!">
|
</div>
|
||||||
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
|
</>}
|
||||||
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
|
<Flex align="center" justify="space-between" className="rb:mt-5!">
|
||||||
</Flex>
|
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
|
||||||
|
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
|
||||||
|
</Flex>
|
||||||
|
|
||||||
<Divider className='rb:my-2.5!' />
|
{(selectedNode as RawCommunityNode).properties.core_entities && <>
|
||||||
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
|
<Divider className='rb:my-2.5!' />
|
||||||
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
|
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
|
||||||
{(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) => <li key={index}>{entity}</li>)}
|
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
|
||||||
</ul>
|
{(selectedNode as RawCommunityNode).properties.core_entities?.map((entity, index) => <li key={index}>{entity}</li>)}
|
||||||
</div>
|
</ul>
|
||||||
|
</>}
|
||||||
|
</div>
|
||||||
: <>
|
: <>
|
||||||
{(selectedNode as Node).name &&
|
{(selectedNode as Node).name &&
|
||||||
<div className="rb:font-medium rb:text-[16px] rb:text-[#212332] rb:leading-5.5 rb:mb-3">
|
<div className="rb:font-medium rb:text-[16px] rb:text-[#212332] rb:leading-5.5 rb:mb-3">
|
||||||
|
|||||||
@@ -4,12 +4,14 @@
|
|||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-16 15:10:17
|
* @Last Modified time: 2026-03-16 15:10:17
|
||||||
*/
|
*/
|
||||||
import { type FC, useEffect, useState, useMemo } from 'react'
|
import { type FC, useEffect, useState, useMemo, useRef } from 'react'
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { useParams } from 'react-router-dom'
|
import { useParams } from 'react-router-dom'
|
||||||
import { Row, Col, Skeleton, Button, Divider, Tooltip, Flex } from 'antd'
|
import { Row, Col, Skeleton, Button, Divider, Tooltip, Flex } from 'antd'
|
||||||
|
|
||||||
|
|
||||||
|
import InfiniteScroll from 'react-infinite-scroll-component'
|
||||||
import RbCard from '@/components/RbCard/Card'
|
import RbCard from '@/components/RbCard/Card'
|
||||||
import {
|
import {
|
||||||
getConversations,
|
getConversations,
|
||||||
@@ -61,6 +63,8 @@ const WorkingDetail: FC = () => {
|
|||||||
const { id } = useParams()
|
const { id } = useParams()
|
||||||
const [loading, setLoading] = useState<boolean>(false)
|
const [loading, setLoading] = useState<boolean>(false)
|
||||||
const [data, setData] = useState<Conversation[]>([])
|
const [data, setData] = useState<Conversation[]>([])
|
||||||
|
const [hasMore, setHasMore] = useState<boolean>(true)
|
||||||
|
const pageRef = useRef<number>(1)
|
||||||
const [messagesLoading, setMessagesLoading] = useState<boolean>(false)
|
const [messagesLoading, setMessagesLoading] = useState<boolean>(false)
|
||||||
const [messages, setMessages] = useState<ChatItem[]>([])
|
const [messages, setMessages] = useState<ChatItem[]>([])
|
||||||
const [detailLoading, setDetailLoading] = useState<boolean>(false)
|
const [detailLoading, setDetailLoading] = useState<boolean>(false)
|
||||||
@@ -80,17 +84,30 @@ const WorkingDetail: FC = () => {
|
|||||||
setSelected(null)
|
setSelected(null)
|
||||||
setDetail(null)
|
setDetail(null)
|
||||||
setData([])
|
setData([])
|
||||||
getConversations(id).then((res) => {
|
setHasMore(true)
|
||||||
const response = res as Conversation[]
|
pageRef.current = 1
|
||||||
setData(response)
|
getConversations(id, 1).then((res) => {
|
||||||
setSelected(response[0] || null)
|
const response = res as { items: Conversation[], page: { hasnext: boolean } }
|
||||||
|
setData(response.items)
|
||||||
|
setSelected(response.items[0] || null)
|
||||||
|
setHasMore(response.page.hasnext)
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setLoading(false)
|
setLoading(false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Load messages and AI insight whenever the selected conversation changes. */
|
const loadMore = () => {
|
||||||
|
if (!id) return
|
||||||
|
const nextPage = pageRef.current + 1
|
||||||
|
getConversations(id, nextPage).then((res) => {
|
||||||
|
const response = res as {items: Conversation[], page: { hasnext: boolean }}
|
||||||
|
setData(prev => [...prev, ...response.items])
|
||||||
|
pageRef.current = nextPage
|
||||||
|
setHasMore(response.page.hasnext)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!id || !selected || !selected.id) return
|
if (!id || !selected || !selected.id) return
|
||||||
getDetail(selected.id)
|
getDetail(selected.id)
|
||||||
@@ -138,16 +155,16 @@ const WorkingDetail: FC = () => {
|
|||||||
: data.length === 0
|
: data.length === 0
|
||||||
? <Empty />
|
? <Empty />
|
||||||
:(
|
:(
|
||||||
<Row gutter={16} className="rb:h-full">
|
<Row gutter={16}>
|
||||||
<Col flex='360px' className="rb:h-full">
|
<Col span={5}>
|
||||||
<RbCard
|
<div id="conversation-list" className="rb:h-[calc(100vh-76px)]! rb:border-r rb:border-[#EAECEE] rb:py-3 rb:px-4 rb:overflow-y-auto">
|
||||||
title={t('workingDetail.conversation')}
|
<InfiniteScroll
|
||||||
headerType="borderless"
|
dataLength={data.length}
|
||||||
headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold"
|
next={loadMore}
|
||||||
bodyClassName='rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)]'
|
hasMore={hasMore}
|
||||||
className="rb:h-full!"
|
loader={null}
|
||||||
>
|
scrollableTarget="conversation-list"
|
||||||
<Flex gap={8} vertical>
|
>
|
||||||
{data.map(item => (
|
{data.map(item => (
|
||||||
<Flex
|
<Flex
|
||||||
key={item.id}
|
key={item.id}
|
||||||
@@ -166,8 +183,8 @@ const WorkingDetail: FC = () => {
|
|||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Flex>
|
</Flex>
|
||||||
))}
|
))}
|
||||||
</Flex>
|
</InfiniteScroll>
|
||||||
</RbCard>
|
</div>
|
||||||
</Col>
|
</Col>
|
||||||
{selected && <>
|
{selected && <>
|
||||||
<Col flex="auto" className="rb:h-full">
|
<Col flex="auto" className="rb:h-full">
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
|
|||||||
|
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
const message = msg
|
const message = msg
|
||||||
const files = toolbarRef.current?.getFiles() || []
|
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
|
||||||
setChatList(prev => [...prev, {
|
setChatList(prev => [...prev, {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: message,
|
content: message,
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
const isUserInputRef = useRef(false);
|
const isUserInputRef = useRef(false);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
// 监听编辑器变化,标记是否为用户输入
|
const removeListener = editor.registerUpdateListener(({ editorState, tags }) => {
|
||||||
const removeListener = editor.registerUpdateListener(({ editorState }) => {
|
if (tags.has('programmatic')) return;
|
||||||
editorState.read(() => {
|
editorState.read(() => {
|
||||||
const root = $getRoot();
|
const root = $getRoot();
|
||||||
const textContent = root.getTextContent();
|
const textContent = root.getTextContent();
|
||||||
@@ -107,7 +107,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
|||||||
});
|
});
|
||||||
root.append(paragraph);
|
root.append(paragraph);
|
||||||
}
|
}
|
||||||
}, { discrete: true });
|
}, { discrete: true, tag: 'programmatic' });
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user