合并 feature/20251219_yjp 分支到 web 分支
冲突解决策略: - web/src/views/KnowledgeBase/ 文件夹下的所有冲突以 feature/20251219_yjp 分支为主 - 其他冲突(如 vite.config.ts)以 web 分支为主 主要更改: - 保留了 feature 分支中的知识库相关功能和组件 - 保持了 web 分支的配置和其他功能 - 添加了自定义文本数据集创建功能 - 更新了知识库管理界面
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -20,7 +20,8 @@ examples/
|
||||
.idea
|
||||
|
||||
# Temporary outputs
|
||||
**/.DS_Store
|
||||
app/core/memory/agent/.DS_Store
|
||||
app/core/memory/src/utils/.DS_Store
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
|
||||
@@ -27,6 +27,7 @@ from . import (
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -56,5 +57,6 @@ manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -14,6 +13,7 @@ from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.api_key_service import ApiKeyService
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
@@ -41,18 +41,14 @@ def create_api_key(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 创建 API Key
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
api_key_obj = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
except BusinessException:
|
||||
@@ -223,13 +219,9 @@ def regenerate_api_key(
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
api_key_obj = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
|
||||
|
||||
logger.info("API Key 重新生成成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
@@ -283,8 +275,8 @@ def get_api_key_stats(
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_logs(
|
||||
api_key_id: uuid.UUID,
|
||||
start_date: Optional[datetime] = Query(None, description="开始日期"),
|
||||
end_date: Optional[datetime] = Query(None, description="结束日期"),
|
||||
start_date: Optional[int] = Query(None, description="开始日期时间戳"),
|
||||
end_date: Optional[int] = Query(None, description="结束日期时间戳"),
|
||||
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
|
||||
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
@@ -302,14 +294,17 @@ def get_api_key_logs(
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
start_datetime = timestamp_to_datetime(start_date) if start_date else None
|
||||
end_datetime = timestamp_to_datetime(end_date) if end_date else None
|
||||
|
||||
# 验证日期范围
|
||||
if start_date and end_date and start_date > end_date:
|
||||
if start_datetime and end_datetime and start_datetime > end_datetime:
|
||||
logger.warning("开始日期晚于结束日期", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat()
|
||||
"start_date": start_datetime.isoformat(),
|
||||
"end_date": end_datetime.isoformat()
|
||||
})
|
||||
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
|
||||
|
||||
@@ -325,8 +320,8 @@ def get_api_key_logs(
|
||||
|
||||
# 构建过滤条件
|
||||
filters = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"start_date": start_datetime,
|
||||
"end_date": end_datetime,
|
||||
"status_code": status_code,
|
||||
"endpoint": endpoint
|
||||
}
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
from app.models.app_model import AppType, App
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
|
||||
from fastapi.responses import StreamingResponse
|
||||
from app.models.app_model import AppType
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -48,7 +52,7 @@ def list_apps(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用
|
||||
|
||||
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
"""
|
||||
@@ -63,8 +67,8 @@ def list_apps(
|
||||
include_shared=include_shared,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||
service = app_service.AppService(db)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
@@ -79,14 +83,14 @@ def get_app(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用详细信息
|
||||
|
||||
|
||||
- 支持获取本工作空间的应用
|
||||
- 支持获取分享给本工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
app = service.get_app(app_id, workspace_id)
|
||||
|
||||
|
||||
# 转换为 Schema 并设置 is_shared 字段
|
||||
app_schema_obj = service._convert_to_schema(app, workspace_id)
|
||||
return success(data=app_schema_obj)
|
||||
@@ -113,7 +117,7 @@ def delete_app(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""删除应用
|
||||
|
||||
|
||||
会级联删除:
|
||||
- Agent 配置
|
||||
- 发布版本
|
||||
@@ -128,9 +132,9 @@ def delete_app(
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
|
||||
|
||||
|
||||
return success(msg="应用删除成功")
|
||||
|
||||
|
||||
@@ -143,7 +147,7 @@ def copy_app(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""复制应用(包括基础信息和配置)
|
||||
|
||||
|
||||
- 复制应用的基础信息(名称、描述、图标等)
|
||||
- 复制 Agent 配置(如果是 agent 类型)
|
||||
- 新应用默认为草稿状态
|
||||
@@ -159,7 +163,7 @@ def copy_app(
|
||||
"new_name": new_name
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
service = AppService(db)
|
||||
new_app = service.copy_app(
|
||||
app_id=app_id,
|
||||
@@ -167,7 +171,7 @@ def copy_app(
|
||||
workspace_id=workspace_id,
|
||||
new_name=new_name
|
||||
)
|
||||
|
||||
|
||||
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
|
||||
|
||||
|
||||
@@ -209,9 +213,9 @@ def publish_app(
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
release = app_service.publish(
|
||||
db,
|
||||
app_id=app_id,
|
||||
publisher_id=current_user.id,
|
||||
db,
|
||||
app_id=app_id,
|
||||
publisher_id=current_user.id,
|
||||
workspace_id=workspace_id,
|
||||
version_name = payload.version_name,
|
||||
release_notes=payload.release_notes
|
||||
@@ -268,13 +272,13 @@ def share_app(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""分享应用到其他工作空间
|
||||
|
||||
|
||||
- 只能分享自己工作空间的应用
|
||||
- 不能分享到自己的工作空间
|
||||
- 同一个应用不能重复分享到同一个工作空间
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
service = app_service.AppService(db)
|
||||
shares = service.share_app(
|
||||
app_id=app_id,
|
||||
@@ -282,7 +286,7 @@ def share_app(
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
|
||||
|
||||
@@ -296,18 +300,18 @@ def unshare_app(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""取消应用分享
|
||||
|
||||
|
||||
- 只能取消自己工作空间应用的分享
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.unshare_app(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@@ -319,17 +323,17 @@ def list_app_shares(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出应用的所有分享记录
|
||||
|
||||
|
||||
- 只能查看自己工作空间应用的分享记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_app_shares(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
@@ -340,10 +344,11 @@ async def draft_run(
|
||||
payload: app_schema.DraftRunRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
|
||||
):
|
||||
"""
|
||||
试运行 Agent,使用当前的草稿配置(未发布的配置)
|
||||
|
||||
|
||||
- 不需要发布应用即可测试
|
||||
- 使用当前的 AgentConfig 配置
|
||||
- 支持流式和非流式返回
|
||||
@@ -367,33 +372,44 @@ async def draft_run(
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
service = AppService(db)
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT and app.type != AppType.WORKFLOW:
|
||||
raise BusinessException("只有 Agent , Workflow 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 处理会话ID(创建或验证)
|
||||
conversation_id = await draft_service._ensure_conversation(
|
||||
conversation_id=payload.conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=payload.user_id
|
||||
)
|
||||
payload.conversation_id = conversation_id
|
||||
|
||||
if app.type == AppType.AGENT:
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
@@ -401,12 +417,12 @@ async def draft_run(
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -419,7 +435,7 @@ async def draft_run(
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
@@ -429,7 +445,7 @@ async def draft_run(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 非流式返回
|
||||
logger.debug(
|
||||
"开始非流式试运行",
|
||||
@@ -440,7 +456,7 @@ async def draft_run(
|
||||
"has_variables": bool(payload.variables)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
result = await draft_service.run(
|
||||
@@ -454,7 +470,7 @@ async def draft_run(
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"试运行返回结果",
|
||||
extra={
|
||||
@@ -462,7 +478,7 @@ async def draft_run(
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 验证结果
|
||||
try:
|
||||
validated_result = app_schema.DraftRunResponse.model_validate(result)
|
||||
@@ -481,10 +497,10 @@ async def draft_run(
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
# 1. 检查多智能体配置完整性
|
||||
service._check_multi_agent_config(app_id)
|
||||
|
||||
|
||||
# 2. 构建多智能体运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=payload.message,
|
||||
conversation_id=payload.conversation_id,
|
||||
@@ -492,7 +508,7 @@ async def draft_run(
|
||||
variables=payload.variables or {},
|
||||
use_llm_routing=True # 默认启用 LLM 路由
|
||||
)
|
||||
|
||||
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
@@ -503,11 +519,11 @@ async def draft_run(
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
app_id=app_id,
|
||||
@@ -517,7 +533,7 @@ async def draft_run(
|
||||
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
@@ -527,7 +543,7 @@ async def draft_run(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4. 非流式返回
|
||||
logger.debug(
|
||||
"开始多智能体非流式试运行",
|
||||
@@ -537,10 +553,10 @@ async def draft_run(
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
multiservice = MultiAgentService(db)
|
||||
result = await multiservice.run(app_id, multi_agent_request)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"多智能体试运行返回结果",
|
||||
extra={
|
||||
@@ -548,12 +564,71 @@ async def draft_run(
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(
|
||||
data=result,
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
|
||||
elif app.type == AppType.WORKFLOW: #工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
"开始多智能体流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
"""多智能体流式事件生成器"""
|
||||
multiservice = MultiAgentService(db)
|
||||
|
||||
# 调用多智能体服务的流式方法
|
||||
async for event in multiservice.run_stream(
|
||||
app_id=app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 4. 非流式返回
|
||||
logger.debug(
|
||||
"开始非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id)
|
||||
}
|
||||
)
|
||||
|
||||
result = await workflow_service.run(app_id, payload,config)
|
||||
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -567,21 +642,21 @@ async def draft_run_compare(
|
||||
):
|
||||
"""
|
||||
多模型对比试运行
|
||||
|
||||
|
||||
- 支持对比 1-5 个模型
|
||||
- 可以是不同的模型,也可以是同一模型的不同参数配置
|
||||
- 通过 model_parameters 覆盖默认参数
|
||||
- 支持并行或串行执行(非流式)
|
||||
- 支持流式返回(串行执行)
|
||||
- 返回每个模型的运行结果和性能对比
|
||||
|
||||
|
||||
使用场景:
|
||||
1. 对比不同模型的效果(GPT-4 vs Claude vs Gemini)
|
||||
2. 调优模型参数(不同 temperature 的效果对比)
|
||||
3. 性能和成本分析
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -597,7 +672,7 @@ async def draft_run_compare(
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
logger.info(
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
@@ -607,13 +682,13 @@ async def draft_run_compare(
|
||||
"stream": payload.stream
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.models import ModelConfig
|
||||
|
||||
|
||||
service = AppService(db)
|
||||
|
||||
|
||||
# 1. 验证应用和权限
|
||||
app = service._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
@@ -621,7 +696,7 @@ async def draft_run_compare(
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -631,7 +706,7 @@ async def draft_run_compare(
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 3. 验证所有模型配置
|
||||
model_configs = []
|
||||
for model_item in payload.models:
|
||||
@@ -639,12 +714,12 @@ async def draft_run_compare(
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
@@ -652,7 +727,7 @@ async def draft_run_compare(
|
||||
"model_config_id": model_item.model_config_id,
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -674,7 +749,7 @@ async def draft_run_compare(
|
||||
timeout=payload.timeout or 60
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
@@ -684,7 +759,7 @@ async def draft_run_compare(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
@@ -703,7 +778,7 @@ async def draft_run_compare(
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
@@ -712,5 +787,36 @@ async def draft_run_compare(
|
||||
"failed": result["failed_count"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=app_schema.DraftRunCompareResponse(**result))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def get_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
|
||||
):
|
||||
"""获取工作流配置
|
||||
|
||||
获取应用的工作流配置详情。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
# 配置总是存在(不存在时返回默认模板)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
|
||||
@cur_workspace_access_guard()
|
||||
async def update_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
payload: WorkflowConfigUpdate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
|
||||
@router.get("/{kb_id}/documents", response_model=ApiResponse)
|
||||
async def get_documents(
|
||||
kb_id: uuid.UUID,
|
||||
parent_id: uuid.UUID,
|
||||
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
from fastapi import APIRouter, Depends, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
@@ -322,36 +323,24 @@ def read_all_config(
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> StreamingResponse:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
# 先尝试从数据库加载配置
|
||||
try:
|
||||
config_loaded = reload_configuration_from_database(str(payload.config_id))
|
||||
if not config_loaded:
|
||||
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
|
||||
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Exception while loading configuration: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = await svc.pilot_run(payload)
|
||||
return success(data=result, msg="试运行完成")
|
||||
except ValueError as e:
|
||||
# 捕获参数验证错误
|
||||
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Pilot run failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""App 服务接口 - 基于 API Key 认证"""
|
||||
from fastapi import APIRouter, Depends
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
@@ -14,3 +17,30 @@ logger = get_business_logger()
|
||||
async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
# /v1/apps/{resource_id}/chat
|
||||
@router.post("/{resource_id}/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat_with_agent_demo(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent 聊天接口demo
|
||||
|
||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
|
||||
Args:
|
||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
message: 请求参数
|
||||
request: 声明请求
|
||||
api_key_auth: 包含验证后的API Key 信息
|
||||
db: db_session
|
||||
"""
|
||||
logger.info(f"API Key Auth: {api_key_auth}")
|
||||
logger.info(f"Resource ID: {resource_id}")
|
||||
logger.info(f"Message: {message}")
|
||||
return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
587
api/app/controllers/workflow_controller.py
Normal file
587
api/app/controllers/workflow_controller.py
Normal file
@@ -0,0 +1,587 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件"""
|
||||
try:
|
||||
async for event in service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 转换为 SSE 格式
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
error_event = {
|
||||
"type": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from functools import wraps
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_utils import add_rate_limit_headers
|
||||
@@ -22,21 +24,17 @@ logger = get_api_logger()
|
||||
|
||||
|
||||
def require_api_key(
|
||||
scopes: Optional[List[str]] = None,
|
||||
resource_type: Optional[str] = None
|
||||
scopes: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
API Key 鉴权装饰器
|
||||
|
||||
Args:
|
||||
scopes: 所需的权限范围列表["app:all",
|
||||
"rag:search", "rag:upload", "rag:delete",
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"]
|
||||
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
|
||||
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
|
||||
|
||||
Usage:
|
||||
@router.get("/app/{resource_id}/chat")
|
||||
@require_api_key(scopes=["app:all"], resource_type="Agent")
|
||||
@require_api_key(scopes=["app"])
|
||||
def chat_with_app(
|
||||
resource_id: uuid.UUID,
|
||||
api_key_auth: ApiKeyAuth = Depends(),
|
||||
@@ -113,31 +111,25 @@ def require_api_key(
|
||||
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
|
||||
)
|
||||
|
||||
if resource_type:
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_type,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_type": resource_type,
|
||||
resource_id = kwargs.get("resource_id")
|
||||
if resource_id and not ApiKeyAuthService.check_resource(
|
||||
api_key_obj,
|
||||
resource_id
|
||||
):
|
||||
logger.warning("API Key 资源访问被拒绝", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
|
||||
"endpoint": str(request.url)
|
||||
})
|
||||
return BusinessException(
|
||||
"API Key 未授权访问该资源",
|
||||
BizCode.API_KEY_INVALID_RESOURCE,
|
||||
context={
|
||||
"required_resource_type": resource_type,
|
||||
"required_resource_id": str(resource_id),
|
||||
"bound_resource_type": api_key_obj.resource_type,
|
||||
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
|
||||
}
|
||||
)
|
||||
"bound_resource_id": str(api_key_obj.resource_id)
|
||||
}
|
||||
)
|
||||
|
||||
kwargs["api_key_auth"] = ApiKeyAuth(
|
||||
api_key_id=api_key_obj.id,
|
||||
@@ -145,14 +137,17 @@ def require_api_key(
|
||||
type=api_key_obj.type,
|
||||
scopes=api_key_obj.scopes,
|
||||
resource_id=api_key_obj.resource_id,
|
||||
resource_type=api_key_obj.resource_type
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
response = await func(*args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
response_time = (end_time - start_time) * 1000
|
||||
if not isinstance(response, Response):
|
||||
response = JSONResponse(content=response)
|
||||
response = add_rate_limit_headers(response, rate_headers)
|
||||
|
||||
asyncio.create_task(log_api_key_usage(
|
||||
db, api_key_obj.id, request, response
|
||||
db, api_key_obj.id, request, response, response_time
|
||||
))
|
||||
return response
|
||||
|
||||
@@ -204,7 +199,8 @@ async def log_api_key_usage(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
request: Request,
|
||||
response: Response
|
||||
response: Response,
|
||||
response_time: float
|
||||
):
|
||||
"""记录 API Key 使用日志"""
|
||||
try:
|
||||
@@ -216,8 +212,8 @@ async def log_api_key_usage(
|
||||
"ip_address": request.client.host if request.client else None,
|
||||
"user_agent": request.headers.get("User-Agent"),
|
||||
"status_code": response.status_code if hasattr(response, "status_code") else None,
|
||||
"response_time": None, # 需要在 middleware 中计算
|
||||
"tokens_used": None, # 需要从响应中提取
|
||||
"response_time": round(response_time),
|
||||
"tokens_used": None,
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,33 +1,14 @@
|
||||
"""API Key 工具函数"""
|
||||
import secrets
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.api_key_schema import ApiKeyType
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class ResourceType:
|
||||
"""资源类型常量"""
|
||||
AGENT = "Agent"
|
||||
CLUSTER = "Cluster"
|
||||
WORKFLOW = "Workflow"
|
||||
KNOWLEDGE = "Knowledge"
|
||||
MEMORY_ENGINE = "Memory_Engine"
|
||||
|
||||
@classmethod
|
||||
def get_all_types(cls) -> list[str]:
|
||||
"""获取所有支持的资源类型"""
|
||||
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
|
||||
|
||||
@classmethod
|
||||
def is_valid_type(cls, resource_type: str) -> bool:
|
||||
"""验证资源类型是否有效"""
|
||||
return resource_type in cls.get_all_types()
|
||||
|
||||
|
||||
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
def generate_api_key(key_type: ApiKeyType) -> str:
|
||||
"""
|
||||
生成 API Key
|
||||
|
||||
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
|
||||
"""
|
||||
# 前缀映射
|
||||
prefix_map = {
|
||||
ApiKeyType.APP: "sk-app-",
|
||||
ApiKeyType.RAG: "sk-rag-",
|
||||
ApiKeyType.MEMORY: "sk-mem-",
|
||||
ApiKeyType.AGENT: "sk-agent-",
|
||||
ApiKeyType.CLUSTER: "sk-multi_agent-",
|
||||
ApiKeyType.WORKFLOW: "sk-workflow-",
|
||||
ApiKeyType.SERVICE: "sk-service-"
|
||||
}
|
||||
|
||||
prefix = prefix_map[key_type]
|
||||
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
|
||||
api_key = f"{prefix}{random_string}"
|
||||
|
||||
# 生成哈希值存储
|
||||
key_hash = hash_api_key(api_key)
|
||||
|
||||
return api_key, key_hash, prefix
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
"""对 API Key 进行哈希
|
||||
|
||||
Args:
|
||||
api_key: API Key 明文
|
||||
|
||||
Returns:
|
||||
str: 哈希值
|
||||
"""
|
||||
return hashlib.sha256(api_key.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_api_key(api_key: str, key_hash: str) -> bool:
|
||||
"""
|
||||
验证 API Key
|
||||
|
||||
Args:
|
||||
api_key: API Key 明文
|
||||
key_hash: 存储的哈希值
|
||||
|
||||
Returns:
|
||||
bool: 是否匹配
|
||||
"""
|
||||
computed_hash = hash_api_key(api_key)
|
||||
return secrets.compare_digest(computed_hash, key_hash)
|
||||
|
||||
|
||||
def validate_resource_binding(
|
||||
resource_type: Optional[str],
|
||||
resource_id: Optional[str]
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
验证资源绑定的有效性
|
||||
|
||||
Args:
|
||||
resource_type: 资源类型
|
||||
resource_id: 资源ID
|
||||
|
||||
Returns:
|
||||
tuple: (是否有效, 错误信息)
|
||||
"""
|
||||
# 如果都为空,表示不绑定资源,这是有效的
|
||||
if not resource_type and not resource_id:
|
||||
return True, ""
|
||||
|
||||
# 如果只有一个为空,这是无效的
|
||||
if not resource_type or not resource_id:
|
||||
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
|
||||
|
||||
# 验证资源类型是否支持
|
||||
if not ResourceType.is_valid_type(resource_type):
|
||||
valid_types = ", ".join(ResourceType.get_all_types())
|
||||
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def get_resource_scope_mapping() -> dict[str, list[str]]:
|
||||
"""
|
||||
获取资源类型与权限范围的映射关系
|
||||
|
||||
Returns:
|
||||
dict: 资源类型到推荐权限范围的映射
|
||||
"""
|
||||
return {
|
||||
ResourceType.AGENT: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.CLUSTER: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.WORKFLOW: [
|
||||
"app:all"
|
||||
],
|
||||
ResourceType.KNOWLEDGE: [
|
||||
"rag:search", "rag:upload", "rag:delete"
|
||||
],
|
||||
ResourceType.MEMORY_ENGINE: [
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
||||
]
|
||||
}
|
||||
return api_key
|
||||
|
||||
|
||||
def add_rate_limit_headers(response, headers: dict):
|
||||
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
|
||||
return response
|
||||
|
||||
|
||||
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
|
||||
"""将时间戳转换为datetime对象"""
|
||||
if timestamp is None:
|
||||
return None
|
||||
|
||||
# 处理毫秒级时间戳
|
||||
if timestamp > 1e10:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
return datetime.fromtimestamp(timestamp)
|
||||
|
||||
|
||||
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
|
||||
"""将datetime对象转换为时间戳(毫秒)"""
|
||||
if dt is None:
|
||||
return None
|
||||
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
@@ -59,6 +59,7 @@ class BizCode(IntEnum):
|
||||
EMBED_NOT_ALLOWED = 6009
|
||||
PERMISSION_DENIED = 6010
|
||||
INVALID_CONVERSATION = 6011
|
||||
CONFIG_MISSING = 6012
|
||||
|
||||
# 模型(7xxx)
|
||||
MODEL_CONFIG_INVALID = 7001
|
||||
@@ -96,7 +97,7 @@ HTTP_MAPPING = {
|
||||
BizCode.TOKEN_INVALID: 401,
|
||||
BizCode.TOKEN_EXPIRED: 401,
|
||||
BizCode.TOKEN_BLACKLISTED: 401,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 404,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.NOT_FOUND: 404,
|
||||
@@ -151,4 +152,4 @@ HTTP_MAPPING = {
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
BizCode.RATE_LIMITED: 429,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_dedup_details,
|
||||
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"selections": {
|
||||
"config_id": "1"
|
||||
"config_id": ""
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false"
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Optional, Callable, Awaitable
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 导入重构后的模块
|
||||
@@ -50,7 +50,11 @@ logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
async def main(
|
||||
dialogue_text: Optional[str] = None,
|
||||
is_pilot_run: bool = False,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
|
||||
):
|
||||
"""
|
||||
记忆系统主流程 - 重构版本
|
||||
|
||||
@@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
is_pilot_run: 是否为试运行模式
|
||||
- True: 试运行模式,不保存到 Neo4j
|
||||
- False: 正常运行模式,保存到 Neo4j
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
|
||||
工作流程:
|
||||
1. 初始化客户端和配置
|
||||
@@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"}
|
||||
)
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
# 对前端传入的对话进行分块处理
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
@@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
else:
|
||||
# 正常运行模式:从 testdata.json 文件加载
|
||||
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
|
||||
@@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
if not os.path.exists(test_data_path):
|
||||
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
|
||||
|
||||
# 进度回调:开始预处理文本
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
|
||||
group_id=config_defs.SELECTED_GROUP_ID,
|
||||
@@ -170,6 +209,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
skip_cleaning=True,
|
||||
)
|
||||
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dialog in chunked_dialogs:
|
||||
for i, chunk in enumerate(dialog.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dialog.id,
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 进度回调:预处理文本完成
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
@@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback, # 传递进度回调
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
@@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
|
||||
# 进度回调:正在知识抽取
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=is_pilot_run, # 传递试运行模式标志
|
||||
@@ -216,6 +282,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
# 进度回调:生成结果
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
|
||||
# 步骤 5: 保存结果或输出结果
|
||||
if is_pilot_run:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
去重功能函数
|
||||
"""
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict, Tuple, Any
|
||||
from app.core.memory.models.graph_models import(
|
||||
StatementEntityEdge,
|
||||
EntityEntityEdge,
|
||||
@@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges(
|
||||
report_append: bool = False,
|
||||
report_stage_notes: List[str] | None = None,
|
||||
dedup_config: DedupConfig | None = None,
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
) -> Tuple[
|
||||
List[ExtractedEntityNode],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
Dict[str, Any] # 新增:返回详细的去重消歧记录
|
||||
]:
|
||||
"""
|
||||
主流程:依次执行精确匹配、模糊匹配与(可选)LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
@@ -981,8 +986,18 @@ async def deduplicate_entities_and_edges(
|
||||
append=report_append,
|
||||
stage_notes=report_stage_notes,
|
||||
)
|
||||
|
||||
# 构建详细的去重消歧记录(用于内存访问,避免解析日志文件)
|
||||
dedup_details = {
|
||||
"exact_merge_map": exact_merge_map,
|
||||
"fuzzy_merge_records": fuzzy_merge_records,
|
||||
"llm_decision_records": local_llm_records,
|
||||
"disamb_records": disamb_records,
|
||||
"id_redirect": id_redirect,
|
||||
"blocked_pairs": blocked_pairs,
|
||||
}
|
||||
|
||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
|
||||
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details
|
||||
|
||||
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
|
||||
def _write_dedup_fusion_report(
|
||||
|
||||
@@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
@@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
@@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return(
|
||||
statement_chunk_edges,
|
||||
fused_statement_entity_edges,
|
||||
fused_entity_entity_edges,
|
||||
dedup_details, # 返回去重详情
|
||||
)
|
||||
|
||||
@@ -12,13 +12,14 @@
|
||||
5. 提供错误处理和日志记录
|
||||
6. 支持试运行模式(不写入数据库)
|
||||
|
||||
作者:Memory Refactoring Team
|
||||
作者:
|
||||
日期:2025-11-21
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
import os
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
connector: Neo4jConnector,
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -103,12 +105,21 @@ class ExtractionOrchestrator:
|
||||
embedder_client: 嵌入模型客户端
|
||||
connector: Neo4j 连接器
|
||||
config: 流水线配置,如果为 None 则使用默认配置
|
||||
progress_callback: 进度回调函数
|
||||
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
self.connector = connector
|
||||
self.config = config or ExtractionPipelineConfig()
|
||||
self.is_pilot_run = False # 默认非试运行模式
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
|
||||
self.id_redirect_map: Dict[str, str] = {} # ID重定向映射
|
||||
|
||||
# 初始化各个提取器
|
||||
self.statement_extractor = StatementExtractor(
|
||||
@@ -160,6 +171,13 @@ class ExtractionOrchestrator:
|
||||
# 步骤 1: 陈述句提取
|
||||
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
|
||||
dialog_data_list = await self._extract_statements(dialog_data_list)
|
||||
|
||||
# 收集陈述句内容和统计数量
|
||||
all_statements_list = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
|
||||
@@ -170,11 +188,90 @@ class ExtractionOrchestrator:
|
||||
chunk_embedding_maps,
|
||||
dialog_embeddings,
|
||||
) = await self._parallel_extract_and_embed(dialog_data_list)
|
||||
|
||||
# 收集实体和三元组内容,并统计数量
|
||||
all_entities_list = []
|
||||
all_triplets_list = []
|
||||
for triplet_map in triplet_maps:
|
||||
for triplet_info in triplet_map.values():
|
||||
if triplet_info:
|
||||
all_entities_list.extend(triplet_info.entities)
|
||||
all_triplets_list.extend(triplet_info.triplets)
|
||||
|
||||
total_entities = len(all_entities_list)
|
||||
total_triplets = len(all_triplets_list)
|
||||
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
|
||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||
|
||||
# 进度回调:按三个阶段分别输出知识抽取结果
|
||||
if self.progress_callback:
|
||||
# 第一阶段:陈述句提取结果
|
||||
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
|
||||
stmt_result = {
|
||||
"extraction_type": "statement",
|
||||
"statement_index": i + 1,
|
||||
"statement": stmt.statement,
|
||||
"statement_id": stmt.id
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
|
||||
|
||||
# 第二阶段:三元组提取结果
|
||||
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
|
||||
triplet_result = {
|
||||
"extraction_type": "triplet",
|
||||
"triplet_index": i + 1,
|
||||
"subject": triplet.subject_name,
|
||||
"predicate": triplet.predicate,
|
||||
"object": triplet.object_name
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
|
||||
|
||||
# 第三阶段:时间提取结果
|
||||
if total_temporal > 0:
|
||||
# 收集时间信息
|
||||
temporal_results = []
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
temporal_results.append({
|
||||
"statement_id": statement.id,
|
||||
"statement": statement.statement,
|
||||
"valid_at": statement.temporal_validity.valid_at,
|
||||
"invalid_at": statement.temporal_validity.invalid_at
|
||||
})
|
||||
|
||||
# 输出时间提取结果
|
||||
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": i + 1,
|
||||
"statement": temporal_result["statement"],
|
||||
"valid_at": temporal_result["valid_at"],
|
||||
"invalid_at": temporal_result["invalid_at"]
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
else:
|
||||
# 如果没有时间信息,也发送一个时间提取完成的消息
|
||||
time_result = {
|
||||
"extraction_type": "temporal",
|
||||
"temporal_index": 0,
|
||||
"message": "未发现时间信息"
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
|
||||
|
||||
# 进度回调:知识抽取完成,传递知识抽取的统计信息
|
||||
extraction_stats = {
|
||||
"statements_count": total_statements,
|
||||
"entities_count": total_entities,
|
||||
"triplets_count": total_triplets,
|
||||
"temporal_ranges_count": total_temporal,
|
||||
}
|
||||
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
|
||||
|
||||
# 步骤 4: 将提取的数据赋值到语句
|
||||
logger.info("步骤 4/6: 数据赋值")
|
||||
dialog_data_list = await self._assign_extracted_data(
|
||||
@@ -218,6 +315,8 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return result
|
||||
|
||||
@@ -732,6 +831,10 @@ class ExtractionOrchestrator:
|
||||
包含所有节点和边的元组
|
||||
"""
|
||||
logger.info("开始创建节点和边")
|
||||
|
||||
# 进度回调:正在创建节点和边
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
|
||||
|
||||
dialogue_nodes = []
|
||||
chunk_nodes = []
|
||||
@@ -904,6 +1007,23 @@ class ExtractionOrchestrator:
|
||||
f"陈述句-实体边: {len(statement_entity_edges)}, "
|
||||
f"实体-实体边: {len(entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:只输出关系创建结果
|
||||
if self.progress_callback:
|
||||
# 输出关系创建结果
|
||||
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
|
||||
|
||||
# 进度回调:创建节点和边完成,传递结果统计
|
||||
nodes_edges_stats = {
|
||||
"dialogue_nodes_count": len(dialogue_nodes),
|
||||
"chunk_nodes_count": len(chunk_nodes),
|
||||
"statement_nodes_count": len(statement_nodes),
|
||||
"entity_nodes_count": len(entity_nodes),
|
||||
"statement_chunk_edges_count": len(statement_chunk_edges),
|
||||
"statement_entity_edges_count": len(statement_entity_edges),
|
||||
"entity_entity_edges_count": len(entity_entity_edges),
|
||||
}
|
||||
await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
@@ -950,6 +1070,11 @@ class ExtractionOrchestrator:
|
||||
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
|
||||
"""
|
||||
logger.info("开始两阶段实体去重和消歧")
|
||||
|
||||
# 进度回调:正在去重消歧
|
||||
if self.progress_callback:
|
||||
await self.progress_callback("deduplication", "正在去重消歧...")
|
||||
|
||||
logger.info(
|
||||
f"去重前: {len(entity_nodes)} 个实体节点, "
|
||||
f"{len(statement_entity_edges)} 条陈述句-实体边, "
|
||||
@@ -963,7 +1088,7 @@ class ExtractionOrchestrator:
|
||||
# 只执行第一层去重
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
|
||||
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
|
||||
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
|
||||
entity_nodes,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
@@ -972,6 +1097,9 @@ class ExtractionOrchestrator:
|
||||
dedup_config=self.config.deduplication,
|
||||
)
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
|
||||
|
||||
result_tuple = (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -1009,7 +1137,11 @@ class ExtractionOrchestrator:
|
||||
_,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = result_tuple
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||
|
||||
logger.info(
|
||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||
@@ -1021,6 +1153,46 @@ class ExtractionOrchestrator:
|
||||
f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, "
|
||||
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
|
||||
)
|
||||
|
||||
# 进度回调:输出去重消歧的具体结果
|
||||
if self.progress_callback:
|
||||
# 分析实体合并情况
|
||||
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出去重合并的实体示例
|
||||
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
|
||||
dedup_result = {
|
||||
"result_type": "entity_merge",
|
||||
"merged_entity_name": merge_detail["main_entity_name"],
|
||||
"merged_count": merge_detail["merged_count"],
|
||||
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
|
||||
|
||||
# 分析实体消歧情况
|
||||
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
|
||||
|
||||
# 输出实体消歧的结果
|
||||
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
|
||||
disamb_result = {
|
||||
"result_type": "entity_disambiguation",
|
||||
"disambiguated_entity_name": disamb_detail["entity_name"],
|
||||
"disambiguation_type": disamb_detail["disamb_type"],
|
||||
"confidence": disamb_detail.get("confidence", "unknown"),
|
||||
"reason": disamb_detail.get("reason", ""),
|
||||
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
|
||||
|
||||
|
||||
|
||||
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
|
||||
await self._send_dedup_progress_callback(
|
||||
len(entity_nodes), len(final_entity_nodes),
|
||||
len(statement_entity_edges), len(final_statement_entity_edges),
|
||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||
)
|
||||
|
||||
|
||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||
try:
|
||||
@@ -1041,6 +1213,378 @@ class ExtractionOrchestrator:
|
||||
logger.error(f"两阶段去重失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _save_dedup_details(
|
||||
self,
|
||||
dedup_details: Dict[str, Any],
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||
|
||||
Args:
|
||||
dedup_details: 去重函数返回的详细记录
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
"""
|
||||
try:
|
||||
# 保存ID重定向映射
|
||||
self.id_redirect_map = dedup_details.get("id_redirect", {})
|
||||
|
||||
# 处理精确匹配的合并记录
|
||||
exact_merge_map = dedup_details.get("exact_merge_map", {})
|
||||
for key, info in exact_merge_map.items():
|
||||
merged_ids = info.get("merged_ids", set())
|
||||
if merged_ids:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "精确匹配",
|
||||
"canonical_id": info.get("canonical_id"),
|
||||
"entity_name": info.get("name"),
|
||||
"entity_type": info.get("entity_type"),
|
||||
"merged_count": len(merged_ids),
|
||||
"merged_ids": list(merged_ids)
|
||||
})
|
||||
|
||||
# 处理模糊匹配的合并记录
|
||||
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
|
||||
for record in fuzzy_merge_records:
|
||||
# 解析模糊匹配记录字符串
|
||||
# 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..."
|
||||
try:
|
||||
import re
|
||||
match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record)
|
||||
if match:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "模糊匹配",
|
||||
"canonical_id": match.group(1),
|
||||
"entity_name": match.group(3),
|
||||
"entity_type": match.group(4),
|
||||
"merged_count": 1,
|
||||
"merged_ids": [match.group(5)]
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
|
||||
|
||||
# 处理LLM去重的合并记录
|
||||
llm_decision_records = dedup_details.get("llm_decision_records", [])
|
||||
for record in llm_decision_records:
|
||||
if "[LLM去重]" in str(record):
|
||||
try:
|
||||
import re
|
||||
# 格式: "[LLM去重] 同名类型相似 name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||
match = re.search(r"同名类型相似 ([^(]+)(([^)]+))\|([^(]+)(([^)]+))", record)
|
||||
if match:
|
||||
self.dedup_merge_records.append({
|
||||
"type": "LLM去重",
|
||||
"entity_name": match.group(1),
|
||||
"entity_type": f"{match.group(2)}|{match.group(4)}",
|
||||
"merged_count": 1,
|
||||
"merged_ids": []
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
|
||||
|
||||
# 处理消歧记录
|
||||
disamb_records = dedup_details.get("disamb_records", [])
|
||||
for record in disamb_records:
|
||||
if "[DISAMB阻断]" in str(record):
|
||||
try:
|
||||
import re
|
||||
# 格式: "[DISAMB阻断] name1(type1)|name2(type2) | conf=0.xx | reason=..."
|
||||
content = str(record).replace("[DISAMB阻断]", "").strip()
|
||||
match = re.search(r"([^(]+)(([^)]+))\|([^(]+)(([^)]+))", content)
|
||||
if match:
|
||||
entity1_name = match.group(1).strip()
|
||||
entity1_type = match.group(2)
|
||||
entity2_name = match.group(3).strip()
|
||||
entity2_type = match.group(4)
|
||||
|
||||
# 提取置信度和原因
|
||||
conf_match = re.search(r"conf=([0-9.]+)", str(record))
|
||||
confidence = conf_match.group(1) if conf_match else "unknown"
|
||||
|
||||
reason_match = re.search(r"reason=([^|]+)", str(record))
|
||||
reason = reason_match.group(1).strip() if reason_match else ""
|
||||
|
||||
self.dedup_disamb_records.append({
|
||||
"entity_name": entity1_name,
|
||||
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
|
||||
"confidence": confidence,
|
||||
"reason": reason[:100] + "..." if len(reason) > 100 else reason
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||
|
||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||
|
||||
async def _analyze_entity_merges(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||
|
||||
Args:
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
|
||||
Returns:
|
||||
合并详情列表,每个元素包含主实体名称和合并数量
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的合并记录
|
||||
if self.dedup_merge_records:
|
||||
# 按合并数量排序,返回前几个
|
||||
sorted_records = sorted(
|
||||
self.dedup_merge_records,
|
||||
key=lambda x: x.get("merged_count", 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
merge_info = []
|
||||
for record in sorted_records:
|
||||
merge_info.append({
|
||||
"main_entity_name": record.get("entity_name", "未知实体"),
|
||||
"merged_count": record.get("merged_count", 1)
|
||||
})
|
||||
|
||||
return merge_info
|
||||
|
||||
# 如果没有保存的记录,返回空列表
|
||||
logger.info("未找到实体合并记录")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def _analyze_entity_disambiguation(
|
||||
self,
|
||||
original_entities: List[ExtractedEntityNode],
|
||||
final_entities: List[ExtractedEntityNode]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||
|
||||
Args:
|
||||
original_entities: 去重前的实体列表
|
||||
final_entities: 去重后的实体列表
|
||||
|
||||
Returns:
|
||||
消歧详情列表,每个元素包含实体名称和消歧类型
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的消歧记录
|
||||
if self.dedup_disamb_records:
|
||||
return self.dedup_disamb_records
|
||||
|
||||
# 如果没有保存的记录,返回空列表
|
||||
logger.info("未找到实体消歧记录")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _get_entity_type_display_name(self, entity_type: str) -> str:
|
||||
"""
|
||||
获取实体类型的中文显示名称
|
||||
|
||||
Args:
|
||||
entity_type: 英文实体类型
|
||||
|
||||
Returns:
|
||||
中文显示名称
|
||||
"""
|
||||
type_mapping = {
|
||||
"Person": "人物实体节点",
|
||||
"Organization": "组织实体节点",
|
||||
"ORG": "组织实体节点",
|
||||
"Location": "地点实体节点",
|
||||
"LOC": "地点实体节点",
|
||||
"Event": "事件实体节点",
|
||||
"Concept": "概念实体节点",
|
||||
"Time": "时间实体节点",
|
||||
"Position": "职位实体节点",
|
||||
"WorkRole": "职业实体节点",
|
||||
"System": "系统实体节点",
|
||||
"Policy": "政策实体节点",
|
||||
"HistoricalPeriod": "历史时期实体节点",
|
||||
"HistoricalState": "历史国家实体节点",
|
||||
"HistoricalEvent": "历史事件实体节点",
|
||||
"EconomicFactor": "经济因素实体节点",
|
||||
"Condition": "条件实体节点",
|
||||
"Numeric": "数值实体节点"
|
||||
}
|
||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||
|
||||
async def _output_relationship_creation_results(
|
||||
self,
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
):
|
||||
"""
|
||||
输出关系创建结果
|
||||
|
||||
Args:
|
||||
entity_entity_edges: 实体-实体边列表
|
||||
entity_nodes: 实体节点列表
|
||||
"""
|
||||
try:
|
||||
# 创建实体ID到名称的映射
|
||||
entity_id_to_name = {node.id: node.name for node in entity_nodes}
|
||||
|
||||
# 输出关系创建结果
|
||||
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
|
||||
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
|
||||
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
|
||||
relation_type = edge.relation_type
|
||||
|
||||
relationship_result = {
|
||||
"result_type": "relationship_creation",
|
||||
"relationship_index": i + 1,
|
||||
"source_entity": source_name,
|
||||
"relation_type": relation_type,
|
||||
"target_entity": target_name,
|
||||
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
|
||||
}
|
||||
|
||||
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||
|
||||
async def _send_dedup_progress_callback(
|
||||
self,
|
||||
original_entities: int,
|
||||
final_entities: int,
|
||||
original_stmt_edges: int,
|
||||
final_stmt_edges: int,
|
||||
original_ent_edges: int,
|
||||
final_ent_edges: int,
|
||||
):
|
||||
"""
|
||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||
|
||||
Args:
|
||||
original_entities: 去重前实体数量
|
||||
final_entities: 去重后实体数量
|
||||
original_stmt_edges: 去重前陈述句-实体边数量
|
||||
final_stmt_edges: 去重后陈述句-实体边数量
|
||||
original_ent_edges: 去重前实体-实体边数量
|
||||
final_ent_edges: 去重后实体-实体边数量
|
||||
"""
|
||||
try:
|
||||
# 解析去重消歧报告文件,获取具体的去重和消歧效果
|
||||
dedup_details = await self._parse_dedup_report()
|
||||
|
||||
# 计算去重效果统计
|
||||
entities_reduced = original_entities - final_entities
|
||||
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
|
||||
ent_edges_reduced = original_ent_edges - final_ent_edges
|
||||
|
||||
# 构建进度回调数据
|
||||
dedup_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": entities_reduced,
|
||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
||||
},
|
||||
"statement_entity_edges": {
|
||||
"original_count": original_stmt_edges,
|
||||
"final_count": final_stmt_edges,
|
||||
"reduced_count": stmt_edges_reduced,
|
||||
},
|
||||
"entity_entity_edges": {
|
||||
"original_count": original_ent_edges,
|
||||
"final_count": final_ent_edges,
|
||||
"reduced_count": ent_edges_reduced,
|
||||
},
|
||||
"dedup_examples": dedup_details.get("dedup_examples", []),
|
||||
"disamb_examples": dedup_details.get("disamb_examples", []),
|
||||
"summary": {
|
||||
"total_merges": dedup_details.get("total_merges", 0),
|
||||
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
|
||||
}
|
||||
}
|
||||
|
||||
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
|
||||
# 即使解析失败,也发送基本的统计信息
|
||||
try:
|
||||
basic_stats = {
|
||||
"entities": {
|
||||
"original_count": original_entities,
|
||||
"final_count": final_entities,
|
||||
"reduced_count": original_entities - final_entities,
|
||||
},
|
||||
"summary": f"实体去重合并{original_entities - final_entities}个"
|
||||
}
|
||||
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
|
||||
except Exception as e2:
|
||||
logger.error(f"发送基本去重统计失败: {e2}", exc_info=True)
|
||||
|
||||
async def _parse_dedup_report(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取去重消歧报告,直接使用内存中的记录(不再解析日志文件)
|
||||
|
||||
Returns:
|
||||
包含去重和消歧详细信息的字典
|
||||
"""
|
||||
try:
|
||||
# 直接使用保存的记录构建报告
|
||||
dedup_examples = []
|
||||
disamb_examples = []
|
||||
total_merges = 0
|
||||
total_disambiguations = 0
|
||||
|
||||
# 处理合并记录
|
||||
for record in self.dedup_merge_records:
|
||||
merge_count = record.get("merged_count", 0)
|
||||
total_merges += merge_count
|
||||
|
||||
dedup_examples.append({
|
||||
"type": record.get("type", "未知"),
|
||||
"entity_name": record.get("entity_name", "未知实体"),
|
||||
"entity_type": record.get("entity_type", "未知类型"),
|
||||
"merge_count": merge_count,
|
||||
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}个"
|
||||
})
|
||||
|
||||
# 处理消歧记录
|
||||
for record in self.dedup_disamb_records:
|
||||
total_disambiguations += 1
|
||||
|
||||
# 从消歧类型中提取实体类型信息
|
||||
disamb_type = record.get("disamb_type", "")
|
||||
entity_name = record.get("entity_name", "未知实体")
|
||||
|
||||
disamb_examples.append({
|
||||
"entity1_name": entity_name,
|
||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
||||
"entity2_name": entity_name,
|
||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||
"description": f"{entity_name},消歧区分成功"
|
||||
})
|
||||
|
||||
return {
|
||||
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
|
||||
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
|
||||
"total_merges": total_merges,
|
||||
"total_disambiguations": total_disambiguations,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取去重报告失败: {e}", exc_info=True)
|
||||
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 数据加载和预处理函数
|
||||
|
||||
436
api/app/core/workflow/executor.py
Normal file
436
api/app/core/workflow/executor.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""
|
||||
工作流执行器
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
|
||||
from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowExecutor:
|
||||
"""工作流执行器
|
||||
|
||||
负责将工作流配置转换为 LangGraph 并执行。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""初始化执行器
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
"""
|
||||
self.workflow_config = workflow_config
|
||||
self.execution_id = execution_id
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""准备初始状态(注入系统变量和会话变量)
|
||||
|
||||
变量命名空间:
|
||||
- sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等)
|
||||
- conv.xxx - 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx - 节点输出(执行时动态生成)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
# 构建分层的变量结构
|
||||
variables = {
|
||||
"sys": {
|
||||
"message": user_message, # 用户消息
|
||||
"conversation_id": input_data.get("conversation_id"), # 会话 ID
|
||||
"execution_id": self.execution_id, # 执行 ID
|
||||
"workspace_id": self.workspace_id, # 工作空间 ID
|
||||
"user_id": self.user_id, # 用户 ID
|
||||
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
|
||||
},
|
||||
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [HumanMessage(content=user_message)],
|
||||
"variables": variables,
|
||||
"node_outputs": {},
|
||||
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> StateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
Returns:
|
||||
编译后的状态图
|
||||
"""
|
||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||
|
||||
# 1. 创建状态图
|
||||
workflow = StateGraph(WorkflowState)
|
||||
|
||||
# 2. 添加所有节点(包括 start 和 end)
|
||||
start_node_id = None
|
||||
end_node_ids = []
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == "start":
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
if node_instance:
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
return node_func
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
logger.debug(f"添加节点: {node_id} (type={node_type})")
|
||||
|
||||
# 3. 添加边
|
||||
# 从 START 连接到 start 节点
|
||||
if start_node_id:
|
||||
workflow.add_edge(START, start_node_id)
|
||||
logger.debug(f"添加边: START -> {start_node_id}")
|
||||
|
||||
for edge in self.edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
condition = edge.get("condition")
|
||||
|
||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||
if source == start_node_id:
|
||||
# 但要连接 start 到下一个节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 处理到 end 节点的边
|
||||
if target in end_node_ids:
|
||||
# 连接到 end 节点
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
continue
|
||||
|
||||
# 跳过错误边(在节点内部处理)
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
if condition:
|
||||
# 条件边
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
|
||||
workflow.add_conditional_edges(source, router)
|
||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||
else:
|
||||
# 普通边
|
||||
workflow.add_edge(source, target)
|
||||
logger.debug(f"添加边: {source} -> {target}")
|
||||
|
||||
# 从 end 节点连接到 END
|
||||
for end_node_id in end_node_ids:
|
||||
workflow.add_edge(end_node_id, END)
|
||||
logger.debug(f"添加边: {end_node_id} -> END")
|
||||
|
||||
# 4. 编译图
|
||||
graph = workflow.compile()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
return graph
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(非流式)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据,包含 message 和 variables
|
||||
|
||||
Returns:
|
||||
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
|
||||
"""
|
||||
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
|
||||
|
||||
# 记录开始时间
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 执行工作流
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state)
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# 提取节点输出(现在包含 start 和 end 节点)
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
|
||||
# 提取最终输出(从最后一个非 start/end 节点)
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
|
||||
# 聚合 token 使用情况
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
|
||||
# 提取 conversation_id(从 start 节点输出)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
if node_output.get("node_type") == "start":
|
||||
conversation_id = node_output.get("output", {}).get("conversation_id")
|
||||
break
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 1. 构建图
|
||||
graph = self.build_graph()
|
||||
|
||||
# 2. 初始化状态(自动注入系统变量)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
|
||||
# 3. 流式执行工作流
|
||||
try:
|
||||
# 使用 astream 获取节点级别的更新
|
||||
async for event in graph.astream(initial_state, stream_mode="updates"):
|
||||
for node_name, state_update in event.items():
|
||||
yield {
|
||||
"type": "node_complete",
|
||||
"node": node_name,
|
||||
"data": state_update,
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
|
||||
|
||||
# 发送完成事件
|
||||
yield {
|
||||
"type": "workflow_complete",
|
||||
"execution_id": self.execution_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": self.execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
return None
|
||||
|
||||
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
|
||||
如果没有 token 使用信息,返回 None
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return await executor.execute(input_data)
|
||||
|
||||
|
||||
async def execute_workflow_stream(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""执行工作流(流式,便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
195
api/app/core/workflow/expression_evaluator.py
Normal file
195
api/app/core/workflow/expression_evaluator.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
安全的表达式求值器
|
||||
|
||||
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""安全的表达式求值器"""
|
||||
|
||||
# 保留的命名空间
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""安全地评估表达式
|
||||
|
||||
Args:
|
||||
expression: 表达式字符串,如 "{{var.score}} > 0.8"
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
表达式求值结果
|
||||
|
||||
Raises:
|
||||
ValueError: 表达式无效或求值失败
|
||||
|
||||
Examples:
|
||||
>>> evaluator = ExpressionEvaluator()
|
||||
>>> evaluator.evaluate(
|
||||
... "var.score > 0.8",
|
||||
... {"score": 0.9},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
|
||||
>>> evaluator.evaluate(
|
||||
... "node.intent.output == '售前咨询'",
|
||||
... {},
|
||||
... {"intent": {"output": "售前咨询"}},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
# 移除 Jinja2 模板语法的花括号(如果存在)
|
||||
expression = expression.strip()
|
||||
if expression.startswith("{{") and expression.endswith("}}"):
|
||||
expression = expression[2:-2].strip()
|
||||
|
||||
# 构建命名空间上下文
|
||||
context = {
|
||||
"var": variables, # 用户变量
|
||||
"node": node_outputs, # 节点输出
|
||||
"sys": system_vars or {}, # 系统变量
|
||||
}
|
||||
|
||||
# 为了向后兼容,也支持直接访问(但会在日志中警告)
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs
|
||||
|
||||
try:
|
||||
# simpleeval 只支持安全的操作:
|
||||
# - 算术运算: +, -, *, /, //, %, **
|
||||
# - 比较运算: ==, !=, <, <=, >, >=
|
||||
# - 逻辑运算: and, or, not
|
||||
# - 成员运算: in, not in
|
||||
# - 属性访问: obj.attr
|
||||
# - 字典/列表访问: obj["key"], obj[0]
|
||||
# 不支持:函数调用、导入、赋值等危险操作
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法无效: {e}")
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式语法错误: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
|
||||
raise ValueError(f"表达式求值失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估布尔表达式(用于条件判断)
|
||||
|
||||
Args:
|
||||
expression: 布尔表达式
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.evaluate_bool(
|
||||
... "var.count >= 10 and var.status == 'active'",
|
||||
... {"count": 15, "status": "active"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
True
|
||||
"""
|
||||
result = ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""验证变量名是否合法
|
||||
|
||||
Args:
|
||||
variables: 变量定义列表
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> ExpressionEvaluator.validate_variable_names([
|
||||
... {"name": "user_input"},
|
||||
... {"name": "var"} # 保留字
|
||||
... ])
|
||||
["变量名 'var' 是保留的命名空间,请使用其他名称"]
|
||||
"""
|
||||
errors = []
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
# 检查是否为保留命名空间
|
||||
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
|
||||
)
|
||||
|
||||
# 检查是否为有效的 Python 标识符
|
||||
if not var_name.isidentifier():
|
||||
errors.append(
|
||||
f"变量名 '{var_name}' 不是有效的标识符"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""评估表达式(便捷函数)"""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""评估条件表达式(便捷函数)"""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, variables, node_outputs, system_vars
|
||||
)
|
||||
24
api/app/core/workflow/nodes/__init__.py
Normal file
24
api/app/core/workflow/nodes/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
工作流节点实现
|
||||
|
||||
提供各种类型的节点实现,用于工作流执行。
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"TransformNode",
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
]
|
||||
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
6
api/app/core/workflow/nodes/agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Agent 节点"""
|
||||
|
||||
from app.core.workflow.nodes.agent.node import AgentNode
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
|
||||
__all__ = ["AgentNode", "AgentNodeConfig"]
|
||||
71
api/app/core/workflow/nodes/agent/config.py
Normal file
71
api/app/core/workflow/nodes/agent/config.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Agent 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class AgentNodeConfig(BaseNodeConfig):
|
||||
"""Agent 节点配置
|
||||
|
||||
调用已配置的 Agent 执行任务。
|
||||
"""
|
||||
|
||||
agent_id: str = Field(
|
||||
...,
|
||||
description="Agent 配置 ID"
|
||||
)
|
||||
|
||||
message: str = Field(
|
||||
default="{{ sys.message }}",
|
||||
description="发送给 Agent 的消息,支持模板变量"
|
||||
)
|
||||
|
||||
conversation_id: str | None = Field(
|
||||
default=None,
|
||||
description="会话 ID,用于多轮对话"
|
||||
)
|
||||
|
||||
variables: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="传递给 Agent 的变量"
|
||||
)
|
||||
|
||||
timeout: int = Field(
|
||||
default=300,
|
||||
ge=1,
|
||||
le=3600,
|
||||
description="超时时间(秒)"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="Agent 的回复内容"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="token_usage",
|
||||
type=VariableType.OBJECT,
|
||||
description="Token 使用情况"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"agent_id": "uuid-here",
|
||||
"message": "{{ sys.message }}",
|
||||
"timeout": 300,
|
||||
"description": "调用客服 Agent"
|
||||
}
|
||||
}
|
||||
152
api/app/core/workflow/nodes/agent/node.py
Normal file
152
api/app/core/workflow/nodes/agent/node.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Agent 节点实现
|
||||
|
||||
调用已发布的 Agent 应用。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentNode(BaseNode):
|
||||
"""Agent 节点
|
||||
|
||||
支持流式和非流式输出。
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "agent",
|
||||
"config": {
|
||||
"agent_id": "uuid", # Agent 的 release_id
|
||||
"message": "{{var.user_input}}"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
(draft_service, release, message): 服务实例、发布配置、消息
|
||||
"""
|
||||
# 1. 渲染消息
|
||||
message_template = self.config.get("message", "")
|
||||
message = self._render_template(message_template, state)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
agent_id = self.config.get("agent_id")
|
||||
if not agent_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||
|
||||
db = next(get_db())
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
|
||||
|
||||
return {
|
||||
"messages": [AIMessage(content=response)],
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": response,
|
||||
"status": "completed",
|
||||
"meta_data": result.get("meta_data", {})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
# 执行 Agent(流式)
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=state.get("workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=state.get("variables", {})
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 最后返回完整结果
|
||||
yield {
|
||||
"type": "complete",
|
||||
"messages": [AIMessage(content=full_response)],
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": full_response,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
109
api/app/core/workflow/nodes/base_config.py
Normal file
109
api/app/core/workflow/nodes/base_config.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""节点配置基类
|
||||
|
||||
定义所有节点配置的通用字段和数据结构。
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VariableType(StrEnum):
|
||||
"""变量类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义
|
||||
|
||||
定义工作流或节点的输入/输出变量。
|
||||
这是一个通用的数据结构,可以在多个地方使用。
|
||||
"""
|
||||
|
||||
name: str = Field(
|
||||
...,
|
||||
description="变量名称"
|
||||
)
|
||||
|
||||
type: VariableType = Field(
|
||||
default=VariableType.STRING,
|
||||
description="变量类型"
|
||||
)
|
||||
|
||||
required: bool = Field(
|
||||
default=False,
|
||||
description="是否必需"
|
||||
)
|
||||
|
||||
default: str | int | float | bool | list | dict | None = Field(
|
||||
default=None,
|
||||
description="默认值"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="变量描述"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
},
|
||||
{
|
||||
"name": "max_length",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 1000,
|
||||
"description": "最大长度"
|
||||
},
|
||||
{
|
||||
"name": "enable_search",
|
||||
"type": "boolean",
|
||||
"required": True,
|
||||
"description": "是否启用搜索"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class BaseNodeConfig(BaseModel):
|
||||
"""节点配置基类
|
||||
|
||||
所有节点配置都应该继承此基类。
|
||||
|
||||
通用字段:
|
||||
- name: 节点名称(显示名称)
|
||||
- description: 节点描述
|
||||
- tags: 节点标签(用于分类和搜索)
|
||||
"""
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="节点描述,说明节点的作用"
|
||||
)
|
||||
|
||||
tags: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="节点标签,用于分类和搜索"
|
||||
)
|
||||
|
||||
class Config:
|
||||
"""Pydantic 配置"""
|
||||
# 允许额外字段(向后兼容)
|
||||
extra = "allow"
|
||||
556
api/app/core/workflow/nodes/base_node.py
Normal file
556
api/app/core/workflow/nodes/base_node.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
工作流节点基类
|
||||
|
||||
定义节点的基本接口和通用功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict, Annotated
|
||||
from operator import add
|
||||
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""工作流状态
|
||||
|
||||
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
|
||||
"""
|
||||
# 消息列表(追加模式)
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
variables: dict[str, Any]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
|
||||
# 格式:{node_id: business_result}
|
||||
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# 执行上下文
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# 错误信息(用于错误边)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""节点基类
|
||||
|
||||
所有节点类型都应该继承此基类,实现 execute 方法。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
self.node_config = node_config
|
||||
self.workflow_config = workflow_config
|
||||
self.node_id = node_config["id"]
|
||||
self.node_type = node_config["type"]
|
||||
self.node_name = node_config.get("name", self.node_id)
|
||||
# 使用 or 运算符处理 None 值
|
||||
self.config = node_config.get("config") or {}
|
||||
self.error_handling = node_config.get("error_handling") or {}
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""执行节点业务逻辑(非流式)
|
||||
|
||||
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
|
||||
BaseNode 会自动包装成标准格式。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
业务结果(任意类型)
|
||||
|
||||
Examples:
|
||||
>>> # LLM 节点
|
||||
>>> return "这是 AI 的回复"
|
||||
|
||||
>>> # Transform 节点
|
||||
>>> return {"processed_data": [...]}
|
||||
|
||||
>>> # Start/End 节点
|
||||
>>> return {"message": "开始", "conversation_id": "xxx"}
|
||||
"""
|
||||
pass
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""执行节点业务逻辑(流式)
|
||||
|
||||
子类可以重写此方法以支持流式输出。
|
||||
默认实现:执行非流式方法并一次性返回。
|
||||
|
||||
节点需要:
|
||||
1. yield 中间结果(如文本片段)
|
||||
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
业务数据(chunk)或完成标记
|
||||
|
||||
Examples:
|
||||
>>> # 流式 LLM 节点
|
||||
>>> full_response = ""
|
||||
>>> async for chunk in llm.astream(prompt):
|
||||
... full_response += chunk
|
||||
... yield chunk # yield 文本片段
|
||||
>>>
|
||||
>>> # 最后 yield 完成标记
|
||||
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
|
||||
"""
|
||||
result = await self.execute(state)
|
||||
# 默认实现:直接 yield 完成标记
|
||||
yield {"__final__": True, "result": result}
|
||||
|
||||
def supports_streaming(self) -> bool:
|
||||
"""节点是否支持流式输出
|
||||
|
||||
Returns:
|
||||
是否支持流式输出
|
||||
"""
|
||||
# 检查子类是否重写了 execute_stream 方法
|
||||
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
|
||||
|
||||
def get_timeout(self) -> int:
|
||||
"""获取超时时间(秒)
|
||||
|
||||
Returns:
|
||||
超时时间
|
||||
"""
|
||||
return 60
|
||||
# return self.error_handling.get("timeout", 60)
|
||||
|
||||
async def run(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行节点(带错误处理和输出包装,非流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute() 方法
|
||||
3. 将业务结果包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 调用节点的业务逻辑
|
||||
business_result = await asyncio.wait_for(
|
||||
self.execute(state),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 提取处理后的输出(调用子类的 _extract_output)
|
||||
extracted_output = self._extract_output(business_result)
|
||||
|
||||
# 包装成标准输出格式
|
||||
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
|
||||
|
||||
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
|
||||
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
|
||||
if isinstance(extracted_output, dict):
|
||||
runtime_var = extracted_output
|
||||
else:
|
||||
runtime_var = {"output": extracted_output}
|
||||
|
||||
# 返回包装后的输出和运行时变量
|
||||
return {
|
||||
**wrapped_output,
|
||||
"runtime_vars": {
|
||||
self.node_id: runtime_var
|
||||
}
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
return self._wrap_error(str(e), elapsed_time, state)
|
||||
|
||||
async def run_stream(self, state: WorkflowState):
|
||||
"""执行节点(带错误处理和输出包装,流式)
|
||||
|
||||
这个方法由 Executor 调用,负责:
|
||||
1. 时间统计
|
||||
2. 调用节点的 execute_stream() 方法
|
||||
3. 将业务数据包装成标准输出格式
|
||||
4. 错误处理
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
标准化的流式事件
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout = self.get_timeout()
|
||||
|
||||
# 累积完整结果(用于最后的包装)
|
||||
chunks = []
|
||||
final_result = None
|
||||
|
||||
# 使用异步生成器包装,支持超时
|
||||
async def stream_with_timeout():
|
||||
nonlocal final_result
|
||||
loop_start = asyncio.get_event_loop().time()
|
||||
|
||||
async for item in self.execute_stream(state):
|
||||
# 检查超时
|
||||
if asyncio.get_event_loop().time() - loop_start > timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
# 检查是否是完成标记
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# 字符串是 chunk
|
||||
chunks.append(item)
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": item,
|
||||
"full_content": "".join(chunks)
|
||||
}
|
||||
else:
|
||||
# 其他类型也当作 chunk 处理
|
||||
chunks.append(str(item))
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": str(item),
|
||||
"full_content": "".join(chunks)
|
||||
}
|
||||
|
||||
async for chunk_event in stream_with_timeout():
|
||||
yield chunk_event
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 包装最终结果
|
||||
final_output = self._wrap_output(final_result, elapsed_time, state)
|
||||
yield {
|
||||
"type": "complete",
|
||||
**final_output
|
||||
}
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
|
||||
}
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
|
||||
yield {
|
||||
"type": "error",
|
||||
**self._wrap_error(str(e), elapsed_time, state)
|
||||
}
|
||||
|
||||
def _wrap_output(
|
||||
self,
|
||||
business_result: Any,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
) -> dict[str, Any]:
|
||||
"""将业务结果包装成标准输出格式
|
||||
|
||||
Args:
|
||||
business_result: 节点返回的业务结果
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 提取输入数据(用于记录)
|
||||
input_data = self._extract_input(state)
|
||||
|
||||
# 提取 token 使用情况(如果有)
|
||||
token_usage = self._extract_token_usage(business_result)
|
||||
|
||||
# 提取实际输出(去除元数据)
|
||||
output = self._extract_output(business_result)
|
||||
|
||||
# 构建标准节点输出
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"node_name": self.node_name,
|
||||
"status": "completed",
|
||||
"input": input_data,
|
||||
"output": output,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": None
|
||||
}
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
}
|
||||
}
|
||||
|
||||
def _wrap_error(
|
||||
self,
|
||||
error_message: str,
|
||||
elapsed_time: float,
|
||||
state: WorkflowState
|
||||
) -> dict[str, Any]:
|
||||
"""将错误包装成标准输出格式
|
||||
|
||||
Args:
|
||||
error_message: 错误信息
|
||||
elapsed_time: 执行耗时
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
标准化的状态更新字典
|
||||
"""
|
||||
# 查找错误边
|
||||
error_edge = self._find_error_edge()
|
||||
|
||||
# 提取输入数据
|
||||
input_data = self._extract_input(state)
|
||||
|
||||
# 构建错误输出
|
||||
node_output = {
|
||||
"node_id": self.node_id,
|
||||
"node_type": self.node_type,
|
||||
"node_name": self.node_name,
|
||||
"status": "failed",
|
||||
"input": input_data,
|
||||
"output": None,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None,
|
||||
"error": error_message
|
||||
}
|
||||
|
||||
if error_edge:
|
||||
# 有错误边:记录错误并继续
|
||||
logger.warning(
|
||||
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
|
||||
)
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: node_output
|
||||
},
|
||||
"error": error_message,
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
# 无错误边:抛出异常停止工作流
|
||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取节点输入数据(用于记录)
|
||||
|
||||
子类可以重写此方法来自定义输入记录。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
# 默认返回配置
|
||||
return {"config": self.config}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""从业务结果中提取实际输出
|
||||
|
||||
子类可以重写此方法来自定义输出提取。
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
Returns:
|
||||
实际输出
|
||||
"""
|
||||
# 默认直接返回业务结果
|
||||
return business_result
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从业务结果中提取 token 使用情况
|
||||
|
||||
子类可以重写此方法来提取 token 信息。
|
||||
|
||||
Args:
|
||||
business_result: 业务结果
|
||||
|
||||
Returns:
|
||||
token 使用情况或 None
|
||||
"""
|
||||
# 默认返回 None
|
||||
return None
|
||||
|
||||
def _find_error_edge(self) -> dict[str, Any] | None:
|
||||
"""查找错误边
|
||||
|
||||
Returns:
|
||||
错误边配置或 None
|
||||
"""
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("source") == self.node_id and edge.get("type") == "error":
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.xxx: 会话变量(跨多轮对话保持)
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
"""评估条件表达式
|
||||
|
||||
支持的变量命名空间:
|
||||
- sys.xxx: 系统变量
|
||||
- conv.xxx: 会话变量
|
||||
- node_id.xxx: 节点输出
|
||||
|
||||
Args:
|
||||
expression: 条件表达式
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
布尔值结果
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
|
||||
# 处理 state 为 None 的情况
|
||||
if state is None:
|
||||
state = {}
|
||||
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
|
||||
"""获取变量池实例
|
||||
|
||||
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
VariablePool 实例
|
||||
|
||||
Examples:
|
||||
>>> pool = self.get_variable_pool(state)
|
||||
>>> message = pool.get("sys.message")
|
||||
>>> llm_output = pool.get("llm_qa.output")
|
||||
"""
|
||||
return VariablePool(state)
|
||||
|
||||
def get_variable(
|
||||
self,
|
||||
selector: list[str] | str,
|
||||
state: WorkflowState,
|
||||
default: Any = None
|
||||
) -> Any:
|
||||
"""获取变量值(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> message = self.get_variable("sys.message", state)
|
||||
>>> output = self.get_variable(["llm_qa", "output"], state)
|
||||
>>> custom = self.get_variable("var.custom", state, default="默认值")
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.get(selector, default=default)
|
||||
|
||||
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
|
||||
"""检查变量是否存在(便捷方法)
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> if self.has_variable("llm_qa.output", state):
|
||||
... output = self.get_variable("llm_qa.output", state)
|
||||
"""
|
||||
pool = VariablePool(state)
|
||||
return pool.has(selector)
|
||||
29
api/app/core/workflow/nodes/configs.py
Normal file
29
api/app/core/workflow/nodes/configs.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""节点配置类统一导出
|
||||
|
||||
所有节点的配置类都在这里导出,方便使用。
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.base_config import (
|
||||
BaseNodeConfig,
|
||||
VariableDefinition,
|
||||
VariableType,
|
||||
)
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
"BaseNodeConfig",
|
||||
"VariableDefinition",
|
||||
"VariableType",
|
||||
# 节点配置
|
||||
"StartNodeConfig",
|
||||
"EndNodeConfig",
|
||||
"LLMNodeConfig",
|
||||
"MessageConfig",
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
]
|
||||
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
6
api/app/core/workflow/nodes/end/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""End 节点"""
|
||||
|
||||
from app.core.workflow.nodes.end.node import EndNode
|
||||
from app.core.workflow.nodes.end.config import EndNodeConfig
|
||||
|
||||
__all__ = ["EndNode", "EndNodeConfig"]
|
||||
37
api/app/core/workflow/nodes/end/config.py
Normal file
37
api/app/core/workflow/nodes/end/config.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""End 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class EndNodeConfig(BaseNodeConfig):
|
||||
"""End 节点配置
|
||||
|
||||
End 节点负责输出工作流的最终结果。
|
||||
"""
|
||||
|
||||
output: str = Field(
|
||||
default="工作流已完成",
|
||||
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="工作流的最终输出"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"output": "{{ llm_qa.output }}",
|
||||
"description": "输出 LLM 的回答"
|
||||
}
|
||||
}
|
||||
53
api/app/core/workflow/nodes/end/node.py
Normal file
53
api/app/core/workflow/nodes/end/node.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
End 节点实现
|
||||
|
||||
工作流的结束节点,输出最终结果。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndNode(BaseNode):
|
||||
"""End 节点
|
||||
|
||||
工作流的结束节点,根据配置的模板输出最终结果。
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
"""执行 end 节点业务逻辑
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
最终输出字符串
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
print("="*20)
|
||||
print( pool.get("start.test"))
|
||||
print("="*20)
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
# 统计信息(用于日志)
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
print("="*20)
|
||||
print(output)
|
||||
print("="*20)
|
||||
return output
|
||||
15
api/app/core/workflow/nodes/enums.py
Normal file
15
api/app/core/workflow/nodes/enums.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from enum import StrEnum
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
END = "end"
|
||||
ANSWER = "answer"
|
||||
LLM = "llm"
|
||||
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
|
||||
IF_ELSE = "if-else"
|
||||
CODE = "code"
|
||||
TRANSFORM = "transform"
|
||||
QUESTION_CLASSIFIER = "question-classifier"
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
6
api/app/core/workflow/nodes/llm/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""LLM 节点"""
|
||||
|
||||
from app.core.workflow.nodes.llm.node import LLMNode
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
|
||||
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]
|
||||
141
api/app/core/workflow/nodes/llm/config.py
Normal file
141
api/app/core/workflow/nodes/llm/config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
"""验证角色"""
|
||||
allowed_roles = ["system", "user", "human", "assistant", "ai"]
|
||||
if v.lower() not in allowed_roles:
|
||||
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
|
||||
return v.lower()
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
支持两种配置方式:
|
||||
1. 简单模式:使用 prompt 字段
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
description="提示词模板(简单模式),支持变量引用"
|
||||
)
|
||||
|
||||
# 消息模式(推荐)
|
||||
messages: list[MessageConfig] | None = Field(
|
||||
default=None,
|
||||
description="消息列表(消息模式),支持多轮对话"
|
||||
)
|
||||
|
||||
# 模型参数
|
||||
temperature: float | None = Field(
|
||||
default=0.7,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
description="温度参数,控制输出的随机性"
|
||||
)
|
||||
|
||||
max_tokens: int | None = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=32000,
|
||||
description="最大生成 token 数"
|
||||
)
|
||||
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="频率惩罚"
|
||||
)
|
||||
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="存在惩罚"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="LLM 生成的文本输出"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="token_usage",
|
||||
type=VariableType.OBJECT,
|
||||
description="Token 使用情况"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"model_id": "uuid-here",
|
||||
"prompt": "请回答:{{ sys.message }}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
},
|
||||
{
|
||||
"model_id": "uuid-here",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的 AI 助手"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{ sys.message }}"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
]
|
||||
}
|
||||
247
api/app/core/workflow/nodes/llm/node.py
Normal file
247
api/app/core/workflow/nodes/llm/node.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
LLM 节点实现
|
||||
|
||||
调用 LLM 模型进行文本生成。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models import ModelConfig
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMNode(BaseNode):
|
||||
"""LLM 节点
|
||||
|
||||
支持流式和非流式输出,使用 LangChain 标准的消息格式。
|
||||
|
||||
配置示例(支持多种消息格式):
|
||||
|
||||
1. 简单文本格式:
|
||||
{
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model_id": "uuid",
|
||||
"prompt": "请分析:{{sys.message}}",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
}
|
||||
|
||||
2. LangChain 消息格式(推荐):
|
||||
{
|
||||
"type": "llm",
|
||||
"config": {
|
||||
"model_id": "uuid",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个专业的 AI 助手。"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "{{sys.message}}"
|
||||
}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 1000
|
||||
}
|
||||
}
|
||||
|
||||
支持的角色类型:
|
||||
- system: 系统消息(SystemMessage)
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
|
||||
"""
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
elif role in ["user", "human"]:
|
||||
messages.append(HumanMessage(content=content))
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append(AIMessage(content=content))
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
prompt_template = self.config.get("prompt", "")
|
||||
prompt_or_messages = self._render_template(prompt_template, state)
|
||||
|
||||
# 2. 获取模型配置
|
||||
model_id = self.config.get("model_id")
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||
with get_db_context() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
),
|
||||
type=model_type
|
||||
)
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
|
||||
for msg in prompt_or_messages
|
||||
] if isinstance(prompt_or_messages, list) else None,
|
||||
"config": {
|
||||
"model_id": self.config.get("model_id"),
|
||||
"temperature": self.config.get("temperature"),
|
||||
"max_tokens": self.config.get("max_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
if isinstance(business_result, AIMessage):
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('prompt_tokens', 0),
|
||||
"completion_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Yields:
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
93
api/app/core/workflow/nodes/node_factory.py
Normal file
93
api/app/core/workflow/nodes/node_factory.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
节点工厂
|
||||
|
||||
根据节点类型创建相应的节点实例。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeFactory:
|
||||
"""节点工厂
|
||||
|
||||
使用工厂模式创建节点实例,便于扩展和维护。
|
||||
"""
|
||||
|
||||
# 节点类型注册表
|
||||
_node_types: dict[str, type[BaseNode]] = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
|
||||
"""注册新的节点类型
|
||||
|
||||
Args:
|
||||
node_type: 节点类型名称
|
||||
node_class: 节点类
|
||||
|
||||
Examples:
|
||||
>>> class CustomNode(BaseNode):
|
||||
... async def execute(self, state):
|
||||
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
|
||||
>>> NodeFactory.register_node_type("custom", CustomNode)
|
||||
"""
|
||||
cls._node_types[node_type] = node_class
|
||||
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def create_node(
|
||||
cls,
|
||||
node_config: dict[str, Any],
|
||||
workflow_config: dict[str, Any]
|
||||
) -> BaseNode | None:
|
||||
"""创建节点实例
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
|
||||
Returns:
|
||||
节点实例或 None(对于不支持的节点类型)
|
||||
|
||||
Raises:
|
||||
ValueError: 不支持的节点类型
|
||||
"""
|
||||
node_type = node_config.get("type")
|
||||
|
||||
# 跳过条件节点(由 LangGraph 处理)
|
||||
if node_type == "condition":
|
||||
return None
|
||||
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
raise ValueError(f"不支持的节点类型: {node_type}")
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
|
||||
@classmethod
|
||||
def get_supported_types(cls) -> list[str]:
|
||||
"""获取支持的节点类型列表
|
||||
|
||||
Returns:
|
||||
节点类型列表
|
||||
"""
|
||||
return list(cls._node_types.keys())
|
||||
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
6
api/app/core/workflow/nodes/start/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Start 节点"""
|
||||
|
||||
from app.core.workflow.nodes.start.node import StartNode
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
__all__ = ["StartNode", "StartNodeConfig"]
|
||||
87
api/app/core/workflow/nodes/start/config.py
Normal file
87
api/app/core/workflow/nodes/start/config.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Start 节点配置"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
"""Start 节点配置
|
||||
|
||||
Start 节点的作用:
|
||||
1. 标记工作流的起点
|
||||
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
|
||||
3. 输出系统变量和会话变量
|
||||
"""
|
||||
|
||||
# 自定义输入变量定义
|
||||
variables: list[VariableDefinition] = Field(
|
||||
default_factory=list,
|
||||
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="message",
|
||||
type=VariableType.STRING,
|
||||
description="用户输入的消息"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_vars",
|
||||
type=VariableType.OBJECT,
|
||||
description="会话级变量"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="execution_id",
|
||||
type=VariableType.STRING,
|
||||
description="执行 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="workspace_id",
|
||||
type=VariableType.STRING,
|
||||
description="工作空间 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="user_id",
|
||||
type=VariableType.STRING,
|
||||
description="用户 ID"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"description": "工作流开始节点",
|
||||
"variables": []
|
||||
},
|
||||
{
|
||||
"description": "带自定义变量的开始节点",
|
||||
"variables": [
|
||||
{
|
||||
"name": "language",
|
||||
"type": "string",
|
||||
"required": False,
|
||||
"default": "zh-CN",
|
||||
"description": "语言设置"
|
||||
},
|
||||
{
|
||||
"name": "max_length",
|
||||
"type": "number",
|
||||
"required": False,
|
||||
"default": 1000,
|
||||
"description": "最大长度"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
136
api/app/core/workflow/nodes/start/node.py
Normal file
136
api/app/core/workflow/nodes/start/node.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Start 节点实现
|
||||
|
||||
工作流的起始节点,定义输入变量并输出系统参数。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StartNode(BaseNode):
|
||||
"""Start 节点
|
||||
|
||||
工作流的起始节点,负责:
|
||||
1. 定义工作流的输入变量(通过配置)
|
||||
2. 输出系统变量(sys.*)
|
||||
3. 输出会话变量(conv.*)
|
||||
|
||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||
"""
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
"""初始化 Start 节点
|
||||
|
||||
Args:
|
||||
node_config: 节点配置
|
||||
workflow_config: 工作流配置
|
||||
"""
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
# 解析并验证配置
|
||||
self.typed_config = StartNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行 start 节点业务逻辑
|
||||
|
||||
Start 节点输出系统变量、会话变量和自定义变量。
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
包含系统参数、会话变量和自定义变量的字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
||||
|
||||
# 创建变量池实例(在方法内复用)
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
# 处理自定义变量(传入 pool 避免重复创建)
|
||||
custom_vars = self._process_custom_variables(pool)
|
||||
|
||||
# 返回业务数据(包含自定义变量)
|
||||
result = {
|
||||
"message": pool.get("sys.message"),
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"workspace_id": pool.get("sys.workspace_id"),
|
||||
"user_id": pool.get("sys.user_id"),
|
||||
**custom_vars # 自定义变量作为节点输出的一部分
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"节点 {self.node_id} (Start) 执行完成,"
|
||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _process_custom_variables(self, pool) -> dict[str, Any]:
|
||||
"""处理自定义变量
|
||||
|
||||
从输入数据中提取自定义变量,应用默认值和验证。
|
||||
|
||||
Args:
|
||||
pool: 变量池实例
|
||||
|
||||
Returns:
|
||||
处理后的自定义变量字典
|
||||
|
||||
Raises:
|
||||
ValueError: 缺少必需变量
|
||||
"""
|
||||
# 获取输入数据中的自定义变量
|
||||
input_variables = pool.get("sys.input_variables", default={})
|
||||
|
||||
processed = {}
|
||||
|
||||
# 遍历配置的变量定义
|
||||
for var_def in self.typed_config.variables:
|
||||
var_name = var_def.name
|
||||
|
||||
# 检查变量是否存在
|
||||
if var_name in input_variables:
|
||||
# 使用用户提供的值
|
||||
processed[var_name] = input_variables[var_name]
|
||||
|
||||
elif var_def.required:
|
||||
# 必需变量缺失
|
||||
raise ValueError(
|
||||
f"缺少必需的输入变量: {var_name}"
|
||||
+ (f" ({var_def.description})" if var_def.description else "")
|
||||
)
|
||||
|
||||
elif var_def.default is not None:
|
||||
# 使用默认值
|
||||
processed[var_name] = var_def.default
|
||||
logger.debug(
|
||||
f"变量 '{var_name}' 使用默认值: {var_def.default}"
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
输入数据字典
|
||||
"""
|
||||
pool = self.get_variable_pool(state)
|
||||
|
||||
return {
|
||||
"execution_id": pool.get("sys.execution_id"),
|
||||
"conversation_id": pool.get("sys.conversation_id"),
|
||||
"message": pool.get("sys.message"),
|
||||
"conversation_vars": pool.get_all_conversation_vars()
|
||||
}
|
||||
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
6
api/app/core/workflow/nodes/transform/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Transform 节点"""
|
||||
|
||||
from app.core.workflow.nodes.transform.node import TransformNode
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
|
||||
__all__ = ["TransformNode", "TransformNodeConfig"]
|
||||
80
api/app/core/workflow/nodes/transform/config.py
Normal file
80
api/app/core/workflow/nodes/transform/config.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Transform 节点配置"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
|
||||
|
||||
class TransformNodeConfig(BaseNodeConfig):
|
||||
"""Transform 节点配置
|
||||
|
||||
用于数据转换和处理。
|
||||
"""
|
||||
|
||||
transform_type: Literal["template", "code", "json"] = Field(
|
||||
default="template",
|
||||
description="转换类型:template(模板), code(代码), json(JSON处理)"
|
||||
)
|
||||
|
||||
# 模板模式
|
||||
template: str | None = Field(
|
||||
default=None,
|
||||
description="转换模板,支持变量引用"
|
||||
)
|
||||
|
||||
# 代码模式
|
||||
code: str | None = Field(
|
||||
default=None,
|
||||
description="Python 代码,用于数据转换"
|
||||
)
|
||||
|
||||
# JSON 模式
|
||||
json_path: str | None = Field(
|
||||
default=None,
|
||||
description="JSON 路径表达式"
|
||||
)
|
||||
|
||||
# 输入变量
|
||||
inputs: dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="输入变量映射,key 为变量名,value 为变量选择器"
|
||||
)
|
||||
|
||||
# 输出变量
|
||||
output_key: str = Field(
|
||||
default="result",
|
||||
description="输出变量的键名"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="result",
|
||||
type=VariableType.STRING,
|
||||
description="转换后的结果"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(根据 output_key 动态生成)"
|
||||
)
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"transform_type": "template",
|
||||
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
|
||||
"output_key": "formatted_result"
|
||||
},
|
||||
{
|
||||
"transform_type": "code",
|
||||
"code": "result = input_text.upper()",
|
||||
"inputs": {
|
||||
"input_text": "{{ sys.message }}"
|
||||
},
|
||||
"output_key": "uppercase_text"
|
||||
}
|
||||
]
|
||||
}
|
||||
60
api/app/core/workflow/nodes/transform/node.py
Normal file
60
api/app/core/workflow/nodes/transform/node.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Transform 节点实现
|
||||
|
||||
数据转换节点,用于处理和转换数据。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformNode(BaseNode):
|
||||
"""数据转换节点
|
||||
|
||||
配置示例:
|
||||
{
|
||||
"type": "transform",
|
||||
"config": {
|
||||
"mapping": {
|
||||
"output_field": "{{node.previous.output}}",
|
||||
"processed": "{{var.input | upper}}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""执行数据转换
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} 开始执行数据转换")
|
||||
|
||||
# 获取映射配置
|
||||
mapping = self.config.get("mapping", {})
|
||||
|
||||
# 执行数据转换
|
||||
transformed_data = {}
|
||||
for target_key, source_template in mapping.items():
|
||||
# 渲染模板获取值
|
||||
value = self._render_template(str(source_template), state)
|
||||
transformed_data[target_key] = value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
|
||||
|
||||
return {
|
||||
"node_outputs": {
|
||||
self.node_id: {
|
||||
"output": transformed_data,
|
||||
"status": "completed"
|
||||
}
|
||||
}
|
||||
}
|
||||
170
api/app/core/workflow/template_loader.py
Normal file
170
api/app/core/workflow/template_loader.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
工作流模板加载器
|
||||
|
||||
从文件系统加载预定义的工作流模板
|
||||
"""
|
||||
|
||||
import os
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
"""工作流模板加载器"""
|
||||
|
||||
def __init__(self, templates_dir: str = "app/templates/workflows"):
|
||||
"""初始化模板加载器
|
||||
|
||||
Args:
|
||||
templates_dir: 模板目录路径
|
||||
"""
|
||||
self.templates_dir = Path(templates_dir)
|
||||
if not self.templates_dir.exists():
|
||||
raise ValueError(f"模板目录不存在: {templates_dir}")
|
||||
|
||||
def list_templates(self) -> list[dict]:
|
||||
"""列出所有可用的模板
|
||||
|
||||
Returns:
|
||||
模板列表,每个模板包含 id, name, description 等信息
|
||||
"""
|
||||
templates = []
|
||||
|
||||
# 遍历模板目录
|
||||
for template_dir in self.templates_dir.iterdir():
|
||||
if not template_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 检查是否有 template.yml 文件
|
||||
template_file = template_dir / "template.yml"
|
||||
if not template_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
# 读取模板配置
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
# 提取模板信息
|
||||
templates.append({
|
||||
"id": template_dir.name,
|
||||
"name": template_data.get("name", template_dir.name),
|
||||
"description": template_data.get("description", ""),
|
||||
"category": template_data.get("category", "general"),
|
||||
"tags": template_data.get("tags", []),
|
||||
"author": template_data.get("author", ""),
|
||||
"version": template_data.get("version", "1.0.0")
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_dir.name} 失败: {e}")
|
||||
continue
|
||||
|
||||
return templates
|
||||
|
||||
def load_template(self, template_id: str) -> Optional[dict]:
|
||||
"""加载指定的模板
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID(目录名)
|
||||
|
||||
Returns:
|
||||
模板配置字典,如果模板不存在则返回 None
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
template_file = template_dir / "template.yml"
|
||||
|
||||
if not template_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(template_file, 'r', encoding='utf-8') as f:
|
||||
template_data = yaml.safe_load(f)
|
||||
|
||||
# 返回工作流配置部分
|
||||
return {
|
||||
"name": template_data.get("name", template_id),
|
||||
"description": template_data.get("description", ""),
|
||||
"nodes": template_data.get("nodes", []),
|
||||
"edges": template_data.get("edges", []),
|
||||
"variables": template_data.get("variables", []),
|
||||
"execution_config": template_data.get("execution_config", {}),
|
||||
"triggers": template_data.get("triggers", [])
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"加载模板 {template_id} 失败: {e}")
|
||||
return None
|
||||
|
||||
def get_template_readme(self, template_id: str) -> Optional[str]:
|
||||
"""获取模板的 README 文档
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
README 内容,如果不存在则返回 None
|
||||
"""
|
||||
template_dir = self.templates_dir / template_id
|
||||
readme_file = template_dir / "README.md"
|
||||
|
||||
if not readme_file.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(readme_file, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
print(f"读取模板 {template_id} 的 README 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 全局模板加载器实例
|
||||
_template_loader: Optional[TemplateLoader] = None
|
||||
|
||||
|
||||
def get_template_loader() -> TemplateLoader:
|
||||
"""获取全局模板加载器实例
|
||||
|
||||
Returns:
|
||||
TemplateLoader 实例
|
||||
"""
|
||||
global _template_loader
|
||||
if _template_loader is None:
|
||||
_template_loader = TemplateLoader()
|
||||
return _template_loader
|
||||
|
||||
|
||||
def list_workflow_templates() -> list[dict]:
|
||||
"""列出所有工作流模板
|
||||
|
||||
Returns:
|
||||
模板列表
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.list_templates()
|
||||
|
||||
|
||||
def load_workflow_template(template_id: str) -> Optional[dict]:
|
||||
"""加载工作流模板
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
模板配置,如果不存在则返回 None
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.load_template(template_id)
|
||||
|
||||
|
||||
def get_workflow_template_readme(template_id: str) -> Optional[str]:
|
||||
"""获取工作流模板的 README
|
||||
|
||||
Args:
|
||||
template_id: 模板 ID
|
||||
|
||||
Returns:
|
||||
README 内容,如果不存在则返回 None
|
||||
"""
|
||||
loader = get_template_loader()
|
||||
return loader.get_template_readme(template_id)
|
||||
170
api/app/core/workflow/template_renderer.py
Normal file
170
api/app/core/workflow/template_renderer.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
模板渲染器
|
||||
|
||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
"""
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else None,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户定义的变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Raises:
|
||||
ValueError: 模板语法错误或变量未定义
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.render(
|
||||
... "Hello {{var.name}}!",
|
||||
... {"name": "World"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'Hello World!'
|
||||
|
||||
>>> renderer.render(
|
||||
... "分析结果: {{node.analyze.output}}",
|
||||
... {},
|
||||
... {"analyze": {"output": "正面情绪"}},
|
||||
... {}
|
||||
... )
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
context = {
|
||||
"var": variables, # 用户变量:{{var.user_input}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
context.update(node_outputs)
|
||||
|
||||
# 为了向后兼容,也支持直接访问用户变量
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs # 旧语法兼容
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.validate("Hello {{var.name}}!")
|
||||
[]
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||
['模板语法错误: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
_default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Examples:
|
||||
>>> render_template(
|
||||
... "请分析: {{var.text}}",
|
||||
... {"text": "这是一段文本"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
return _default_renderer.render(template, variables, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
"""验证模板语法(便捷函数)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
Returns:
|
||||
错误列表
|
||||
"""
|
||||
return _default_renderer.validate(template)
|
||||
277
api/app/core/workflow/validator.py
Normal file
277
api/app/core/workflow/validator.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
工作流配置验证器
|
||||
|
||||
验证工作流配置的有效性,确保配置符合规范。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowValidator:
|
||||
"""工作流配置验证器"""
|
||||
|
||||
@staticmethod
|
||||
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
|
||||
Examples:
|
||||
>>> config = {
|
||||
... "nodes": [
|
||||
... {"id": "start", "type": "start"},
|
||||
... {"id": "end", "type": "end"}
|
||||
... ],
|
||||
... "edges": [
|
||||
... {"source": "start", "target": "end"}
|
||||
... ]
|
||||
... }
|
||||
>>> is_valid, errors = WorkflowValidator.validate(config)
|
||||
>>> is_valid
|
||||
True
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# 支持字典和 Pydantic 模型
|
||||
if isinstance(workflow_config, dict):
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
edges = workflow_config.get("edges", [])
|
||||
variables = workflow_config.get("variables", [])
|
||||
else:
|
||||
# Pydantic 模型
|
||||
nodes = getattr(workflow_config, "nodes", [])
|
||||
edges = getattr(workflow_config, "edges", [])
|
||||
variables = getattr(workflow_config, "variables", [])
|
||||
|
||||
# 1. 验证 start 节点(有且只有一个)
|
||||
start_nodes = [n for n in nodes if n.get("type") == "start"]
|
||||
if len(start_nodes) == 0:
|
||||
errors.append("工作流必须有一个 start 节点")
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == "end"]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
# 4. 验证节点必须有 id 和 type
|
||||
for i, node in enumerate(nodes):
|
||||
if not node.get("id"):
|
||||
errors.append(f"节点 #{i} 缺少 id 字段")
|
||||
if not node.get("type"):
|
||||
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
|
||||
|
||||
# 5. 验证边的有效性
|
||||
node_id_set = set(node_ids)
|
||||
for i, edge in enumerate(edges):
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
if not source:
|
||||
errors.append(f"边 #{i} 缺少 source 字段")
|
||||
elif source not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
|
||||
|
||||
if not target:
|
||||
errors.append(f"边 #{i} 缺少 target 字段")
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
|
||||
Args:
|
||||
start_id: 起始节点 ID
|
||||
edges: 边列表
|
||||
|
||||
Returns:
|
||||
可达节点 ID 集合
|
||||
"""
|
||||
reachable = {start_id}
|
||||
queue = [start_id]
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for edge in edges:
|
||||
if edge.get("source") == current:
|
||||
target = edge.get("target")
|
||||
if target and target not in reachable:
|
||||
reachable.add(target)
|
||||
queue.append(target)
|
||||
|
||||
return reachable
|
||||
|
||||
@staticmethod
|
||||
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
|
||||
"""检测是否存在循环依赖(DFS)
|
||||
|
||||
Args:
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
|
||||
Returns:
|
||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
edge_type = edge.get("type")
|
||||
|
||||
# 跳过错误边
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
graph[source].append(target)
|
||||
|
||||
# DFS 检测环
|
||||
visited = set()
|
||||
rec_stack = set()
|
||||
path = []
|
||||
cycle_path = []
|
||||
|
||||
def dfs(node: str) -> bool:
|
||||
"""DFS 检测环,返回是否找到环"""
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
path.append(node)
|
||||
|
||||
for neighbor in graph.get(node, []):
|
||||
if neighbor not in visited:
|
||||
if dfs(neighbor):
|
||||
return True
|
||||
elif neighbor in rec_stack:
|
||||
# 找到环,记录环路径
|
||||
cycle_start = path.index(neighbor)
|
||||
cycle_path.extend([*path[cycle_start:], neighbor])
|
||||
return True
|
||||
|
||||
rec_stack.remove(node)
|
||||
path.pop()
|
||||
return False
|
||||
|
||||
# 检查所有节点
|
||||
for node_id in graph:
|
||||
if node_id not in visited:
|
||||
if dfs(node_id):
|
||||
return True, cycle_path
|
||||
|
||||
return False, []
|
||||
|
||||
@staticmethod
|
||||
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布(更严格的验证)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
# 额外的发布验证
|
||||
nodes = workflow_config.get("nodes", [])
|
||||
|
||||
# 1. 验证所有节点都有名称
|
||||
for node in nodes:
|
||||
if node.get("type") not in ["start", "end"] and not node.get("name"):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 2. 验证所有非 start/end 节点都有配置
|
||||
for node in nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type not in ["start", "end"]:
|
||||
config = node.get("config")
|
||||
if not config or not isinstance(config, dict):
|
||||
errors.append(
|
||||
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
|
||||
)
|
||||
|
||||
# 3. 验证必填变量
|
||||
variables = workflow_config.get("variables", [])
|
||||
required_vars = [v for v in variables if v.get("required")]
|
||||
if required_vars:
|
||||
# 这里只是提示,实际执行时会检查
|
||||
logger.info(
|
||||
f"工作流包含 {len(required_vars)} 个必填变量: "
|
||||
f"{[v.get('name') for v in required_vars]}"
|
||||
)
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
|
||||
def validate_workflow_config(
|
||||
workflow_config: dict[str, Any],
|
||||
for_publish: bool = False
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置(便捷函数)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
for_publish: 是否为发布验证(更严格)
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
if for_publish:
|
||||
return WorkflowValidator.validate_for_publish(workflow_config)
|
||||
else:
|
||||
return WorkflowValidator.validate(workflow_config)
|
||||
293
api/app/core/workflow/variable_pool.py
Normal file
293
api/app/core/workflow/variable_pool.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
变量池 (Variable Pool)
|
||||
|
||||
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||
|
||||
变量类型:
|
||||
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||
2. 节点输出 (node_id.*) - 节点执行结果
|
||||
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
|
||||
用于引用变量的路径表示。
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector(["sys", "message"])
|
||||
>>> selector = VariableSelector(["node_A", "output"])
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
"""
|
||||
|
||||
def __init__(self, path: list[str]):
|
||||
"""初始化变量选择器
|
||||
|
||||
Args:
|
||||
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
|
||||
"""
|
||||
if not path or len(path) < 1:
|
||||
raise ValueError("变量路径不能为空")
|
||||
|
||||
self.path = path
|
||||
self.namespace = path[0] # sys, var, 或 node_id
|
||||
self.key = path[1] if len(path) > 1 else None
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, selector_str: str) -> "VariableSelector":
|
||||
"""从字符串创建选择器
|
||||
|
||||
Args:
|
||||
selector_str: 选择器字符串,如 "sys.message" 或 "node_A.output"
|
||||
|
||||
Returns:
|
||||
VariableSelector 实例
|
||||
|
||||
Examples:
|
||||
>>> selector = VariableSelector.from_string("sys.message")
|
||||
>>> selector = VariableSelector.from_string("llm_qa.output")
|
||||
"""
|
||||
path = selector_str.split(".")
|
||||
return cls(path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return ".".join(self.path)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"VariableSelector({self.path})"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
"""变量池
|
||||
|
||||
管理工作流执行过程中的所有变量。
|
||||
|
||||
变量命名空间:
|
||||
- sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id)
|
||||
- conv.*: 会话变量(跨多轮对话保持的变量)
|
||||
- <node_id>.*: 节点输出
|
||||
|
||||
Examples:
|
||||
>>> pool = VariablePool(state)
|
||||
>>> pool.get(["sys", "message"])
|
||||
"用户的问题"
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
"AI 的回答"
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict[str, Any]):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
state: 工作流状态(LangGraph State)
|
||||
"""
|
||||
self.state = state
|
||||
|
||||
def get(self, selector: list[str] | str, default: Any = None) -> Any:
|
||||
"""获取变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器,可以是列表或字符串
|
||||
default: 默认值(变量不存在时返回)
|
||||
|
||||
Returns:
|
||||
变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.get(["sys", "message"])
|
||||
>>> pool.get("sys.message")
|
||||
>>> pool.get(["llm_qa", "output"])
|
||||
>>> pool.get("llm_qa.output")
|
||||
|
||||
Raises:
|
||||
KeyError: 变量不存在且未提供默认值
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
if not selector or len(selector) < 1:
|
||||
raise ValueError("变量选择器不能为空")
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
try:
|
||||
# 系统变量
|
||||
if namespace == "sys":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
return self.state.get("variables", {}).get("sys", {}).get(key, default)
|
||||
|
||||
# 会话变量
|
||||
elif namespace == "conv":
|
||||
key = selector[1] if len(selector) > 1 else None
|
||||
if not key:
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
return self.state.get("variables", {}).get("conv", {}).get(key, default)
|
||||
|
||||
# 节点输出(从 runtime_vars 读取)
|
||||
else:
|
||||
node_id = namespace
|
||||
runtime_vars = self.state.get("runtime_vars", {})
|
||||
|
||||
if node_id not in runtime_vars:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"节点 '{node_id}' 的输出不存在")
|
||||
|
||||
node_var = runtime_vars[node_id]
|
||||
|
||||
# 如果只有节点 ID,返回整个变量
|
||||
if len(selector) == 1:
|
||||
return node_var
|
||||
|
||||
# 获取特定字段
|
||||
# 支持嵌套访问,如 node_id.field.subfield
|
||||
result = node_var
|
||||
for k in selector[1:]:
|
||||
if isinstance(result, dict):
|
||||
result = result.get(k)
|
||||
if result is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
|
||||
else:
|
||||
if default is not None:
|
||||
return default
|
||||
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
|
||||
|
||||
return result
|
||||
|
||||
except KeyError:
|
||||
if default is not None:
|
||||
return default
|
||||
raise
|
||||
|
||||
def set(self, selector: list[str] | str, value: Any):
|
||||
"""设置变量值
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
value: 变量值
|
||||
|
||||
Examples:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
>>> pool.set("conv.user_name", "张三")
|
||||
|
||||
Note:
|
||||
- 只能设置会话变量 (conv.*)
|
||||
- 系统变量和节点输出是只读的
|
||||
"""
|
||||
# 转换为 VariableSelector
|
||||
if isinstance(selector, str):
|
||||
selector = VariableSelector.from_string(selector).path
|
||||
|
||||
if not selector or len(selector) < 2:
|
||||
raise ValueError("变量选择器必须包含命名空间和键名")
|
||||
|
||||
namespace = selector[0]
|
||||
|
||||
if namespace != "conv":
|
||||
raise ValueError("只能设置会话变量 (conv.*)")
|
||||
|
||||
key = selector[1]
|
||||
|
||||
# 确保 variables 结构存在
|
||||
if "variables" not in self.state:
|
||||
self.state["variables"] = {"sys": {}, "conv": {}}
|
||||
if "conv" not in self.state["variables"]:
|
||||
self.state["variables"]["conv"] = {}
|
||||
|
||||
# 设置值
|
||||
self.state["variables"]["conv"][key] = value
|
||||
|
||||
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
|
||||
|
||||
def has(self, selector: list[str] | str) -> bool:
|
||||
"""检查变量是否存在
|
||||
|
||||
Args:
|
||||
selector: 变量选择器
|
||||
|
||||
Returns:
|
||||
变量是否存在
|
||||
|
||||
Examples:
|
||||
>>> pool.has(["sys", "message"])
|
||||
True
|
||||
>>> pool.has("llm_qa.output")
|
||||
False
|
||||
"""
|
||||
try:
|
||||
self.get(selector)
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("sys", {})
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
return self.state.get("variables", {}).get("conv", {})
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
return self.state.get("runtime_vars", {})
|
||||
|
||||
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
Args:
|
||||
node_id: 节点 ID
|
||||
|
||||
Returns:
|
||||
节点输出或 None
|
||||
"""
|
||||
return self.state.get("runtime_vars", {}).get(node_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""导出为字典
|
||||
|
||||
Returns:
|
||||
包含所有变量的字典
|
||||
"""
|
||||
return {
|
||||
"system": self.get_all_system_vars(),
|
||||
"conversation": self.get_all_conversation_vars(),
|
||||
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sys_vars = self.get_all_system_vars()
|
||||
conv_vars = self.get_all_conversation_vars()
|
||||
runtime_vars = self.get_all_node_outputs()
|
||||
|
||||
return (
|
||||
f"VariablePool(\n"
|
||||
f" system_vars={len(sys_vars)},\n"
|
||||
f" conversation_vars={len(conv_vars)},\n"
|
||||
f" runtime_vars={len(runtime_vars)}\n"
|
||||
f")"
|
||||
)
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import subprocess
|
||||
from dotenv import load_dotenv
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.config import settings
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.core.response_utils import fail
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
@@ -38,9 +37,13 @@ router = APIRouter(prefix="/memory", tags=["Memory"])
|
||||
|
||||
# 管理端 API (JWT 认证)
|
||||
from app.controllers import manager_router
|
||||
|
||||
# 服务端 API (API Key 认证)
|
||||
from app.controllers.service import service_router
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode, HTTP_MAPPING
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
from app.core.response_utils import fail
|
||||
|
||||
# Initialize logging system
|
||||
LoggingConfig.setup_logging()
|
||||
@@ -414,5 +417,4 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@@ -15,9 +15,11 @@ from .end_user_model import EndUser
|
||||
from .appshare_model import AppShare
|
||||
from .release_share_model import ReleaseShare
|
||||
from .conversation_model import Conversation, Message
|
||||
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType, ResourceType
|
||||
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
|
||||
from .data_config_model import DataConfig
|
||||
from .multi_agent_model import MultiAgentConfig, AgentInvocation
|
||||
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
|
||||
from .retrieval_info import RetrievalInfo
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -46,8 +48,11 @@ __all__ = [
|
||||
"ApiKey",
|
||||
"ApiKeyLog",
|
||||
"ApiKeyType",
|
||||
"ResourceType",
|
||||
"DataConfig",
|
||||
"MultiAgentConfig",
|
||||
"AgentInvocation"
|
||||
"AgentInvocation",
|
||||
"WorkflowConfig",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"RetrievalInfo"
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text, Enum
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from enum import StrEnum
|
||||
@@ -12,18 +12,10 @@ from app.db import Base
|
||||
|
||||
class ApiKeyType(StrEnum):
|
||||
"""API Key 类型"""
|
||||
APP = "app" # 应用 API Key
|
||||
RAG = "rag" # RAG API Key
|
||||
MEMORY = "memory" # Memory API Key
|
||||
|
||||
|
||||
class ResourceType(StrEnum):
|
||||
"""资源类型枚举"""
|
||||
AGENT = "Agent" # 智能体
|
||||
CLUSTER = "Cluster" # 集群
|
||||
WORKFLOW = "Workflow" # 工作流
|
||||
KNOWLEDGE = "Knowledge" # 知识库
|
||||
MEMORY_ENGINE = "Memory_Engine" # 记忆引擎
|
||||
AGENT = "agent" # 智能体
|
||||
CLUSTER = "multi_agent" # 集群
|
||||
WORKFLOW = "workflow" # 工作流
|
||||
SERVICE = "service" # 服务
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
@@ -35,18 +27,16 @@ class ApiKey(Base):
|
||||
# 基本信息
|
||||
name = Column(String(255), nullable=False, comment="API Key 名称")
|
||||
description = Column(Text, comment="描述")
|
||||
key_prefix = Column(String(20), nullable=False, comment="Key 前缀")
|
||||
key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值")
|
||||
api_key = Column(String(255), nullable=False, unique=True, index=True, comment="API Key 明文")
|
||||
|
||||
# 类型和权限
|
||||
type = Column(String(50), nullable=False, index=True, comment="API Key 类型")
|
||||
scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表")
|
||||
scopes = Column(JSONB, default=list, comment="权限范围列表")
|
||||
|
||||
# 关联资源
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False,
|
||||
index=True, comment="所属工作空间")
|
||||
resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID")
|
||||
resource_type = Column(String(50), comment="资源类型")
|
||||
|
||||
# 限制和配额
|
||||
rate_limit = Column(Integer, default=10, comment="QPS限制(请求/秒)")
|
||||
|
||||
@@ -86,6 +86,14 @@ class App(Base):
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# 一对一:工作流配置(仅当 type=workflow 时有效)
|
||||
workflow_config = relationship(
|
||||
"WorkflowConfig",
|
||||
back_populates="app",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# 发布版本关联
|
||||
current_release = relationship("AppRelease", foreign_keys=[current_release_id])
|
||||
|
||||
@@ -61,7 +61,7 @@ class ModelConfig(Base):
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
# 关联关系
|
||||
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")
|
||||
|
||||
196
api/app/models/workflow_model.py
Normal file
196
api/app/models/workflow_model.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
工作流相关数据模型
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, ForeignKey, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class WorkflowConfig(Base):
|
||||
"""工作流配置表"""
|
||||
__tablename__ = "workflow_configs"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
|
||||
# 关联应用(一对一)
|
||||
app_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("apps.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
index=True
|
||||
)
|
||||
|
||||
# 节点和边的定义(JSON 格式)
|
||||
nodes = Column(JSONB, nullable=False, default=list)
|
||||
edges = Column(JSONB, nullable=False, default=list)
|
||||
|
||||
# 全局变量定义
|
||||
variables = Column(JSONB, default=list)
|
||||
|
||||
# 执行配置
|
||||
execution_config = Column(JSONB, nullable=False, default=dict)
|
||||
|
||||
# 触发器配置(可选)
|
||||
triggers = Column(JSONB, default=list)
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, nullable=False, default=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=datetime.datetime.now,
|
||||
onupdate=datetime.datetime.now
|
||||
)
|
||||
|
||||
# 关系
|
||||
app = relationship("App", back_populates="workflow_config")
|
||||
executions = relationship(
|
||||
"WorkflowExecution",
|
||||
back_populates="workflow_config",
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<WorkflowConfig(id={self.id}, app_id={self.app_id})>"
|
||||
|
||||
|
||||
class WorkflowExecution(Base):
|
||||
"""工作流执行记录表"""
|
||||
__tablename__ = "workflow_executions"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
|
||||
# 关联信息
|
||||
workflow_config_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("workflow_configs.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
app_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("apps.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
conversation_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("conversations.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True
|
||||
)
|
||||
|
||||
# 执行信息
|
||||
execution_id = Column(String(100), nullable=False, unique=True, index=True)
|
||||
trigger_type = Column(String(20), nullable=False) # manual, schedule, webhook, event
|
||||
triggered_by = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("users.id"),
|
||||
nullable=True
|
||||
)
|
||||
|
||||
# 输入输出
|
||||
input_data = Column(JSONB)
|
||||
output_data = Column(JSONB)
|
||||
context = Column(JSONB, default=dict)
|
||||
|
||||
# 状态
|
||||
status = Column(String(20), nullable=False, default="pending", index=True)
|
||||
# 可选值:pending, running, completed, failed, cancelled, timeout
|
||||
|
||||
error_message = Column(Text)
|
||||
error_node_id = Column(String(100))
|
||||
|
||||
# 性能指标
|
||||
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||
completed_at = Column(DateTime)
|
||||
elapsed_time = Column(Float) # 耗时(秒)
|
||||
|
||||
# 资源使用
|
||||
token_usage = Column(JSONB)
|
||||
|
||||
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
|
||||
meta_data = Column(JSONB, default=dict)
|
||||
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
# 关系
|
||||
workflow_config = relationship("WorkflowConfig", back_populates="executions")
|
||||
app = relationship("App")
|
||||
conversation = relationship("Conversation")
|
||||
triggered_by_user = relationship("User", foreign_keys=[triggered_by])
|
||||
node_executions = relationship(
|
||||
"WorkflowNodeExecution",
|
||||
back_populates="execution",
|
||||
cascade="all, delete-orphan",
|
||||
order_by="WorkflowNodeExecution.execution_order"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<WorkflowExecution(id={self.id}, execution_id={self.execution_id}, status={self.status})>"
|
||||
|
||||
|
||||
class WorkflowNodeExecution(Base):
|
||||
"""工作流节点执行记录表"""
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
|
||||
# 关联执行
|
||||
execution_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("workflow_executions.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
# 节点信息
|
||||
node_id = Column(String(100), nullable=False, index=True)
|
||||
node_type = Column(String(20), nullable=False)
|
||||
node_name = Column(String(100))
|
||||
|
||||
# 执行顺序
|
||||
execution_order = Column(Integer, nullable=False)
|
||||
retry_count = Column(Integer, nullable=False, default=0)
|
||||
|
||||
# 输入输出
|
||||
input_data = Column(JSONB)
|
||||
output_data = Column(JSONB)
|
||||
|
||||
# 状态
|
||||
status = Column(String(20), nullable=False, default="pending", index=True)
|
||||
# 可选值:pending, running, completed, failed, skipped, cached
|
||||
|
||||
error_message = Column(Text)
|
||||
|
||||
# 性能指标
|
||||
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
completed_at = Column(DateTime)
|
||||
elapsed_time = Column(Float) # 耗时(秒)
|
||||
|
||||
# 资源使用(针对 LLM 节点)
|
||||
token_usage = Column(JSONB)
|
||||
|
||||
# 缓存信息
|
||||
cache_hit = Column(Boolean, default=False)
|
||||
cache_key = Column(String(255))
|
||||
|
||||
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
|
||||
meta_data = Column(JSONB, default=dict)
|
||||
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
# 关系
|
||||
execution = relationship("WorkflowExecution", back_populates="node_executions")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<WorkflowNodeExecution(id={self.id}, node_id={self.node_id}, status={self.status})>"
|
||||
@@ -27,9 +27,9 @@ class ApiKeyRepository:
|
||||
return db.get(ApiKey, api_key_id)
|
||||
|
||||
@staticmethod
|
||||
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
||||
"""根据哈希值获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
||||
def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]:
|
||||
"""根据 API Key 获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.api_key == api_key)
|
||||
return db.scalars(stmt).first()
|
||||
|
||||
@staticmethod
|
||||
@@ -63,11 +63,15 @@ class ApiKeyRepository:
|
||||
@staticmethod
|
||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
|
||||
"""更新 API Key"""
|
||||
allow_none_fields = {"description", "quota_limit", "expires_at"}
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
for key, value in update_data.items():
|
||||
if value is not None:
|
||||
if key in allow_none_fields:
|
||||
setattr(api_key, key, value)
|
||||
else:
|
||||
if value is not None:
|
||||
setattr(api_key, key, value)
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
@@ -122,6 +126,7 @@ class ApiKeyRepository:
|
||||
"quota_used": api_key.quota_used,
|
||||
"quota_limit": api_key.quota_limit,
|
||||
"last_used_at": api_key.last_used_at,
|
||||
"rate_limit": api_key.rate_limit,
|
||||
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class AppRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
|
||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
try:
|
||||
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
||||
@@ -24,7 +24,19 @@ class AppRepository:
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_apps_by_id(self, app_id: uuid.UUID) -> App:
|
||||
try:
|
||||
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first()
|
||||
return app
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_workspace_id(workspace_id)
|
||||
|
||||
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_id(app_id)
|
||||
|
||||
247
api/app/repositories/workflow_repository.py
Normal file
247
api/app/repositories/workflow_repository.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
工作流数据访问层
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Annotated
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import (
|
||||
WorkflowConfig,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution
|
||||
)
|
||||
from app.db import get_db
|
||||
|
||||
|
||||
class WorkflowConfigRepository:
|
||||
"""工作流配置仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None:
|
||||
"""根据应用 ID 获取工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
工作流配置或 None
|
||||
"""
|
||||
return self.db.query(WorkflowConfig).filter(
|
||||
WorkflowConfig.app_id == app_id,
|
||||
WorkflowConfig.is_active == True
|
||||
).first()
|
||||
|
||||
def create_or_update(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None
|
||||
) -> WorkflowConfig:
|
||||
"""创建或更新工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
triggers: 触发器列表
|
||||
|
||||
Returns:
|
||||
工作流配置
|
||||
"""
|
||||
# 查找现有配置
|
||||
existing = self.get_by_app_id(app_id)
|
||||
|
||||
if existing:
|
||||
# 更新现有配置
|
||||
existing.nodes = nodes
|
||||
existing.edges = edges
|
||||
if variables is not None:
|
||||
existing.variables = variables
|
||||
if execution_config is not None:
|
||||
existing.execution_config = execution_config
|
||||
if triggers is not None:
|
||||
existing.triggers = triggers
|
||||
self.db.commit()
|
||||
self.db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
# 创建新配置
|
||||
config = WorkflowConfig(
|
||||
app_id=app_id,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
variables=variables or [],
|
||||
execution_config=execution_config or {},
|
||||
triggers=triggers or []
|
||||
)
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
class WorkflowExecutionRepository:
|
||||
"""工作流执行记录仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None:
|
||||
"""根据执行 ID 获取执行记录
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
|
||||
Returns:
|
||||
执行记录或 None
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
def get_by_app_id(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据应用 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).limit(limit).offset(offset).all()
|
||||
|
||||
def get_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据会话 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.conversation_id == conversation_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).all()
|
||||
|
||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||
"""统计应用的执行次数
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
执行次数
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).count()
|
||||
|
||||
def count_by_status(self, app_id: uuid.UUID, status: str) -> int:
|
||||
"""统计指定状态的执行次数
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
status: 状态
|
||||
|
||||
Returns:
|
||||
执行次数
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id,
|
||||
WorkflowExecution.status == status
|
||||
).count()
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository:
|
||||
"""工作流节点执行记录仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_execution_id(
|
||||
self,
|
||||
execution_id: uuid.UUID
|
||||
) -> list[WorkflowNodeExecution]:
|
||||
"""根据执行 ID 获取节点执行记录列表
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
|
||||
Returns:
|
||||
节点执行记录列表(按执行顺序排序)
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.execution_order
|
||||
).all()
|
||||
|
||||
def get_by_node_id(
|
||||
self,
|
||||
execution_id: uuid.UUID,
|
||||
node_id: str
|
||||
) -> list[WorkflowNodeExecution]:
|
||||
"""根据节点 ID 获取节点执行记录(可能有多次重试)
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
node_id: 节点 ID
|
||||
|
||||
Returns:
|
||||
节点执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id,
|
||||
WorkflowNodeExecution.node_id == node_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.retry_count
|
||||
).all()
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_workflow_config_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowConfigRepository:
|
||||
"""获取工作流配置仓储(依赖注入)"""
|
||||
return WorkflowConfigRepository(db)
|
||||
|
||||
|
||||
def get_workflow_execution_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowExecutionRepository:
|
||||
"""获取工作流执行记录仓储(依赖注入)"""
|
||||
return WorkflowExecutionRepository(db)
|
||||
|
||||
|
||||
def get_workflow_node_execution_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""获取工作流节点执行记录仓储(依赖注入)"""
|
||||
return WorkflowNodeExecutionRepository(db)
|
||||
@@ -1,11 +1,11 @@
|
||||
"""API Key Schema"""
|
||||
import datetime
|
||||
import uuid
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic.v1 import validator
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, field_serializer, computed_field
|
||||
from typing import Optional, List
|
||||
|
||||
from app.models.api_key_model import ApiKeyType, ResourceType
|
||||
from app.models.api_key_model import ApiKeyType
|
||||
from app.core.api_key_utils import timestamp_to_datetime, datetime_to_timestamp
|
||||
|
||||
|
||||
class ApiKeyCreate(BaseModel):
|
||||
@@ -15,20 +15,34 @@ class ApiKeyCreate(BaseModel):
|
||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||
resource_type: Optional[ResourceType] = Field(None, description="资源类型")
|
||||
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
|
||||
@validator('scopes')
|
||||
@computed_field
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""检查API Key是否已过期"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_validator('expires_at', mode='before')
|
||||
@classmethod
|
||||
def parse_expires_at(cls, v):
|
||||
"""将时间戳转换为datetime"""
|
||||
if isinstance(v, (int, float)):
|
||||
return timestamp_to_datetime(v)
|
||||
return v
|
||||
|
||||
@field_validator('scopes')
|
||||
@classmethod
|
||||
def validate_scopes(cls, v):
|
||||
"""验证权限范围格式"""
|
||||
valid_scopes = [
|
||||
"app:all",
|
||||
"rag:search", "rag:upload", "rag:delete",
|
||||
"memory:read", "memory:write", "memory:delete", "memory:search"
|
||||
]
|
||||
if v is None:
|
||||
return []
|
||||
valid_scopes = ["app", "rag", "memory"]
|
||||
for scope in v:
|
||||
if scope not in valid_scopes:
|
||||
raise ValueError(f"无效范围: {scope}")
|
||||
@@ -46,14 +60,29 @@ class ApiKeyUpdate(BaseModel):
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
|
||||
@validator('scopes')
|
||||
@computed_field
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""检查API Key是否已过期"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_validator('expires_at', mode='before')
|
||||
@classmethod
|
||||
def parse_expires_at(cls, v):
|
||||
"""将时间戳转换为datetime"""
|
||||
if isinstance(v, (int, float)):
|
||||
return timestamp_to_datetime(v)
|
||||
return v
|
||||
|
||||
@field_validator('scopes')
|
||||
@classmethod
|
||||
def validate_scopes(cls, v):
|
||||
"""验证权限范围格式"""
|
||||
valid_scopes = {
|
||||
'app:all',
|
||||
'rag:search', 'rag:upload', 'rag:delete',
|
||||
'memory:read', 'memory:write', 'memory:delete', 'memory:search'
|
||||
}
|
||||
if v is None:
|
||||
return v
|
||||
valid_scopes = ["app", "rag", "memory"]
|
||||
for scope in v:
|
||||
if scope not in valid_scopes:
|
||||
raise ValueError(f"无效范围: {scope}")
|
||||
@@ -67,18 +96,31 @@ class ApiKeyResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str]
|
||||
api_key: str = Field(..., description="API Key 明文(仅创建时返回)")
|
||||
key_prefix: str
|
||||
api_key: str
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
rate_limit: int
|
||||
daily_request_limit: int
|
||||
quota_limit: Optional[int]
|
||||
is_active: bool
|
||||
expires_at: Optional[datetime.datetime]
|
||||
created_at: datetime.datetime
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""检查API Key是否已过期"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_serializer('expires_at', 'created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
"""API Key 信息(不包含明文 Key)"""
|
||||
@@ -87,11 +129,10 @@ class ApiKey(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: Optional[str]
|
||||
key_prefix: str
|
||||
api_key: str
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
rate_limit: int
|
||||
daily_request_limit: int
|
||||
quota_limit: Optional[int]
|
||||
@@ -105,6 +146,20 @@ class ApiKey(BaseModel):
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""检查API Key是否已过期"""
|
||||
if not self.expires_at:
|
||||
return False
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
|
||||
class ApiKeyStats(BaseModel):
|
||||
"""API Key 使用统计"""
|
||||
@@ -115,6 +170,12 @@ class ApiKeyStats(BaseModel):
|
||||
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
|
||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||
|
||||
@field_serializer('last_used_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
|
||||
class ApiKeyQuery(BaseModel):
|
||||
"""API Key 查询参数"""
|
||||
@@ -132,7 +193,6 @@ class ApiKeyAuth(BaseModel):
|
||||
type: str
|
||||
scopes: List[str]
|
||||
resource_id: Optional[uuid.UUID]
|
||||
resource_type: Optional[str]
|
||||
|
||||
|
||||
class ApiKeyLog(BaseModel):
|
||||
@@ -157,3 +217,9 @@ class ApiKeyLog(BaseModel):
|
||||
|
||||
# 时间信息
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -60,6 +61,7 @@ class ConflictSchema(BaseModel):
|
||||
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_data(cls, v):
|
||||
if isinstance(v, dict):
|
||||
d = v.get("data")
|
||||
@@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel):
|
||||
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
@@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(round(time.time() * 1000))
|
||||
return round(time.time() * 1000)
|
||||
|
||||
|
||||
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:
|
||||
|
||||
215
api/app/schemas/workflow_schema.py
Normal file
215
api/app/schemas/workflow_schema.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
工作流相关的 Pydantic Schema
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Any
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
|
||||
# ==================== 节点和边定义 ====================
|
||||
|
||||
class NodeConfig(BaseModel):
|
||||
"""节点配置"""
|
||||
model_config = ConfigDict(extra="allow") # 允许额外字段
|
||||
|
||||
|
||||
class NodeDefinition(BaseModel):
|
||||
"""节点定义"""
|
||||
id: str = Field(..., description="节点唯一标识")
|
||||
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
|
||||
name: str | None = Field(None, description="节点名称")
|
||||
description: str | None = Field(None, description="节点描述")
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
|
||||
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
|
||||
error_handling: dict[str, Any] | None = Field(None, description="错误处理配置")
|
||||
cache: dict[str, Any] | None = Field(None, description="缓存配置")
|
||||
|
||||
|
||||
class EdgeDefinition(BaseModel):
|
||||
"""边定义"""
|
||||
id: str | None = Field(None, description="边唯一标识(可选)")
|
||||
source: str = Field(..., description="源节点 ID")
|
||||
target: str = Field(..., description="目标节点 ID")
|
||||
type: str | None = Field(None, description="边类型: normal, error")
|
||||
condition: str | None = Field(None, description="条件表达式(条件边)")
|
||||
label: str | None = Field(None, description="边标签")
|
||||
|
||||
|
||||
class VariableDefinition(BaseModel):
|
||||
"""变量定义"""
|
||||
name: str = Field(..., description="变量名称")
|
||||
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
|
||||
required: bool = Field(default=False, description="是否必填")
|
||||
default: Any = Field(None, description="默认值")
|
||||
description: str | None = Field(None, description="变量描述")
|
||||
|
||||
|
||||
class ExecutionConfig(BaseModel):
|
||||
"""执行配置"""
|
||||
max_iterations: int = Field(default=100, ge=1, le=1000, description="最大迭代次数")
|
||||
timeout: int = Field(default=600, ge=10, le=3600, description="全局超时时间(秒)")
|
||||
enable_cache: bool = Field(default=True, description="是否启用节点缓存")
|
||||
parallel_limit: int = Field(default=5, ge=1, le=20, description="并行执行限制")
|
||||
|
||||
|
||||
class TriggerConfig(BaseModel):
|
||||
"""触发器配置"""
|
||||
type: str = Field(..., description="触发器类型: schedule, webhook, event")
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
|
||||
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
|
||||
class WorkflowConfigCreate(BaseModel):
|
||||
"""创建工作流配置"""
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list, description="节点列表")
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list, description="边列表")
|
||||
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
|
||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
|
||||
|
||||
|
||||
class WorkflowConfigUpdate(BaseModel):
|
||||
"""更新工作流配置"""
|
||||
nodes: list[NodeDefinition] | None = None
|
||||
edges: list[EdgeDefinition] | None = None
|
||||
variables: list[VariableDefinition] | None = None
|
||||
execution_config: ExecutionConfig | None = None
|
||||
triggers: list[TriggerConfig] | None = None
|
||||
|
||||
|
||||
class WorkflowConfig(BaseModel):
|
||||
"""工作流配置输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
nodes: list[dict[str, Any]]
|
||||
edges: list[dict[str, Any]]
|
||||
variables: list[dict[str, Any]]
|
||||
execution_config: dict[str, Any]
|
||||
triggers: list[dict[str, Any]]
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
class WorkflowExecutionRequest(BaseModel):
|
||||
"""工作流执行请求"""
|
||||
message: str | None = Field(None, description="用户消息(可选)")
|
||||
variables: dict[str, Any] = Field(default_factory=dict, description="输入变量")
|
||||
conversation_id: str | None = Field(None, description="会话 ID(用于关联对话)")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
|
||||
class WorkflowExecutionResponse(BaseModel):
|
||||
"""工作流执行响应(非流式)"""
|
||||
execution_id: str = Field(..., description="执行 ID")
|
||||
status: str = Field(..., description="执行状态")
|
||||
output: str | None = Field(None, description="最终输出(字符串,便于快速访问)")
|
||||
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
|
||||
error_message: str | None = Field(None, description="错误信息")
|
||||
elapsed_time: float | None = Field(None, description="耗时(秒)")
|
||||
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
|
||||
|
||||
|
||||
class WorkflowExecutionStreamChunk(BaseModel):
|
||||
"""工作流执行流式响应块"""
|
||||
type: str = Field(..., description="事件类型: node_start, token, node_complete, error_redirect, workflow_complete")
|
||||
execution_id: str = Field(..., description="执行 ID")
|
||||
data: dict[str, Any] = Field(default_factory=dict, description="事件数据")
|
||||
|
||||
|
||||
class WorkflowExecution(BaseModel):
|
||||
"""工作流执行记录输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
workflow_config_id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
conversation_id: uuid.UUID | None
|
||||
execution_id: str
|
||||
trigger_type: str
|
||||
triggered_by: uuid.UUID | None
|
||||
input_data: dict[str, Any] | None
|
||||
output_data: dict[str, Any] | None
|
||||
context: dict[str, Any]
|
||||
status: str
|
||||
error_message: str | None
|
||||
error_node_id: str | None
|
||||
started_at: datetime.datetime
|
||||
completed_at: datetime.datetime | None
|
||||
elapsed_time: float | None
|
||||
token_usage: dict[str, Any] | None
|
||||
meta_data: dict[str, Any]
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer("started_at", when_used="json")
|
||||
def _serialize_started_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("completed_at", when_used="json")
|
||||
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class WorkflowNodeExecution(BaseModel):
|
||||
"""工作流节点执行记录输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
execution_id: uuid.UUID
|
||||
node_id: str
|
||||
node_type: str
|
||||
node_name: str | None
|
||||
execution_order: int
|
||||
retry_count: int
|
||||
input_data: dict[str, Any] | None
|
||||
output_data: dict[str, Any] | None
|
||||
status: str
|
||||
error_message: str | None
|
||||
started_at: datetime.datetime
|
||||
completed_at: datetime.datetime | None
|
||||
elapsed_time: float | None
|
||||
token_usage: dict[str, Any] | None
|
||||
cache_hit: bool
|
||||
cache_key: str | None
|
||||
meta_data: dict[str, Any]
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer("started_at", when_used="json")
|
||||
def _serialize_started_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("completed_at", when_used="json")
|
||||
def _serialize_completed_at(self, dt: datetime.datetime | None):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ==================== 验证响应 ====================
|
||||
|
||||
class WorkflowValidationResponse(BaseModel):
|
||||
"""工作流验证响应"""
|
||||
is_valid: bool = Field(..., description="是否有效")
|
||||
errors: list[str] = Field(default_factory=list, description="错误列表")
|
||||
warnings: list[str] = Field(default_factory=list, description="警告列表")
|
||||
@@ -13,7 +13,7 @@ from app.models.api_key_model import ApiKey
|
||||
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
|
||||
from app.core.api_key_utils import generate_api_key
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
)
|
||||
@@ -33,48 +33,39 @@ class ApiKeyService:
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyCreate
|
||||
) -> Tuple[ApiKey, str]:
|
||||
) -> ApiKey:
|
||||
"""
|
||||
创建 API Key
|
||||
Returns:
|
||||
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
|
||||
ApiKey: API Key 对象
|
||||
"""
|
||||
try:
|
||||
# 验证资源绑定
|
||||
if data.resource_type or data.resource_id:
|
||||
is_valid, error_msg = validate_resource_binding(
|
||||
data.resource_type, str(data.resource_id) if data.resource_id else None
|
||||
)
|
||||
if not is_valid:
|
||||
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
|
||||
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.resource_id == data.resource_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
# 生成 API Key
|
||||
api_key, key_hash, key_prefix = generate_api_key(data.type)
|
||||
api_key = generate_api_key(data.type)
|
||||
|
||||
# 创建数据
|
||||
api_key_data = {
|
||||
"id": uuid.uuid4(),
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash,
|
||||
"api_key": api_key,
|
||||
"type": data.type,
|
||||
"scopes": data.scopes,
|
||||
"workspace_id": workspace_id,
|
||||
"resource_id": data.resource_id,
|
||||
"resource_type": data.resource_type,
|
||||
"rate_limit": data.rate_limit or 10,
|
||||
"daily_request_limit": data.daily_request_limit or 10000,
|
||||
"rate_limit": data.rate_limit,
|
||||
"daily_request_limit": data.daily_request_limit,
|
||||
"quota_limit": data.quota_limit,
|
||||
"expires_at": data.expires_at,
|
||||
"created_by": user_id,
|
||||
@@ -90,7 +81,7 @@ class ApiKeyService:
|
||||
"type": data.type
|
||||
})
|
||||
|
||||
return api_key_obj, api_key
|
||||
return api_key_obj
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
@@ -152,13 +143,14 @@ class ApiKeyService:
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.resource_id == data.resource_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active,
|
||||
ApiKey.id != api_key_id
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||
@@ -188,7 +180,7 @@ class ApiKeyService:
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Tuple[ApiKey, str]:
|
||||
) -> ApiKey:
|
||||
"""重新生成 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
@@ -197,18 +189,17 @@ class ApiKeyService:
|
||||
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
||||
|
||||
# 生成新的 API Key
|
||||
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
|
||||
new_api_key = generate_api_key(api_key.type)
|
||||
|
||||
# 更新
|
||||
ApiKeyRepository.update(db, api_key_id, {
|
||||
"key_hash": key_hash,
|
||||
"key_prefix": key_prefix
|
||||
"api_key": new_api_key
|
||||
})
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
||||
return api_key, new_api_key
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_stats(
|
||||
@@ -330,7 +321,6 @@ class RateLimiterService:
|
||||
"X-RateLimit-Reset": str(qps_info["reset"])
|
||||
}
|
||||
|
||||
# Check daily requests
|
||||
daily_ok, daily_info = await self.check_daily_requests(
|
||||
api_key.id,
|
||||
api_key.daily_request_limit
|
||||
@@ -342,7 +332,6 @@ class RateLimiterService:
|
||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
headers = {
|
||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||
@@ -363,13 +352,12 @@ class ApiKeyAuthService:
|
||||
验证API Key 有效性
|
||||
|
||||
检查:
|
||||
1. Key hash 是否存在
|
||||
1. API Key 是否存在
|
||||
2. is_active 是否为true
|
||||
3. expires_at 是否未过期
|
||||
4. quota 是否未超限
|
||||
"""
|
||||
key_hash = hash_api_key(api_key)
|
||||
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
|
||||
api_key_obj = ApiKeyRepository.get_by_api_key(db, api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
return None
|
||||
@@ -393,14 +381,7 @@ class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def check_resource(
|
||||
api_key: ApiKey,
|
||||
resource_type: str,
|
||||
resource_id: uuid.UUID
|
||||
) -> bool:
|
||||
"""检查资源绑定"""
|
||||
if not api_key.resource_id:
|
||||
return True
|
||||
|
||||
return (
|
||||
api_key.resource_type == resource_type and
|
||||
api_key.resource_id == resource_id
|
||||
)
|
||||
return api_key.resource_id == resource_id
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,12 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
||||
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
||||
支持 dialogue_text 参数用于试运行模式。
|
||||
流式执行试运行,产生 SSE 格式的进度事件
|
||||
|
||||
Args:
|
||||
payload: 试运行配置和对话文本
|
||||
|
||||
Yields:
|
||||
SSE 格式的字符串,包含以下事件类型:
|
||||
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
|
||||
- result: 最终结果
|
||||
- error: 错误信息
|
||||
- done: 完成标记
|
||||
|
||||
Raises:
|
||||
ValueError: 当配置无效或参数缺失时
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
yield format_sse_message("starting", {
|
||||
"message": "开始试运行...",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
||||
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
# "data": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在谷歌工作。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "10",
|
||||
# "created_at": "2023-01-01",
|
||||
# "expired_at": "2023-01-02",
|
||||
# "valid_at": "2023-01-01",
|
||||
# "invalid_at": "2023-01-02",
|
||||
# "entity_ids": []
|
||||
# },
|
||||
# "conflict": True,
|
||||
# "conflict_memory": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在清华大学当讲师。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "1",
|
||||
# "created_at": "2019-12-01T19:15:05.213210",
|
||||
# "expired_at": None,
|
||||
# "valid_at": None,
|
||||
# "invalid_at": None,
|
||||
# "entity_ids": []
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
from app.core.memory.utils.config.get_example_data import get_example_data
|
||||
data = get_example_data()
|
||||
reflexion_result = await reflexion(data)
|
||||
|
||||
# 读取输出,使用全局配置路径
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
progress_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
进度回调函数,将进度事件放入队列
|
||||
|
||||
Args:
|
||||
stage: 阶段标识
|
||||
message: 进度消息
|
||||
data: 可选的结果数据(用于传递节点执行结果)
|
||||
"""
|
||||
await progress_queue.put((stage, message, data))
|
||||
|
||||
# 步骤 3: 在后台任务中执行管线
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
# 标记管线完成
|
||||
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
|
||||
except Exception as e:
|
||||
# 将异常放入队列
|
||||
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
|
||||
|
||||
# 启动后台任务
|
||||
pipeline_task = asyncio.create_task(run_pipeline())
|
||||
|
||||
# 步骤 4: 从队列中读取进度事件并发出
|
||||
while True:
|
||||
try:
|
||||
# 等待进度事件,设置超时以检测客户端断开
|
||||
stage, message, data = await asyncio.wait_for(
|
||||
progress_queue.get(),
|
||||
timeout=0.5
|
||||
)
|
||||
|
||||
# 检查特殊标记
|
||||
if stage == "__PIPELINE_COMPLETE__":
|
||||
break
|
||||
elif stage == "__PIPELINE_ERROR__":
|
||||
raise RuntimeError(message)
|
||||
|
||||
# 构建进度事件数据
|
||||
progress_data = {
|
||||
"message": message,
|
||||
"time": int(time.time() * 1000)
|
||||
}
|
||||
|
||||
# 如果有结果数据,添加到事件中
|
||||
if data:
|
||||
progress_data["data"] = data
|
||||
|
||||
# 发出进度事件,使用 stage 作为事件类型
|
||||
yield format_sse_message(stage, progress_data)
|
||||
|
||||
except TimeoutError:
|
||||
# 超时,继续等待(这允许检测客户端断开)
|
||||
continue
|
||||
|
||||
# 等待管线任务完成
|
||||
await pipeline_task
|
||||
|
||||
# 步骤 5: 读取提取结果
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
# 步骤 6: 发出结果事件
|
||||
result_data = {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "logs", "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
yield format_sse_message("result", result_data)
|
||||
|
||||
# 步骤 7: 发出完成事件
|
||||
yield format_sse_message("done", {
|
||||
"message": "试运行完成",
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 客户端断开连接
|
||||
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
|
||||
raise
|
||||
except Exception as e:
|
||||
# 发出错误事件
|
||||
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
|
||||
yield format_sse_message("error", {
|
||||
"code": 5000,
|
||||
"message": "试运行失败",
|
||||
"error": str(e),
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
|
||||
731
api/app/services/workflow_service.py
Normal file
731
api/app/services/workflow_service.py
Normal file
@@ -0,0 +1,731 @@
|
||||
"""
|
||||
工作流服务层
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Any, Annotated
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
from app.repositories.workflow_repository import (
|
||||
WorkflowConfigRepository,
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository,
|
||||
get_workflow_config_repository,
|
||||
get_workflow_execution_repository,
|
||||
get_workflow_node_execution_repository
|
||||
)
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.schemas import DraftRunRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowService:
|
||||
"""工作流服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.config_repo = WorkflowConfigRepository(db)
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
def create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
"""创建工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
triggers: 触发器列表
|
||||
validate: 是否验证配置
|
||||
|
||||
Returns:
|
||||
工作流配置
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置无效时抛出
|
||||
"""
|
||||
# 构建配置字典
|
||||
config_dict = {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"variables": variables or [],
|
||||
"execution_config": execution_config or {},
|
||||
"triggers": triggers or []
|
||||
}
|
||||
|
||||
# 验证配置
|
||||
if validate:
|
||||
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
|
||||
if not is_valid:
|
||||
logger.warning(f"工作流配置验证失败: {errors}")
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INVALID_PARAMETER,
|
||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
# 创建或更新配置
|
||||
config = self.config_repo.create_or_update(
|
||||
app_id=app_id,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
variables=variables,
|
||||
execution_config=execution_config,
|
||||
triggers=triggers
|
||||
)
|
||||
|
||||
logger.info(f"创建工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return config
|
||||
|
||||
def get_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig | None:
|
||||
"""获取工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
工作流配置或 None
|
||||
"""
|
||||
return self.config_repo.get_by_app_id(app_id)
|
||||
|
||||
def update_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]] | None = None,
|
||||
edges: list[dict[str, Any]] | None = None,
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
"""更新工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
triggers: 触发器列表
|
||||
validate: 是否验证配置
|
||||
|
||||
Returns:
|
||||
工作流配置
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或无效时抛出
|
||||
"""
|
||||
# 获取现有配置
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 合并配置
|
||||
updated_nodes = nodes if nodes is not None else config.nodes
|
||||
updated_edges = edges if edges is not None else config.edges
|
||||
updated_variables = variables if variables is not None else config.variables
|
||||
updated_execution_config = execution_config if execution_config is not None else config.execution_config
|
||||
updated_triggers = triggers if triggers is not None else config.triggers
|
||||
|
||||
# 构建配置字典
|
||||
config_dict = {
|
||||
"nodes": updated_nodes,
|
||||
"edges": updated_edges,
|
||||
"variables": updated_variables,
|
||||
"execution_config": updated_execution_config,
|
||||
"triggers": updated_triggers
|
||||
}
|
||||
|
||||
# 验证配置
|
||||
if validate:
|
||||
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
|
||||
if not is_valid:
|
||||
logger.warning(f"工作流配置验证失败: {errors}")
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INVALID_PARAMETER,
|
||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||
)
|
||||
|
||||
# 更新配置
|
||||
config = self.config_repo.create_or_update(
|
||||
app_id=app_id,
|
||||
nodes=updated_nodes,
|
||||
edges=updated_edges,
|
||||
variables=updated_variables,
|
||||
execution_config=updated_execution_config,
|
||||
triggers=updated_triggers
|
||||
)
|
||||
|
||||
logger.info(f"更新工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return config
|
||||
|
||||
def delete_workflow_config(self, app_id: uuid.UUID) -> bool:
|
||||
"""删除工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
return False
|
||||
|
||||
self.config_repo.delete(config.id)
|
||||
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
|
||||
return True
|
||||
|
||||
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
|
||||
"""检查工作流配置的完整性
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不完整或不存在时抛出
|
||||
"""
|
||||
|
||||
# 1. 检查多智能体配置是否存在
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
"工作流配置不存在,无法运行",
|
||||
BizCode.CONFIG_MISSING
|
||||
)
|
||||
# validator 现在支持直接接受 Pydantic 模型
|
||||
is_valid, errors = validate_workflow_config(config, for_publish=False)
|
||||
if not is_valid:
|
||||
logger.warning(f"工作流配置验证失败: {errors}")
|
||||
raise BusinessException(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
message=f"工作流配置无效: {'; '.join(errors)}"
|
||||
)
|
||||
return config
|
||||
|
||||
def validate_workflow_config_for_publish(
|
||||
self,
|
||||
app_id: uuid.UUID
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在时抛出
|
||||
"""
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config,
|
||||
"triggers": config.triggers
|
||||
}
|
||||
|
||||
return validate_workflow_config(config_dict, for_publish=True)
|
||||
|
||||
# ==================== 执行管理 ====================
|
||||
|
||||
def create_execution(
|
||||
self,
|
||||
workflow_config_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
trigger_type: str,
|
||||
triggered_by: uuid.UUID | None = None,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
input_data: dict[str, Any] | None = None
|
||||
) -> WorkflowExecution:
|
||||
"""创建工作流执行记录
|
||||
|
||||
Args:
|
||||
workflow_config_id: 工作流配置 ID
|
||||
app_id: 应用 ID
|
||||
trigger_type: 触发类型
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
执行记录
|
||||
"""
|
||||
# 生成执行 ID
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
execution = WorkflowExecution(
|
||||
workflow_config_id=workflow_config_id,
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
execution_id=execution_id,
|
||||
trigger_type=trigger_type,
|
||||
triggered_by=triggered_by,
|
||||
input_data=input_data or {},
|
||||
status="pending"
|
||||
)
|
||||
|
||||
self.db.add(execution)
|
||||
self.db.commit()
|
||||
self.db.refresh(execution)
|
||||
|
||||
logger.info(f"创建工作流执行记录: execution_id={execution_id}")
|
||||
return execution
|
||||
|
||||
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
|
||||
"""获取执行记录
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
|
||||
Returns:
|
||||
执行记录或 None
|
||||
"""
|
||||
return self.execution_repo.get_by_execution_id(execution_id)
|
||||
|
||||
def get_executions_by_app(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> list[WorkflowExecution]:
|
||||
"""获取应用的执行记录列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.execution_repo.get_by_app_id(app_id, limit, offset)
|
||||
|
||||
def update_execution_status(
|
||||
self,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
output_data: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
error_node_id: str | None = None
|
||||
) -> WorkflowExecution:
|
||||
"""更新执行状态
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
status: 状态
|
||||
output_data: 输出数据
|
||||
error_message: 错误信息
|
||||
error_node_id: 出错节点 ID
|
||||
|
||||
Returns:
|
||||
执行记录
|
||||
|
||||
Raises:
|
||||
BusinessException: 执行记录不存在时抛出
|
||||
"""
|
||||
execution = self.get_execution(execution_id)
|
||||
if not execution:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
message=f"执行记录不存在: execution_id={execution_id}"
|
||||
)
|
||||
|
||||
execution.status = status
|
||||
if output_data is not None:
|
||||
execution.output_data = output_data
|
||||
if error_message is not None:
|
||||
execution.error_message = error_message
|
||||
if error_node_id is not None:
|
||||
execution.error_node_id = error_node_id
|
||||
|
||||
# 如果是完成状态,计算耗时
|
||||
if status in ["completed", "failed", "cancelled", "timeout"]:
|
||||
if not execution.completed_at:
|
||||
execution.completed_at = datetime.datetime.now()
|
||||
elapsed = (execution.completed_at - execution.started_at).total_seconds()
|
||||
execution.elapsed_time = elapsed
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(execution)
|
||||
|
||||
logger.info(f"更新执行状态: execution_id={execution_id}, status={status}")
|
||||
return execution
|
||||
|
||||
def get_execution_statistics(self, app_id: uuid.UUID) -> dict[str, Any]:
|
||||
"""获取执行统计信息
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
total = self.execution_repo.count_by_app_id(app_id)
|
||||
completed = self.execution_repo.count_by_status(app_id, "completed")
|
||||
failed = self.execution_repo.count_by_status(app_id, "failed")
|
||||
running = self.execution_repo.count_by_status(app_id, "running")
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"running": running,
|
||||
"success_rate": completed / total if total > 0 else 0
|
||||
}
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
):
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
"""
|
||||
# 1. 获取工作流配置
|
||||
if not config:
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by_uuid,
|
||||
conversation_id=conversation_id_uuid,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id="",
|
||||
user_id=payload.user_id
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
triggered_by: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
stream: bool = False
|
||||
):
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
"""
|
||||
# 1. 获取工作流配置
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
raise BusinessException(
|
||||
error_code=BizCode.RESOURCE_NOT_FOUND,
|
||||
message=f"应用不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
if stream:
|
||||
# 流式执行
|
||||
return self._run_workflow_stream(
|
||||
workflow_config_dict,
|
||||
input_data,
|
||||
execution.execution_id,
|
||||
str(app.workspace_id),
|
||||
str(triggered_by)
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(app.workspace_id),
|
||||
user_id=str(triggered_by)
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {})
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
error_code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""运行工作流(流式,内部方法)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件
|
||||
"""
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
try:
|
||||
output_data = {}
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
):
|
||||
# 转发事件
|
||||
yield event
|
||||
|
||||
# 收集输出数据
|
||||
if event.get("type") == "node_complete":
|
||||
node_data = event.get("data", {})
|
||||
node_outputs = node_data.get("node_outputs", {})
|
||||
output_data.update(node_outputs)
|
||||
|
||||
# 处理完成事件
|
||||
if event.get("type") == "workflow_complete":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=output_data
|
||||
)
|
||||
|
||||
# 处理错误事件
|
||||
if event.get("type") == "workflow_error":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=event.get("error")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
yield {
|
||||
"type": "workflow_error",
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_workflow_service(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowService:
|
||||
"""获取工作流服务(依赖注入)"""
|
||||
return WorkflowService(db)
|
||||
219
api/app/templates/workflows/customer_service/template.yml
Normal file
219
api/app/templates/workflows/customer_service/template.yml
Normal file
@@ -0,0 +1,219 @@
|
||||
# 智能客服工作流模板
|
||||
id: customer_service_v1
|
||||
name: 智能客服工作流
|
||||
description: 智能客服场景,包含意图识别、知识库查询和回复生成
|
||||
category: customer_service
|
||||
version: "1.0.0"
|
||||
author: RedBear Memory Team
|
||||
tags:
|
||||
- 客服
|
||||
- 意图识别
|
||||
- 知识库
|
||||
- 多步骤
|
||||
|
||||
# 工作流配置
|
||||
nodes:
|
||||
- id: start
|
||||
type: start
|
||||
name: 开始
|
||||
position:
|
||||
x: 100
|
||||
y: 200
|
||||
|
||||
- id: intent_recognition
|
||||
type: llm
|
||||
name: 意图识别
|
||||
config:
|
||||
prompt: |
|
||||
分析用户的问题,识别意图类型。
|
||||
|
||||
用户问题:{{ var.user_message }}
|
||||
|
||||
请从以下类型中选择一个:
|
||||
- product_inquiry: 产品咨询
|
||||
- technical_support: 技术支持
|
||||
- complaint: 投诉建议
|
||||
- other: 其他
|
||||
|
||||
只返回类型名称,不要其他内容。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.3
|
||||
max_tokens: 50
|
||||
position:
|
||||
x: 300
|
||||
y: 200
|
||||
|
||||
- id: intent_router
|
||||
type: condition
|
||||
name: 意图路由
|
||||
position:
|
||||
x: 500
|
||||
y: 200
|
||||
|
||||
- id: product_handler
|
||||
type: llm
|
||||
name: 产品咨询处理
|
||||
config:
|
||||
prompt: |
|
||||
用户咨询产品相关问题。
|
||||
|
||||
问题:{{ var.user_message }}
|
||||
意图:{{ node.intent_recognition.output }}
|
||||
|
||||
请提供专业、友好的产品咨询回复。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.7
|
||||
max_tokens: 500
|
||||
position:
|
||||
x: 700
|
||||
y: 100
|
||||
|
||||
- id: support_handler
|
||||
type: llm
|
||||
name: 技术支持处理
|
||||
config:
|
||||
prompt: |
|
||||
用户需要技术支持。
|
||||
|
||||
问题:{{ var.user_message }}
|
||||
意图:{{ node.intent_recognition.output }}
|
||||
|
||||
请提供详细的技术支持方案。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.5
|
||||
max_tokens: 800
|
||||
position:
|
||||
x: 700
|
||||
y: 200
|
||||
|
||||
- id: complaint_handler
|
||||
type: llm
|
||||
name: 投诉处理
|
||||
config:
|
||||
prompt: |
|
||||
用户提出投诉或建议。
|
||||
|
||||
问题:{{ var.user_message }}
|
||||
意图:{{ node.intent_recognition.output }}
|
||||
|
||||
请以同理心回应,并提供解决方案。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.8
|
||||
max_tokens: 600
|
||||
position:
|
||||
x: 700
|
||||
y: 300
|
||||
|
||||
- id: general_handler
|
||||
type: llm
|
||||
name: 通用处理
|
||||
config:
|
||||
prompt: |
|
||||
用户的问题类型:其他
|
||||
|
||||
问题:{{ var.user_message }}
|
||||
|
||||
请提供友好的回复。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.7
|
||||
max_tokens: 400
|
||||
position:
|
||||
x: 700
|
||||
y: 400
|
||||
|
||||
- id: end
|
||||
type: end
|
||||
name: 结束
|
||||
position:
|
||||
x: 900
|
||||
y: 200
|
||||
|
||||
edges:
|
||||
- source: start
|
||||
target: intent_recognition
|
||||
label: 开始分析
|
||||
|
||||
- source: intent_recognition
|
||||
target: intent_router
|
||||
label: 识别完成
|
||||
|
||||
- source: intent_router
|
||||
target: product_handler
|
||||
condition: "'product_inquiry' in node['intent_recognition']['output']"
|
||||
label: 产品咨询
|
||||
|
||||
- source: intent_router
|
||||
target: support_handler
|
||||
condition: "'technical_support' in node['intent_recognition']['output']"
|
||||
label: 技术支持
|
||||
|
||||
- source: intent_router
|
||||
target: complaint_handler
|
||||
condition: "'complaint' in node['intent_recognition']['output']"
|
||||
label: 投诉建议
|
||||
|
||||
- source: intent_router
|
||||
target: general_handler
|
||||
condition: "True" # 默认路径
|
||||
label: 其他
|
||||
|
||||
- source: product_handler
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
- source: support_handler
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
- source: complaint_handler
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
- source: general_handler
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
# 变量定义
|
||||
variables:
|
||||
- name: user_message
|
||||
type: string
|
||||
required: true
|
||||
description: 用户的消息
|
||||
default: ""
|
||||
|
||||
- name: user_name
|
||||
type: string
|
||||
required: false
|
||||
description: 用户姓名(可选)
|
||||
default: "客户"
|
||||
|
||||
# 执行配置
|
||||
execution_config:
|
||||
max_execution_time: 120
|
||||
max_iterations: 10
|
||||
|
||||
# 触发器
|
||||
triggers: []
|
||||
|
||||
# 使用示例
|
||||
examples:
|
||||
- name: 产品咨询
|
||||
description: 用户咨询产品功能
|
||||
input:
|
||||
user_message: "你们的产品支持多语言吗?"
|
||||
user_name: "张三"
|
||||
expected_output: "产品功能介绍"
|
||||
|
||||
- name: 技术支持
|
||||
description: 用户遇到技术问题
|
||||
input:
|
||||
user_message: "我无法登录系统,一直显示密码错误"
|
||||
user_name: "李四"
|
||||
expected_output: "技术支持方案"
|
||||
|
||||
- name: 投诉处理
|
||||
description: 用户提出投诉
|
||||
input:
|
||||
user_message: "你们的服务态度太差了,我要投诉"
|
||||
user_name: "王五"
|
||||
expected_output: "同理心回应和解决方案"
|
||||
131
api/app/templates/workflows/data_processing/template.yml
Normal file
131
api/app/templates/workflows/data_processing/template.yml
Normal file
@@ -0,0 +1,131 @@
|
||||
# 数据处理工作流模板
|
||||
id: data_processing_v1
|
||||
name: 数据处理工作流
|
||||
description: 数据提取、转换和分析的完整流程
|
||||
category: data_processing
|
||||
version: "1.0.0"
|
||||
author: RedBear Memory Team
|
||||
tags:
|
||||
- 数据处理
|
||||
- ETL
|
||||
- 分析
|
||||
- Transform
|
||||
|
||||
# 工作流配置
|
||||
nodes:
|
||||
- id: start
|
||||
type: start
|
||||
name: 开始
|
||||
position:
|
||||
x: 100
|
||||
y: 200
|
||||
|
||||
- id: extract_data
|
||||
type: transform
|
||||
name: 数据提取
|
||||
config:
|
||||
expression: |
|
||||
{
|
||||
"raw_text": var['input_text'],
|
||||
"length": len(var['input_text']),
|
||||
"timestamp": sys['execution_id']
|
||||
}
|
||||
position:
|
||||
x: 300
|
||||
y: 200
|
||||
|
||||
- id: analyze_data
|
||||
type: llm
|
||||
name: 数据分析
|
||||
config:
|
||||
prompt: |
|
||||
请分析以下数据:
|
||||
|
||||
原始文本:{{ node.extract_data.raw_text }}
|
||||
文本长度:{{ node.extract_data.length }}
|
||||
|
||||
请提供:
|
||||
1. 主题分类
|
||||
2. 情感分析
|
||||
3. 关键信息提取
|
||||
|
||||
以 JSON 格式返回结果。
|
||||
model: gpt-3.5-turbo
|
||||
temperature: 0.3
|
||||
max_tokens: 500
|
||||
position:
|
||||
x: 500
|
||||
y: 200
|
||||
|
||||
- id: transform_result
|
||||
type: transform
|
||||
name: 结果转换
|
||||
config:
|
||||
expression: |
|
||||
{
|
||||
"original_length": node['extract_data']['length'],
|
||||
"analysis": node['analyze_data']['output'],
|
||||
"processed_at": sys['execution_id'],
|
||||
"status": "completed"
|
||||
}
|
||||
position:
|
||||
x: 700
|
||||
y: 200
|
||||
|
||||
- id: end
|
||||
type: end
|
||||
name: 结束
|
||||
position:
|
||||
x: 900
|
||||
y: 200
|
||||
|
||||
edges:
|
||||
- source: start
|
||||
target: extract_data
|
||||
label: 开始提取
|
||||
|
||||
- source: extract_data
|
||||
target: analyze_data
|
||||
label: 开始分析
|
||||
|
||||
- source: analyze_data
|
||||
target: transform_result
|
||||
label: 转换结果
|
||||
|
||||
- source: transform_result
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
# 变量定义
|
||||
variables:
|
||||
- name: input_text
|
||||
type: string
|
||||
required: true
|
||||
description: 待处理的文本数据
|
||||
default: ""
|
||||
|
||||
# 执行配置
|
||||
execution_config:
|
||||
max_execution_time: 180
|
||||
max_iterations: 5
|
||||
|
||||
# 触发器
|
||||
triggers: []
|
||||
|
||||
# 使用示例
|
||||
examples:
|
||||
- name: 文本分析
|
||||
description: 分析一段文本
|
||||
input:
|
||||
input_text: "今天天气真好,心情也很愉快。我们公司推出了新产品,市场反响热烈。"
|
||||
expected_output:
|
||||
original_length: 35
|
||||
analysis: "主题:天气和产品,情感:积极"
|
||||
status: "completed"
|
||||
|
||||
- name: 长文本处理
|
||||
description: 处理较长的文本
|
||||
input:
|
||||
input_text: "这是一段很长的文本..."
|
||||
expected_output:
|
||||
status: "completed"
|
||||
99
api/app/templates/workflows/multi_step_qa/template.yml
Normal file
99
api/app/templates/workflows/multi_step_qa/template.yml
Normal file
@@ -0,0 +1,99 @@
|
||||
# 多步骤问答工作流
|
||||
# 演示节点输出参数的使用
|
||||
|
||||
id: multi_step_qa_v1
|
||||
name: 多步骤问答工作流
|
||||
description: 先分析问题,再生成答案,展示节点间的数据传递
|
||||
category: advanced
|
||||
version: "1.0.0"
|
||||
author: RedBear Memory Team
|
||||
tags:
|
||||
- 问答
|
||||
- 多步骤
|
||||
- LLM
|
||||
|
||||
# 工作流配置
|
||||
nodes:
|
||||
- id: start
|
||||
type: start
|
||||
name: 开始
|
||||
position:
|
||||
x: 100
|
||||
y: 100
|
||||
|
||||
- id: analyze_question
|
||||
type: llm
|
||||
name: 分析问题
|
||||
description: 分析用户问题的类型和意图
|
||||
config:
|
||||
model_id: gpt-3.5-turbo
|
||||
temperature: 0.3
|
||||
max_tokens: 500
|
||||
messages:
|
||||
- role: system
|
||||
content: |
|
||||
你是一个问题分析专家。请分析用户的问题,提取以下信息:
|
||||
1. 问题类型(事实性、观点性、操作性等)
|
||||
2. 问题领域(科技、历史、文化等)
|
||||
3. 关键词
|
||||
- role: user
|
||||
content: "{{ sys.message }}"
|
||||
position:
|
||||
x: 300
|
||||
y: 100
|
||||
|
||||
- id: generate_answer
|
||||
type: llm
|
||||
name: 生成答案
|
||||
description: 根据问题分析结果生成详细答案
|
||||
config:
|
||||
model_id: gpt-3.5-turbo
|
||||
temperature: 0.7
|
||||
max_tokens: 1000
|
||||
messages:
|
||||
- role: system
|
||||
content: |
|
||||
你是一个专业的AI助手。根据问题分析结果,生成准确、详细的答案。
|
||||
|
||||
问题分析结果:
|
||||
{{ analyze_question.output }}
|
||||
- role: user
|
||||
content: "{{ sys.message }}"
|
||||
position:
|
||||
x: 500
|
||||
y: 100
|
||||
|
||||
- id: end
|
||||
type: end
|
||||
name: 结束
|
||||
config:
|
||||
output: "{{ generate_answer.output }}"
|
||||
position:
|
||||
x: 700
|
||||
y: 100
|
||||
|
||||
edges:
|
||||
- source: start
|
||||
target: analyze_question
|
||||
label: 开始分析
|
||||
|
||||
- source: analyze_question
|
||||
target: generate_answer
|
||||
label: 生成答案
|
||||
|
||||
- source: generate_answer
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
# 变量定义
|
||||
variables:
|
||||
- name: user_question
|
||||
type: string
|
||||
required: true
|
||||
description: 用户的问题
|
||||
default: ""
|
||||
|
||||
# 执行配置
|
||||
execution_config:
|
||||
max_execution_time: 120
|
||||
max_iterations: 1
|
||||
100
api/app/templates/workflows/simple_qa/template.yml
Normal file
100
api/app/templates/workflows/simple_qa/template.yml
Normal file
@@ -0,0 +1,100 @@
|
||||
# 简单问答工作流模板
|
||||
id: simple_qa_v1
|
||||
name: 简单问答工作流
|
||||
description: 最基础的问答工作流,适合快速开始
|
||||
category: basic
|
||||
version: "1.0.0"
|
||||
author: RedBear Memory Team
|
||||
tags:
|
||||
- 问答
|
||||
- 基础
|
||||
- LLM
|
||||
|
||||
# 工作流配置
|
||||
nodes:
|
||||
- id: start
|
||||
type: start
|
||||
name: 开始
|
||||
position:
|
||||
x: 100
|
||||
y: 100
|
||||
|
||||
- id: llm_qa
|
||||
type: llm
|
||||
name: LLM 问答
|
||||
config:
|
||||
# 使用 LangChain 标准的消息格式
|
||||
messages:
|
||||
- role: system
|
||||
content: |
|
||||
你是一个专业、友好且乐于助人的 AI 助手。
|
||||
|
||||
你的职责:
|
||||
- 准确理解用户的问题并提供有价值的回答
|
||||
- 保持回答的专业性和准确性
|
||||
- 如果不确定答案,诚实地告知用户
|
||||
- 使用清晰、易懂的语言进行交流
|
||||
|
||||
回答风格:
|
||||
- 简洁明了,直击要点
|
||||
- 必要时提供详细解释和示例
|
||||
- 使用友好、礼貌的语气
|
||||
- 适当使用格式化(如列表、段落)提高可读性
|
||||
|
||||
- role: user
|
||||
content: "{{ sys.message }}"
|
||||
|
||||
model_id: gpt-3.5-turbo
|
||||
temperature: 0.7
|
||||
max_tokens: 1000
|
||||
position:
|
||||
x: 300
|
||||
y: 100
|
||||
|
||||
- id: end
|
||||
type: end
|
||||
name: 结束
|
||||
config:
|
||||
output: "{{ llm_qa.output }}"
|
||||
position:
|
||||
x: 500
|
||||
y: 100
|
||||
|
||||
edges:
|
||||
- source: start
|
||||
target: llm_qa
|
||||
label: 开始处理
|
||||
|
||||
- source: llm_qa
|
||||
target: end
|
||||
label: 完成
|
||||
|
||||
# 变量定义
|
||||
variables:
|
||||
- name: user_question
|
||||
type: string
|
||||
required: true
|
||||
description: 用户的问题
|
||||
default: ""
|
||||
|
||||
# 执行配置
|
||||
execution_config:
|
||||
max_execution_time: 60
|
||||
max_iterations: 1
|
||||
|
||||
# 触发器(可选)
|
||||
triggers: []
|
||||
|
||||
# 使用示例
|
||||
examples:
|
||||
- name: 基础问答
|
||||
description: 询问一个简单的问题
|
||||
input:
|
||||
user_question: "什么是人工智能?"
|
||||
expected_output: "关于人工智能的解释"
|
||||
|
||||
- name: 技术咨询
|
||||
description: 询问技术问题
|
||||
input:
|
||||
user_question: "如何学习 Python 编程?"
|
||||
expected_output: "Python 学习建议"
|
||||
27
api/app/utils/sse_utils.py
Normal file
27
api/app/utils/sse_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Server-Sent Events (SSE) Utility Functions
|
||||
|
||||
Provides shared utilities for formatting and handling SSE messages.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def format_sse_message(event_type: str, data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format a message in Server-Sent Events (SSE) format.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (stage name, result, error, done)
|
||||
data: Event data dictionary to be serialized as JSON
|
||||
|
||||
Returns:
|
||||
SSE formatted string: "event: <type>\\ndata: <json>\\n\\n"
|
||||
|
||||
Example:
|
||||
>>> format_sse_message("loading", {"message": "Loading..."})
|
||||
'event: loading\\ndata: {"message": "Loading..."}\\n\\n'
|
||||
"""
|
||||
json_data = json.dumps(data, ensure_ascii=False)
|
||||
return f"event: {event_type}\ndata: {json_data}\n\n"
|
||||
@@ -27,12 +27,21 @@ import 'dayjs/locale/en'
|
||||
import 'dayjs/locale/zh-cn'
|
||||
import 'dayjs/plugin/timezone'
|
||||
import 'dayjs/plugin/utc'
|
||||
import { cookieUtils } from './utils/request';
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
function App() {
|
||||
const { t } = useTranslation();
|
||||
const { locale, language, timeZone } = useI18n()
|
||||
useEffect(() => {
|
||||
const authToken = cookieUtils.get('authToken')
|
||||
if (!authToken && !window.location.hash.includes('#/login')) {
|
||||
window.location.href = `/#/login`;
|
||||
}
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
document.title = t('memoryBear')
|
||||
|
||||
@@ -200,7 +200,7 @@ export const deleteFile = async (id: string) => {
|
||||
|
||||
// 获取文档列表
|
||||
export const getDocumentList = async (query: PathQuery) => {
|
||||
const response = await request.get(`${apiPrefix}/documents/${query.kb_id}/${query.parent_id}/documents`, query);
|
||||
const response = await request.get(`${apiPrefix}/documents/${query.kb_id}/documents`, query);
|
||||
return response as KnowledgeBaseDocumentData[];
|
||||
};
|
||||
// 文档详情
|
||||
@@ -213,6 +213,11 @@ export const createDocument = async (data: KnowledgeBaseDocumentData) => {
|
||||
const response = await request.post(`${apiPrefix}/documents/document`, data);
|
||||
return response as KnowledgeBaseDocumentData;
|
||||
};
|
||||
// 自定义文档上传并创建
|
||||
export const createDocumentAndUpload = async ( data: any, params: PathQuery) => {
|
||||
const response = await request.post(`${apiPrefix}/files/customtext`, data, { params } );
|
||||
return response as any;
|
||||
};
|
||||
// 更新文档
|
||||
export const updateDocument = async (id: string, data: KnowledgeBaseDocumentData) => {
|
||||
const response = await request.put(`${apiPrefix}/documents/${id}`, data);
|
||||
@@ -223,9 +228,9 @@ export const deleteDocument = async (id: string) => {
|
||||
const response = await request.delete(`${apiPrefix}/documents/${id}`);
|
||||
return response;
|
||||
};
|
||||
// 文档解析
|
||||
export const parseDocument = async (id: string) => {
|
||||
const response = await request.post(`${apiPrefix}/documents/${id}/chunks`);
|
||||
// 文档解析 / 分块
|
||||
export const parseDocument = async (id: string, data: any) => {
|
||||
const response = await request.post(`${apiPrefix}/documents/${id}/chunks`, data);
|
||||
return response as any;
|
||||
};
|
||||
// 文档分块预览
|
||||
|
||||
@@ -3,6 +3,7 @@ import { Layout, Dropdown, Space, Breadcrumb } from 'antd';
|
||||
import type { MenuProps, BreadcrumbProps } from 'antd';
|
||||
import { UserOutlined, LogoutOutlined, SettingOutlined } from '@ant-design/icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLocation } from 'react-router-dom';
|
||||
import { useUser } from '@/store/user';
|
||||
import { useMenu } from '@/store/menu';
|
||||
import styles from './index.module.css'
|
||||
@@ -12,12 +13,35 @@ const { Header } = Layout;
|
||||
|
||||
const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => {
|
||||
const { t } = useTranslation();
|
||||
const location = useLocation();
|
||||
const settingModalRef = useRef<SettingModalRef>(null)
|
||||
const userInfoModalRef = useRef<UserInfoModalRef>(null)
|
||||
|
||||
const { user, logout } = useUser();
|
||||
const { allBreadcrumbs } = useMenu();
|
||||
const breadcrumbs = allBreadcrumbs[source] || [];
|
||||
|
||||
// 根据当前路由动态选择面包屑源
|
||||
const getBreadcrumbSource = () => {
|
||||
const pathname = location.pathname;
|
||||
|
||||
// 知识库列表页面使用默认的 space 面包屑
|
||||
if (pathname === '/knowledge-base') {
|
||||
return 'space';
|
||||
}
|
||||
|
||||
// 知识库详情相关页面使用独立的面包屑
|
||||
if (pathname.includes('/knowledge-base/') && pathname !== '/knowledge-base') {
|
||||
return 'space-detail';
|
||||
}
|
||||
|
||||
// 其他页面使用传入的 source
|
||||
return source;
|
||||
};
|
||||
|
||||
const breadcrumbSource = getBreadcrumbSource();
|
||||
const breadcrumbs = allBreadcrumbs[breadcrumbSource] || [];
|
||||
|
||||
|
||||
|
||||
// 处理退出登录
|
||||
const handleLogout = () => {
|
||||
|
||||
248
web/src/hooks/useBreadcrumbManager.ts
Normal file
248
web/src/hooks/useBreadcrumbManager.ts
Normal file
@@ -0,0 +1,248 @@
|
||||
import { useCallback } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useMenu } from '@/store/menu';
|
||||
import type { MenuItem } from '@/store/menu';
|
||||
|
||||
export interface BreadcrumbItem {
|
||||
id: string;
|
||||
name: string;
|
||||
type?: 'knowledgeBase' | 'folder' | 'document';
|
||||
}
|
||||
|
||||
export interface BreadcrumbPath {
|
||||
knowledgeBaseFolderPath: BreadcrumbItem[]; // 知识库文件夹路径
|
||||
knowledgeBase?: BreadcrumbItem; // 知识库信息
|
||||
documentFolderPath: BreadcrumbItem[]; // 文档文件夹路径
|
||||
document?: BreadcrumbItem; // 文档信息
|
||||
}
|
||||
|
||||
export interface BreadcrumbOptions {
|
||||
onKnowledgeBaseMenuClick?: () => void;
|
||||
onKnowledgeBaseFolderClick?: (folderId: string, folderPath: BreadcrumbItem[]) => void;
|
||||
// 新增:区分面包屑类型
|
||||
breadcrumbType?: 'list' | 'detail';
|
||||
}
|
||||
|
||||
export const useBreadcrumbManager = (options?: BreadcrumbOptions) => {
|
||||
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
|
||||
const navigate = useNavigate();
|
||||
|
||||
const updateBreadcrumbs = useCallback((breadcrumbPath: BreadcrumbPath) => {
|
||||
const breadcrumbType = options?.breadcrumbType || 'list';
|
||||
|
||||
// 获取基础面包屑,对于详情页面,使用列表页面的基础面包屑作为起点
|
||||
const baseBreadcrumbs = breadcrumbType === 'list'
|
||||
? (allBreadcrumbs['space'] || [])
|
||||
: (allBreadcrumbs['space'] || []); // 详情页面也从 space 获取基础面包屑
|
||||
|
||||
// 只保留知识库菜单项之前的面包屑
|
||||
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
|
||||
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
|
||||
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
|
||||
: baseBreadcrumbs;
|
||||
|
||||
// 给"知识库管理"添加点击事件
|
||||
const breadcrumbsWithClick = filteredBaseBreadcrumbs.map((item) => {
|
||||
if (item.path === '/knowledge-base') {
|
||||
return {
|
||||
...item,
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
|
||||
if (options?.onKnowledgeBaseMenuClick) {
|
||||
// 如果提供了回调函数,执行回调
|
||||
options.onKnowledgeBaseMenuClick();
|
||||
} else if (breadcrumbType === 'detail') {
|
||||
// 知识库详情页面:没有回调函数时,返回到知识库列表页面
|
||||
navigate('/knowledge-base', {
|
||||
state: {
|
||||
resetToRoot: true,
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
},
|
||||
};
|
||||
}
|
||||
return item;
|
||||
});
|
||||
|
||||
let customBreadcrumbs: MenuItem[] = [...breadcrumbsWithClick];
|
||||
|
||||
if (breadcrumbType === 'list') {
|
||||
// 知识库列表页面:只显示知识库文件夹路径
|
||||
customBreadcrumbs = [
|
||||
...breadcrumbsWithClick,
|
||||
...breadcrumbPath.knowledgeBaseFolderPath.map((folder, index) => ({
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: folder.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
|
||||
// 如果有回调函数,直接调用回调函数来更新状态
|
||||
if (options?.onKnowledgeBaseFolderClick) {
|
||||
options.onKnowledgeBaseFolderClick(folder.id, breadcrumbPath.knowledgeBaseFolderPath.slice(0, index + 1));
|
||||
} else {
|
||||
// 否则使用导航(兜底逻辑)
|
||||
navigate('/knowledge-base', {
|
||||
state: {
|
||||
navigateToFolder: folder.id,
|
||||
folderPath: breadcrumbPath.knowledgeBaseFolderPath.slice(0, index + 1)
|
||||
}
|
||||
});
|
||||
}
|
||||
return false;
|
||||
},
|
||||
})),
|
||||
];
|
||||
} else {
|
||||
// 知识库详情页面:显示知识库名称 + 文档文件夹路径 + 文档名称
|
||||
customBreadcrumbs = [
|
||||
...breadcrumbsWithClick,
|
||||
|
||||
// 添加知识库名称
|
||||
...(breadcrumbPath.knowledgeBase ? [{
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: breadcrumbPath.knowledgeBase.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 返回到知识库详情页的根目录
|
||||
const navigationState = {
|
||||
fromKnowledgeBaseList: true,
|
||||
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
|
||||
resetToRoot: true, // 添加重置到根目录的标志
|
||||
refresh: true, // 添加刷新标志
|
||||
timestamp: Date.now(), // 添加时间戳确保状态变化
|
||||
};
|
||||
navigate(`/knowledge-base/${breadcrumbPath.knowledgeBase!.id}/private`, {
|
||||
state: navigationState,
|
||||
replace: true // 使用 replace 避免历史记录堆积
|
||||
});
|
||||
return false;
|
||||
},
|
||||
}] : []),
|
||||
|
||||
// 添加文档文件夹路径
|
||||
...breadcrumbPath.documentFolderPath.map((folder, index) => ({
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: folder.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 返回到知识库详情页的对应文件夹
|
||||
const navigationState = {
|
||||
fromKnowledgeBaseList: true,
|
||||
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
|
||||
navigateToDocumentFolder: folder.id,
|
||||
documentFolderPath: breadcrumbPath.documentFolderPath.slice(0, index + 1),
|
||||
refresh: true, // 添加刷新标志
|
||||
timestamp: Date.now(), // 添加时间戳确保状态变化
|
||||
};
|
||||
navigate(`/knowledge-base/${breadcrumbPath.knowledgeBase!.id}/private`, {
|
||||
state: navigationState,
|
||||
replace: true // 使用 replace 避免历史记录堆积
|
||||
});
|
||||
return false;
|
||||
},
|
||||
})),
|
||||
|
||||
// 添加文档名称(如果存在)
|
||||
...(breadcrumbPath.document ? [{
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: breadcrumbPath.document.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
// 文档名称不可点击
|
||||
}] : []),
|
||||
];
|
||||
}
|
||||
|
||||
// 根据面包屑类型使用不同的键,实现独立的面包屑路径
|
||||
const breadcrumbKey = breadcrumbType === 'list' ? 'space' : 'space-detail';
|
||||
|
||||
|
||||
|
||||
setCustomBreadcrumbs(customBreadcrumbs, breadcrumbKey);
|
||||
}, [setCustomBreadcrumbs, navigate, options?.breadcrumbType, options?.onKnowledgeBaseMenuClick, options?.onKnowledgeBaseFolderClick]);
|
||||
|
||||
return {
|
||||
updateBreadcrumbs,
|
||||
};
|
||||
};
|
||||
@@ -11,8 +11,10 @@ export const checkAuthStatus = (): boolean => {
|
||||
|
||||
// 递归检查路由是否存在于菜单数据中
|
||||
export const checkRoutePermission = (menus: MenuItem[], currentPath: string): boolean => {
|
||||
// 首页默认有权限
|
||||
if (currentPath === '/' || currentPath.includes('knowledge-detail')) return true;
|
||||
// 首页和知识库相关页面默认有权限
|
||||
if (currentPath === '/' || currentPath.includes('knowledge-detail') || currentPath.includes('knowledge-base')) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const menu of menus) {
|
||||
// 检查当前菜单的path是否匹配
|
||||
@@ -26,6 +28,7 @@ export const checkRoutePermission = (menus: MenuItem[], currentPath: string): bo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
@@ -52,7 +55,7 @@ export const useRouteGuard = (source: 'space' | 'manage') => {
|
||||
const hasPermission = checkRoutePermission(menus, location.pathname);
|
||||
if (!hasPermission) {
|
||||
// 无权限访问该路由,重定向到无权限页面
|
||||
// navigate('/not-found', { replace: true });
|
||||
// navigate('/no-permission', { replace: true });
|
||||
}
|
||||
}
|
||||
}, [navigate, location.pathname, location.search, location.hash, menus]);
|
||||
|
||||
@@ -479,11 +479,18 @@ export const en = {
|
||||
noDataSets: 'No datasets yet, click the button below or drag files to create.',
|
||||
createEmptyDataSet: '+ Empty Dataset',
|
||||
createImageDataSet: '+ Image Dataset',
|
||||
createContent: 'Create Content',
|
||||
title: 'Title',
|
||||
content: 'Content',
|
||||
pleaseEnterTitle: 'Please enter title',
|
||||
pleaseEnterContent: 'Please enter content',
|
||||
// createImageDataSet: '+ Image Dataset',
|
||||
dragFilesHere: 'Drag files here to upload',
|
||||
createImport: 'Create/Import',
|
||||
textDataSet: 'Text Dataset',
|
||||
imageDataSet: 'Image Dataset',
|
||||
blankDataset: 'Blank Dataset',
|
||||
customTextDataset: 'Custom Text Dataset',
|
||||
text: 'Text',
|
||||
search: 'Search',
|
||||
image: 'Image',
|
||||
|
||||
@@ -110,6 +110,11 @@ export const zh = {
|
||||
noDataSets: '暂无数据集,点击下方按钮或拖拽文件创建。',
|
||||
createEmptyDataSet: '+ 空白数据集',
|
||||
createImageDataSet: '+ 图片数据集',
|
||||
createContent: '创建内容',
|
||||
title: '标题',
|
||||
content: '内容',
|
||||
pleaseEnterTitle: '请输入标题',
|
||||
pleaseEnterContent: '请输入内容',
|
||||
dragFilesHere: '拖拽文件到此处上传',
|
||||
downloadOriginal: '下载原始内容',
|
||||
createImport: '新建/导入',
|
||||
@@ -117,6 +122,7 @@ export const zh = {
|
||||
imageDataSet: '图片数据集',
|
||||
blankDataset: '空白数据集',
|
||||
emptyDataSet: '空白数据集',
|
||||
customTextDataset: '自定义文本数据集',
|
||||
text: '文本',
|
||||
search: '搜索',
|
||||
image: '图片',
|
||||
|
||||
@@ -32,7 +32,7 @@ interface MenuState {
|
||||
allBreadcrumbs: Record<'space' | 'manage' | string, MenuItem[]>;
|
||||
loadMenus: (source: 'space' | 'manage') => void;
|
||||
updateBreadcrumbs: (keyPath: string[], source: 'space' | 'manage') => void;
|
||||
setCustomBreadcrumbs: (breadcrumbs: MenuItem[], source: 'space' | 'manage') => void;
|
||||
setCustomBreadcrumbs: (breadcrumbs: MenuItem[], source: string) => void;
|
||||
}
|
||||
|
||||
const initBreadcrumbs = localStorage.getItem('breadcrumbs') || '[]'
|
||||
|
||||
@@ -8,15 +8,14 @@ import type { UploadFileResponse,KnowledgeBaseDocumentData } from '@/views/Knowl
|
||||
import type { ColumnsType } from 'antd/es/table';
|
||||
import UploadFiles from '@/components/Upload/UploadFiles';
|
||||
import type { UploadRequestOption } from 'rc-upload/lib/interface';
|
||||
import { uploadFile, getDocumentList, previewDocumentChunk, parseDocument, updateDocument, deleteDocument } from '@/api/knowledgeBase';
|
||||
import { uploadFile, getDocumentList, parseDocument, updateDocument, deleteDocument } from '@/api/knowledgeBase';
|
||||
import exitIcon from '@/assets/images/knowledgeBase/exit.png';
|
||||
import { NoData } from '../components/noData';
|
||||
import noDataIcon from '@/assets/images/knowledgeBase/noData.png';
|
||||
|
||||
import SliderInput from '@/components/SliderInput';
|
||||
import DelimiterSelector from '../components/DelimiterSelector';
|
||||
const { confirm } = Modal
|
||||
const { TextArea } = Input;
|
||||
import styles from '../index.module.css';
|
||||
|
||||
const style: React.CSSProperties = {
|
||||
display: 'flex',
|
||||
gap: 16,
|
||||
@@ -71,12 +70,11 @@ const CreateDataset = () => {
|
||||
const initialFileIds = locationState.fileIds ?? (locationState.fileId ? [locationState.fileId] : []);
|
||||
const [current, setCurrent] = useState<number>(stepIndexMap[initialStepKey]);
|
||||
const tableRef = useRef<TableRef>(null);
|
||||
|
||||
|
||||
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
|
||||
const [chunkData, setChunkData] = useState<any[]>([]);
|
||||
const [total, setTotal] = useState<number>(0);
|
||||
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
|
||||
const [curSelectedFileId, setCurSelectedFileId] = useState<number>(-1);
|
||||
const [previewLoading, setPreviewLoading] = useState<boolean>(false);
|
||||
|
||||
const [pollingLoading, setPollingLoading] = useState<boolean>(false);
|
||||
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const [delimiter, setDelimiter] = useState<string | undefined>(undefined);
|
||||
@@ -121,6 +119,7 @@ const CreateDataset = () => {
|
||||
layout_recognize:'DeepDOC',
|
||||
delimiter: delimiter,
|
||||
chunk_token_num: blockSize,
|
||||
auto_question: processingMethod === 'directBlock' ? 0 : 1,
|
||||
}
|
||||
}
|
||||
updateDocument(id, params)
|
||||
@@ -145,7 +144,7 @@ const CreateDataset = () => {
|
||||
});
|
||||
return;
|
||||
}
|
||||
debugger
|
||||
|
||||
|
||||
// 显示确认弹框
|
||||
confirm({
|
||||
@@ -168,7 +167,7 @@ const CreateDataset = () => {
|
||||
const startProcessing = (autoReturnToList: boolean) => {
|
||||
// 触发文档解析
|
||||
rechunkFileIds.map((id) => {
|
||||
parseDocument(id);
|
||||
parseDocument(id, {});
|
||||
});
|
||||
|
||||
// 开启 loading
|
||||
@@ -276,21 +275,7 @@ const CreateDataset = () => {
|
||||
onError?.(error as Error);
|
||||
});
|
||||
};
|
||||
// 点击文件 预览分块
|
||||
const handlePreview = async(item: KnowledgeBaseDocumentData, index: number) => {
|
||||
setCurSelectedFileId(index);
|
||||
setPreviewLoading(true);
|
||||
try{
|
||||
const res = await previewDocumentChunk(knowledgeBaseId ?? '', item.id ?? '');
|
||||
setChunkData(res.items || []);
|
||||
setTotal(res.page.total || 0);
|
||||
console.log('res', res);
|
||||
}catch(error) {
|
||||
console.log('error', error);
|
||||
} finally {
|
||||
setPreviewLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 轮询检查文档处理状态
|
||||
// autoReturn: 是否在所有文档完成时自动返回列表页
|
||||
@@ -346,6 +331,8 @@ const CreateDataset = () => {
|
||||
state: {
|
||||
refresh: true,
|
||||
timestamp: Date.now(), // 添加时间戳确保每次都是新的 state
|
||||
// 保持返回到原来的文档文件夹位置
|
||||
navigateToDocumentFolder: parentId !== knowledgeBaseId ? parentId : undefined,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
@@ -565,8 +552,8 @@ const CreateDataset = () => {
|
||||
{rechunkFileIds.length > 0 ? (
|
||||
<Table
|
||||
ref={tableRef}
|
||||
apiUrl={`/documents/${knowledgeBaseId}/${parentId}/documents`}
|
||||
apiParams={{
|
||||
apiUrl={`/documents/${knowledgeBaseId}/documents`}
|
||||
apiParams={{
|
||||
document_ids: rechunkFileIds.join(','),
|
||||
}}
|
||||
columns={columns}
|
||||
|
||||
@@ -4,11 +4,12 @@
|
||||
* @Author: yujiangping
|
||||
* @Date: 2025-11-15 16:13:47
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2025-11-29 19:46:46
|
||||
* @LastEditTime: 2025-12-12 20:02:05
|
||||
*/
|
||||
import { useEffect, useState, useRef, type FC } from 'react';
|
||||
import { useNavigate, useParams, useLocation } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
|
||||
import { Button, Spin, message, Switch } from 'antd';
|
||||
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase';
|
||||
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
|
||||
@@ -25,7 +26,18 @@ const DocumentDetails: FC = () => {
|
||||
const navigate = useNavigate();
|
||||
const { knowledgeBaseId } = useParams<{ knowledgeBaseId: string }>();
|
||||
const location = useLocation();
|
||||
const { documentId, parentId: locationParentId } = location.state as { documentId: string; parentId?: string };
|
||||
const { updateBreadcrumbs } = useBreadcrumbManager({
|
||||
breadcrumbType: 'detail'
|
||||
});
|
||||
const {
|
||||
documentId,
|
||||
parentId: locationParentId,
|
||||
breadcrumbPath
|
||||
} = location.state as {
|
||||
documentId: string;
|
||||
parentId?: string;
|
||||
breadcrumbPath?: BreadcrumbPath;
|
||||
};
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [document, setDocument] = useState<KnowledgeBaseDocumentData | null>(null);
|
||||
const [chunkList, setChunkList] = useState<RecallTestData[]>([]);
|
||||
@@ -44,6 +56,13 @@ const DocumentDetails: FC = () => {
|
||||
}
|
||||
}, [documentId]);
|
||||
|
||||
// 更新面包屑
|
||||
useEffect(() => {
|
||||
if (breadcrumbPath) {
|
||||
updateBreadcrumbs(breadcrumbPath);
|
||||
}
|
||||
}, [breadcrumbPath, updateBreadcrumbs]);
|
||||
|
||||
// 当文档加载完成且 progress === 1 时,加载分块列表
|
||||
useEffect(() => {
|
||||
if (document && document.progress === 1 && !isManualRefreshRef.current) {
|
||||
@@ -179,7 +198,18 @@ const DocumentDetails: FC = () => {
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
if (knowledgeBaseId) {
|
||||
if (knowledgeBaseId && breadcrumbPath) {
|
||||
// 返回到知识库详情页,并传递面包屑信息以恢复状态
|
||||
const navigationState = {
|
||||
fromKnowledgeBaseList: true,
|
||||
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
|
||||
navigateToDocumentFolder: locationParentId,
|
||||
documentFolderPath: breadcrumbPath.documentFolderPath,
|
||||
timestamp: Date.now(), // 添加时间戳确保状态变化
|
||||
};
|
||||
navigate(`/knowledge-base/${knowledgeBaseId}/private`, { state: navigationState });
|
||||
} else if (knowledgeBaseId) {
|
||||
// 降级处理:直接跳转到知识库详情页
|
||||
navigate(`/knowledge-base/${knowledgeBaseId}/private`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
import { useEffect, useState, useRef, type FC } from 'react';
|
||||
import { useEffect, useState, useRef, useCallback, type FC } from 'react';
|
||||
import { useNavigate, useParams, useLocation } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Switch, Button, Dropdown, Space, Modal, message } from 'antd';
|
||||
@@ -12,26 +12,29 @@ import { MoreOutlined } from '@ant-design/icons';
|
||||
import folderIcon from '@/assets/images/knowledgeBase/folder.png';
|
||||
import textIcon from '@/assets/images/knowledgeBase/text.png';
|
||||
import editIcon from '@/assets/images/knowledgeBase/edit.png';
|
||||
import blankIcon from '@/assets/images/knowledgeBase/blankDocument.png';
|
||||
import { getKnowledgeBaseDetail, deleteDocument, downloadFile, updateKnowledgeBase } from '@/api/knowledgeBase';
|
||||
import type {
|
||||
CreateModalRef,
|
||||
KnowledgeBaseListItem,
|
||||
RecallTestDrawerRef,
|
||||
CreateFolderModalRef,
|
||||
CreateImageModalRef,
|
||||
ShareModalRef,
|
||||
CreateDatasetModalRef,FolderFormData,
|
||||
KnowledgeBaseDocumentData
|
||||
import {
|
||||
type CreateModalRef,
|
||||
type KnowledgeBaseListItem,
|
||||
type RecallTestDrawerRef,
|
||||
type CreateFolderModalRef,
|
||||
type CreateSetModalRef,
|
||||
type ShareModalRef,
|
||||
type CreateDatasetModalRef,type FolderFormData,
|
||||
type KnowledgeBaseDocumentData,
|
||||
} from '@/views/KnowledgeBase/types';
|
||||
import RecallTestDrawer from '../components/RecallTestDrawer';
|
||||
import CreateFolderModal from '../components/CreateFolderModal';
|
||||
import CreateContentModal from '../components/CreateContentModal';
|
||||
import CreateModal from '../components/CreateModal';
|
||||
import ShareModal from '../components/ShareModal';
|
||||
import CreateDatasetModal from '../components/CreateDatasetModal';
|
||||
import CreateImageDataset from '../components/CreateImageDataset';
|
||||
import FolderTree, { type TreeNodeData } from '../components/FolderTree';
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import { useMenu } from '@/store/menu';
|
||||
|
||||
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
|
||||
import './Private.css'
|
||||
const { confirm } = Modal
|
||||
// 树节点数据类型
|
||||
@@ -48,7 +51,8 @@ const Private: FC = () => {
|
||||
const [tableApi, setTableApi] = useState<string | undefined>(undefined);
|
||||
const recallTestDrawerRef = useRef<RecallTestDrawerRef>(null);
|
||||
const createFolderModalRef = useRef<CreateFolderModalRef>(null);
|
||||
const createImageDataset = useRef<CreateImageModalRef>(null)
|
||||
const createImageDataset = useRef<CreateSetModalRef>(null)
|
||||
const createContentModalRef = useRef<CreateSetModalRef>(null);
|
||||
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseListItem | null>(null);
|
||||
const [folder, setFolder] = useState<FolderFormData | null>({
|
||||
kb_id:knowledgeBaseId ?? '',
|
||||
@@ -56,47 +60,47 @@ const Private: FC = () => {
|
||||
});
|
||||
const [query, setQuery] = useState<Record<string, unknown>>({
|
||||
orderby: 'created_at',
|
||||
desc: true,
|
||||
desc: true
|
||||
});
|
||||
const modalRef = useRef<CreateModalRef>(null)
|
||||
const shareModalRef = useRef<ShareModalRef>(null);
|
||||
const datasetModalRef = useRef<CreateDatasetModalRef>(null);
|
||||
const [folderTreeRefreshKey, setFolderTreeRefreshKey] = useState(0);
|
||||
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
|
||||
const [folderPath, setFolderPath] = useState<Array<{ id: string; name: string }>>([]);
|
||||
const [autoExpandPath, setAutoExpandPath] = useState<Array<{ id: string; name: string }>>([]);
|
||||
|
||||
const { updateBreadcrumbs } = useBreadcrumbManager({
|
||||
breadcrumbType: 'detail',
|
||||
// 不提供 onKnowledgeBaseMenuClick,让它使用默认的导航行为(返回列表页面)
|
||||
onKnowledgeBaseFolderClick: useCallback((folderId: string, folderPath: Array<{ id: string; name: string }>) => {
|
||||
// 点击文件夹面包屑时,导航到对应文件夹
|
||||
setParentId(folderId);
|
||||
setFolderPath(folderPath);
|
||||
setSelectedKeys([folderId]);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: folderId
|
||||
});
|
||||
|
||||
// 确保query对象发生变化,触发表格刷新
|
||||
setQuery({
|
||||
orderby: 'created_at',
|
||||
desc: true,
|
||||
parent_id: folderId,
|
||||
_timestamp: Date.now()
|
||||
});
|
||||
|
||||
// 确保API URL正确设置
|
||||
setTableApi(`/documents/${knowledgeBaseId}/documents`);
|
||||
|
||||
// 手动触发表格刷新,确保数据更新
|
||||
setTimeout(() => {
|
||||
tableRef.current?.loadData();
|
||||
}, 100);
|
||||
}, [knowledgeBaseId])
|
||||
});
|
||||
const [folderPath, setFolderPath] = useState<BreadcrumbItem[]>([]);
|
||||
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
|
||||
useEffect(() => {
|
||||
if (knowledgeBaseId) {
|
||||
let url = `/documents/${knowledgeBaseId}/${parentId}/documents`;
|
||||
setTableApi(url);
|
||||
fetchKnowledgeBaseDetail(knowledgeBaseId);
|
||||
}
|
||||
}, [knowledgeBaseId]);
|
||||
|
||||
// 更新面包屑
|
||||
useEffect(() => {
|
||||
if (knowledgeBase) {
|
||||
updateBreadcrumbs();
|
||||
}
|
||||
}, [knowledgeBase, folderPath]);
|
||||
|
||||
// 监听 tableApi 变化,自动刷新表格数据
|
||||
useEffect(() => {
|
||||
if (tableApi) {
|
||||
tableRef.current?.loadData();
|
||||
}
|
||||
}, [tableApi]);
|
||||
|
||||
// 监听 location state 变化,如果有 refresh 标志则刷新列表
|
||||
useEffect(() => {
|
||||
const state = location.state as { refresh?: boolean; timestamp?: number } | null;
|
||||
if (state?.refresh) {
|
||||
tableRef.current?.loadData();
|
||||
// 清除 state,避免重复刷新
|
||||
navigate(location.pathname, { replace: true, state: {} });
|
||||
}
|
||||
}, [location.state]);
|
||||
|
||||
const [knowledgeBaseFolderPath, setKnowledgeBaseFolderPath] = useState<BreadcrumbItem[]>([]);
|
||||
const fetchKnowledgeBaseDetail = async (id: string) => {
|
||||
setLoading(true);
|
||||
try {
|
||||
@@ -109,110 +113,160 @@ const Private: FC = () => {
|
||||
}
|
||||
};
|
||||
|
||||
// 更新面包屑,包含知识库名称和文件夹路径
|
||||
const updateBreadcrumbs = () => {
|
||||
if (!knowledgeBase) return;
|
||||
|
||||
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
|
||||
// 只保留知识库菜单项之前的面包屑
|
||||
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
|
||||
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
|
||||
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
|
||||
: baseBreadcrumbs;
|
||||
|
||||
const customBreadcrumbs = [
|
||||
...filteredBaseBreadcrumbs,
|
||||
{
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: knowledgeBase.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
// 阻止默认行为和事件冒泡
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 点击知识库名称,回到根目录
|
||||
setParentId(knowledgeBaseId);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: knowledgeBaseId ?? ''
|
||||
});
|
||||
setTableApi(`/documents/${knowledgeBaseId}/${knowledgeBaseId}/documents`);
|
||||
setFolderPath([]);
|
||||
setSelectedKeys([knowledgeBaseId ?? '']);
|
||||
return false;
|
||||
},
|
||||
},
|
||||
...folderPath.map((folder, index) => ({
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: folder.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
// 阻止默认行为和事件冒泡
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 点击文件夹,回到该文件夹层级
|
||||
setParentId(folder.id);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: folder.id
|
||||
});
|
||||
setTableApi(`/documents/${knowledgeBaseId}/${folder.id}/documents`);
|
||||
// 更新文件夹路径,只保留到当前点击的文件夹
|
||||
setFolderPath(folderPath.slice(0, index + 1));
|
||||
setSelectedKeys([folder.id]);
|
||||
return false;
|
||||
},
|
||||
})),
|
||||
];
|
||||
useEffect(() => {
|
||||
if (knowledgeBaseId) {
|
||||
let url = `/documents/${knowledgeBaseId}/documents`;
|
||||
setTableApi(url);
|
||||
fetchKnowledgeBaseDetail(knowledgeBaseId);
|
||||
}
|
||||
}, [knowledgeBaseId]);
|
||||
|
||||
setCustomBreadcrumbs(customBreadcrumbs, 'space');
|
||||
};
|
||||
// 更新面包屑
|
||||
useEffect(() => {
|
||||
if (knowledgeBase) {
|
||||
updateBreadcrumbs({
|
||||
knowledgeBaseFolderPath,
|
||||
knowledgeBase: {
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
type: 'knowledgeBase'
|
||||
},
|
||||
documentFolderPath: folderPath,
|
||||
});
|
||||
}
|
||||
}, [knowledgeBase, knowledgeBaseFolderPath, folderPath, updateBreadcrumbs]);
|
||||
|
||||
// 监听 tableApi 变化,自动刷新表格数据
|
||||
useEffect(() => {
|
||||
if (tableApi) {
|
||||
tableRef.current?.loadData();
|
||||
}
|
||||
}, [tableApi]);
|
||||
|
||||
// 监听 query 变化,确保表格数据更新
|
||||
useEffect(() => {
|
||||
if (tableApi && query._timestamp) {
|
||||
// 当 query 中有 _timestamp 时,说明是通过面包屑或其他方式触发的更新
|
||||
tableRef.current?.loadData();
|
||||
}
|
||||
}, [query._timestamp, tableApi]);
|
||||
|
||||
// 监听 location state 变化
|
||||
useEffect(() => {
|
||||
const state = location.state as {
|
||||
refresh?: boolean;
|
||||
timestamp?: number;
|
||||
fromKnowledgeBaseList?: boolean;
|
||||
knowledgeBaseFolderPath?: BreadcrumbItem[];
|
||||
parentId?: string;
|
||||
navigateToDocumentFolder?: string;
|
||||
documentFolderPath?: BreadcrumbItem[];
|
||||
resetToRoot?: boolean;
|
||||
} | null;
|
||||
|
||||
if (state?.refresh) {
|
||||
tableRef.current?.loadData();
|
||||
// 清除 state,避免重复刷新
|
||||
navigate(location.pathname, { replace: true, state: {} });
|
||||
}
|
||||
|
||||
// 如果是从知识库列表页跳转过来的,设置知识库文件夹路径
|
||||
if (state?.fromKnowledgeBaseList && state?.knowledgeBaseFolderPath) {
|
||||
setKnowledgeBaseFolderPath(state.knowledgeBaseFolderPath);
|
||||
}
|
||||
|
||||
// 如果需要重置到根目录(回到初始状态)
|
||||
if (state?.resetToRoot) {
|
||||
// 重置所有状态到初始状态,和页面初始化保持一致
|
||||
setParentId(knowledgeBaseId);
|
||||
setFolderPath([]);
|
||||
setSelectedKeys([]);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: knowledgeBaseId ?? ''
|
||||
});
|
||||
setQuery({
|
||||
orderby: 'created_at',
|
||||
desc: true,
|
||||
_timestamp: Date.now() // 添加时间戳确保query对象发生变化,触发API调用
|
||||
});
|
||||
|
||||
// 重新设置API URL
|
||||
const rootUrl = `/documents/${knowledgeBaseId}/documents`;
|
||||
setTableApi(rootUrl);
|
||||
|
||||
// 清除自动展开路径
|
||||
setAutoExpandPath([]);
|
||||
|
||||
// 刷新文件夹树(简单的刷新,不需要复杂的重置逻辑)
|
||||
setFolderTreeRefreshKey((prev) => prev + 1);
|
||||
|
||||
// 清除 state,避免重复处理
|
||||
navigate(location.pathname, { replace: true, state: {} });
|
||||
}
|
||||
|
||||
// 如果是从文档详情页返回,恢复文档文件夹路径
|
||||
if (state?.navigateToDocumentFolder && state?.documentFolderPath) {
|
||||
setFolderPath(state.documentFolderPath);
|
||||
setParentId(state.navigateToDocumentFolder);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: state.navigateToDocumentFolder
|
||||
});
|
||||
setQuery(prevQuery => ({
|
||||
...prevQuery,
|
||||
parent_id: state.navigateToDocumentFolder,
|
||||
_timestamp: Date.now()
|
||||
}));
|
||||
setTableApi(`/documents/${knowledgeBaseId}/documents`);
|
||||
setSelectedKeys([state.navigateToDocumentFolder]);
|
||||
|
||||
// 设置自动展开路径,让FolderTree自动展开到对应位置
|
||||
setAutoExpandPath(state.documentFolderPath);
|
||||
|
||||
// 手动触发表格刷新
|
||||
setTimeout(() => {
|
||||
tableRef.current?.loadData();
|
||||
}, 100);
|
||||
|
||||
// 清除自动展开路径,避免重复触发(延迟清除,确保FolderTree处理完成)
|
||||
setTimeout(() => {
|
||||
setAutoExpandPath([]);
|
||||
}, 2000);
|
||||
}
|
||||
}, [location.state, knowledgeBaseId, navigate, location.pathname]);
|
||||
|
||||
// 处理树节点选择
|
||||
const onSelect = (keys: React.Key[]) => {
|
||||
if (!keys.length) return;
|
||||
if (!keys.length) {
|
||||
// 如果没有选中任何节点,回到根目录(初始状态)
|
||||
setParentId(knowledgeBaseId);
|
||||
setFolder({
|
||||
kb_id: knowledgeBaseId ?? '',
|
||||
parent_id: knowledgeBaseId ?? ''
|
||||
});
|
||||
setQuery({
|
||||
orderby: 'created_at',
|
||||
desc: true,
|
||||
_timestamp: Date.now() // 添加时间戳确保query对象发生变化
|
||||
});
|
||||
setSelectedKeys([]);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!folder) return;
|
||||
|
||||
const f = {
|
||||
...folder,
|
||||
parent_id: String(keys[0]),
|
||||
}
|
||||
let url = `/documents/${knowledgeBaseId}/${String(keys[0])}/documents`;
|
||||
setQuery({
|
||||
...query,
|
||||
parent_id: String(keys[0]),
|
||||
_timestamp: Date.now() // 添加时间戳确保query对象发生变化
|
||||
})
|
||||
let url = `/documents/${knowledgeBaseId}/documents`;
|
||||
|
||||
setTableApi(url);
|
||||
setParentId(String(keys[0]))
|
||||
setFolder(f)
|
||||
@@ -253,6 +307,15 @@ const Private: FC = () => {
|
||||
datasetModalRef?.current?.handleOpen(knowledgeBase?.id,folder?.parent_id ?? knowledgeBase?.id ?? '');
|
||||
},
|
||||
},
|
||||
{
|
||||
key: '8',
|
||||
icon: <img src={blankIcon} alt="Custome Text" style={{ width: 16, height: 16 }} />,
|
||||
label: t('knowledgeBase.customTextDataset'),
|
||||
onClick: () => {
|
||||
createContentModalRef?.current?.handleOpen(knowledgeBase?.id ?? '', folder?.parent_id ?? knowledgeBase?.id ?? '');
|
||||
// handleCreate('folder'); // 传入 type: 'folder'
|
||||
},
|
||||
},
|
||||
// 暂时未实现
|
||||
// {
|
||||
// key: '3',
|
||||
@@ -413,6 +476,21 @@ const Private: FC = () => {
|
||||
state: {
|
||||
documentId: document.id,
|
||||
parentId: parentId ?? knowledgeBaseId,
|
||||
// 传递面包屑信息
|
||||
breadcrumbPath: {
|
||||
knowledgeBaseFolderPath,
|
||||
knowledgeBase: {
|
||||
id: knowledgeBase?.id || knowledgeBaseId,
|
||||
name: knowledgeBase?.name || '',
|
||||
type: 'knowledgeBase'
|
||||
},
|
||||
documentFolderPath: folderPath,
|
||||
document: {
|
||||
id: document.id,
|
||||
name: document.file_name || '',
|
||||
type: 'document'
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -486,7 +564,9 @@ const Private: FC = () => {
|
||||
}
|
||||
const refreshDirectoryTree = async () => {
|
||||
// 先刷新知识库详情,确保数据是最新的
|
||||
await fetchKnowledgeBaseDetail(knowledgeBase.id);
|
||||
if (knowledgeBase?.id) {
|
||||
await fetchKnowledgeBaseDetail(knowledgeBase.id);
|
||||
}
|
||||
// 添加短暂延迟,确保后端数据已经完全更新
|
||||
await new Promise(resolve => setTimeout(resolve, 300));
|
||||
// 然后刷新文件夹树
|
||||
@@ -501,6 +581,7 @@ const Private: FC = () => {
|
||||
}
|
||||
const handleRootTreeLoad = (nodes: TreeNodeData[] | null) => {
|
||||
if (!nodes || nodes.length === 0) {
|
||||
// 如果没有节点,设置folder为null(这会隐藏FolderTree)
|
||||
setFolder(null);
|
||||
} else {
|
||||
// 如果有节点且 folder 为 null,重新设置 folder
|
||||
@@ -524,6 +605,7 @@ const Private: FC = () => {
|
||||
}
|
||||
|
||||
const handleRefreshTable = () => {
|
||||
debugger
|
||||
// 刷新表格数据
|
||||
tableRef.current?.loadData();
|
||||
}
|
||||
@@ -545,6 +627,7 @@ const Private: FC = () => {
|
||||
onRootLoad={handleRootTreeLoad}
|
||||
onFolderPathChange={handleFolderPathChange}
|
||||
selectedKeys={selectedKeys}
|
||||
autoExpandPath={autoExpandPath}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
@@ -601,6 +684,10 @@ const Private: FC = () => {
|
||||
ref={createFolderModalRef}
|
||||
refreshTable={refreshDirectoryTree}
|
||||
/>
|
||||
<CreateContentModal
|
||||
ref={createContentModalRef}
|
||||
refreshTable={handleRefreshTable}
|
||||
/>
|
||||
<CreateModal
|
||||
ref={modalRef}
|
||||
refreshTable={handleRefreshTable}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useEffect, useState, useRef, type FC } from 'react';
|
||||
import { useParams } from 'react-router-dom';
|
||||
import { useParams, useLocation } from 'react-router-dom';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { KnowledgeBaseListItem, RecallTestDrawerRef } from '@/views/KnowledgeBase/types';
|
||||
import RecallTest from '../components/RecallTest';
|
||||
@@ -15,17 +15,22 @@ import kbModelIcon from '@/assets/images/knowledgeBase/kb-model.png';
|
||||
import kbHistoryIcon from '@/assets/images/knowledgeBase/kb-history.png';
|
||||
import { getKnowledgeBaseDetail } from '@/api/knowledgeBase';
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import { useMenu } from '@/store/menu';
|
||||
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
|
||||
|
||||
const Share: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const params = useParams<{ knowledgeBaseId: string }>();
|
||||
const location = useLocation();
|
||||
const knowledgeBaseId = params.knowledgeBaseId;
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseListItem | null>(null);
|
||||
const recallTestRef = useRef<RecallTestDrawerRef>(null);
|
||||
const [infoItems, setInfoItems] = useState<InfoItem[]>([]);
|
||||
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
|
||||
const [knowledgeBaseFolderPath, setKnowledgeBaseFolderPath] = useState<BreadcrumbItem[]>([]);
|
||||
|
||||
const { updateBreadcrumbs } = useBreadcrumbManager({
|
||||
breadcrumbType: 'detail'
|
||||
});
|
||||
useEffect(() => {
|
||||
console.log('Share.tsx - useParams result:', params);
|
||||
console.log('Share.tsx - knowledgeBaseId:', knowledgeBaseId);
|
||||
@@ -46,9 +51,30 @@ const Share: FC = () => {
|
||||
// 更新面包屑
|
||||
useEffect(() => {
|
||||
if (knowledgeBase) {
|
||||
updateBreadcrumbs();
|
||||
updateBreadcrumbs({
|
||||
knowledgeBaseFolderPath,
|
||||
knowledgeBase: {
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
type: 'knowledgeBase'
|
||||
},
|
||||
documentFolderPath: [],
|
||||
});
|
||||
}
|
||||
}, [knowledgeBase]);
|
||||
}, [knowledgeBase, knowledgeBaseFolderPath, updateBreadcrumbs]);
|
||||
|
||||
// 监听 location state 变化
|
||||
useEffect(() => {
|
||||
const state = location.state as {
|
||||
fromKnowledgeBaseList?: boolean;
|
||||
knowledgeBaseFolderPath?: BreadcrumbItem[];
|
||||
} | null;
|
||||
|
||||
// 如果是从知识库列表页跳转过来的,设置知识库文件夹路径
|
||||
if (state?.fromKnowledgeBaseList && state?.knowledgeBaseFolderPath) {
|
||||
setKnowledgeBaseFolderPath(state.knowledgeBaseFolderPath);
|
||||
}
|
||||
}, [location.state]);
|
||||
const formatInfoItems = (data: KnowledgeBaseListItem): InfoItem[] => {
|
||||
const items: InfoItem[] = [
|
||||
{
|
||||
@@ -112,46 +138,7 @@ const Share: FC = () => {
|
||||
});
|
||||
};
|
||||
|
||||
// 更新面包屑,包含知识库名称
|
||||
const updateBreadcrumbs = () => {
|
||||
if (!knowledgeBase) return;
|
||||
|
||||
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
|
||||
// 只保留知识库菜单项之前的面包屑
|
||||
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
|
||||
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
|
||||
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
|
||||
: baseBreadcrumbs;
|
||||
|
||||
const customBreadcrumbs = [
|
||||
...filteredBaseBreadcrumbs,
|
||||
{
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: knowledgeBase.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
},
|
||||
];
|
||||
|
||||
setCustomBreadcrumbs(customBreadcrumbs, 'space');
|
||||
};
|
||||
|
||||
// const handleBack = () => {
|
||||
// navigate('/knowledge-base');
|
||||
|
||||
117
web/src/views/KnowledgeBase/components/CreateContentModal.tsx
Normal file
117
web/src/views/KnowledgeBase/components/CreateContentModal.tsx
Normal file
@@ -0,0 +1,117 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
|
||||
import { Form, Input } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import RbModal from '@/components/RbModal';
|
||||
import { createDocumentAndUpload } from '@/api/knowledgeBase'
|
||||
import type { CreateSetModalRef,CreateSetMoealRefProps } from '../types'
|
||||
interface ContentFormData {
|
||||
title: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
const CreateContentModal = forwardRef<CreateSetModalRef, CreateSetMoealRefProps>(
|
||||
({ refreshTable }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [form] = Form.useForm<ContentFormData>();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [kbId, setKbId] = useState<string>('');
|
||||
const [parentId, setParentId] = useState<string>('');
|
||||
|
||||
const handleClose = () => {
|
||||
form.resetFields();
|
||||
setLoading(false);
|
||||
setVisible(false);
|
||||
setKbId('');
|
||||
setParentId('');
|
||||
};
|
||||
|
||||
const handleOpen = (kb_id: string, parent_id: string) => {
|
||||
setKbId(kb_id);
|
||||
setParentId(parent_id);
|
||||
form.resetFields();
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
const handleSave = async () => {
|
||||
try {
|
||||
const values = await form.validateFields();
|
||||
setLoading(true);
|
||||
|
||||
// TODO: 这里需要调用相应的API来保存内容
|
||||
const params = {
|
||||
// ...values,
|
||||
kb_id: kbId,
|
||||
parent_id: parentId,
|
||||
};
|
||||
|
||||
|
||||
const response = await createDocumentAndUpload(values, params)
|
||||
if(response){
|
||||
handleChunking(response.kb_id,parentId,response.id)
|
||||
}
|
||||
handleClose();
|
||||
} catch (err) {
|
||||
console.error('创建内容失败:', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
const handleChunking = (kb_id: string, parent_id: string, file_id: string) => {
|
||||
if (!kb_id) return;
|
||||
const targetFileId = file_id
|
||||
navigate(`/knowledge-base/${kb_id}/create-dataset`, {
|
||||
state: {
|
||||
source: 'local',
|
||||
knowledgeBaseId: kb_id,
|
||||
parentId: parent_id ?? kb_id,
|
||||
startStep: 'parameterSettings',
|
||||
fileId: targetFileId,
|
||||
},
|
||||
});
|
||||
}
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={t('knowledgeBase.createContent')}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText={t('common.create')}
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
width={600}
|
||||
>
|
||||
<Form form={form} layout="vertical">
|
||||
<Form.Item
|
||||
name="title"
|
||||
label={t('knowledgeBase.title')}
|
||||
rules={[{ required: true, message: t('knowledgeBase.pleaseEnterTitle') }]}
|
||||
>
|
||||
<Input placeholder={t('knowledgeBase.pleaseEnterTitle')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="content"
|
||||
label={t('knowledgeBase.content')}
|
||||
rules={[{ required: true, message: t('knowledgeBase.pleaseEnterContent') }]}
|
||||
>
|
||||
<Input.TextArea
|
||||
placeholder={t('knowledgeBase.pleaseEnterContent')}
|
||||
rows={8}
|
||||
showCount
|
||||
maxLength={5000}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
export default CreateContentModal;
|
||||
@@ -0,0 +1,34 @@
|
||||
import { useRef } from 'react';
|
||||
import { Button } from 'antd';
|
||||
import CreateContentModal from './CreateContentModal';
|
||||
import type { CreateContentModalRef } from '../types';
|
||||
|
||||
// 使用示例组件
|
||||
const CreateContentModalExample = () => {
|
||||
const createContentModalRef = useRef<CreateContentModalRef>(null);
|
||||
|
||||
const handleOpenModal = () => {
|
||||
// 打开弹窗,传入知识库ID和父级ID
|
||||
createContentModalRef.current?.handleOpen('kb_123', 'parent_456');
|
||||
};
|
||||
|
||||
const handleRefreshTable = () => {
|
||||
console.log('刷新表格数据');
|
||||
// 这里可以添加刷新表格的逻辑
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Button type="primary" onClick={handleOpenModal}>
|
||||
创建内容
|
||||
</Button>
|
||||
|
||||
<CreateContentModal
|
||||
ref={createContentModalRef}
|
||||
refreshTable={handleRefreshTable}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default CreateContentModalExample;
|
||||
@@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
|
||||
import { Form, Input } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UploadFile } from 'antd';
|
||||
import type { CreateImageModalRef, CreateImageMoealRefProps,UploadFileResponse } from '@/views/KnowledgeBase/types';
|
||||
import type { CreateSetModalRef, CreateSetMoealRefProps, UploadFileResponse } from '@/views/KnowledgeBase/types';
|
||||
import type { UploadRequestOption } from 'rc-upload/lib/interface';
|
||||
import RbModal from '@/components/RbModal';
|
||||
import UploadFiles from '@/components/Upload/UploadFiles';
|
||||
@@ -13,7 +13,7 @@ interface ImageDatasetFormData {
|
||||
images: UploadFile[];
|
||||
}
|
||||
|
||||
const CreateImageDataset = forwardRef<CreateImageModalRef, CreateImageMoealRefProps>(
|
||||
const CreateImageDataset = forwardRef<CreateSetModalRef, CreateSetMoealRefProps>(
|
||||
({ refreshTable }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const [visible, setVisible] = useState(false);
|
||||
|
||||
@@ -60,6 +60,8 @@ interface FolderTreeProps {
|
||||
onRootLoad?: (nodes: TreeNodeData[] | null) => void;
|
||||
onFolderPathChange?: (path: Array<{ id: string; name: string }>) => void;
|
||||
selectedKeys?: React.Key[];
|
||||
// 新增:自动展开到指定路径
|
||||
autoExpandPath?: Array<{ id: string; name: string }>;
|
||||
}
|
||||
|
||||
const renderIcon = (icon?: string) => {
|
||||
@@ -275,8 +277,11 @@ const FolderTree: FC<FolderTreeProps> = ({
|
||||
onRootLoad,
|
||||
onFolderPathChange,
|
||||
selectedKeys,
|
||||
autoExpandPath,
|
||||
}) => {
|
||||
const [treeData, setTreeData] = useState<TreeNodeData[]>([]);
|
||||
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
|
||||
const [autoExpandInProgress, setAutoExpandInProgress] = useState(false);
|
||||
|
||||
// 更新树节点数据的辅助函数
|
||||
const updateTreeData = (nodes: TreeNodeData[], key: Key, children: TreeNodeData[]): TreeNodeData[] => {
|
||||
@@ -370,6 +375,109 @@ const FolderTree: FC<FolderTreeProps> = ({
|
||||
return null;
|
||||
};
|
||||
|
||||
// 查找节点的辅助函数
|
||||
const findNodeInTree = (nodes: TreeNodeData[], key: string): TreeNodeData | null => {
|
||||
for (const node of nodes) {
|
||||
if (String(node.key) === key) {
|
||||
return node;
|
||||
}
|
||||
if (node.children) {
|
||||
const found = findNodeInTree(node.children, key);
|
||||
if (found) return found;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
// 渐进式自动展开到指定路径
|
||||
useEffect(() => {
|
||||
if (!autoExpandPath || autoExpandPath.length === 0 || autoExpandInProgress || treeData.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const expandToPath = async () => {
|
||||
setAutoExpandInProgress(true);
|
||||
|
||||
try {
|
||||
const keysToExpand: React.Key[] = [];
|
||||
let currentTreeData = treeData;
|
||||
|
||||
// 逐级展开,从第一级开始(跳过根节点,因为根节点已经加载)
|
||||
for (let i = 0; i < autoExpandPath.length - 1; i++) {
|
||||
const nodeKey = autoExpandPath[i].id;
|
||||
keysToExpand.push(nodeKey);
|
||||
|
||||
// 查找当前节点
|
||||
const targetNode = findNodeInTree(currentTreeData, nodeKey);
|
||||
|
||||
if (targetNode && targetNode.children === undefined) {
|
||||
// 如果子节点未加载,先加载
|
||||
try {
|
||||
console.log(`自动展开:加载节点 ${nodeKey} 的子节点`);
|
||||
const children = await buildTreeNodes(knowledgeBaseId, nodeKey);
|
||||
|
||||
// 更新树数据
|
||||
setTreeData((prevData) => {
|
||||
const newData = updateTreeData(prevData, nodeKey, children);
|
||||
currentTreeData = newData; // 更新当前引用
|
||||
return newData;
|
||||
});
|
||||
|
||||
// 等待状态更新完成
|
||||
await new Promise(resolve => setTimeout(resolve, 150));
|
||||
|
||||
} catch (error) {
|
||||
console.error(`自动展开时加载节点 ${nodeKey} 失败:`, error);
|
||||
// 加载失败时停止展开
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置展开的节点
|
||||
setExpandedKeys(keysToExpand);
|
||||
|
||||
// 选中最后一个节点(目标文件夹)
|
||||
const targetKey = autoExpandPath[autoExpandPath.length - 1]?.id;
|
||||
if (targetKey) {
|
||||
console.log(`自动展开:选中目标节点 ${targetKey}`);
|
||||
// 延迟选中,确保展开动画完成
|
||||
setTimeout(() => {
|
||||
if (onSelect) {
|
||||
onSelect([targetKey], {
|
||||
selected: true,
|
||||
selectedNodes: [],
|
||||
node: {} as any,
|
||||
event: 'select',
|
||||
nativeEvent: new MouseEvent('click')
|
||||
});
|
||||
}
|
||||
}, 200);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('自动展开路径失败:', error);
|
||||
} finally {
|
||||
// 延迟重置标志,确保展开过程完全完成
|
||||
setTimeout(() => {
|
||||
setAutoExpandInProgress(false);
|
||||
}, 500);
|
||||
}
|
||||
};
|
||||
|
||||
// 延迟执行,确保树数据已经加载完成
|
||||
const timer = setTimeout(expandToPath, 300);
|
||||
return () => clearTimeout(timer);
|
||||
}, [autoExpandPath, treeData.length, knowledgeBaseId, onSelect, autoExpandInProgress]);
|
||||
|
||||
// 处理展开事件
|
||||
const handleExpand: TreeProps['onExpand'] = (expandedKeys, info) => {
|
||||
setExpandedKeys(expandedKeys);
|
||||
if (onExpand) {
|
||||
onExpand(expandedKeys, info);
|
||||
}
|
||||
};
|
||||
|
||||
// 处理选择事件,计算并传递路径
|
||||
const handleSelect: TreeProps['onSelect'] = (selectedKeys, info) => {
|
||||
if (selectedKeys.length > 0) {
|
||||
@@ -391,11 +499,13 @@ const FolderTree: FC<FolderTreeProps> = ({
|
||||
|
||||
return (
|
||||
<DirectoryTree
|
||||
key={refreshKey} // 添加key确保refreshKey变化时重新渲染整个组件
|
||||
multiple={multiple}
|
||||
className={className}
|
||||
style={style}
|
||||
onSelect={handleSelect}
|
||||
onExpand={onExpand}
|
||||
onExpand={handleExpand}
|
||||
expandedKeys={expandedKeys}
|
||||
loadData={onLoadData}
|
||||
treeData={treeNodes}
|
||||
selectedKeys={selectedKeys}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useEffect, useState, useRef, useMemo, type FC } from 'react';
|
||||
import { useEffect, useState, useRef, useMemo, useCallback, type FC } from 'react';
|
||||
import { Row, Col, Button, Dropdown, Modal, message, Tooltip } from 'antd'
|
||||
import type { MenuProps } from 'antd';
|
||||
import { EllipsisOutlined } from '@ant-design/icons';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import { useNavigate, useLocation } from 'react-router-dom';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import clsx from 'clsx';
|
||||
@@ -18,7 +18,8 @@ import Empty from '@/components/Empty'
|
||||
import { getKnowledgeBaseList, getModelList, getModelTypeList, deleteKnowledgeBase, getKnowledgeBaseTypeList } from '@/api/knowledgeBase'
|
||||
const { confirm } = Modal;
|
||||
import InfiniteScroll from 'react-infinite-scroll-component';
|
||||
import { useMenu } from '@/store/menu';
|
||||
|
||||
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
|
||||
|
||||
type ModelMenuInfo = {
|
||||
menu: NonNullable<MenuProps['items']>;
|
||||
@@ -28,6 +29,7 @@ type ModelMenuInfo = {
|
||||
const KnowledgeBaseManagement: FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const navigate = useNavigate();
|
||||
const location = useLocation();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [data, setData] = useState<KnowledgeBaseListItem[]>([])
|
||||
const [page, setPage] = useState(1)
|
||||
@@ -42,10 +44,29 @@ const KnowledgeBaseManagement: FC = () => {
|
||||
const modelListCache = useRef<Record<string, string>>({});
|
||||
const modalRef = useRef<CreateModalRef>(null)
|
||||
const [messageApi, contextHolder] = message.useMessage();
|
||||
const processedStateRef = useRef<any>(null);
|
||||
|
||||
// 使用 menu store 管理面包屑
|
||||
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
|
||||
const [folderPath, setFolderPath] = useState<Array<{ id: string; name: string }>>([]);
|
||||
// 使用面包屑管理 Hook
|
||||
const { updateBreadcrumbs } = useBreadcrumbManager({
|
||||
breadcrumbType: 'list',
|
||||
onKnowledgeBaseMenuClick: useCallback(() => {
|
||||
// 返回根目录
|
||||
setFolderPath([]);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: undefined,
|
||||
}));
|
||||
}, []),
|
||||
onKnowledgeBaseFolderClick: useCallback((folderId: string, folderPath: Array<{ id: string; name: string }>) => {
|
||||
// 直接更新文件夹路径和查询状态
|
||||
setFolderPath(folderPath);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: folderId,
|
||||
}));
|
||||
}, [])
|
||||
});
|
||||
const [folderPath, setFolderPath] = useState<BreadcrumbItem[]>([]);
|
||||
|
||||
|
||||
// 生成下拉菜单项(根据当前 item)
|
||||
@@ -134,7 +155,7 @@ const KnowledgeBaseManagement: FC = () => {
|
||||
handleCreate(type);
|
||||
},
|
||||
}));
|
||||
}, [knowledgeBaseTypes, t, folderPath, query]);
|
||||
}, [knowledgeBaseTypes, t]);
|
||||
const typeToFieldKey = (type: string) => {
|
||||
const normalized = (type || '').toLowerCase();
|
||||
switch (normalized) {
|
||||
@@ -371,90 +392,72 @@ const KnowledgeBaseManagement: FC = () => {
|
||||
}));
|
||||
return;
|
||||
}
|
||||
|
||||
// 根据权限类型跳转到不同的详情页
|
||||
if (knowledgeBase.permission_id === 'Private' || knowledgeBase.permission_id === 'private') {
|
||||
navigate(`/knowledge-base/${knowledgeBase.id}/private`)
|
||||
// 跳转时传递当前的文件夹路径信息
|
||||
const navigationState = {
|
||||
fromKnowledgeBaseList: true,
|
||||
knowledgeBaseFolderPath: folderPath,
|
||||
parentId: query.parent_id,
|
||||
timestamp: Date.now(), // 添加时间戳确保每次跳转状态都不同
|
||||
};
|
||||
const targetPath = knowledgeBase.permission_id === 'Private' || knowledgeBase.permission_id === 'private'
|
||||
? `/knowledge-base/${knowledgeBase.id}/private`
|
||||
: `/knowledge-base/${knowledgeBase.id}/share`;
|
||||
|
||||
// 检查是否是相同路径跳转
|
||||
const currentPath = location.pathname;
|
||||
|
||||
if (currentPath === targetPath) {
|
||||
// 如果是相同路径,使用replace并强制刷新状态
|
||||
navigate(targetPath, {
|
||||
state: navigationState,
|
||||
replace: true
|
||||
});
|
||||
} else {
|
||||
navigate(`/knowledge-base/${knowledgeBase.id}/share`)
|
||||
// 不同路径,正常跳转
|
||||
navigate(targetPath, { state: navigationState });
|
||||
}
|
||||
}
|
||||
// 更新面包屑的函数
|
||||
const updateBreadcrumbs = () => {
|
||||
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
|
||||
// 只保留知识库菜单项之前的面包屑
|
||||
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
|
||||
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
|
||||
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
|
||||
: baseBreadcrumbs;
|
||||
|
||||
// 给"知识库管理"添加点击事件,返回根目录
|
||||
const breadcrumbsWithClick = filteredBaseBreadcrumbs.map((item) => {
|
||||
if (item.path === '/knowledge-base') {
|
||||
return {
|
||||
...item,
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 返回根目录
|
||||
setFolderPath([]);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: undefined,
|
||||
}));
|
||||
return false;
|
||||
},
|
||||
};
|
||||
}
|
||||
return item;
|
||||
});
|
||||
|
||||
const customBreadcrumbs = [
|
||||
...breadcrumbsWithClick,
|
||||
...folderPath.map((folder, index) => ({
|
||||
id: 0,
|
||||
parent: 0,
|
||||
code: null,
|
||||
label: folder.name,
|
||||
i18nKey: null,
|
||||
path: null,
|
||||
enable: true,
|
||||
display: true,
|
||||
level: 0,
|
||||
sort: 0,
|
||||
icon: null,
|
||||
iconActive: null,
|
||||
menuDesc: null,
|
||||
deleted: null,
|
||||
updateTime: 0,
|
||||
new_: null,
|
||||
keepAlive: false,
|
||||
master: null,
|
||||
disposable: false,
|
||||
appSystem: null,
|
||||
subs: [],
|
||||
onClick: (e?: React.MouseEvent) => {
|
||||
e?.preventDefault();
|
||||
e?.stopPropagation();
|
||||
// 点击文件夹,回到该文件夹层级
|
||||
const newFolderPath = folderPath.slice(0, index + 1);
|
||||
setFolderPath(newFolderPath);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: folder.id,
|
||||
}));
|
||||
return false;
|
||||
},
|
||||
})),
|
||||
];
|
||||
|
||||
setCustomBreadcrumbs(customBreadcrumbs, 'space');
|
||||
};
|
||||
|
||||
// 更新面包屑
|
||||
useEffect(() => {
|
||||
updateBreadcrumbs();
|
||||
}, [folderPath]);
|
||||
updateBreadcrumbs({
|
||||
knowledgeBaseFolderPath: folderPath,
|
||||
documentFolderPath: [],
|
||||
});
|
||||
}, [folderPath, updateBreadcrumbs]);
|
||||
|
||||
// 处理从详情页返回的导航
|
||||
useEffect(() => {
|
||||
const state = location.state as {
|
||||
navigateToFolder?: string;
|
||||
folderPath?: Array<{ id: string; name: string }>;
|
||||
resetToRoot?: boolean;
|
||||
} | null;
|
||||
|
||||
// 避免重复处理相同的状态
|
||||
if (state && state !== processedStateRef.current) {
|
||||
processedStateRef.current = state;
|
||||
|
||||
if (state.resetToRoot) {
|
||||
// 重置到根目录
|
||||
setFolderPath([]);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: undefined,
|
||||
}));
|
||||
} else if (state?.navigateToFolder && state?.folderPath) {
|
||||
// 恢复文件夹路径和查询状态
|
||||
setFolderPath(state.folderPath);
|
||||
setQuery((prev) => ({
|
||||
...prev,
|
||||
parent_id: state.navigateToFolder,
|
||||
}));
|
||||
}
|
||||
|
||||
// 不清除 state,避免干扰后续导航
|
||||
// 使用 processedStateRef 来避免重复处理相同的 state
|
||||
}
|
||||
}, [location.state, navigate]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchModelTypes();
|
||||
@@ -465,7 +468,7 @@ const KnowledgeBaseManagement: FC = () => {
|
||||
if (modelTypes.length) {
|
||||
fetchData(1, false);
|
||||
}
|
||||
}, [modelTypes, query])
|
||||
}, [modelTypes, query.parent_id, query.keywords, query.orderby, query.desc])
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -146,11 +146,19 @@ export interface CreateFolderModalRefProps{
|
||||
refreshTable?: () => void;
|
||||
}
|
||||
|
||||
//他建图片数据集
|
||||
export interface CreateImageModalRef{
|
||||
handleOpen: (kb_id:string,parent_id:string) => void;
|
||||
//创建图片数据集 / 创建自定义文本数据集
|
||||
export interface CreateSetModalRef{
|
||||
handleOpen: (kb_id:string, parent_id:string) => void;
|
||||
}
|
||||
export interface CreateImageMoealRefProps{
|
||||
export interface CreateSetMoealRefProps{
|
||||
refreshTable?: () => void;
|
||||
}
|
||||
|
||||
// 创建内容
|
||||
export interface CreateContentModalRef {
|
||||
handleOpen: (kb_id: string, parent_id: string) => void;
|
||||
}
|
||||
export interface CreateContentModalRefProps {
|
||||
refreshTable?: () => void;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user