合并 feature/20251219_yjp 分支到 web 分支

冲突解决策略:
- web/src/views/KnowledgeBase/ 文件夹下的所有冲突以 feature/20251219_yjp 分支为主
- 其他冲突(如 vite.config.ts)以 web 分支为主

主要更改:
- 保留了 feature 分支中的知识库相关功能和组件
- 保持了 web 分支的配置和其他功能
- 添加了自定义文本数据集创建功能
- 更新了知识库管理界面
This commit is contained in:
yujiangping
2025-12-17 15:25:20 +08:00
83 changed files with 8786 additions and 1098 deletions

3
.gitignore vendored
View File

@@ -20,7 +20,8 @@ examples/
.idea
# Temporary outputs
**/.DS_Store
app/core/memory/agent/.DS_Store
app/core/memory/src/utils/.DS_Store
time.log
celerybeat-schedule.db
search_results.json

View File

@@ -27,6 +27,7 @@ from . import (
release_share_controller,
public_share_controller,
multi_agent_controller,
workflow_controller,
)
# 创建管理端 API 路由器
@@ -56,5 +57,6 @@ manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,7 +1,6 @@
"""API Key 管理接口 - 基于 JWT 认证"""
import uuid
from typing import Optional
from datetime import datetime
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -14,6 +13,7 @@ from app.core.response_utils import success
from app.schemas import api_key_schema
from app.schemas.response_schema import ApiResponse
from app.services.api_key_service import ApiKeyService
from app.core.api_key_utils import timestamp_to_datetime
from app.core.logging_config import get_api_logger
from app.core.exceptions import (
BusinessException,
@@ -41,18 +41,14 @@ def create_api_key(
workspace_id = current_user.current_workspace_id
# 创建 API Key
api_key_obj, api_key = ApiKeyService.create_api_key(
api_key_obj = ApiKeyService.create_api_key(
db,
workspace_id=workspace_id,
user_id=current_user.id,
data=data
)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
return success(data=response_data, msg="API Key 创建成功")
except BusinessException:
@@ -223,13 +219,9 @@ def regenerate_api_key(
"""
try:
workspace_id = current_user.current_workspace_id
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
api_key_obj = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
# 返回包含明文 Key 的响应
response_data = api_key_schema.ApiKeyResponse(
**api_key_obj.__dict__,
api_key=api_key
)
response_data = api_key_schema.ApiKeyResponse.model_validate(api_key_obj)
logger.info("API Key 重新生成成功", extra={
"api_key_id": str(api_key_id),
@@ -283,8 +275,8 @@ def get_api_key_stats(
@cur_workspace_access_guard()
def get_api_key_logs(
api_key_id: uuid.UUID,
start_date: Optional[datetime] = Query(None, description="开始日期"),
end_date: Optional[datetime] = Query(None, description="结束日期"),
start_date: Optional[int] = Query(None, description="开始日期时间戳"),
end_date: Optional[int] = Query(None, description="结束日期时间戳"),
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
page: int = Query(1, ge=1, description="页码"),
@@ -302,14 +294,17 @@ def get_api_key_logs(
try:
workspace_id = current_user.current_workspace_id
start_datetime = timestamp_to_datetime(start_date) if start_date else None
end_datetime = timestamp_to_datetime(end_date) if end_date else None
# 验证日期范围
if start_date and end_date and start_date > end_date:
if start_datetime and end_datetime and start_datetime > end_datetime:
logger.warning("开始日期晚于结束日期", extra={
"api_key_id": str(api_key_id),
"workspace_id": str(workspace_id),
"user_id": str(current_user.id),
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat()
"start_date": start_datetime.isoformat(),
"end_date": end_datetime.isoformat()
})
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
@@ -325,8 +320,8 @@ def get_api_key_logs(
# 构建过滤条件
filters = {
"start_date": start_date,
"end_date": end_date,
"start_date": start_datetime,
"end_date": end_datetime,
"status_code": status_code,
"endpoint": endpoint
}

View File

@@ -1,22 +1,26 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends
from typing import Optional, Annotated
from fastapi import APIRouter, Depends, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
from app.models.app_model import AppType, App
from app.repositories import knowledge_repository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.app_service import AppService
from app.services.agent_config_helper import enrich_agent_config
from app.dependencies import get_current_user, cur_workspace_access_guard, workspace_access_guard
from fastapi.responses import StreamingResponse
from app.models.app_model import AppType
from app.core.error_codes import BizCode
from app.services.app_service import AppService
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -48,7 +52,7 @@ def list_apps(
current_user=Depends(get_current_user),
):
"""列出应用
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
"""
@@ -63,8 +67,8 @@ def list_apps(
include_shared=include_shared,
page=page,
pagesize=pagesize,
)
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
@@ -79,14 +83,14 @@ def get_app(
current_user=Depends(get_current_user),
):
"""获取应用详细信息
- 支持获取本工作空间的应用
- 支持获取分享给本工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
app = service.get_app(app_id, workspace_id)
# 转换为 Schema 并设置 is_shared 字段
app_schema_obj = service._convert_to_schema(app, workspace_id)
return success(data=app_schema_obj)
@@ -113,7 +117,7 @@ def delete_app(
current_user=Depends(get_current_user),
):
"""删除应用
会级联删除:
- Agent 配置
- 发布版本
@@ -128,9 +132,9 @@ def delete_app(
"workspace_id": str(workspace_id)
}
)
app_service.delete_app(db, app_id=app_id, workspace_id=workspace_id)
return success(msg="应用删除成功")
@@ -143,7 +147,7 @@ def copy_app(
current_user=Depends(get_current_user),
):
"""复制应用(包括基础信息和配置)
- 复制应用的基础信息(名称、描述、图标等)
- 复制 Agent 配置(如果是 agent 类型)
- 新应用默认为草稿状态
@@ -159,7 +163,7 @@ def copy_app(
"new_name": new_name
}
)
service = AppService(db)
new_app = service.copy_app(
app_id=app_id,
@@ -167,7 +171,7 @@ def copy_app(
workspace_id=workspace_id,
new_name=new_name
)
return success(data=app_schema.App.model_validate(new_app), msg="应用复制成功")
@@ -209,9 +213,9 @@ def publish_app(
):
workspace_id = current_user.current_workspace_id
release = app_service.publish(
db,
app_id=app_id,
publisher_id=current_user.id,
db,
app_id=app_id,
publisher_id=current_user.id,
workspace_id=workspace_id,
version_name = payload.version_name,
release_notes=payload.release_notes
@@ -268,13 +272,13 @@ def share_app(
current_user=Depends(get_current_user),
):
"""分享应用到其他工作空间
- 只能分享自己工作空间的应用
- 不能分享到自己的工作空间
- 同一个应用不能重复分享到同一个工作空间
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.share_app(
app_id=app_id,
@@ -282,7 +286,7 @@ def share_app(
user_id=current_user.id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data, msg=f"应用已分享到 {len(shares)} 个工作空间")
@@ -296,18 +300,18 @@ def unshare_app(
current_user=Depends(get_current_user),
):
"""取消应用分享
- 只能取消自己工作空间应用的分享
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.unshare_app(
app_id=app_id,
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg="应用分享已取消")
@@ -319,17 +323,17 @@ def list_app_shares(
current_user=Depends(get_current_user),
):
"""列出应用的所有分享记录
- 只能查看自己工作空间应用的分享记录
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_app_shares(
app_id=app_id,
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@@ -340,10 +344,11 @@ async def draft_run(
payload: app_schema.DraftRunRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
workflow_service: Annotated[WorkflowService, Depends(get_workflow_service)] = None
):
"""
试运行 Agent使用当前的草稿配置未发布的配置
- 不需要发布应用即可测试
- 使用当前的 AgentConfig 配置
- 支持流式和非流式返回
@@ -367,33 +372,44 @@ async def draft_run(
)
if knowledge: user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService
service = AppService(db)
draft_service = DraftRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT:
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
if app.type != AppType.AGENT and app.type != AppType.MULTI_AGENT and app.type != AppType.WORKFLOW:
raise BusinessException("只有 Agent , Workflow 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
service._validate_app_accessible(app, workspace_id)
# 处理会话ID创建或验证
conversation_id = await draft_service._ensure_conversation(
conversation_id=payload.conversation_id,
app_id=app_id,
workspace_id=workspace_id,
user_id=payload.user_id
)
payload.conversation_id = conversation_id
if app.type == AppType.AGENT:
service._check_agent_config(app_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
@@ -401,12 +417,12 @@ async def draft_run(
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
@@ -419,7 +435,7 @@ async def draft_run(
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -429,7 +445,7 @@ async def draft_run(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
logger.debug(
"开始非流式试运行",
@@ -440,7 +456,7 @@ async def draft_run(
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
@@ -454,7 +470,7 @@ async def draft_run(
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
"试运行返回结果",
extra={
@@ -462,7 +478,7 @@ async def draft_run(
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
}
)
# 验证结果
try:
validated_result = app_schema.DraftRunResponse.model_validate(result)
@@ -481,10 +497,10 @@ async def draft_run(
elif app.type == AppType.MULTI_AGENT:
# 1. 检查多智能体配置完整性
service._check_multi_agent_config(app_id)
# 2. 构建多智能体运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
@@ -492,7 +508,7 @@ async def draft_run(
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
# 3. 流式返回
if payload.stream:
logger.debug(
@@ -503,11 +519,11 @@ async def draft_run(
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
@@ -517,7 +533,7 @@ async def draft_run(
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -527,7 +543,7 @@ async def draft_run(
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
"开始多智能体非流式试运行",
@@ -537,10 +553,10 @@ async def draft_run(
"has_conversation_id": bool(payload.conversation_id)
}
)
multiservice = MultiAgentService(db)
result = await multiservice.run(app_id, multi_agent_request)
logger.debug(
"多智能体试运行返回结果",
extra={
@@ -548,12 +564,71 @@ async def draft_run(
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="多 Agent 任务执行成功"
)
elif app.type == AppType.WORKFLOW: #工作流
config = workflow_service.check_config(app_id)
# 3. 流式返回
if payload.stream:
logger.debug(
"开始多智能体流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
async def event_generator():
"""多智能体流式事件生成器"""
multiservice = MultiAgentService(db)
# 调用多智能体服务的流式方法
async for event in multiservice.run_stream(
app_id=app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 4. 非流式返回
logger.debug(
"开始非流式试运行",
extra={
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id)
}
)
result = await workflow_service.run(app_id, payload,config)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
@@ -567,21 +642,21 @@ async def draft_run_compare(
):
"""
多模型对比试运行
- 支持对比 1-5 个模型
- 可以是不同的模型,也可以是同一模型的不同参数配置
- 通过 model_parameters 覆盖默认参数
- 支持并行或串行执行(非流式)
- 支持流式返回(串行执行)
- 返回每个模型的运行结果和性能对比
使用场景:
1. 对比不同模型的效果GPT-4 vs Claude vs Gemini
2. 调优模型参数(不同 temperature 的效果对比)
3. 性能和成本分析
"""
workspace_id = current_user.current_workspace_id
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -597,7 +672,7 @@ async def draft_run_compare(
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
logger.info(
"多模型对比试运行",
extra={
@@ -607,13 +682,13 @@ async def draft_run_compare(
"stream": payload.stream
}
)
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.models import ModelConfig
service = AppService(db)
# 1. 验证应用和权限
app = service._get_app_or_404(app_id)
if app.type != "agent":
@@ -621,7 +696,7 @@ async def draft_run_compare(
from app.core.error_codes import BizCode
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
@@ -631,7 +706,7 @@ async def draft_run_compare(
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 验证所有模型配置
model_configs = []
for model_item in payload.models:
@@ -639,12 +714,12 @@ async def draft_run_compare(
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
@@ -652,7 +727,7 @@ async def draft_run_compare(
"model_config_id": model_item.model_config_id,
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 流式返回
if payload.stream:
async def event_generator():
@@ -674,7 +749,7 @@ async def draft_run_compare(
timeout=payload.timeout or 60
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
@@ -684,7 +759,7 @@ async def draft_run_compare(
"X-Accel-Buffering": "no"
}
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
@@ -703,7 +778,7 @@ async def draft_run_compare(
parallel=payload.parallel,
timeout=payload.timeout or 60
)
logger.info(
"多模型对比完成",
extra={
@@ -712,5 +787,36 @@ async def draft_run_compare(
"failed": result["failed_count"]
}
)
return success(data=app_schema.DraftRunCompareResponse(**result))
@router.get("/{app_id}/workflow")
@cur_workspace_access_guard()
async def get_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""获取工作流配置
获取应用的工作流配置详情。
"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
# 配置总是存在(不存在时返回默认模板)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.put("/{app_id}/workflow", summary="更新 Workflow 配置")
@cur_workspace_access_guard()
async def update_workflow_config(
app_id: uuid.UUID,
payload: WorkflowConfigUpdate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))

View File

@@ -29,10 +29,10 @@ router = APIRouter(
)
@router.get("/{kb_id}/{parent_id}/documents", response_model=ApiResponse)
@router.get("/{kb_id}/documents", response_model=ApiResponse)
async def get_documents(
kb_id: uuid.UUID,
parent_id: uuid.UUID,
parent_id: Optional[uuid.UUID] = Query(None, description="parent folder id when type is Folder"),
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: created_at,updated_at"),

View File

@@ -1,8 +1,9 @@
from typing import Optional
from typing import Optional, Union
import os
import uuid
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, Query, UploadFile
from fastapi import APIRouter, Depends, UploadFile
from fastapi.responses import StreamingResponse
from app.db import get_db
@@ -322,36 +323,24 @@ def read_all_config(
return fail(BizCode.INTERNAL_ERROR, "查询所有配置失败", str(e))
@router.post("/pilot_run", response_model=ApiResponse) # 试运行:触发执行主管线,使用 POST 更为合理
@router.post("/pilot_run", response_model=None)
async def pilot_run(
payload: ConfigPilotRun,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> StreamingResponse:
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
# 先尝试从数据库加载配置
try:
config_loaded = reload_configuration_from_database(str(payload.config_id))
if not config_loaded:
api_logger.error(f"Failed to load configuration for config_id: {payload.config_id}")
return fail(BizCode.INTERNAL_ERROR, "配置加载失败", f"无法加载 config_id={payload.config_id} 的配置")
api_logger.info(f"Configuration loaded successfully for config_id: {payload.config_id}")
except Exception as e:
api_logger.error(f"Exception while loading configuration: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
try:
svc = DataConfigService(db)
result = await svc.pilot_run(payload)
return success(data=result, msg="试运行完成")
except ValueError as e:
# 捕获参数验证错误
api_logger.error(f"Pilot run parameter validation failed: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, "参数验证失败", str(e))
except Exception as e:
api_logger.error(f"Pilot run failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "试运行失败", str(e))
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。

View File

@@ -1,10 +1,13 @@
"""App 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends
import uuid
from fastapi import APIRouter, Depends, Request, Body
from sqlalchemy.orm import Session
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.api_key_auth import require_api_key
from app.schemas.api_key_schema import ApiKeyAuth
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
logger = get_business_logger()
@@ -14,3 +17,30 @@ logger = get_business_logger()
async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")
# /v1/apps/{resource_id}/chat
@router.post("/{resource_id}/chat")
@require_api_key(scopes=["app"])
async def chat_with_agent_demo(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="聊天消息内容"),
):
"""
Agent 聊天接口demo
scopes: 所需的权限范围列表["app", "rag", "memory"]
Args:
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
message: 请求参数
request: 声明请求
api_key_auth: 包含验证后的API Key 信息
db: db_session
"""
logger.info(f"API Key Auth: {api_key_auth}")
logger.info(f"Resource ID: {resource_id}")
logger.info(f"Message: {message}")
return success(data={"received": True}, msg="消息已接收")

View File

@@ -0,0 +1,587 @@
"""
工作流 API 控制器
"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Path, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.models.app_model import App
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.schemas.workflow_schema import (
WorkflowConfigCreate,
WorkflowConfigUpdate,
WorkflowConfig,
WorkflowValidationResponse,
WorkflowExecution,
WorkflowNodeExecution,
WorkflowExecutionRequest,
WorkflowExecutionResponse
)
from app.core.response_utils import success, fail
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/apps", tags=["workflow"])
# ==================== 工作流配置管理 ====================
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 创建工作流配置
workflow_config = service.create_workflow_config(
app_id=app_id,
nodes=[node.model_dump() for node in config.nodes],
edges=[edge.model_dump() for edge in config.edges],
variables=[var.model_dump() for var in config.variables],
execution_config=config.execution_config.model_dump(),
triggers=[trigger.model_dump() for trigger in config.triggers],
validate=True # 进行基础验证
)
return success(
data=WorkflowConfig.model_validate(workflow_config),
msg="工作流配置创建成功"
)
except BusinessException as e:
logger.warning(f"创建工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)]
#
# ):
# """获取工作流配置
#
# 获取应用的工作流配置详情。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
#
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
#
# # 获取工作流配置
# service = WorkflowService(db)
# workflow_config = service.get_workflow_config(app_id)
#
# if not workflow_config:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="工作流配置不存在"
# )
#
# return success(
# data=WorkflowConfig.model_validate(workflow_config)
# )
#
# except Exception as e:
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"获取工作流配置失败: {str(e)}"
# )
# @router.put("/{app_id}/workflow")
# async def update_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# config: WorkflowConfigUpdate,
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)],
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
# ):
# """更新工作流配置
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
# # 更新工作流配置
# workflow_config = service.update_workflow_config(
# app_id=app_id,
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
# validate=True
# )
# return success(
# data=WorkflowConfig.model_validate(workflow_config),
# msg="工作流配置更新成功"
# )
# except BusinessException as e:
# logger.warning(f"更新工作流配置失败: {e.message}")
# return fail(code=e.error_code, msg=e.message)
# except Exception as e:
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"更新工作流配置失败: {str(e)}"
# )
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
删除应用的工作流配置。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 删除工作流配置
deleted = service.delete_workflow_config(app_id)
if not deleted:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
return success(msg="工作流配置删除成功")
except Exception as e:
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"删除工作流配置失败: {str(e)}"
)
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证工作流配置
if for_publish:
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
else:
workflow_config = service.get_workflow_config(app_id)
if not workflow_config:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
from app.core.workflow.validator import validate_workflow_config as validate_config
config_dict = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
"execution_config": workflow_config.execution_config,
"triggers": workflow_config.triggers
}
is_valid, errors = validate_config(config_dict, for_publish=False)
return success(
data=WorkflowValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=[]
)
)
except BusinessException as e:
logger.warning(f"验证工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"验证工作流配置失败: {str(e)}"
)
# ==================== 工作流执行管理 ====================
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
获取应用的工作流执行历史记录。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 获取执行记录
executions = service.get_executions_by_app(app_id, limit, offset)
# 获取统计信息
statistics = service.get_execution_statistics(app_id)
return success(
data={
"executions": [WorkflowExecution.model_validate(e) for e in executions],
"statistics": statistics,
"pagination": {
"limit": limit,
"offset": offset,
"total": statistics["total"]
}
}
)
except Exception as e:
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行记录失败: {str(e)}"
)
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
获取单个工作流执行的详细信息,包括所有节点的执行记录。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 获取节点执行记录
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
return success(
data={
"execution": WorkflowExecution.model_validate(execution),
"node_executions": [
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
]
}
)
except Exception as e:
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行详情失败: {str(e)}"
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
执行工作流并返回结果。支持流式和非流式两种模式。
**非流式模式**:等待工作流执行完成后返回完整结果。
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 准备输入数据
input_data = {
"message": request.message or "",
"variables": request.variables
}
# 执行工作流
if request.stream:
# 流式执行
from fastapi.responses import StreamingResponse
import json
async def event_generator():
"""生成 SSE 事件"""
try:
async for event in service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 转换为 SSE 格式
yield f"data: {json.dumps(event)}\n\n"
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
error_event = {
"type": "error",
"error": str(e)
}
yield f"data: {json.dumps(error_event)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream"
)
else:
# 非流式执行
result = await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=False
)
return success(
data=WorkflowExecutionResponse(
execution_id=result["execution_id"],
status=result["status"],
output=result.get("output"),
output_data=result.get("output_data"),
error_message=result.get("error_message"),
elapsed_time=result.get("elapsed_time"),
token_usage=result.get("token_usage")
),
msg="工作流执行完成"
)
except BusinessException as e:
logger.warning(f"执行工作流失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"执行工作流异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"执行工作流失败: {str(e)}"
)
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
取消正在运行的工作流执行。
**注意**:当前版本仅更新状态为 cancelled实际的执行取消功能待实现。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 检查执行状态
if execution.status not in ["pending", "running"]:
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"无法取消状态为 {execution.status} 的执行"
)
# 更新状态为 cancelled
service.update_execution_status(execution_id, "cancelled")
return success(msg="工作流执行已取消")
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"取消工作流执行失败: {str(e)}"
)

View File

@@ -1,10 +1,12 @@
import asyncio
import time
import uuid
from functools import wraps
from typing import Optional, List
from datetime import datetime
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from app.core.api_key_utils import add_rate_limit_headers
@@ -22,21 +24,17 @@ logger = get_api_logger()
def require_api_key(
scopes: Optional[List[str]] = None,
resource_type: Optional[str] = None
scopes: Optional[List[str]] = None
):
"""
API Key 鉴权装饰器
Args:
scopes: 所需的权限范围列表["app:all",
"rag:search", "rag:upload", "rag:delete",
"memory:read", "memory:write", "memory:delete", "memory:search"]
resource_type: 所需的资源类型("Agent", "Cluster", "Workflow", "Knowledge", "Memory_Engine")
scopes: 所需的权限范围列表[“app”, "rag", "memory"]
Usage:
@router.get("/app/{resource_id}/chat")
@require_api_key(scopes=["app:all"], resource_type="Agent")
@require_api_key(scopes=["app"])
def chat_with_app(
resource_id: uuid.UUID,
api_key_auth: ApiKeyAuth = Depends(),
@@ -113,31 +111,25 @@ def require_api_key(
context={"required_scopes": scopes, "missing_scopes": missing_scopes}
)
if resource_type:
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_type,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_type": resource_type,
resource_id = kwargs.get("resource_id")
if resource_id and not ApiKeyAuthService.check_resource(
api_key_obj,
resource_id
):
logger.warning("API Key 资源访问被拒绝", extra={
"api_key_id": str(api_key_obj.id),
"required_resource_id": str(resource_id),
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None,
"endpoint": str(request.url)
})
return BusinessException(
"API Key 未授权访问该资源",
BizCode.API_KEY_INVALID_RESOURCE,
context={
"required_resource_type": resource_type,
"required_resource_id": str(resource_id),
"bound_resource_type": api_key_obj.resource_type,
"bound_resource_id": str(api_key_obj.resource_id) if api_key_obj.resource_id else None
}
)
"bound_resource_id": str(api_key_obj.resource_id)
}
)
kwargs["api_key_auth"] = ApiKeyAuth(
api_key_id=api_key_obj.id,
@@ -145,14 +137,17 @@ def require_api_key(
type=api_key_obj.type,
scopes=api_key_obj.scopes,
resource_id=api_key_obj.resource_id,
resource_type=api_key_obj.resource_type
)
start_time = time.perf_counter()
response = await func(*args, **kwargs)
end_time = time.perf_counter()
response_time = (end_time - start_time) * 1000
if not isinstance(response, Response):
response = JSONResponse(content=response)
response = add_rate_limit_headers(response, rate_headers)
asyncio.create_task(log_api_key_usage(
db, api_key_obj.id, request, response
db, api_key_obj.id, request, response, response_time
))
return response
@@ -204,7 +199,8 @@ async def log_api_key_usage(
db: Session,
api_key_id: uuid.UUID,
request: Request,
response: Response
response: Response,
response_time: float
):
"""记录 API Key 使用日志"""
try:
@@ -216,8 +212,8 @@ async def log_api_key_usage(
"ip_address": request.client.host if request.client else None,
"user_agent": request.headers.get("User-Agent"),
"status_code": response.status_code if hasattr(response, "status_code") else None,
"response_time": None, # 需要在 middleware 中计算
"tokens_used": None, # 需要从响应中提取
"response_time": round(response_time),
"tokens_used": None,
"created_at": datetime.now()
}

View File

@@ -1,33 +1,14 @@
"""API Key 工具函数"""
import secrets
import hashlib
from typing import Optional
from typing import Optional, Union
from datetime import datetime
from app.schemas.api_key_schema import ApiKeyType
from fastapi import Response
from fastapi.responses import JSONResponse
class ResourceType:
"""资源类型常量"""
AGENT = "Agent"
CLUSTER = "Cluster"
WORKFLOW = "Workflow"
KNOWLEDGE = "Knowledge"
MEMORY_ENGINE = "Memory_Engine"
@classmethod
def get_all_types(cls) -> list[str]:
"""获取所有支持的资源类型"""
return [cls.AGENT, cls.CLUSTER, cls.WORKFLOW, cls.KNOWLEDGE, cls.MEMORY_ENGINE]
@classmethod
def is_valid_type(cls, resource_type: str) -> bool:
"""验证资源类型是否有效"""
return resource_type in cls.get_all_types()
def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
def generate_api_key(key_type: ApiKeyType) -> str:
"""
生成 API Key
@@ -39,102 +20,17 @@ def generate_api_key(key_type: ApiKeyType) -> tuple[str, str, str]:
"""
# 前缀映射
prefix_map = {
ApiKeyType.APP: "sk-app-",
ApiKeyType.RAG: "sk-rag-",
ApiKeyType.MEMORY: "sk-mem-",
ApiKeyType.AGENT: "sk-agent-",
ApiKeyType.CLUSTER: "sk-multi_agent-",
ApiKeyType.WORKFLOW: "sk-workflow-",
ApiKeyType.SERVICE: "sk-service-"
}
prefix = prefix_map[key_type]
random_string = secrets.token_urlsafe(32)[:32] # 32 字符
api_key = f"{prefix}{random_string}"
# 生成哈希值存储
key_hash = hash_api_key(api_key)
return api_key, key_hash, prefix
def hash_api_key(api_key: str) -> str:
"""对 API Key 进行哈希
Args:
api_key: API Key 明文
Returns:
str: 哈希值
"""
return hashlib.sha256(api_key.encode()).hexdigest()
def verify_api_key(api_key: str, key_hash: str) -> bool:
"""
验证 API Key
Args:
api_key: API Key 明文
key_hash: 存储的哈希值
Returns:
bool: 是否匹配
"""
computed_hash = hash_api_key(api_key)
return secrets.compare_digest(computed_hash, key_hash)
def validate_resource_binding(
resource_type: Optional[str],
resource_id: Optional[str]
) -> tuple[bool, str]:
"""
验证资源绑定的有效性
Args:
resource_type: 资源类型
resource_id: 资源ID
Returns:
tuple: (是否有效, 错误信息)
"""
# 如果都为空,表示不绑定资源,这是有效的
if not resource_type and not resource_id:
return True, ""
# 如果只有一个为空,这是无效的
if not resource_type or not resource_id:
return False, "resource_type 和 resource_id 必须同时提供或同时为空"
# 验证资源类型是否支持
if not ResourceType.is_valid_type(resource_type):
valid_types = ", ".join(ResourceType.get_all_types())
return False, f"不支持的资源类型 '{resource_type}',支持的类型:{valid_types}"
return True, ""
def get_resource_scope_mapping() -> dict[str, list[str]]:
"""
获取资源类型与权限范围的映射关系
Returns:
dict: 资源类型到推荐权限范围的映射
"""
return {
ResourceType.AGENT: [
"app:all"
],
ResourceType.CLUSTER: [
"app:all"
],
ResourceType.WORKFLOW: [
"app:all"
],
ResourceType.KNOWLEDGE: [
"rag:search", "rag:upload", "rag:delete"
],
ResourceType.MEMORY_ENGINE: [
"memory:read", "memory:write", "memory:delete", "memory:search"
]
}
return api_key
def add_rate_limit_headers(response, headers: dict):
@@ -151,3 +47,21 @@ def add_rate_limit_headers(response, headers: dict):
return response
def timestamp_to_datetime(timestamp: Optional[Union[int, float]]) -> Optional[datetime]:
"""将时间戳转换为datetime对象"""
if timestamp is None:
return None
# 处理毫秒级时间戳
if timestamp > 1e10:
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
"""将datetime对象转换为时间戳毫秒"""
if dt is None:
return None
return int(dt.timestamp() * 1000)

View File

@@ -59,6 +59,7 @@ class BizCode(IntEnum):
EMBED_NOT_ALLOWED = 6009
PERMISSION_DENIED = 6010
INVALID_CONVERSATION = 6011
CONFIG_MISSING = 6012
# 模型7xxx
MODEL_CONFIG_INVALID = 7001
@@ -96,7 +97,7 @@ HTTP_MAPPING = {
BizCode.TOKEN_INVALID: 401,
BizCode.TOKEN_EXPIRED: 401,
BizCode.TOKEN_BLACKLISTED: 401,
BizCode.FORBIDDEN: 403,
BizCode.FORBIDDEN: 403,
BizCode.TENANT_NOT_FOUND: 404,
BizCode.WORKSPACE_NO_ACCESS: 403,
BizCode.NOT_FOUND: 404,
@@ -151,4 +152,4 @@ HTTP_MAPPING = {
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,
BizCode.RATE_LIMITED: 429,
}
}

View File

@@ -106,6 +106,8 @@ async def write(content: str, user_id: str, apply_id: str, group_id: str, ref_id
all_statement_chunk_edges,
all_statement_entity_edges,
all_entity_entity_edges,
all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
log_time("Extraction Pipeline", time.time() - step_start, log_file)

View File

@@ -1,5 +1,5 @@
{
"selections": {
"config_id": "1"
"config_id": ""
}
}

View File

@@ -21,7 +21,7 @@ os.environ["LANGCHAIN_TRACING"] = "false"
import asyncio
import time
from datetime import datetime
from typing import Optional
from typing import Optional, Callable, Awaitable
from dotenv import load_dotenv
# 导入重构后的模块
@@ -50,7 +50,11 @@ logger = get_memory_logger(__name__)
async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
async def main(
dialogue_text: Optional[str] = None,
is_pilot_run: bool = False,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None
):
"""
记忆系统主流程 - 重构版本
@@ -61,6 +65,12 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
is_pilot_run: 是否为试运行模式
- True: 试运行模式,不保存到 Neo4j
- False: 正常运行模式,保存到 Neo4j
progress_callback: 可选的进度回调函数
- 类型: Callable[[str, str, Optional[dict]], Awaitable[None]]
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典,包含详细的进度信息或结果
- 在管线关键点调用以报告进度和结果数据
工作流程:
1. 初始化客户端和配置
@@ -141,6 +151,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
metadata={"source": "pilot_run", "input_type": "frontend_text"}
)
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
# 对前端传入的对话进行分块处理
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
@@ -148,6 +162,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
llm_client=llm_client,
)
logger.info(f"Processed frontend dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
else:
# 正常运行模式:从 testdata.json 文件加载
logger.warning("[MAIN] ✗ Falling back to testdata.json (dialogue_text not provided or empty)")
@@ -159,6 +194,10 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
if not os.path.exists(test_data_path):
raise FileNotFoundError(f"Test data file not found: {test_data_path}")
# 进度回调:开始预处理文本
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_with_preprocessing(
chunker_strategy=config_defs.SELECTED_CHUNKER_STRATEGY,
group_id=config_defs.SELECTED_GROUP_ID,
@@ -170,6 +209,27 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
skip_cleaning=True,
)
logger.info(f"Loaded {len(chunked_dialogs)} dialogues from testdata.json")
# 进度回调:输出每个分块的结果
if progress_callback:
for dialog in chunked_dialogs:
for i, chunk in enumerate(dialog.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dialog.id,
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 进度回调:预处理文本完成
preprocessing_summary = {
"total_chunks": sum(len(dialog.chunks) for dialog in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": config_defs.SELECTED_CHUNKER_STRATEGY
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
@@ -188,6 +248,7 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback, # 传递进度回调
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
@@ -196,6 +257,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
logger.info("Running extraction pipeline...")
step_start = time.time()
# 进度回调:正在知识抽取
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=is_pilot_run, # 传递试运行模式标志
@@ -216,6 +282,11 @@ async def main(dialogue_text: Optional[str] = None, is_pilot_run: bool = False):
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
# 进度回调:生成结果
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 保存结果或输出结果
if is_pilot_run:

View File

@@ -2,7 +2,7 @@
去重功能函数
"""
from app.core.memory.models.variate_config import DedupConfig
from typing import List, Dict, Tuple
from typing import List, Dict, Tuple, Any
from app.core.memory.models.graph_models import(
StatementEntityEdge,
EntityEntityEdge,
@@ -895,7 +895,12 @@ async def deduplicate_entities_and_edges(
report_append: bool = False,
report_stage_notes: List[str] | None = None,
dedup_config: DedupConfig | None = None,
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
) -> Tuple[
List[ExtractedEntityNode],
List[StatementEntityEdge],
List[EntityEntityEdge],
Dict[str, Any] # 新增:返回详细的去重消歧记录
]:
"""
主流程依次执行精确匹配、模糊匹配与可选LLM 决策融合,随后对边做重定向与去重。之后再处理边,是关系去重和消歧
返回:去重后的实体、语句→实体边、实体↔实体边。
@@ -981,8 +986,18 @@ async def deduplicate_entities_and_edges(
append=report_append,
stage_notes=report_stage_notes,
)
# 构建详细的去重消歧记录(用于内存访问,避免解析日志文件)
dedup_details = {
"exact_merge_map": exact_merge_map,
"fuzzy_merge_records": fuzzy_merge_records,
"llm_decision_records": local_llm_records,
"disamb_records": disamb_records,
"id_redirect": id_redirect,
"blocked_pairs": blocked_pairs,
}
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values())
return deduped_entities, list(stmt_ent_map.values()), list(ent_ent_map.values()), dedup_details
# 独立模块:去重融合报告写入(与实体/边的计算解耦)
def _write_dedup_fusion_report(

View File

@@ -39,6 +39,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge],
List[StatementEntityEdge],
List[EntityEntityEdge],
dict, # 新增:返回去重详情
]:
"""
执行两层实体去重与融合:
@@ -62,7 +63,7 @@ async def dedup_layers_and_merge_and_return(
break
# 第一层去重消歧
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
@@ -103,4 +104,5 @@ async def dedup_layers_and_merge_and_return(
statement_chunk_edges,
fused_statement_entity_edges,
fused_entity_entity_edges,
dedup_details, # 返回去重详情
)

View File

@@ -12,13 +12,14 @@
5. 提供错误处理和日志记录
6. 支持试运行模式(不写入数据库)
作者:Memory Refactoring Team
作者:
日期2025-11-21
"""
import asyncio
import logging
from typing import List, Dict, Any, Tuple, Optional
import os
from typing import List, Dict, Any, Tuple, Optional, Callable, Awaitable
from datetime import datetime
from app.core.memory.models.message_models import DialogData
@@ -94,6 +95,7 @@ class ExtractionOrchestrator:
embedder_client: OpenAIEmbedderClient,
connector: Neo4jConnector,
config: Optional[ExtractionPipelineConfig] = None,
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
):
"""
初始化流水线编排器
@@ -103,12 +105,21 @@ class ExtractionOrchestrator:
embedder_client: 嵌入模型客户端
connector: Neo4j 连接器
config: 流水线配置,如果为 None 则使用默认配置
progress_callback: 进度回调函数
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
- 在管线关键点调用以报告进度和结果数据
"""
self.llm_client = llm_client
self.embedder_client = embedder_client
self.connector = connector
self.config = config or ExtractionPipelineConfig()
self.is_pilot_run = False # 默认非试运行模式
self.progress_callback = progress_callback # 保存进度回调函数
# 保存去重消歧的详细记录(内存中的数据结构)
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
self.dedup_disamb_records: List[Dict[str, Any]] = [] # 实体消歧记录
self.id_redirect_map: Dict[str, str] = {} # ID重定向映射
# 初始化各个提取器
self.statement_extractor = StatementExtractor(
@@ -160,6 +171,13 @@ class ExtractionOrchestrator:
# 步骤 1: 陈述句提取
logger.info("步骤 1/6: 陈述句提取(全局分块级并行)")
dialog_data_list = await self._extract_statements(dialog_data_list)
# 收集陈述句内容和统计数量
all_statements_list = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements)
total_statements = len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取和嵌入生成")
@@ -170,11 +188,90 @@ class ExtractionOrchestrator:
chunk_embedding_maps,
dialog_embeddings,
) = await self._parallel_extract_and_embed(dialog_data_list)
# 收集实体和三元组内容,并统计数量
all_entities_list = []
all_triplets_list = []
for triplet_map in triplet_maps:
for triplet_info in triplet_map.values():
if triplet_info:
all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets)
total_entities = len(all_entities_list)
total_triplets = len(all_triplets_list)
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入")
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
# 进度回调:按三个阶段分别输出知识抽取结果
if self.progress_callback:
# 第一阶段:陈述句提取结果
for i, stmt in enumerate(all_statements_list[:10]): # 只输出前10个陈述句
stmt_result = {
"extraction_type": "statement",
"statement_index": i + 1,
"statement": stmt.statement,
"statement_id": stmt.id
}
await self.progress_callback("knowledge_extraction_result", "陈述句提取完成", stmt_result)
# 第二阶段:三元组提取结果
for i, triplet in enumerate(all_triplets_list[:10]): # 只输出前10个三元组
triplet_result = {
"extraction_type": "triplet",
"triplet_index": i + 1,
"subject": triplet.subject_name,
"predicate": triplet.predicate,
"object": triplet.object_name
}
await self.progress_callback("knowledge_extraction_result", "三元组提取完成", triplet_result)
# 第三阶段:时间提取结果
if total_temporal > 0:
# 收集时间信息
temporal_results = []
for dialog in dialog_data_list:
for chunk in dialog.chunks:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
temporal_results.append({
"statement_id": statement.id,
"statement": statement.statement,
"valid_at": statement.temporal_validity.valid_at,
"invalid_at": statement.temporal_validity.invalid_at
})
# 输出时间提取结果
for i, temporal_result in enumerate(temporal_results[:5]): # 只输出前5个时间提取结果
time_result = {
"extraction_type": "temporal",
"temporal_index": i + 1,
"statement": temporal_result["statement"],
"valid_at": temporal_result["valid_at"],
"invalid_at": temporal_result["invalid_at"]
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
else:
# 如果没有时间信息,也发送一个时间提取完成的消息
time_result = {
"extraction_type": "temporal",
"temporal_index": 0,
"message": "未发现时间信息"
}
await self.progress_callback("knowledge_extraction_result", "时间提取完成", time_result)
# 进度回调:知识抽取完成,传递知识抽取的统计信息
extraction_stats = {
"statements_count": total_statements,
"entities_count": total_entities,
"triplets_count": total_triplets,
"temporal_ranges_count": total_temporal,
}
await self.progress_callback("knowledge_extraction_complete", "知识抽取完成", extraction_stats)
# 步骤 4: 将提取的数据赋值到语句
logger.info("步骤 4/6: 数据赋值")
dialog_data_list = await self._assign_extracted_data(
@@ -218,6 +315,8 @@ class ExtractionOrchestrator:
dialog_data_list,
)
logger.info(f"知识提取流水线运行完成({mode_str}")
return result
@@ -732,6 +831,10 @@ class ExtractionOrchestrator:
包含所有节点和边的元组
"""
logger.info("开始创建节点和边")
# 进度回调:正在创建节点和边
if self.progress_callback:
await self.progress_callback("creating_nodes_edges", "正在创建节点和边...")
dialogue_nodes = []
chunk_nodes = []
@@ -904,6 +1007,23 @@ class ExtractionOrchestrator:
f"陈述句-实体边: {len(statement_entity_edges)}, "
f"实体-实体边: {len(entity_entity_edges)}"
)
# 进度回调:只输出关系创建结果
if self.progress_callback:
# 输出关系创建结果
await self._output_relationship_creation_results(entity_entity_edges, entity_nodes)
# 进度回调:创建节点和边完成,传递结果统计
nodes_edges_stats = {
"dialogue_nodes_count": len(dialogue_nodes),
"chunk_nodes_count": len(chunk_nodes),
"statement_nodes_count": len(statement_nodes),
"entity_nodes_count": len(entity_nodes),
"statement_chunk_edges_count": len(statement_chunk_edges),
"statement_entity_edges_count": len(statement_entity_edges),
"entity_entity_edges_count": len(entity_entity_edges),
}
await self.progress_callback("creating_nodes_edges_complete", "创建节点和边完成", nodes_edges_stats)
return (
dialogue_nodes,
@@ -950,6 +1070,11 @@ class ExtractionOrchestrator:
- 第三个元组:去重后的 (实体节点列表, 陈述句-实体边列表, 实体-实体边列表)
"""
logger.info("开始两阶段实体去重和消歧")
# 进度回调:正在去重消歧
if self.progress_callback:
await self.progress_callback("deduplication", "正在去重消歧...")
logger.info(
f"去重前: {len(entity_nodes)} 个实体节点, "
f"{len(statement_entity_edges)} 条陈述句-实体边, "
@@ -963,7 +1088,7 @@ class ExtractionOrchestrator:
# 只执行第一层去重
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import deduplicate_entities_and_edges
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges = await deduplicate_entities_and_edges(
dedup_entity_nodes, dedup_statement_entity_edges, dedup_entity_entity_edges, dedup_details = await deduplicate_entities_and_edges(
entity_nodes,
statement_entity_edges,
entity_entity_edges,
@@ -972,6 +1097,9 @@ class ExtractionOrchestrator:
dedup_config=self.config.deduplication,
)
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, dedup_entity_nodes)
result_tuple = (
dialogue_nodes,
chunk_nodes,
@@ -1009,7 +1137,11 @@ class ExtractionOrchestrator:
_,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = result_tuple
# 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
logger.info(
f"去重后: {len(final_entity_nodes)} 个实体节点, "
@@ -1021,6 +1153,46 @@ class ExtractionOrchestrator:
f"陈述句-实体边减少 {len(statement_entity_edges) - len(final_statement_entity_edges)}, "
f"实体-实体边减少 {len(entity_entity_edges) - len(final_entity_entity_edges)}"
)
# 进度回调:输出去重消歧的具体结果
if self.progress_callback:
# 分析实体合并情况
merge_info = await self._analyze_entity_merges(entity_nodes, final_entity_nodes)
# 输出去重合并的实体示例
for i, merge_detail in enumerate(merge_info[:5]): # 输出前5个去重结果
dedup_result = {
"result_type": "entity_merge",
"merged_entity_name": merge_detail["main_entity_name"],
"merged_count": merge_detail["merged_count"],
"message": f"{merge_detail['main_entity_name']}合并{merge_detail['merged_count']}个:相似实体已合并"
}
await self.progress_callback("dedup_disambiguation_result", "实体去重完成", dedup_result)
# 分析实体消歧情况
disamb_info = await self._analyze_entity_disambiguation(entity_nodes, final_entity_nodes)
# 输出实体消歧的结果
for i, disamb_detail in enumerate(disamb_info[:5]): # 输出前5个消歧结果
disamb_result = {
"result_type": "entity_disambiguation",
"disambiguated_entity_name": disamb_detail["entity_name"],
"disambiguation_type": disamb_detail["disamb_type"],
"confidence": disamb_detail.get("confidence", "unknown"),
"reason": disamb_detail.get("reason", ""),
"message": f"{disamb_detail['entity_name']}消歧完成:{disamb_detail['disamb_type']}"
}
await self.progress_callback("dedup_disambiguation_result", "实体消歧完成", disamb_result)
# 进度回调:去重消歧完成,传递去重和消歧的具体效果
await self._send_dedup_progress_callback(
len(entity_nodes), len(final_entity_nodes),
len(statement_entity_edges), len(final_statement_entity_edges),
len(entity_entity_edges), len(final_entity_entity_edges)
)
# 写入提取结果汇总(试运行和正式模式都需要生成)
try:
@@ -1041,6 +1213,378 @@ class ExtractionOrchestrator:
logger.error(f"两阶段去重失败: {e}", exc_info=True)
raise
def _save_dedup_details(
self,
dedup_details: Dict[str, Any],
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
):
"""
保存去重消歧的详细记录到实例变量(基于内存数据结构)
Args:
dedup_details: 去重函数返回的详细记录
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
"""
try:
# 保存ID重定向映射
self.id_redirect_map = dedup_details.get("id_redirect", {})
# 处理精确匹配的合并记录
exact_merge_map = dedup_details.get("exact_merge_map", {})
for key, info in exact_merge_map.items():
merged_ids = info.get("merged_ids", set())
if merged_ids:
self.dedup_merge_records.append({
"type": "精确匹配",
"canonical_id": info.get("canonical_id"),
"entity_name": info.get("name"),
"entity_type": info.get("entity_type"),
"merged_count": len(merged_ids),
"merged_ids": list(merged_ids)
})
# 处理模糊匹配的合并记录
fuzzy_merge_records = dedup_details.get("fuzzy_merge_records", [])
for record in fuzzy_merge_records:
# 解析模糊匹配记录字符串
# 格式: "[模糊] 规范实体 id (group|name|type) <- 合并实体 id (group|name|type) | s_name=0.xxx, ..."
try:
import re
match = re.search(r"规范实体 (\S+) \(([^|]+)\|([^|]+)\|([^)]+)\) <- 合并实体 (\S+)", record)
if match:
self.dedup_merge_records.append({
"type": "模糊匹配",
"canonical_id": match.group(1),
"entity_name": match.group(3),
"entity_type": match.group(4),
"merged_count": 1,
"merged_ids": [match.group(5)]
})
except Exception as e:
logger.debug(f"解析模糊匹配记录失败: {record}, 错误: {e}")
# 处理LLM去重的合并记录
llm_decision_records = dedup_details.get("llm_decision_records", [])
for record in llm_decision_records:
if "[LLM去重]" in str(record):
try:
import re
# 格式: "[LLM去重] 同名类型相似 name1type1|name2type2 | conf=0.xx | reason=..."
match = re.search(r"同名类型相似 ([^]+)([^]+)\|([^]+)([^]+)", record)
if match:
self.dedup_merge_records.append({
"type": "LLM去重",
"entity_name": match.group(1),
"entity_type": f"{match.group(2)}|{match.group(4)}",
"merged_count": 1,
"merged_ids": []
})
except Exception as e:
logger.debug(f"解析LLM去重记录失败: {record}, 错误: {e}")
# 处理消歧记录
disamb_records = dedup_details.get("disamb_records", [])
for record in disamb_records:
if "[DISAMB阻断]" in str(record):
try:
import re
# 格式: "[DISAMB阻断] name1type1|name2type2 | conf=0.xx | reason=..."
content = str(record).replace("[DISAMB阻断]", "").strip()
match = re.search(r"([^]+)([^]+)\|([^]+)([^]+)", content)
if match:
entity1_name = match.group(1).strip()
entity1_type = match.group(2)
entity2_name = match.group(3).strip()
entity2_type = match.group(4)
# 提取置信度和原因
conf_match = re.search(r"conf=([0-9.]+)", str(record))
confidence = conf_match.group(1) if conf_match else "unknown"
reason_match = re.search(r"reason=([^|]+)", str(record))
reason = reason_match.group(1).strip() if reason_match else ""
self.dedup_disamb_records.append({
"entity_name": entity1_name,
"disamb_type": f"消歧阻断:{entity1_type} vs {entity2_type}",
"confidence": confidence,
"reason": reason[:100] + "..." if len(reason) > 100 else reason
})
except Exception as e:
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
except Exception as e:
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
async def _analyze_entity_merges(
self,
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
Args:
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
Returns:
合并详情列表,每个元素包含主实体名称和合并数量
"""
try:
# 直接使用保存的合并记录
if self.dedup_merge_records:
# 按合并数量排序,返回前几个
sorted_records = sorted(
self.dedup_merge_records,
key=lambda x: x.get("merged_count", 0),
reverse=True
)
merge_info = []
for record in sorted_records:
merge_info.append({
"main_entity_name": record.get("entity_name", "未知实体"),
"merged_count": record.get("merged_count", 1)
})
return merge_info
# 如果没有保存的记录,返回空列表
logger.info("未找到实体合并记录")
return []
except Exception as e:
logger.error(f"分析实体合并情况失败: {e}", exc_info=True)
return []
async def _analyze_entity_disambiguation(
self,
original_entities: List[ExtractedEntityNode],
final_entities: List[ExtractedEntityNode]
) -> List[Dict[str, Any]]:
"""
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
Args:
original_entities: 去重前的实体列表
final_entities: 去重后的实体列表
Returns:
消歧详情列表,每个元素包含实体名称和消歧类型
"""
try:
# 直接使用保存的消歧记录
if self.dedup_disamb_records:
return self.dedup_disamb_records
# 如果没有保存的记录,返回空列表
logger.info("未找到实体消歧记录")
return []
except Exception as e:
logger.error(f"分析实体消歧情况失败: {e}", exc_info=True)
return []
def _get_entity_type_display_name(self, entity_type: str) -> str:
"""
获取实体类型的中文显示名称
Args:
entity_type: 英文实体类型
Returns:
中文显示名称
"""
type_mapping = {
"Person": "人物实体节点",
"Organization": "组织实体节点",
"ORG": "组织实体节点",
"Location": "地点实体节点",
"LOC": "地点实体节点",
"Event": "事件实体节点",
"Concept": "概念实体节点",
"Time": "时间实体节点",
"Position": "职位实体节点",
"WorkRole": "职业实体节点",
"System": "系统实体节点",
"Policy": "政策实体节点",
"HistoricalPeriod": "历史时期实体节点",
"HistoricalState": "历史国家实体节点",
"HistoricalEvent": "历史事件实体节点",
"EconomicFactor": "经济因素实体节点",
"Condition": "条件实体节点",
"Numeric": "数值实体节点"
}
return type_mapping.get(entity_type, f"{entity_type}实体节点")
async def _output_relationship_creation_results(
self,
entity_entity_edges: List[EntityEntityEdge],
entity_nodes: List[ExtractedEntityNode]
):
"""
输出关系创建结果
Args:
entity_entity_edges: 实体-实体边列表
entity_nodes: 实体节点列表
"""
try:
# 创建实体ID到名称的映射
entity_id_to_name = {node.id: node.name for node in entity_nodes}
# 输出关系创建结果
for i, edge in enumerate(entity_entity_edges[:10]): # 只输出前10个关系
source_name = entity_id_to_name.get(edge.source, f"Entity_{edge.source}")
target_name = entity_id_to_name.get(edge.target, f"Entity_{edge.target}")
relation_type = edge.relation_type
relationship_result = {
"result_type": "relationship_creation",
"relationship_index": i + 1,
"source_entity": source_name,
"relation_type": relation_type,
"target_entity": target_name,
"relationship_text": f"{source_name} -[{relation_type}]-> {target_name}"
}
await self.progress_callback("creating_nodes_edges_result", "关系创建", relationship_result)
except Exception as e:
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
async def _send_dedup_progress_callback(
self,
original_entities: int,
final_entities: int,
original_stmt_edges: int,
final_stmt_edges: int,
original_ent_edges: int,
final_ent_edges: int,
):
"""
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
Args:
original_entities: 去重前实体数量
final_entities: 去重后实体数量
original_stmt_edges: 去重前陈述句-实体边数量
final_stmt_edges: 去重后陈述句-实体边数量
original_ent_edges: 去重前实体-实体边数量
final_ent_edges: 去重后实体-实体边数量
"""
try:
# 解析去重消歧报告文件,获取具体的去重和消歧效果
dedup_details = await self._parse_dedup_report()
# 计算去重效果统计
entities_reduced = original_entities - final_entities
stmt_edges_reduced = original_stmt_edges - final_stmt_edges
ent_edges_reduced = original_ent_edges - final_ent_edges
# 构建进度回调数据
dedup_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": entities_reduced,
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
},
"statement_entity_edges": {
"original_count": original_stmt_edges,
"final_count": final_stmt_edges,
"reduced_count": stmt_edges_reduced,
},
"entity_entity_edges": {
"original_count": original_ent_edges,
"final_count": final_ent_edges,
"reduced_count": ent_edges_reduced,
},
"dedup_examples": dedup_details.get("dedup_examples", []),
"disamb_examples": dedup_details.get("disamb_examples", []),
"summary": {
"total_merges": dedup_details.get("total_merges", 0),
"total_disambiguations": dedup_details.get("total_disambiguations", 0),
}
}
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", dedup_stats)
except Exception as e:
logger.error(f"发送去重消歧进度回调失败: {e}", exc_info=True)
# 即使解析失败,也发送基本的统计信息
try:
basic_stats = {
"entities": {
"original_count": original_entities,
"final_count": final_entities,
"reduced_count": original_entities - final_entities,
},
"summary": f"实体去重合并{original_entities - final_entities}"
}
await self.progress_callback("dedup_disambiguation_complete", "去重消歧完成", basic_stats)
except Exception as e2:
logger.error(f"发送基本去重统计失败: {e2}", exc_info=True)
async def _parse_dedup_report(self) -> Dict[str, Any]:
"""
获取去重消歧报告,直接使用内存中的记录(不再解析日志文件)
Returns:
包含去重和消歧详细信息的字典
"""
try:
# 直接使用保存的记录构建报告
dedup_examples = []
disamb_examples = []
total_merges = 0
total_disambiguations = 0
# 处理合并记录
for record in self.dedup_merge_records:
merge_count = record.get("merged_count", 0)
total_merges += merge_count
dedup_examples.append({
"type": record.get("type", "未知"),
"entity_name": record.get("entity_name", "未知实体"),
"entity_type": record.get("entity_type", "未知类型"),
"merge_count": merge_count,
"description": f"{record.get('entity_name', '未知实体')}实体去重合并{merge_count}"
})
# 处理消歧记录
for record in self.dedup_disamb_records:
total_disambiguations += 1
# 从消歧类型中提取实体类型信息
disamb_type = record.get("disamb_type", "")
entity_name = record.get("entity_name", "未知实体")
disamb_examples.append({
"entity1_name": entity_name,
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
"entity2_name": entity_name,
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
"description": f"{entity_name},消歧区分成功"
})
return {
"dedup_examples": dedup_examples[:5], # 只返回前5个示例
"disamb_examples": disamb_examples[:5], # 只返回前5个示例
"total_merges": total_merges,
"total_disambiguations": total_disambiguations,
}
except Exception as e:
logger.error(f"获取去重报告失败: {e}", exc_info=True)
return {"dedup_examples": [], "disamb_examples": [], "total_merges": 0, "total_disambiguations": 0}
# ============================================================================
# 数据加载和预处理函数

View File

@@ -0,0 +1,436 @@
"""
工作流执行器
基于 LangGraph 的工作流执行引擎。
"""
import logging
import uuid
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.models.workflow_model import WorkflowExecution, WorkflowNodeExecution
from app.db import get_db
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""工作流执行器
负责将工作流配置转换为 LangGraph 并执行。
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
Args:
workflow_config: 工作流配置
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
"""
self.workflow_config = workflow_config
self.execution_id = execution_id
self.workspace_id = workspace_id
self.user_id = user_id
self.nodes = workflow_config.get("nodes", [])
self.edges = workflow_config.get("edges", [])
self.execution_config = workflow_config.get("execution_config", {})
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
"""准备初始状态(注入系统变量和会话变量)
变量命名空间:
- sys.xxx - 系统变量execution_id, workspace_id, user_id, message, input_variables 等)
- conv.xxx - 会话变量(跨多轮对话保持)
- node_id.xxx - 节点输出(执行时动态生成)
Args:
input_data: 输入数据
Returns:
初始化的工作流状态
"""
user_message = input_data.get("message") or ""
conversation_vars = input_data.get("conversation_vars") or {}
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
# 构建分层的变量结构
variables = {
"sys": {
"message": user_message, # 用户消息
"conversation_id": input_data.get("conversation_id"), # 会话 ID
"execution_id": self.execution_id, # 执行 ID
"workspace_id": self.workspace_id, # 工作空间 ID
"user_id": self.user_id, # 用户 ID
"input_variables": input_variables, # 自定义输入变量(给 Start 节点使用)
},
"conv": conversation_vars # 会话级变量(跨多轮对话保持)
}
return {
"messages": [HumanMessage(content=user_message)],
"variables": variables,
"node_outputs": {},
"runtime_vars": {}, # 运行时节点变量(简化版,供快速访问)
"execution_id": self.execution_id,
"workspace_id": self.workspace_id,
"user_id": self.user_id,
"error": None,
"error_node": None
}
def build_graph(self) -> StateGraph:
"""构建 LangGraph
Returns:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 处理到 end 节点的边
if target in end_node_ids:
# 连接到 end 节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
workflow.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph
async def execute(
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
Args:
input_data: 输入数据,包含 message 和 variables
Returns:
执行结果,包含 status, output, node_outputs, elapsed_time, token_usage
"""
logger.info(f"开始执行工作流: execution_id={self.execution_id}")
# 记录开始时间
start_time = datetime.datetime.now()
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 执行工作流
try:
result = await graph.ainvoke(initial_state)
# 计算耗时
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
# 提取节点输出(现在包含 start 和 end 节点)
node_outputs = result.get("node_outputs", {})
# 提取最终输出(从最后一个非 start/end 节点)
final_output = self._extract_final_output(node_outputs)
# 聚合 token 使用情况
token_usage = self._aggregate_token_usage(node_outputs)
# 提取 conversation_id从 start 节点输出)
conversation_id = None
for node_id, node_output in node_outputs.items():
if node_output.get("node_type") == "start":
conversation_id = node_output.get("output", {}).get("conversation_id")
break
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
return {
"status": "completed",
"output": final_output,
"node_outputs": node_outputs,
"messages": result.get("messages", []),
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": result.get("error")
}
except Exception as e:
# 计算耗时(即使失败也记录)
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream(
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
Args:
input_data: 输入数据
Yields:
流式事件
"""
logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}")
# 1. 构建图
graph = self.build_graph()
# 2. 初始化状态(自动注入系统变量)
initial_state = self._prepare_initial_state(input_data)
# 3. 流式执行工作流
try:
# 使用 astream 获取节点级别的更新
async for event in graph.astream(initial_state, stream_mode="updates"):
for node_name, state_update in event.items():
yield {
"type": "node_complete",
"node": node_name,
"data": state_update,
"execution_id": self.execution_id
}
logger.info(f"工作流执行完成(流式): execution_id={self.execution_id}")
# 发送完成事件
yield {
"type": "workflow_complete",
"execution_id": self.execution_id
}
except Exception as e:
logger.error(f"工作流执行失败(流式): execution_id={self.execution_id}, error={e}", exc_info=True)
yield {
"type": "workflow_error",
"execution_id": self.execution_id,
"error": str(e)
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
优先级:
1. 最后一个执行的非 start/end 节点的 output
2. 如果没有节点输出,返回 None
Args:
node_outputs: 所有节点的输出
Returns:
最终输出字符串或 None
"""
if not node_outputs:
return None
# 获取最后一个节点的输出
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
if last_node_output and isinstance(last_node_output, dict):
return last_node_output.get("output")
return None
def _aggregate_token_usage(self, node_outputs: dict[str, Any]) -> dict[str, int] | None:
"""聚合所有节点的 token 使用情况
Args:
node_outputs: 所有节点的输出
Returns:
聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z}
如果没有 token 使用信息,返回 None
"""
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
has_token_info = False
for node_output in node_outputs.values():
if isinstance(node_output, dict):
token_usage = node_output.get("token_usage")
if token_usage and isinstance(token_usage, dict):
has_token_info = True
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
total_completion_tokens += token_usage.get("completion_tokens", 0)
total_tokens += token_usage.get("total_tokens", 0)
if not has_token_info:
return None
return {
"prompt_tokens": total_prompt_tokens,
"completion_tokens": total_completion_tokens,
"total_tokens": total_tokens
}
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Returns:
执行结果
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
return await executor.execute(input_data)
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
executor = WorkflowExecutor(
workflow_config=workflow_config,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
async for event in executor.execute_stream(input_data):
yield event

View File

@@ -0,0 +1,195 @@
"""
安全的表达式求值器
使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。
"""
import logging
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
logger = logging.getLogger(__name__)
class ExpressionEvaluator:
"""安全的表达式求值器"""
# 保留的命名空间
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
def evaluate(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""安全地评估表达式
Args:
expression: 表达式字符串,如 "{{var.score}} > 0.8"
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
表达式求值结果
Raises:
ValueError: 表达式无效或求值失败
Examples:
>>> evaluator = ExpressionEvaluator()
>>> evaluator.evaluate(
... "var.score > 0.8",
... {"score": 0.9},
... {},
... {}
... )
True
>>> evaluator.evaluate(
... "node.intent.output == '售前咨询'",
... {},
... {"intent": {"output": "售前咨询"}},
... {}
... )
True
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量
"node": node_outputs, # 节点输出
"sys": system_vars or {}, # 系统变量
}
# 为了向后兼容,也支持直接访问(但会在日志中警告)
context.update(variables)
context["nodes"] = node_outputs
try:
# simpleeval 只支持安全的操作:
# - 算术运算: +, -, *, /, //, %, **
# - 比较运算: ==, !=, <, <=, >, >=
# - 逻辑运算: and, or, not
# - 成员运算: in, not in
# - 属性访问: obj.attr
# - 字典/列表访问: obj["key"], obj[0]
# 不支持:函数调用、导入、赋值等危险操作
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except InvalidExpression as e:
logger.error(f"表达式语法无效: {expression}, 错误: {e}")
raise ValueError(f"表达式语法无效: {e}")
except SyntaxError as e:
logger.error(f"表达式语法错误: {expression}, 错误: {e}")
raise ValueError(f"表达式语法错误: {e}")
except Exception as e:
logger.error(f"表达式求值异常: {expression}, 错误: {e}")
raise ValueError(f"表达式求值失败: {e}")
@staticmethod
def evaluate_bool(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估布尔表达式(用于条件判断)
Args:
expression: 布尔表达式
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
布尔值结果
Examples:
>>> ExpressionEvaluator.evaluate_bool(
... "var.count >= 10 and var.status == 'active'",
... {"count": 15, "status": "active"},
... {},
... {}
... )
True
"""
result = ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""验证变量名是否合法
Args:
variables: 变量定义列表
Returns:
错误列表,如果为空则验证通过
Examples:
>>> ExpressionEvaluator.validate_variable_names([
... {"name": "user_input"},
... {"name": "var"} # 保留字
... ])
["变量名 'var' 是保留的命名空间,请使用其他名称"]
"""
errors = []
for var in variables:
var_name = var.get("name", "")
# 检查是否为保留命名空间
if var_name in ExpressionEvaluator.RESERVED_NAMESPACES:
errors.append(
f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称"
)
# 检查是否为有效的 Python 标识符
if not var_name.isidentifier():
errors.append(
f"变量名 '{var_name}' 不是有效的标识符"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""评估表达式(便捷函数)"""
return ExpressionEvaluator.evaluate(
expression, variables, node_outputs, system_vars
)
def evaluate_condition(
expression: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""评估条件表达式(便捷函数)"""
return ExpressionEvaluator.evaluate_bool(
expression, variables, node_outputs, system_vars
)

View File

@@ -0,0 +1,24 @@
"""
工作流节点实现
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
__all__ = [
"BaseNode",
"WorkflowState",
"LLMNode",
"AgentNode",
"TransformNode",
"StartNode",
"EndNode",
"NodeFactory",
]

View File

@@ -0,0 +1,6 @@
"""Agent 节点"""
from app.core.workflow.nodes.agent.node import AgentNode
from app.core.workflow.nodes.agent.config import AgentNodeConfig
__all__ = ["AgentNode", "AgentNodeConfig"]

View File

@@ -0,0 +1,71 @@
"""Agent 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class AgentNodeConfig(BaseNodeConfig):
"""Agent 节点配置
调用已配置的 Agent 执行任务。
"""
agent_id: str = Field(
...,
description="Agent 配置 ID"
)
message: str = Field(
default="{{ sys.message }}",
description="发送给 Agent 的消息,支持模板变量"
)
conversation_id: str | None = Field(
default=None,
description="会话 ID用于多轮对话"
)
variables: dict[str, str] | None = Field(
default=None,
description="传递给 Agent 的变量"
)
timeout: int = Field(
default=300,
ge=1,
le=3600,
description="超时时间(秒)"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="Agent 的回复内容"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"agent_id": "uuid-here",
"message": "{{ sys.message }}",
"timeout": 300,
"description": "调用客服 Agent"
}
}

View File

@@ -0,0 +1,152 @@
"""
Agent 节点实现
调用已发布的 Agent 应用。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.services.draft_run_service import DraftRunService
from app.models import AppRelease
from app.db import get_db
logger = logging.getLogger(__name__)
class AgentNode(BaseNode):
"""Agent 节点
支持流式和非流式输出。
配置示例:
{
"type": "agent",
"config": {
"agent_id": "uuid", # Agent 的 release_id
"message": "{{var.user_input}}"
}
}
"""
def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
state: 工作流状态
Returns:
(draft_service, release, message): 服务实例、发布配置、消息
"""
# 1. 渲染消息
message_template = self.config.get("message", "")
message = self._render_template(message_template, state)
# 2. 获取 Agent 配置
agent_id = self.config.get("agent_id")
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""非流式执行
Args:
state: 工作流状态
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
)
response = result.get("response", "")
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(response)}")
return {
"messages": [AIMessage(content=response)],
"node_outputs": {
self.node_id: {
"output": response,
"status": "completed",
"meta_data": result.get("meta_data", {})
}
}
}
async def execute_stream(self, state: WorkflowState):
"""流式执行
Args:
state: 工作流状态
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(state)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=state.get("workspace_id"),
user_id=state.get("user_id"),
variables=state.get("variables", {})
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
# 最后返回完整结果
yield {
"type": "complete",
"messages": [AIMessage(content=full_response)],
"node_outputs": {
self.node_id: {
"output": full_response,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,109 @@
"""节点配置基类
定义所有节点配置的通用字段和数据结构。
"""
from enum import StrEnum
from pydantic import BaseModel, Field
class VariableType(StrEnum):
"""变量类型枚举"""
STRING = "string"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class VariableDefinition(BaseModel):
"""变量定义
定义工作流或节点的输入/输出变量。
这是一个通用的数据结构,可以在多个地方使用。
"""
name: str = Field(
...,
description="变量名称"
)
type: VariableType = Field(
default=VariableType.STRING,
description="变量类型"
)
required: bool = Field(
default=False,
description="是否必需"
)
default: str | int | float | bool | list | dict | None = Field(
default=None,
description="默认值"
)
description: str | None = Field(
default=None,
description="变量描述"
)
class Config:
json_schema_extra = {
"examples": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
},
{
"name": "enable_search",
"type": "boolean",
"required": True,
"description": "是否启用搜索"
}
]
}
class BaseNodeConfig(BaseModel):
"""节点配置基类
所有节点配置都应该继承此基类。
通用字段:
- name: 节点名称(显示名称)
- description: 节点描述
- tags: 节点标签(用于分类和搜索)
"""
name: str | None = Field(
default=None,
description="节点名称(显示名称),如果不设置则使用节点 ID"
)
description: str | None = Field(
default=None,
description="节点描述,说明节点的作用"
)
tags: list[str] = Field(
default_factory=list,
description="节点标签,用于分类和搜索"
)
class Config:
"""Pydantic 配置"""
# 允许额外字段(向后兼容)
extra = "allow"

View File

@@ -0,0 +1,556 @@
"""
工作流节点基类
定义节点的基本接口和通用功能。
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from typing import Any, TypedDict, Annotated
from operator import add
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class WorkflowState(TypedDict):
"""工作流状态
在节点间传递的状态对象,包含消息、变量、节点输出等信息。
"""
# 消息列表(追加模式)
messages: Annotated[list[AnyMessage], add]
# 输入变量(从配置的 variables 传入)
variables: dict[str, Any]
# 节点输出(存储每个节点的执行结果,用于变量引用)
# 使用自定义合并函数,将新的节点输出合并到现有字典中
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 运行时节点变量(简化版,只存储业务数据,供节点间快速访问)
# 格式:{node_id: business_result}
runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
# 执行上下文
execution_id: str
workspace_id: str
user_id: str
# 错误信息(用于错误边)
error: str | None
error_node: str | None
class BaseNode(ABC):
"""节点基类
所有节点类型都应该继承此基类,实现 execute 方法。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
self.node_config = node_config
self.workflow_config = workflow_config
self.node_id = node_config["id"]
self.node_type = node_config["type"]
self.node_name = node_config.get("name", self.node_id)
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@abstractmethod
async def execute(self, state: WorkflowState) -> Any:
"""执行节点业务逻辑(非流式)
节点只需要返回业务结果,不需要关心输出格式、时间统计等。
BaseNode 会自动包装成标准格式。
Args:
state: 工作流状态
Returns:
业务结果(任意类型)
Examples:
>>> # LLM 节点
>>> return "这是 AI 的回复"
>>> # Transform 节点
>>> return {"processed_data": [...]}
>>> # Start/End 节点
>>> return {"message": "开始", "conversation_id": "xxx"}
"""
pass
async def execute_stream(self, state: WorkflowState):
"""执行节点业务逻辑(流式)
子类可以重写此方法以支持流式输出。
默认实现:执行非流式方法并一次性返回。
节点需要:
1. yield 中间结果(如文本片段)
2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result}
Args:
state: 工作流状态
Yields:
业务数据chunk或完成标记
Examples:
>>> # 流式 LLM 节点
>>> full_response = ""
>>> async for chunk in llm.astream(prompt):
... full_response += chunk
... yield chunk # yield 文本片段
>>>
>>> # 最后 yield 完成标记
>>> yield {"__final__": True, "result": AIMessage(content=full_response)}
"""
result = await self.execute(state)
# 默认实现:直接 yield 完成标记
yield {"__final__": True, "result": result}
def supports_streaming(self) -> bool:
"""节点是否支持流式输出
Returns:
是否支持流式输出
"""
# 检查子类是否重写了 execute_stream 方法
return self.execute_stream.__func__ != BaseNode.execute_stream.__func__
def get_timeout(self) -> int:
"""获取超时时间(秒)
Returns:
超时时间
"""
return 60
# return self.error_handling.get("timeout", 60)
async def run(self, state: WorkflowState) -> dict[str, Any]:
"""执行节点(带错误处理和输出包装,非流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute() 方法
3. 将业务结果包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Returns:
标准化的状态更新字典
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 调用节点的业务逻辑
business_result = await asyncio.wait_for(
self.execute(state),
timeout=timeout
)
elapsed_time = time.time() - start_time
# 提取处理后的输出(调用子类的 _extract_output
extracted_output = self._extract_output(business_result)
# 包装成标准输出格式
wrapped_output = self._wrap_output(business_result, elapsed_time, state)
# 将提取后的输出存储到运行时变量中(供后续节点快速访问)
# 如果提取后的输出是字典,拆包存储;否则存储为 output 字段
if isinstance(extracted_output, dict):
runtime_var = extracted_output
else:
runtime_var = {"output": extracted_output}
# 返回包装后的输出和运行时变量
return {
**wrapped_output,
"runtime_vars": {
self.node_id: runtime_var
}
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
return self._wrap_error(str(e), elapsed_time, state)
async def run_stream(self, state: WorkflowState):
"""执行节点(带错误处理和输出包装,流式)
这个方法由 Executor 调用,负责:
1. 时间统计
2. 调用节点的 execute_stream() 方法
3. 将业务数据包装成标准输出格式
4. 错误处理
Args:
state: 工作流状态
Yields:
标准化的流式事件
"""
import time
start_time = time.time()
try:
timeout = self.get_timeout()
# 累积完整结果(用于最后的包装)
chunks = []
final_result = None
# 使用异步生成器包装,支持超时
async def stream_with_timeout():
nonlocal final_result
loop_start = asyncio.get_event_loop().time()
async for item in self.execute_stream(state):
# 检查超时
if asyncio.get_event_loop().time() - loop_start > timeout:
raise TimeoutError()
# 检查是否是完成标记
if isinstance(item, dict) and item.get("__final__"):
final_result = item["result"]
elif isinstance(item, str):
# 字符串是 chunk
chunks.append(item)
yield {
"type": "chunk",
"node_id": self.node_id,
"content": item,
"full_content": "".join(chunks)
}
else:
# 其他类型也当作 chunk 处理
chunks.append(str(item))
yield {
"type": "chunk",
"node_id": self.node_id,
"content": str(item),
"full_content": "".join(chunks)
}
async for chunk_event in stream_with_timeout():
yield chunk_event
elapsed_time = time.time() - start_time
# 包装最终结果
final_output = self._wrap_output(final_result, elapsed_time, state)
yield {
"type": "complete",
**final_output
}
except TimeoutError:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)")
yield {
"type": "error",
**self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state)
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True)
yield {
"type": "error",
**self._wrap_error(str(e), elapsed_time, state)
}
def _wrap_output(
self,
business_result: Any,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将业务结果包装成标准输出格式
Args:
business_result: 节点返回的业务结果
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 提取输入数据(用于记录)
input_data = self._extract_input(state)
# 提取 token 使用情况(如果有)
token_usage = self._extract_token_usage(business_result)
# 提取实际输出(去除元数据)
output = self._extract_output(business_result)
# 构建标准节点输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "completed",
"input": input_data,
"output": output,
"elapsed_time": elapsed_time,
"token_usage": token_usage,
"error": None
}
return {
"node_outputs": {
self.node_id: node_output
}
}
def _wrap_error(
self,
error_message: str,
elapsed_time: float,
state: WorkflowState
) -> dict[str, Any]:
"""将错误包装成标准输出格式
Args:
error_message: 错误信息
elapsed_time: 执行耗时
state: 工作流状态
Returns:
标准化的状态更新字典
"""
# 查找错误边
error_edge = self._find_error_edge()
# 提取输入数据
input_data = self._extract_input(state)
# 构建错误输出
node_output = {
"node_id": self.node_id,
"node_type": self.node_type,
"node_name": self.node_name,
"status": "failed",
"input": input_data,
"output": None,
"elapsed_time": elapsed_time,
"token_usage": None,
"error": error_message
}
if error_edge:
# 有错误边:记录错误并继续
logger.warning(
f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}"
)
return {
"node_outputs": {
self.node_id: node_output
},
"error": error_message,
"error_node": self.node_id
}
else:
# 无错误边:抛出异常停止工作流
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取节点输入数据(用于记录)
子类可以重写此方法来自定义输入记录。
Args:
state: 工作流状态
Returns:
输入数据字典
"""
# 默认返回配置
return {"config": self.config}
def _extract_output(self, business_result: Any) -> Any:
"""从业务结果中提取实际输出
子类可以重写此方法来自定义输出提取。
Args:
business_result: 业务结果
Returns:
实际输出
"""
# 默认直接返回业务结果
return business_result
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从业务结果中提取 token 使用情况
子类可以重写此方法来提取 token 信息。
Args:
business_result: 业务结果
Returns:
token 使用情况或 None
"""
# 默认返回 None
return None
def _find_error_edge(self) -> dict[str, Any] | None:
"""查找错误边
Returns:
错误边配置或 None
"""
for edge in self.workflow_config.get("edges", []):
if edge.get("source") == self.node_id and edge.get("type") == "error":
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None) -> str:
"""渲染模板
支持的变量命名空间:
- sys.xxx: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.xxx: 会话变量(跨多轮对话保持)
- node_id.xxx: 节点输出
Args:
template: 模板字符串
state: 工作流状态
Returns:
渲染后的字符串
"""
from app.core.workflow.template_renderer import render_template
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return render_template(
template=template,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
"""评估条件表达式
支持的变量命名空间:
- sys.xxx: 系统变量
- conv.xxx: 会话变量
- node_id.xxx: 节点输出
Args:
expression: 条件表达式
state: 工作流状态
Returns:
布尔值结果
"""
from app.core.workflow.expression_evaluator import evaluate_condition
# 处理 state 为 None 的情况
if state is None:
state = {}
# 使用变量池获取变量
pool = VariablePool(state)
return evaluate_condition(
expression=expression,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
)
def get_variable_pool(self, state: WorkflowState) -> VariablePool:
"""获取变量池实例
VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。
Args:
state: 工作流状态
Returns:
VariablePool 实例
Examples:
>>> pool = self.get_variable_pool(state)
>>> message = pool.get("sys.message")
>>> llm_output = pool.get("llm_qa.output")
"""
return VariablePool(state)
def get_variable(
self,
selector: list[str] | str,
state: WorkflowState,
default: Any = None
) -> Any:
"""获取变量值(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
default: 默认值
Returns:
变量值
Examples:
>>> message = self.get_variable("sys.message", state)
>>> output = self.get_variable(["llm_qa", "output"], state)
>>> custom = self.get_variable("var.custom", state, default="默认值")
"""
pool = VariablePool(state)
return pool.get(selector, default=default)
def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool:
"""检查变量是否存在(便捷方法)
Args:
selector: 变量选择器
state: 工作流状态
Returns:
变量是否存在
Examples:
>>> if self.has_variable("llm_qa.output", state):
... output = self.get_variable("llm_qa.output", state)
"""
pool = VariablePool(state)
return pool.has(selector)

View File

@@ -0,0 +1,29 @@
"""节点配置类统一导出
所有节点的配置类都在这里导出,方便使用。
"""
from app.core.workflow.nodes.base_config import (
BaseNodeConfig,
VariableDefinition,
VariableType,
)
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = [
# 基础类
"BaseNodeConfig",
"VariableDefinition",
"VariableType",
# 节点配置
"StartNodeConfig",
"EndNodeConfig",
"LLMNodeConfig",
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
]

View File

@@ -0,0 +1,6 @@
"""End 节点"""
from app.core.workflow.nodes.end.node import EndNode
from app.core.workflow.nodes.end.config import EndNodeConfig
__all__ = ["EndNode", "EndNodeConfig"]

View File

@@ -0,0 +1,37 @@
"""End 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class EndNodeConfig(BaseNodeConfig):
"""End 节点配置
End 节点负责输出工作流的最终结果。
"""
output: str = Field(
default="工作流已完成",
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="工作流的最终输出"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"example": {
"output": "{{ llm_qa.output }}",
"description": "输出 LLM 的回答"
}
}

View File

@@ -0,0 +1,53 @@
"""
End 节点实现
工作流的结束节点,输出最终结果。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class EndNode(BaseNode):
"""End 节点
工作流的结束节点,根据配置的模板输出最终结果。
"""
async def execute(self, state: WorkflowState) -> str:
"""执行 end 节点业务逻辑
Args:
state: 工作流状态
Returns:
最终输出字符串
"""
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板
output_template = self.config.get("output")
pool = self.get_variable_pool(state)
print("="*20)
print( pool.get("start.test"))
print("="*20)
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
else:
output = "工作流已完成"
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
print("="*20)
print(output)
print("="*20)
return output

View File

@@ -0,0 +1,15 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
IF_ELSE = "if-else"
CODE = "code"
TRANSFORM = "transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"

View File

@@ -0,0 +1,6 @@
"""LLM 节点"""
from app.core.workflow.nodes.llm.node import LLMNode
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
__all__ = ["LLMNode", "LLMNodeConfig", "MessageConfig"]

View File

@@ -0,0 +1,141 @@
"""LLM 节点配置"""
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
"""验证角色"""
allowed_roles = ["system", "user", "human", "assistant", "ai"]
if v.lower() not in allowed_roles:
raise ValueError(f"角色必须是以下之一: {', '.join(allowed_roles)}")
return v.lower()
class LLMNodeConfig(BaseNodeConfig):
"""LLM 节点配置
支持两种配置方式:
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
ge=0.0,
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="output",
type=VariableType.STRING,
description="LLM 生成的文本输出"
),
VariableDefinition(
name="token_usage",
type=VariableType.OBJECT,
description="Token 使用情况"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [
{
"model_id": "uuid-here",
"prompt": "请回答:{{ sys.message }}",
"temperature": 0.7,
"max_tokens": 1000
},
{
"model_id": "uuid-here",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手"
},
{
"role": "user",
"content": "{{ sys.message }}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
]
}

View File

@@ -0,0 +1,247 @@
"""
LLM 节点实现
调用 LLM 模型进行文本生成。
"""
import logging
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models import ModelConfig
from app.db import get_db, get_db_context
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
class LLMNode(BaseNode):
"""LLM 节点
支持流式和非流式输出,使用 LangChain 标准的消息格式。
配置示例(支持多种消息格式):
1. 简单文本格式:
{
"type": "llm",
"config": {
"model_id": "uuid",
"prompt": "请分析:{{sys.message}}",
"temperature": 0.7,
"max_tokens": 1000
}
}
2. LangChain 消息格式(推荐):
{
"type": "llm",
"config": {
"model_id": "uuid",
"messages": [
{
"role": "system",
"content": "你是一个专业的 AI 助手。"
},
{
"role": "user",
"content": "{{sys.message}}"
}
],
"temperature": 0.7,
"max_tokens": 1000
}
}
支持的角色类型:
- system: 系统消息SystemMessage
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
state: 工作流状态
Returns:
(llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串
"""
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
elif role in ["user", "human"]:
messages.append(HumanMessage(content=content))
elif role in ["ai", "assistant"]:
messages.append(AIMessage(content=content))
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
prompt_template = self.config.get("prompt", "")
prompt_or_messages = self._render_template(prompt_template, state)
# 2. 获取模型配置
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
),
type=model_type
)
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
Args:
state: 工作流状态
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
if hasattr(response, 'content'):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
{"role": msg.__class__.__name__.replace("Message", "").lower(), "content": msg.content}
for msg in prompt_or_messages
] if isinstance(prompt_or_messages, list) else None,
"config": {
"model_id": self.config.get("model_id"),
"temperature": self.config.get("temperature"),
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
usage = business_result.response_metadata.get('token_usage')
if usage:
return {
"prompt_tokens": usage.get('prompt_tokens', 0),
"completion_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
Args:
state: 工作流状态
Yields:
文本片段chunk或完成标记
"""
llm, prompt_or_messages = self._prepare_llm(state)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
# 累积完整响应
full_response = ""
last_chunk = None
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
if hasattr(chunk, 'content'):
content = chunk.content
else:
content = str(chunk)
full_response += content
last_chunk = chunk
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
content=full_response,
response_metadata=last_chunk.response_metadata if hasattr(last_chunk, 'response_metadata') else {}
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -0,0 +1,93 @@
"""
节点工厂
根据节点类型创建相应的节点实例。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.end import EndNode
logger = logging.getLogger(__name__)
class NodeFactory:
"""节点工厂
使用工厂模式创建节点实例,便于扩展和维护。
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
"""注册新的节点类型
Args:
node_type: 节点类型名称
node_class: 节点类
Examples:
>>> class CustomNode(BaseNode):
... async def execute(self, state):
... return {"node_outputs": {self.node_id: {"output": "custom"}}}
>>> NodeFactory.register_node_type("custom", CustomNode)
"""
cls._node_types[node_type] = node_class
logger.info(f"注册节点类型: {node_type} -> {node_class.__name__}")
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
Returns:
节点实例或 None对于不支持的节点类型
Raises:
ValueError: 不支持的节点类型
"""
node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod
def get_supported_types(cls) -> list[str]:
"""获取支持的节点类型列表
Returns:
节点类型列表
"""
return list(cls._node_types.keys())

View File

@@ -0,0 +1,6 @@
"""Start 节点"""
from app.core.workflow.nodes.start.node import StartNode
from app.core.workflow.nodes.start.config import StartNodeConfig
__all__ = ["StartNode", "StartNodeConfig"]

View File

@@ -0,0 +1,87 @@
"""Start 节点配置"""
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class StartNodeConfig(BaseNodeConfig):
"""Start 节点配置
Start 节点的作用:
1. 标记工作流的起点
2. 定义自定义输入变量(会作为节点输出,通过 start_node_id.variable_name 访问)
3. 输出系统变量和会话变量
"""
# 自定义输入变量定义
variables: list[VariableDefinition] = Field(
default_factory=list,
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="message",
type=VariableType.STRING,
description="用户输入的消息"
),
VariableDefinition(
name="conversation_vars",
type=VariableType.OBJECT,
description="会话级变量"
),
VariableDefinition(
name="execution_id",
type=VariableType.STRING,
description="执行 ID"
),
VariableDefinition(
name="conversation_id",
type=VariableType.STRING,
description="会话 ID"
),
VariableDefinition(
name="workspace_id",
type=VariableType.STRING,
description="工作空间 ID"
),
VariableDefinition(
name="user_id",
type=VariableType.STRING,
description="用户 ID"
)
],
description="输出变量定义(自动生成,通常不需要修改)"
)
class Config:
json_schema_extra = {
"examples": [
{
"description": "工作流开始节点",
"variables": []
},
{
"description": "带自定义变量的开始节点",
"variables": [
{
"name": "language",
"type": "string",
"required": False,
"default": "zh-CN",
"description": "语言设置"
},
{
"name": "max_length",
"type": "number",
"required": False,
"default": 1000,
"description": "最大长度"
}
]
}
]
}

View File

@@ -0,0 +1,136 @@
"""
Start 节点实现
工作流的起始节点,定义输入变量并输出系统参数。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.start.config import StartNodeConfig
logger = logging.getLogger(__name__)
class StartNode(BaseNode):
"""Start 节点
工作流的起始节点,负责:
1. 定义工作流的输入变量(通过配置)
2. 输出系统变量sys.*
3. 输出会话变量conv.*
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置
self.typed_config = StartNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行 start 节点业务逻辑
Start 节点输出系统变量、会话变量和自定义变量。
Args:
state: 工作流状态
Returns:
包含系统参数、会话变量和自定义变量的字典
"""
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 创建变量池实例(在方法内复用)
pool = self.get_variable_pool(state)
# 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(pool)
# 返回业务数据(包含自定义变量)
result = {
"message": pool.get("sys.message"),
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"workspace_id": pool.get("sys.workspace_id"),
"user_id": pool.get("sys.user_id"),
**custom_vars # 自定义变量作为节点输出的一部分
}
logger.info(
f"节点 {self.node_id} (Start) 执行完成,"
f"输出了 {len(custom_vars)} 个自定义变量"
)
return result
def _process_custom_variables(self, pool) -> dict[str, Any]:
"""处理自定义变量
从输入数据中提取自定义变量,应用默认值和验证。
Args:
pool: 变量池实例
Returns:
处理后的自定义变量字典
Raises:
ValueError: 缺少必需变量
"""
# 获取输入数据中的自定义变量
input_variables = pool.get("sys.input_variables", default={})
processed = {}
# 遍历配置的变量定义
for var_def in self.typed_config.variables:
var_name = var_def.name
# 检查变量是否存在
if var_name in input_variables:
# 使用用户提供的值
processed[var_name] = input_variables[var_name]
elif var_def.required:
# 必需变量缺失
raise ValueError(
f"缺少必需的输入变量: {var_name}"
+ (f" ({var_def.description})" if var_def.description else "")
)
elif var_def.default is not None:
# 使用默认值
processed[var_name] = var_def.default
logger.debug(
f"变量 '{var_name}' 使用默认值: {var_def.default}"
)
return processed
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)
Args:
state: 工作流状态
Returns:
输入数据字典
"""
pool = self.get_variable_pool(state)
return {
"execution_id": pool.get("sys.execution_id"),
"conversation_id": pool.get("sys.conversation_id"),
"message": pool.get("sys.message"),
"conversation_vars": pool.get_all_conversation_vars()
}

View File

@@ -0,0 +1,6 @@
"""Transform 节点"""
from app.core.workflow.nodes.transform.node import TransformNode
from app.core.workflow.nodes.transform.config import TransformNodeConfig
__all__ = ["TransformNode", "TransformNodeConfig"]

View File

@@ -0,0 +1,80 @@
"""Transform 节点配置"""
from typing import Literal
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
class TransformNodeConfig(BaseNodeConfig):
"""Transform 节点配置
用于数据转换和处理。
"""
transform_type: Literal["template", "code", "json"] = Field(
default="template",
description="转换类型template(模板), code(代码), json(JSON处理)"
)
# 模板模式
template: str | None = Field(
default=None,
description="转换模板,支持变量引用"
)
# 代码模式
code: str | None = Field(
default=None,
description="Python 代码,用于数据转换"
)
# JSON 模式
json_path: str | None = Field(
default=None,
description="JSON 路径表达式"
)
# 输入变量
inputs: dict[str, str] | None = Field(
default=None,
description="输入变量映射key 为变量名value 为变量选择器"
)
# 输出变量
output_key: str = Field(
default="result",
description="输出变量的键名"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
VariableDefinition(
name="result",
type=VariableType.STRING,
description="转换后的结果"
)
],
description="输出变量定义(根据 output_key 动态生成)"
)
class Config:
json_schema_extra = {
"examples": [
{
"transform_type": "template",
"template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}",
"output_key": "formatted_result"
},
{
"transform_type": "code",
"code": "result = input_text.upper()",
"inputs": {
"input_text": "{{ sys.message }}"
},
"output_key": "uppercase_text"
}
]
}

View File

@@ -0,0 +1,60 @@
"""
Transform 节点实现
数据转换节点,用于处理和转换数据。
"""
import logging
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
logger = logging.getLogger(__name__)
class TransformNode(BaseNode):
"""数据转换节点
配置示例:
{
"type": "transform",
"config": {
"mapping": {
"output_field": "{{node.previous.output}}",
"processed": "{{var.input | upper}}"
}
}
}
"""
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行数据转换
Args:
state: 工作流状态
Returns:
状态更新字典
"""
logger.info(f"节点 {self.node_id} 开始执行数据转换")
# 获取映射配置
mapping = self.config.get("mapping", {})
# 执行数据转换
transformed_data = {}
for target_key, source_template in mapping.items():
# 渲染模板获取值
value = self._render_template(str(source_template), state)
transformed_data[target_key] = value
logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}")
return {
"node_outputs": {
self.node_id: {
"output": transformed_data,
"status": "completed"
}
}
}

View File

@@ -0,0 +1,170 @@
"""
工作流模板加载器
从文件系统加载预定义的工作流模板
"""
import os
import yaml
from pathlib import Path
from typing import Optional
class TemplateLoader:
"""工作流模板加载器"""
def __init__(self, templates_dir: str = "app/templates/workflows"):
"""初始化模板加载器
Args:
templates_dir: 模板目录路径
"""
self.templates_dir = Path(templates_dir)
if not self.templates_dir.exists():
raise ValueError(f"模板目录不存在: {templates_dir}")
def list_templates(self) -> list[dict]:
"""列出所有可用的模板
Returns:
模板列表,每个模板包含 id, name, description 等信息
"""
templates = []
# 遍历模板目录
for template_dir in self.templates_dir.iterdir():
if not template_dir.is_dir():
continue
# 检查是否有 template.yml 文件
template_file = template_dir / "template.yml"
if not template_file.exists():
continue
try:
# 读取模板配置
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 提取模板信息
templates.append({
"id": template_dir.name,
"name": template_data.get("name", template_dir.name),
"description": template_data.get("description", ""),
"category": template_data.get("category", "general"),
"tags": template_data.get("tags", []),
"author": template_data.get("author", ""),
"version": template_data.get("version", "1.0.0")
})
except Exception as e:
print(f"加载模板 {template_dir.name} 失败: {e}")
continue
return templates
def load_template(self, template_id: str) -> Optional[dict]:
"""加载指定的模板
Args:
template_id: 模板 ID目录名
Returns:
模板配置字典,如果模板不存在则返回 None
"""
template_dir = self.templates_dir / template_id
template_file = template_dir / "template.yml"
if not template_file.exists():
return None
try:
with open(template_file, 'r', encoding='utf-8') as f:
template_data = yaml.safe_load(f)
# 返回工作流配置部分
return {
"name": template_data.get("name", template_id),
"description": template_data.get("description", ""),
"nodes": template_data.get("nodes", []),
"edges": template_data.get("edges", []),
"variables": template_data.get("variables", []),
"execution_config": template_data.get("execution_config", {}),
"triggers": template_data.get("triggers", [])
}
except Exception as e:
print(f"加载模板 {template_id} 失败: {e}")
return None
def get_template_readme(self, template_id: str) -> Optional[str]:
"""获取模板的 README 文档
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
template_dir = self.templates_dir / template_id
readme_file = template_dir / "README.md"
if not readme_file.exists():
return None
try:
with open(readme_file, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
print(f"读取模板 {template_id} 的 README 失败: {e}")
return None
# 全局模板加载器实例
_template_loader: Optional[TemplateLoader] = None
def get_template_loader() -> TemplateLoader:
"""获取全局模板加载器实例
Returns:
TemplateLoader 实例
"""
global _template_loader
if _template_loader is None:
_template_loader = TemplateLoader()
return _template_loader
def list_workflow_templates() -> list[dict]:
"""列出所有工作流模板
Returns:
模板列表
"""
loader = get_template_loader()
return loader.list_templates()
def load_workflow_template(template_id: str) -> Optional[dict]:
"""加载工作流模板
Args:
template_id: 模板 ID
Returns:
模板配置,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.load_template(template_id)
def get_workflow_template_readme(template_id: str) -> Optional[str]:
"""获取工作流模板的 README
Args:
template_id: 模板 ID
Returns:
README 内容,如果不存在则返回 None
"""
loader = get_template_loader()
return loader.get_template_readme(template_id)

View File

@@ -0,0 +1,170 @@
"""
模板渲染器
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
"""
import logging
from typing import Any
from jinja2 import Template, TemplateSyntaxError, UndefinedError, Environment, StrictUndefined
logger = logging.getLogger(__name__)
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
"""
self.env = Environment(
undefined=StrictUndefined if strict else None,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
def render(
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板
Args:
template: 模板字符串
variables: 用户定义的变量
node_outputs: 节点输出结果
system_vars: 系统变量
Returns:
渲染后的字符串
Raises:
ValueError: 模板语法错误或变量未定义
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
... "Hello {{var.name}}!",
... {"name": "World"},
... {},
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {}
... )
'分析结果: 正面情绪'
"""
# 构建命名空间上下文
context = {
"var": variables, # 用户变量:{{var.user_input}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
context.update(node_outputs)
# 为了向后兼容,也支持直接访问用户变量
context.update(variables)
context["nodes"] = node_outputs # 旧语法兼容
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
Args:
template: 模板字符串
Returns:
错误列表,如果为空则验证通过
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
return errors
# 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板(便捷函数)
Args:
template: 模板字符串
variables: 用户变量
node_outputs: 节点输出
system_vars: 系统变量
Returns:
渲染后的字符串
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... {},
... {}
... )
'请分析: 这是一段文本'
"""
return _default_renderer.render(template, variables, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
Args:
template: 模板字符串
Returns:
错误列表
"""
return _default_renderer.validate(template)

View File

@@ -0,0 +1,277 @@
"""
工作流配置验证器
验证工作流配置的有效性,确保配置符合规范。
"""
import logging
from typing import Any, Union
logger = logging.getLogger(__name__)
class WorkflowValidator:
"""工作流配置验证器"""
@staticmethod
def validate(workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
(is_valid, errors): 是否有效和错误列表
Examples:
>>> config = {
... "nodes": [
... {"id": "start", "type": "start"},
... {"id": "end", "type": "end"}
... ],
... "edges": [
... {"source": "start", "target": "end"}
... ]
... }
>>> is_valid, errors = WorkflowValidator.validate(config)
>>> is_valid
True
"""
errors = []
# 支持字典和 Pydantic 模型
if isinstance(workflow_config, dict):
nodes = workflow_config.get("nodes", [])
edges = workflow_config.get("edges", [])
variables = workflow_config.get("variables", [])
else:
# Pydantic 模型
nodes = getattr(workflow_config, "nodes", [])
edges = getattr(workflow_config, "edges", [])
variables = getattr(workflow_config, "variables", [])
# 1. 验证 start 节点(有且只有一个)
start_nodes = [n for n in nodes if n.get("type") == "start"]
if len(start_nodes) == 0:
errors.append("工作流必须有一个 start 节点")
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == "end"]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
if len(node_ids) != len(set(node_ids)):
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
# 4. 验证节点必须有 id 和 type
for i, node in enumerate(nodes):
if not node.get("id"):
errors.append(f"节点 #{i} 缺少 id 字段")
if not node.get("type"):
errors.append(f"节点 #{i} (id={node.get('id', 'unknown')}) 缺少 type 字段")
# 5. 验证边的有效性
node_id_set = set(node_ids)
for i, edge in enumerate(edges):
source = edge.get("source")
target = edge.get("target")
if not source:
errors.append(f"边 #{i} 缺少 source 字段")
elif source not in node_id_set:
errors.append(f"边 #{i} 的 source 节点不存在: {source}")
if not target:
errors.append(f"边 #{i} 缺少 target 字段")
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
from app.core.workflow.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:
start_id: 起始节点 ID
edges: 边列表
Returns:
可达节点 ID 集合
"""
reachable = {start_id}
queue = [start_id]
while queue:
current = queue.pop(0)
for edge in edges:
if edge.get("source") == current:
target = edge.get("target")
if target and target not in reachable:
reachable.add(target)
queue.append(target)
return reachable
@staticmethod
def _has_cycle(nodes: list[dict], edges: list[dict]) -> tuple[bool, list[str]]:
"""检测是否存在循环依赖DFS
Args:
nodes: 节点列表
edges: 边列表
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
# 跳过错误边
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []
graph[source].append(target)
# DFS 检测环
visited = set()
rec_stack = set()
path = []
cycle_path = []
def dfs(node: str) -> bool:
"""DFS 检测环,返回是否找到环"""
visited.add(node)
rec_stack.add(node)
path.append(node)
for neighbor in graph.get(node, []):
if neighbor not in visited:
if dfs(neighbor):
return True
elif neighbor in rec_stack:
# 找到环,记录环路径
cycle_start = path.index(neighbor)
cycle_path.extend([*path[cycle_start:], neighbor])
return True
rec_stack.remove(node)
path.pop()
return False
# 检查所有节点
for node_id in graph:
if node_id not in visited:
if dfs(node_id):
return True, cycle_path
return False, []
@staticmethod
def validate_for_publish(workflow_config: dict[str, Any]) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布(更严格的验证)
Args:
workflow_config: 工作流配置
Returns:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
if not is_valid:
return False, errors
# 额外的发布验证
nodes = workflow_config.get("nodes", [])
# 1. 验证所有节点都有名称
for node in nodes:
if node.get("type") not in ["start", "end"] and not node.get("name"):
errors.append(
f"节点 {node.get('id')} 缺少名称(发布时必须提供)"
)
# 2. 验证所有非 start/end 节点都有配置
for node in nodes:
node_type = node.get("type")
if node_type not in ["start", "end"]:
config = node.get("config")
if not config or not isinstance(config, dict):
errors.append(
f"节点 {node.get('id')} 缺少配置(发布时必须提供)"
)
# 3. 验证必填变量
variables = workflow_config.get("variables", [])
required_vars = [v for v in variables if v.get("required")]
if required_vars:
# 这里只是提示,实际执行时会检查
logger.info(
f"工作流包含 {len(required_vars)} 个必填变量: "
f"{[v.get('name') for v in required_vars]}"
)
return len(errors) == 0, errors
def validate_workflow_config(
workflow_config: dict[str, Any],
for_publish: bool = False
) -> tuple[bool, list[str]]:
"""验证工作流配置(便捷函数)
Args:
workflow_config: 工作流配置
for_publish: 是否为发布验证(更严格)
Returns:
(is_valid, errors): 是否有效和错误列表
"""
if for_publish:
return WorkflowValidator.validate_for_publish(workflow_config)
else:
return WorkflowValidator.validate(workflow_config)

View File

@@ -0,0 +1,293 @@
"""
变量池 (Variable Pool)
工作流执行的数据中心,管理所有变量的存储和访问。
变量类型:
1. 系统变量 (sys.*) - 系统内置变量execution_id, workspace_id, user_id, message 等)
2. 节点输出 (node_id.*) - 节点执行结果
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
class VariableSelector:
"""变量选择器
用于引用变量的路径表示。
Examples:
>>> selector = VariableSelector(["sys", "message"])
>>> selector = VariableSelector(["node_A", "output"])
>>> selector = VariableSelector.from_string("sys.message")
"""
def __init__(self, path: list[str]):
"""初始化变量选择器
Args:
path: 变量路径,如 ["sys", "message"] 或 ["node_A", "output"]
"""
if not path or len(path) < 1:
raise ValueError("变量路径不能为空")
self.path = path
self.namespace = path[0] # sys, var, 或 node_id
self.key = path[1] if len(path) > 1 else None
@classmethod
def from_string(cls, selector_str: str) -> "VariableSelector":
"""从字符串创建选择器
Args:
selector_str: 选择器字符串,如 "sys.message""node_A.output"
Returns:
VariableSelector 实例
Examples:
>>> selector = VariableSelector.from_string("sys.message")
>>> selector = VariableSelector.from_string("llm_qa.output")
"""
path = selector_str.split(".")
return cls(path)
def __str__(self) -> str:
return ".".join(self.path)
def __repr__(self) -> str:
return f"VariableSelector({self.path})"
class VariablePool:
"""变量池
管理工作流执行过程中的所有变量。
变量命名空间:
- sys.*: 系统变量message, execution_id, workspace_id, user_id, conversation_id
- conv.*: 会话变量(跨多轮对话保持的变量)
- <node_id>.*: 节点输出
Examples:
>>> pool = VariablePool(state)
>>> pool.get(["sys", "message"])
"用户的问题"
>>> pool.get(["llm_qa", "output"])
"AI 的回答"
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: dict[str, Any]):
"""初始化变量池
Args:
state: 工作流状态LangGraph State
"""
self.state = state
def get(self, selector: list[str] | str, default: Any = None) -> Any:
"""获取变量值
Args:
selector: 变量选择器,可以是列表或字符串
default: 默认值(变量不存在时返回)
Returns:
变量值
Examples:
>>> pool.get(["sys", "message"])
>>> pool.get("sys.message")
>>> pool.get(["llm_qa", "output"])
>>> pool.get("llm_qa.output")
Raises:
KeyError: 变量不存在且未提供默认值
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 1:
raise ValueError("变量选择器不能为空")
namespace = selector[0]
try:
# 系统变量
if namespace == "sys":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("sys", {})
return self.state.get("variables", {}).get("sys", {}).get(key, default)
# 会话变量
elif namespace == "conv":
key = selector[1] if len(selector) > 1 else None
if not key:
return self.state.get("variables", {}).get("conv", {})
return self.state.get("variables", {}).get("conv", {}).get(key, default)
# 节点输出(从 runtime_vars 读取)
else:
node_id = namespace
runtime_vars = self.state.get("runtime_vars", {})
if node_id not in runtime_vars:
if default is not None:
return default
raise KeyError(f"节点 '{node_id}' 的输出不存在")
node_var = runtime_vars[node_id]
# 如果只有节点 ID返回整个变量
if len(selector) == 1:
return node_var
# 获取特定字段
# 支持嵌套访问,如 node_id.field.subfield
result = node_var
for k in selector[1:]:
if isinstance(result, dict):
result = result.get(k)
if result is None:
if default is not None:
return default
raise KeyError(f"字段 '{'.'.join(selector)}' 不存在")
else:
if default is not None:
return default
raise KeyError(f"无法访问 '{'.'.join(selector)}'")
return result
except KeyError:
if default is not None:
return default
raise
def set(self, selector: list[str] | str, value: Any):
"""设置变量值
Args:
selector: 变量选择器
value: 变量值
Examples:
>>> pool.set(["conv", "user_name"], "张三")
>>> pool.set("conv.user_name", "张三")
Note:
- 只能设置会话变量 (conv.*)
- 系统变量和节点输出是只读的
"""
# 转换为 VariableSelector
if isinstance(selector, str):
selector = VariableSelector.from_string(selector).path
if not selector or len(selector) < 2:
raise ValueError("变量选择器必须包含命名空间和键名")
namespace = selector[0]
if namespace != "conv":
raise ValueError("只能设置会话变量 (conv.*)")
key = selector[1]
# 确保 variables 结构存在
if "variables" not in self.state:
self.state["variables"] = {"sys": {}, "conv": {}}
if "conv" not in self.state["variables"]:
self.state["variables"]["conv"] = {}
# 设置值
self.state["variables"]["conv"][key] = value
logger.debug(f"设置变量: {'.'.join(selector)} = {value}")
def has(self, selector: list[str] | str) -> bool:
"""检查变量是否存在
Args:
selector: 变量选择器
Returns:
变量是否存在
Examples:
>>> pool.has(["sys", "message"])
True
>>> pool.has("llm_qa.output")
False
"""
try:
self.get(selector)
return True
except KeyError:
return False
def get_all_system_vars(self) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
return self.state.get("variables", {}).get("sys", {})
def get_all_conversation_vars(self) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
return self.state.get("variables", {}).get("conv", {})
def get_all_node_outputs(self) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
return self.state.get("runtime_vars", {})
def get_node_output(self, node_id: str) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
Returns:
节点输出或 None
"""
return self.state.get("runtime_vars", {}).get(node_id)
def to_dict(self) -> dict[str, Any]:
"""导出为字典
Returns:
包含所有变量的字典
"""
return {
"system": self.get_all_system_vars(),
"conversation": self.get_all_conversation_vars(),
"nodes": self.get_all_node_outputs() # 从 runtime_vars 读取
}
def __repr__(self) -> str:
sys_vars = self.get_all_system_vars()
conv_vars = self.get_all_conversation_vars()
runtime_vars = self.get_all_node_outputs()
return (
f"VariablePool(\n"
f" system_vars={len(sys_vars)},\n"
f" conversation_vars={len(conv_vars)},\n"
f" runtime_vars={len(runtime_vars)}\n"
f")"
)

View File

@@ -1,10 +1,9 @@
import os
import subprocess
from dotenv import load_dotenv
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from app.core.response_utils import fail
from app.core.logging_config import LoggingConfig, get_logger
@@ -38,9 +37,13 @@ router = APIRouter(prefix="/memory", tags=["Memory"])
# 管理端 API (JWT 认证)
from app.controllers import manager_router
# 服务端 API (API Key 认证)
from app.controllers.service import service_router
from app.core.config import settings
from app.core.error_codes import BizCode, HTTP_MAPPING
from app.core.exceptions import BusinessException
from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail
# Initialize logging system
LoggingConfig.setup_logging()
@@ -414,5 +417,4 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -15,9 +15,11 @@ from .end_user_model import EndUser
from .appshare_model import AppShare
from .release_share_model import ReleaseShare
from .conversation_model import Conversation, Message
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType, ResourceType
from .api_key_model import ApiKey, ApiKeyLog, ApiKeyType
from .data_config_model import DataConfig
from .multi_agent_model import MultiAgentConfig, AgentInvocation
from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution
from .retrieval_info import RetrievalInfo
__all__ = [
"Tenants",
@@ -46,8 +48,11 @@ __all__ = [
"ApiKey",
"ApiKeyLog",
"ApiKeyType",
"ResourceType",
"DataConfig",
"MultiAgentConfig",
"AgentInvocation"
"AgentInvocation",
"WorkflowConfig",
"WorkflowExecution",
"WorkflowNodeExecution",
"RetrievalInfo"
]

View File

@@ -2,7 +2,7 @@
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text, Enum
from sqlalchemy import Column, String, Boolean, DateTime, Integer, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from enum import StrEnum
@@ -12,18 +12,10 @@ from app.db import Base
class ApiKeyType(StrEnum):
"""API Key 类型"""
APP = "app" # 应用 API Key
RAG = "rag" # RAG API Key
MEMORY = "memory" # Memory API Key
class ResourceType(StrEnum):
"""资源类型枚举"""
AGENT = "Agent" # 智能体
CLUSTER = "Cluster" # 集群
WORKFLOW = "Workflow" # 工作流
KNOWLEDGE = "Knowledge" # 知识库
MEMORY_ENGINE = "Memory_Engine" # 记忆引擎
AGENT = "agent" # 智能体
CLUSTER = "multi_agent" # 集群
WORKFLOW = "workflow" # 工作流
SERVICE = "service" # 服务
class ApiKey(Base):
@@ -35,18 +27,16 @@ class ApiKey(Base):
# 基本信息
name = Column(String(255), nullable=False, comment="API Key 名称")
description = Column(Text, comment="描述")
key_prefix = Column(String(20), nullable=False, comment="Key 前缀")
key_hash = Column(String(255), nullable=False, unique=True, index=True, comment="Key 哈希值")
api_key = Column(String(255), nullable=False, unique=True, index=True, comment="API Key 明文")
# 类型和权限
type = Column(String(50), nullable=False, index=True, comment="API Key 类型")
scopes = Column(JSONB, nullable=False, default=list, comment="权限范围列表")
scopes = Column(JSONB, default=list, comment="权限范围列表")
# 关联资源
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False,
index=True, comment="所属工作空间")
resource_id = Column(UUID(as_uuid=True), index=True, comment="关联资源ID")
resource_type = Column(String(50), comment="资源类型")
# 限制和配额
rate_limit = Column(Integer, default=10, comment="QPS限制请求/秒)")

View File

@@ -86,6 +86,14 @@ class App(Base):
uselist=False,
cascade="all, delete-orphan",
)
# 一对一:工作流配置(仅当 type=workflow 时有效)
workflow_config = relationship(
"WorkflowConfig",
back_populates="app",
uselist=False,
cascade="all, delete-orphan",
)
# 发布版本关联
current_release = relationship("AppRelease", foreign_keys=[current_release_id])

View File

@@ -61,7 +61,7 @@ class ModelConfig(Base):
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
# 关联关系
api_keys = relationship("ModelApiKey", back_populates="model_config", cascade="all, delete-orphan")

View File

@@ -0,0 +1,196 @@
"""
工作流相关数据模型
"""
import datetime
import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, ForeignKey, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.orm import relationship
from app.db import Base
class WorkflowConfig(Base):
"""工作流配置表"""
__tablename__ = "workflow_configs"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联应用(一对一)
app_id = Column(
UUID(as_uuid=True),
ForeignKey("apps.id", ondelete="CASCADE"),
nullable=False,
unique=True,
index=True
)
# 节点和边的定义JSON 格式)
nodes = Column(JSONB, nullable=False, default=list)
edges = Column(JSONB, nullable=False, default=list)
# 全局变量定义
variables = Column(JSONB, default=list)
# 执行配置
execution_config = Column(JSONB, nullable=False, default=dict)
# 触发器配置(可选)
triggers = Column(JSONB, default=list)
# 状态
is_active = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(
DateTime,
nullable=False,
default=datetime.datetime.now,
onupdate=datetime.datetime.now
)
# 关系
app = relationship("App", back_populates="workflow_config")
executions = relationship(
"WorkflowExecution",
back_populates="workflow_config",
cascade="all, delete-orphan"
)
def __repr__(self):
return f"<WorkflowConfig(id={self.id}, app_id={self.app_id})>"
class WorkflowExecution(Base):
"""工作流执行记录表"""
__tablename__ = "workflow_executions"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联信息
workflow_config_id = Column(
UUID(as_uuid=True),
ForeignKey("workflow_configs.id", ondelete="CASCADE"),
nullable=False,
index=True
)
app_id = Column(
UUID(as_uuid=True),
ForeignKey("apps.id", ondelete="CASCADE"),
nullable=False,
index=True
)
conversation_id = Column(
UUID(as_uuid=True),
ForeignKey("conversations.id", ondelete="SET NULL"),
nullable=True,
index=True
)
# 执行信息
execution_id = Column(String(100), nullable=False, unique=True, index=True)
trigger_type = Column(String(20), nullable=False) # manual, schedule, webhook, event
triggered_by = Column(
UUID(as_uuid=True),
ForeignKey("users.id"),
nullable=True
)
# 输入输出
input_data = Column(JSONB)
output_data = Column(JSONB)
context = Column(JSONB, default=dict)
# 状态
status = Column(String(20), nullable=False, default="pending", index=True)
# 可选值pending, running, completed, failed, cancelled, timeout
error_message = Column(Text)
error_node_id = Column(String(100))
# 性能指标
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
completed_at = Column(DateTime)
elapsed_time = Column(Float) # 耗时(秒)
# 资源使用
token_usage = Column(JSONB)
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
meta_data = Column(JSONB, default=dict)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
# 关系
workflow_config = relationship("WorkflowConfig", back_populates="executions")
app = relationship("App")
conversation = relationship("Conversation")
triggered_by_user = relationship("User", foreign_keys=[triggered_by])
node_executions = relationship(
"WorkflowNodeExecution",
back_populates="execution",
cascade="all, delete-orphan",
order_by="WorkflowNodeExecution.execution_order"
)
def __repr__(self):
return f"<WorkflowExecution(id={self.id}, execution_id={self.execution_id}, status={self.status})>"
class WorkflowNodeExecution(Base):
"""工作流节点执行记录表"""
__tablename__ = "workflow_node_executions"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联执行
execution_id = Column(
UUID(as_uuid=True),
ForeignKey("workflow_executions.id", ondelete="CASCADE"),
nullable=False,
index=True
)
# 节点信息
node_id = Column(String(100), nullable=False, index=True)
node_type = Column(String(20), nullable=False)
node_name = Column(String(100))
# 执行顺序
execution_order = Column(Integer, nullable=False)
retry_count = Column(Integer, nullable=False, default=0)
# 输入输出
input_data = Column(JSONB)
output_data = Column(JSONB)
# 状态
status = Column(String(20), nullable=False, default="pending", index=True)
# 可选值pending, running, completed, failed, skipped, cached
error_message = Column(Text)
# 性能指标
started_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
completed_at = Column(DateTime)
elapsed_time = Column(Float) # 耗时(秒)
# 资源使用(针对 LLM 节点)
token_usage = Column(JSONB)
# 缓存信息
cache_hit = Column(Boolean, default=False)
cache_key = Column(String(255))
# 元数据(使用 meta_data 避免与 SQLAlchemy 保留字 metadata 冲突)
meta_data = Column(JSONB, default=dict)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
# 关系
execution = relationship("WorkflowExecution", back_populates="node_executions")
def __repr__(self):
return f"<WorkflowNodeExecution(id={self.id}, node_id={self.node_id}, status={self.status})>"

View File

@@ -27,9 +27,9 @@ class ApiKeyRepository:
return db.get(ApiKey, api_key_id)
@staticmethod
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
"""根据哈希值获取 API Key"""
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]:
"""根据 API Key 获取 API Key"""
stmt = select(ApiKey).where(ApiKey.api_key == api_key)
return db.scalars(stmt).first()
@staticmethod
@@ -63,11 +63,15 @@ class ApiKeyRepository:
@staticmethod
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
"""更新 API Key"""
allow_none_fields = {"description", "quota_limit", "expires_at"}
api_key = db.get(ApiKey, api_key_id)
if api_key:
for key, value in update_data.items():
if value is not None:
if key in allow_none_fields:
setattr(api_key, key, value)
else:
if value is not None:
setattr(api_key, key, value)
db.flush()
return api_key
@@ -122,6 +126,7 @@ class ApiKeyRepository:
"quota_used": api_key.quota_used,
"quota_limit": api_key.quota_limit,
"last_used_at": api_key.last_used_at,
"rate_limit": api_key.rate_limit,
"avg_response_time": float(avg_response_time) if avg_response_time else None
}

View File

@@ -14,7 +14,7 @@ class AppRepository:
def __init__(self, db: Session):
self.db = db
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
"""根据工作空间ID查询应用"""
try:
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
@@ -24,7 +24,19 @@ class AppRepository:
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
raise
def get_apps_by_id(self, app_id: uuid.UUID) -> App:
try:
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first()
return app
except Exception as e:
raise
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_workspace_id(workspace_id)
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
"""根据工作空间ID查询应用"""
repo = AppRepository(db)
return repo.get_apps_by_id(app_id)

View File

@@ -0,0 +1,247 @@
"""
工作流数据访问层
"""
import uuid
from typing import Any, Annotated
from sqlalchemy.orm import Session
from sqlalchemy import desc
from fastapi import Depends
from app.models.workflow_model import (
WorkflowConfig,
WorkflowExecution,
WorkflowNodeExecution
)
from app.db import get_db
class WorkflowConfigRepository:
"""工作流配置仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None:
"""根据应用 ID 获取工作流配置
Args:
app_id: 应用 ID
Returns:
工作流配置或 None
"""
return self.db.query(WorkflowConfig).filter(
WorkflowConfig.app_id == app_id,
WorkflowConfig.is_active == True
).first()
def create_or_update(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None
) -> WorkflowConfig:
"""创建或更新工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
Returns:
工作流配置
"""
# 查找现有配置
existing = self.get_by_app_id(app_id)
if existing:
# 更新现有配置
existing.nodes = nodes
existing.edges = edges
if variables is not None:
existing.variables = variables
if execution_config is not None:
existing.execution_config = execution_config
if triggers is not None:
existing.triggers = triggers
self.db.commit()
self.db.refresh(existing)
return existing
else:
# 创建新配置
config = WorkflowConfig(
app_id=app_id,
nodes=nodes,
edges=edges,
variables=variables or [],
execution_config=execution_config or {},
triggers=triggers or []
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
return config
class WorkflowExecutionRepository:
"""工作流执行记录仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None:
"""根据执行 ID 获取执行记录
Args:
execution_id: 执行 ID
Returns:
执行记录或 None
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.execution_id == execution_id
).first()
def get_by_app_id(
self,
app_id: uuid.UUID,
limit: int = 50,
offset: int = 0
) -> list[WorkflowExecution]:
"""根据应用 ID 获取执行记录列表
Args:
app_id: 应用 ID
limit: 返回数量限制
offset: 偏移量
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).order_by(
desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all()
def get_by_conversation_id(
self,
conversation_id: uuid.UUID
) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表
Args:
conversation_id: 会话 ID
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id
).order_by(
desc(WorkflowExecution.started_at)
).all()
def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数
Args:
app_id: 应用 ID
Returns:
执行次数
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).count()
def count_by_status(self, app_id: uuid.UUID, status: str) -> int:
"""统计指定状态的执行次数
Args:
app_id: 应用 ID
status: 状态
Returns:
执行次数
"""
return self.db.query(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id,
WorkflowExecution.status == status
).count()
class WorkflowNodeExecutionRepository:
"""工作流节点执行记录仓储"""
def __init__(self, db: Session):
self.db = db
def get_by_execution_id(
self,
execution_id: uuid.UUID
) -> list[WorkflowNodeExecution]:
"""根据执行 ID 获取节点执行记录列表
Args:
execution_id: 执行 ID
Returns:
节点执行记录列表(按执行顺序排序)
"""
return self.db.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id
).order_by(
WorkflowNodeExecution.execution_order
).all()
def get_by_node_id(
self,
execution_id: uuid.UUID,
node_id: str
) -> list[WorkflowNodeExecution]:
"""根据节点 ID 获取节点执行记录(可能有多次重试)
Args:
execution_id: 执行 ID
node_id: 节点 ID
Returns:
节点执行记录列表
"""
return self.db.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id
).order_by(
WorkflowNodeExecution.retry_count
).all()
# ==================== 依赖注入函数 ====================
def get_workflow_config_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowConfigRepository:
"""获取工作流配置仓储(依赖注入)"""
return WorkflowConfigRepository(db)
def get_workflow_execution_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowExecutionRepository:
"""获取工作流执行记录仓储(依赖注入)"""
return WorkflowExecutionRepository(db)
def get_workflow_node_execution_repository(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowNodeExecutionRepository:
"""获取工作流节点执行记录仓储(依赖注入)"""
return WorkflowNodeExecutionRepository(db)

View File

@@ -1,11 +1,11 @@
"""API Key Schema"""
import datetime
import uuid
from pydantic import BaseModel, Field, ConfigDict
from pydantic.v1 import validator
from pydantic import BaseModel, Field, ConfigDict, field_validator, field_serializer, computed_field
from typing import Optional, List
from app.models.api_key_model import ApiKeyType, ResourceType
from app.models.api_key_model import ApiKeyType
from app.core.api_key_utils import timestamp_to_datetime, datetime_to_timestamp
class ApiKeyCreate(BaseModel):
@@ -15,20 +15,34 @@ class ApiKeyCreate(BaseModel):
type: ApiKeyType = Field(..., description="API Key 类型")
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
resource_type: Optional[ResourceType] = Field(None, description="资源类型")
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@validator('scopes')
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_validator('expires_at', mode='before')
@classmethod
def parse_expires_at(cls, v):
"""将时间戳转换为datetime"""
if isinstance(v, (int, float)):
return timestamp_to_datetime(v)
return v
@field_validator('scopes')
@classmethod
def validate_scopes(cls, v):
"""验证权限范围格式"""
valid_scopes = [
"app:all",
"rag:search", "rag:upload", "rag:delete",
"memory:read", "memory:write", "memory:delete", "memory:search"
]
if v is None:
return []
valid_scopes = ["app", "rag", "memory"]
for scope in v:
if scope not in valid_scopes:
raise ValueError(f"无效范围: {scope}")
@@ -46,14 +60,29 @@ class ApiKeyUpdate(BaseModel):
is_active: Optional[bool] = Field(None, description="是否激活")
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@validator('scopes')
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_validator('expires_at', mode='before')
@classmethod
def parse_expires_at(cls, v):
"""将时间戳转换为datetime"""
if isinstance(v, (int, float)):
return timestamp_to_datetime(v)
return v
@field_validator('scopes')
@classmethod
def validate_scopes(cls, v):
"""验证权限范围格式"""
valid_scopes = {
'app:all',
'rag:search', 'rag:upload', 'rag:delete',
'memory:read', 'memory:write', 'memory:delete', 'memory:search'
}
if v is None:
return v
valid_scopes = ["app", "rag", "memory"]
for scope in v:
if scope not in valid_scopes:
raise ValueError(f"无效范围: {scope}")
@@ -67,18 +96,31 @@ class ApiKeyResponse(BaseModel):
id: uuid.UUID
name: str
description: Optional[str]
api_key: str = Field(..., description="API Key 明文(仅创建时返回)")
key_prefix: str
api_key: str
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
rate_limit: int
daily_request_limit: int
quota_limit: Optional[int]
is_active: bool
expires_at: Optional[datetime.datetime]
created_at: datetime.datetime
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'created_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKey(BaseModel):
"""API Key 信息(不包含明文 Key"""
@@ -87,11 +129,10 @@ class ApiKey(BaseModel):
id: uuid.UUID
name: str
description: Optional[str]
key_prefix: str
api_key: str
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
rate_limit: int
daily_request_limit: int
quota_limit: Optional[int]
@@ -105,6 +146,20 @@ class ApiKey(BaseModel):
created_at: datetime.datetime
updated_at: datetime.datetime
@computed_field
@property
def is_expired(self) -> bool:
"""检查API Key是否已过期"""
if not self.expires_at:
return False
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKeyStats(BaseModel):
"""API Key 使用统计"""
@@ -115,6 +170,12 @@ class ApiKeyStats(BaseModel):
last_used_at: Optional[datetime.datetime] = Field(None, description="最后使用时间")
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
@field_serializer('last_used_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
class ApiKeyQuery(BaseModel):
"""API Key 查询参数"""
@@ -132,7 +193,6 @@ class ApiKeyAuth(BaseModel):
type: str
scopes: List[str]
resource_id: Optional[uuid.UUID]
resource_type: Optional[str]
class ApiKeyLog(BaseModel):
@@ -157,3 +217,9 @@ class ApiKeyLog(BaseModel):
# 时间信息
created_at: datetime.datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime.datetime) -> int:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)

View File

@@ -46,6 +46,7 @@ class ConflictResultSchema(BaseModel):
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -60,6 +61,7 @@ class ConflictSchema(BaseModel):
conflict_memory: Optional[BaseDataSchema] = Field(None, description="The conflict memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_data(cls, v):
if isinstance(v, dict):
d = v.get("data")
@@ -88,6 +90,7 @@ class ReflexionResultSchema(BaseModel):
resolved: Optional[ResolvedSchema] = Field(None, description="The resolved memory data.")
@model_validator(mode="before")
@classmethod
def _normalize_resolved(cls, v):
if isinstance(v, dict):
conflict = v.get("conflict")
@@ -311,7 +314,7 @@ class ApiResponse(BaseModel): # 通用API响应模型
def _now_ms() -> int:
return int(round(time.time() * 1000))
return round(time.time() * 1000)
def ok(msg: str = "OK", data: Optional[Any] = None, time: Optional[int] = None) -> ApiResponse:

View File

@@ -0,0 +1,215 @@
"""
工作流相关的 Pydantic Schema
"""
import datetime
import uuid
from typing import Any
from pydantic import BaseModel, Field, ConfigDict, field_serializer
# ==================== 节点和边定义 ====================
class NodeConfig(BaseModel):
"""节点配置"""
model_config = ConfigDict(extra="allow") # 允许额外字段
class NodeDefinition(BaseModel):
"""节点定义"""
id: str = Field(..., description="节点唯一标识")
type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code")
name: str | None = Field(None, description="节点名称")
description: str | None = Field(None, description="节点描述")
config: dict[str, Any] = Field(default_factory=dict, description="节点配置")
position: dict[str, float] | None = Field(None, description="节点位置 {x, y}")
error_handling: dict[str, Any] | None = Field(None, description="错误处理配置")
cache: dict[str, Any] | None = Field(None, description="缓存配置")
class EdgeDefinition(BaseModel):
"""边定义"""
id: str | None = Field(None, description="边唯一标识(可选)")
source: str = Field(..., description="源节点 ID")
target: str = Field(..., description="目标节点 ID")
type: str | None = Field(None, description="边类型: normal, error")
condition: str | None = Field(None, description="条件表达式(条件边)")
label: str | None = Field(None, description="边标签")
class VariableDefinition(BaseModel):
"""变量定义"""
name: str = Field(..., description="变量名称")
type: str = Field(default="string", description="变量类型: string, number, boolean, object, array")
required: bool = Field(default=False, description="是否必填")
default: Any = Field(None, description="默认值")
description: str | None = Field(None, description="变量描述")
class ExecutionConfig(BaseModel):
"""执行配置"""
max_iterations: int = Field(default=100, ge=1, le=1000, description="最大迭代次数")
timeout: int = Field(default=600, ge=10, le=3600, description="全局超时时间(秒)")
enable_cache: bool = Field(default=True, description="是否启用节点缓存")
parallel_limit: int = Field(default=5, ge=1, le=20, description="并行执行限制")
class TriggerConfig(BaseModel):
"""触发器配置"""
type: str = Field(..., description="触发器类型: schedule, webhook, event")
config: dict[str, Any] = Field(default_factory=dict, description="触发器配置")
# ==================== 工作流配置 ====================
class WorkflowConfigCreate(BaseModel):
"""创建工作流配置"""
nodes: list[NodeDefinition] = Field(default_factory=list, description="节点列表")
edges: list[EdgeDefinition] = Field(default_factory=list, description="边列表")
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
class WorkflowConfigUpdate(BaseModel):
"""更新工作流配置"""
nodes: list[NodeDefinition] | None = None
edges: list[EdgeDefinition] | None = None
variables: list[VariableDefinition] | None = None
execution_config: ExecutionConfig | None = None
triggers: list[TriggerConfig] | None = None
class WorkflowConfig(BaseModel):
"""工作流配置输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
app_id: uuid.UUID
nodes: list[dict[str, Any]]
edges: list[dict[str, Any]]
variables: list[dict[str, Any]]
execution_config: dict[str, Any]
triggers: list[dict[str, Any]]
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# ==================== 工作流执行 ====================
class WorkflowExecutionRequest(BaseModel):
"""工作流执行请求"""
message: str | None = Field(None, description="用户消息(可选)")
variables: dict[str, Any] = Field(default_factory=dict, description="输入变量")
conversation_id: str | None = Field(None, description="会话 ID用于关联对话")
stream: bool = Field(default=False, description="是否流式返回")
class WorkflowExecutionResponse(BaseModel):
"""工作流执行响应(非流式)"""
execution_id: str = Field(..., description="执行 ID")
status: str = Field(..., description="执行状态")
output: str | None = Field(None, description="最终输出(字符串,便于快速访问)")
output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据")
error_message: str | None = Field(None, description="错误信息")
elapsed_time: float | None = Field(None, description="耗时(秒)")
token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}")
class WorkflowExecutionStreamChunk(BaseModel):
"""工作流执行流式响应块"""
type: str = Field(..., description="事件类型: node_start, token, node_complete, error_redirect, workflow_complete")
execution_id: str = Field(..., description="执行 ID")
data: dict[str, Any] = Field(default_factory=dict, description="事件数据")
class WorkflowExecution(BaseModel):
"""工作流执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
workflow_config_id: uuid.UUID
app_id: uuid.UUID
conversation_id: uuid.UUID | None
execution_id: str
trigger_type: str
triggered_by: uuid.UUID | None
input_data: dict[str, Any] | None
output_data: dict[str, Any] | None
context: dict[str, Any]
status: str
error_message: str | None
error_node_id: str | None
started_at: datetime.datetime
completed_at: datetime.datetime | None
elapsed_time: float | None
token_usage: dict[str, Any] | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class WorkflowNodeExecution(BaseModel):
"""工作流节点执行记录输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
execution_id: uuid.UUID
node_id: str
node_type: str
node_name: str | None
execution_order: int
retry_count: int
input_data: dict[str, Any] | None
output_data: dict[str, Any] | None
status: str
error_message: str | None
started_at: datetime.datetime
completed_at: datetime.datetime | None
elapsed_time: float | None
token_usage: dict[str, Any] | None
cache_hit: bool
cache_key: str | None
meta_data: dict[str, Any]
created_at: datetime.datetime
@field_serializer("started_at", when_used="json")
def _serialize_started_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("completed_at", when_used="json")
def _serialize_completed_at(self, dt: datetime.datetime | None):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# ==================== 验证响应 ====================
class WorkflowValidationResponse(BaseModel):
"""工作流验证响应"""
is_valid: bool = Field(..., description="是否有效")
errors: list[str] = Field(default_factory=list, description="错误列表")
warnings: list[str] = Field(default_factory=list, description="警告列表")

View File

@@ -13,7 +13,7 @@ from app.models.api_key_model import ApiKey
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
from app.core.api_key_utils import generate_api_key
from app.core.exceptions import (
BusinessException,
)
@@ -33,48 +33,39 @@ class ApiKeyService:
workspace_id: uuid.UUID,
user_id: uuid.UUID,
data: api_key_schema.ApiKeyCreate
) -> Tuple[ApiKey, str]:
) -> ApiKey:
"""
创建 API Key
Returns:
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
ApiKey: API Key 对象
"""
try:
# 验证资源绑定
if data.resource_type or data.resource_id:
is_valid, error_msg = validate_resource_binding(
data.resource_type, str(data.resource_id) if data.resource_id else None
)
if not is_valid:
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.resource_id == data.resource_id,
ApiKey.name == data.name,
ApiKey.is_active
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 生成 API Key
api_key, key_hash, key_prefix = generate_api_key(data.type)
api_key = generate_api_key(data.type)
# 创建数据
api_key_data = {
"id": uuid.uuid4(),
"name": data.name,
"description": data.description,
"key_prefix": key_prefix,
"key_hash": key_hash,
"api_key": api_key,
"type": data.type,
"scopes": data.scopes,
"workspace_id": workspace_id,
"resource_id": data.resource_id,
"resource_type": data.resource_type,
"rate_limit": data.rate_limit or 10,
"daily_request_limit": data.daily_request_limit or 10000,
"rate_limit": data.rate_limit,
"daily_request_limit": data.daily_request_limit,
"quota_limit": data.quota_limit,
"expires_at": data.expires_at,
"created_by": user_id,
@@ -90,7 +81,7 @@ class ApiKeyService:
"type": data.type
})
return api_key_obj, api_key
return api_key_obj
except Exception as e:
db.rollback()
@@ -152,13 +143,14 @@ class ApiKeyService:
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.resource_id == data.resource_id,
ApiKey.name == data.name,
ApiKey.is_active,
ApiKey.id != api_key_id
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)
@@ -188,7 +180,7 @@ class ApiKeyService:
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Tuple[ApiKey, str]:
) -> ApiKey:
"""重新生成 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
@@ -197,18 +189,17 @@ class ApiKeyService:
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
# 生成新的 API Key
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
new_api_key = generate_api_key(api_key.type)
# 更新
ApiKeyRepository.update(db, api_key_id, {
"key_hash": key_hash,
"key_prefix": key_prefix
"api_key": new_api_key
})
db.commit()
db.refresh(api_key)
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
return api_key, new_api_key
return api_key
@staticmethod
def get_stats(
@@ -330,7 +321,6 @@ class RateLimiterService:
"X-RateLimit-Reset": str(qps_info["reset"])
}
# Check daily requests
daily_ok, daily_info = await self.check_daily_requests(
api_key.id,
api_key.daily_request_limit
@@ -342,7 +332,6 @@ class RateLimiterService:
"X-RateLimit-Reset": str(daily_info["reset"])
}
# All checks passed
headers = {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
@@ -363,13 +352,12 @@ class ApiKeyAuthService:
验证API Key 有效性
检查:
1. Key hash 是否存在
1. API Key 是否存在
2. is_active 是否为true
3. expires_at 是否未过期
4. quota 是否未超限
"""
key_hash = hash_api_key(api_key)
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
api_key_obj = ApiKeyRepository.get_by_api_key(db, api_key)
if not api_key_obj:
return None
@@ -393,14 +381,7 @@ class ApiKeyAuthService:
@staticmethod
def check_resource(
api_key: ApiKey,
resource_type: str,
resource_id: uuid.UUID
) -> bool:
"""检查资源绑定"""
if not api_key.resource_id:
return True
return (
api_key.resource_type == resource_type and
api_key.resource_id == resource_id
)
return api_key.resource_id == resource_id

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,12 @@ Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, AsyncGenerator
import os
import json
import asyncio
import time
from datetime import datetime
from sqlalchemy.orm import Session
from dotenv import load_dotenv
@@ -14,6 +17,7 @@ from dotenv import load_dotenv
from app.models.user_model import User
from app.models.end_user_model import EndUser
from app.core.logging_config import get_logger
from app.utils.sse_utils import format_sse_message
from app.schemas.memory_storage_schema import (
ConfigFilter,
ConfigPilotRun,
@@ -225,101 +229,175 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return self._convert_timestamps_to_format(data_list)
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
async def pilot_run_stream(self, payload: ConfigPilotRun) -> AsyncGenerator[str, None]:
"""
选择策略与内存覆写与同步版保持一致:优先 payload.config_id其次 dbrun.json两者皆无时报错。
支持 dialogue_text 参数用于试运行模式。
流式执行试运行,产生 SSE 格式的进度事件
Args:
payload: 试运行配置和对话文本
Yields:
SSE 格式的字符串,包含以下事件类型:
- 各种阶段名称: 进度更新 (如 starting, knowledge_extraction_complete 等)
- result: 最终结果
- error: 错误信息
- done: 完成标记
Raises:
ValueError: 当配置无效或参数缺失时
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
try:
# 发出初始进度事件
yield format_sse_message("starting", {
"message": "开始试运行...",
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(复用现有逻辑)
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN_STREAM] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# 应用内存覆写并刷新常量(在导入主管线前)
# 注意:仅在内存中覆写配置,不修改 runtime.json 文
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# 导入并 await 主管线(使用当前 ASGI 事件循环)
from app.core.memory.main import main as pipeline_main
from app.core.memory.utils.self_reflexion_utils import reflexion
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
logger.info("[PILOT_RUN] pipeline_main completed")
# 调用自我反思
# data = [
# {
# "data": {
# "id": "1",
# "statement": "张明现在在谷歌工作。",
# "group_id": "1",
# "chunk_id": "10",
# "created_at": "2023-01-01",
# "expired_at": "2023-01-02",
# "valid_at": "2023-01-01",
# "invalid_at": "2023-01-02",
# "entity_ids": []
# },
# "conflict": True,
# "conflict_memory": {
# "id": "1",
# "statement": "张明现在在清华大学当讲师。",
# "group_id": "1",
# "chunk_id": "1",
# "created_at": "2019-12-01T19:15:05.213210",
# "expired_at": None,
# "valid_at": None,
# "invalid_at": None,
# "entity_ids": []
# }
# }
# ]
from app.core.memory.utils.config.get_example_data import get_example_data
data = get_example_data()
reflexion_result = await reflexion(data)
# 读取输出,使用全局配置路径
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
return {
"config_id": cid,
"time_log": os.path.join(project_root, "time.log"),
"extracted_result": extracted_result,
}
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事
progress_queue: asyncio.Queue = asyncio.Queue()
async def progress_callback(stage: str, message: str, data: Optional[Dict[str, Any]] = None) -> None:
"""
进度回调函数,将进度事件放入队列
Args:
stage: 阶段标识
message: 进度消息
data: 可选的结果数据(用于传递节点执行结果)
"""
await progress_queue.put((stage, message, data))
# 步骤 3: 在后台任务中执行管线
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
# 标记管线完成
await progress_queue.put(("__PIPELINE_COMPLETE__", "", None))
except Exception as e:
# 将异常放入队列
await progress_queue.put(("__PIPELINE_ERROR__", str(e), None))
# 启动后台任务
pipeline_task = asyncio.create_task(run_pipeline())
# 步骤 4: 从队列中读取进度事件并发出
while True:
try:
# 等待进度事件,设置超时以检测客户端断开
stage, message, data = await asyncio.wait_for(
progress_queue.get(),
timeout=0.5
)
# 检查特殊标记
if stage == "__PIPELINE_COMPLETE__":
break
elif stage == "__PIPELINE_ERROR__":
raise RuntimeError(message)
# 构建进度事件数据
progress_data = {
"message": message,
"time": int(time.time() * 1000)
}
# 如果有结果数据,添加到事件中
if data:
progress_data["data"] = data
# 发出进度事件,使用 stage 作为事件类型
yield format_sse_message(stage, progress_data)
except TimeoutError:
# 超时,继续等待(这允许检测客户端断开)
continue
# 等待管线任务完成
await pipeline_task
# 步骤 5: 读取提取结果
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
# 步骤 6: 发出结果事件
result_data = {
"config_id": cid,
"time_log": os.path.join(project_root, "logs", "time.log"),
"extracted_result": extracted_result,
}
yield format_sse_message("result", result_data)
# 步骤 7: 发出完成事件
yield format_sse_message("done", {
"message": "试运行完成",
"time": int(time.time() * 1000)
})
except asyncio.CancelledError:
# 客户端断开连接
logger.info("[PILOT_RUN_STREAM] Client disconnected during streaming")
raise
except Exception as e:
# 发出错误事件
logger.error(f"[PILOT_RUN_STREAM] Error during streaming: {e}", exc_info=True)
yield format_sse_message("error", {
"code": 5000,
"message": "试运行失败",
"error": str(e),
"time": int(time.time() * 1000)
})
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------

View File

@@ -0,0 +1,731 @@
"""
工作流服务层
"""
import logging
import uuid
import datetime
from typing import Any, Annotated
from sqlalchemy.orm import Session
from fastapi import Depends
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
from app.repositories.workflow_repository import (
WorkflowConfigRepository,
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository,
get_workflow_config_repository,
get_workflow_execution_repository,
get_workflow_node_execution_repository
)
from app.core.workflow.validator import validate_workflow_config
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.db import get_db
from app.schemas import DraftRunRequest
logger = logging.getLogger(__name__)
class WorkflowService:
"""工作流服务"""
def __init__(self, db: Session):
self.db = db
self.config_repo = WorkflowConfigRepository(db)
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
# ==================== 配置管理 ====================
def create_workflow_config(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]],
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None,
validate: bool = True
) -> WorkflowConfig:
"""创建工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
validate: 是否验证配置
Returns:
工作流配置
Raises:
BusinessException: 配置无效时抛出
"""
# 构建配置字典
config_dict = {
"nodes": nodes,
"edges": edges,
"variables": variables or [],
"execution_config": execution_config or {},
"triggers": triggers or []
}
# 验证配置
if validate:
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
# 创建或更新配置
config = self.config_repo.create_or_update(
app_id=app_id,
nodes=nodes,
edges=edges,
variables=variables,
execution_config=execution_config,
triggers=triggers
)
logger.info(f"创建工作流配置成功: app_id={app_id}, config_id={config.id}")
return config
def get_workflow_config(self, app_id: uuid.UUID) -> WorkflowConfig | None:
"""获取工作流配置
Args:
app_id: 应用 ID
Returns:
工作流配置或 None
"""
return self.config_repo.get_by_app_id(app_id)
def update_workflow_config(
self,
app_id: uuid.UUID,
nodes: list[dict[str, Any]] | None = None,
edges: list[dict[str, Any]] | None = None,
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None,
validate: bool = True
) -> WorkflowConfig:
"""更新工作流配置
Args:
app_id: 应用 ID
nodes: 节点列表
edges: 边列表
variables: 变量列表
execution_config: 执行配置
triggers: 触发器列表
validate: 是否验证配置
Returns:
工作流配置
Raises:
BusinessException: 配置不存在或无效时抛出
"""
# 获取现有配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 合并配置
updated_nodes = nodes if nodes is not None else config.nodes
updated_edges = edges if edges is not None else config.edges
updated_variables = variables if variables is not None else config.variables
updated_execution_config = execution_config if execution_config is not None else config.execution_config
updated_triggers = triggers if triggers is not None else config.triggers
# 构建配置字典
config_dict = {
"nodes": updated_nodes,
"edges": updated_edges,
"variables": updated_variables,
"execution_config": updated_execution_config,
"triggers": updated_triggers
}
# 验证配置
if validate:
is_valid, errors = validate_workflow_config(config_dict, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
error_code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
# 更新配置
config = self.config_repo.create_or_update(
app_id=app_id,
nodes=updated_nodes,
edges=updated_edges,
variables=updated_variables,
execution_config=updated_execution_config,
triggers=updated_triggers
)
logger.info(f"更新工作流配置成功: app_id={app_id}, config_id={config.id}")
return config
def delete_workflow_config(self, app_id: uuid.UUID) -> bool:
"""删除工作流配置
Args:
app_id: 应用 ID
Returns:
是否删除成功
"""
config = self.get_workflow_config(app_id)
if not config:
return False
self.config_repo.delete(config.id)
logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}")
return True
def check_config(self, app_id: uuid.UUID) -> WorkflowConfig:
"""检查工作流配置的完整性
Args:
app_id: 应用 ID
Raises:
BusinessException: 配置不完整或不存在时抛出
"""
# 1. 检查多智能体配置是否存在
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
"工作流配置不存在,无法运行",
BizCode.CONFIG_MISSING
)
# validator 现在支持直接接受 Pydantic 模型
is_valid, errors = validate_workflow_config(config, for_publish=False)
if not is_valid:
logger.warning(f"工作流配置验证失败: {errors}")
raise BusinessException(
code=BizCode.INVALID_PARAMETER,
message=f"工作流配置无效: {'; '.join(errors)}"
)
return config
def validate_workflow_config_for_publish(
self,
app_id: uuid.UUID
) -> tuple[bool, list[str]]:
"""验证工作流配置是否可以发布
Args:
app_id: 应用 ID
Returns:
(is_valid, errors): 是否有效和错误列表
Raises:
BusinessException: 配置不存在时抛出
"""
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config,
"triggers": config.triggers
}
return validate_workflow_config(config_dict, for_publish=True)
# ==================== 执行管理 ====================
def create_execution(
self,
workflow_config_id: uuid.UUID,
app_id: uuid.UUID,
trigger_type: str,
triggered_by: uuid.UUID | None = None,
conversation_id: uuid.UUID | None = None,
input_data: dict[str, Any] | None = None
) -> WorkflowExecution:
"""创建工作流执行记录
Args:
workflow_config_id: 工作流配置 ID
app_id: 应用 ID
trigger_type: 触发类型
triggered_by: 触发用户 ID
conversation_id: 会话 ID
input_data: 输入数据
Returns:
执行记录
"""
# 生成执行 ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution = WorkflowExecution(
workflow_config_id=workflow_config_id,
app_id=app_id,
conversation_id=conversation_id,
execution_id=execution_id,
trigger_type=trigger_type,
triggered_by=triggered_by,
input_data=input_data or {},
status="pending"
)
self.db.add(execution)
self.db.commit()
self.db.refresh(execution)
logger.info(f"创建工作流执行记录: execution_id={execution_id}")
return execution
def get_execution(self, execution_id: str) -> WorkflowExecution | None:
"""获取执行记录
Args:
execution_id: 执行 ID
Returns:
执行记录或 None
"""
return self.execution_repo.get_by_execution_id(execution_id)
def get_executions_by_app(
self,
app_id: uuid.UUID,
limit: int = 50,
offset: int = 0
) -> list[WorkflowExecution]:
"""获取应用的执行记录列表
Args:
app_id: 应用 ID
limit: 返回数量限制
offset: 偏移量
Returns:
执行记录列表
"""
return self.execution_repo.get_by_app_id(app_id, limit, offset)
def update_execution_status(
self,
execution_id: str,
status: str,
output_data: dict[str, Any] | None = None,
error_message: str | None = None,
error_node_id: str | None = None
) -> WorkflowExecution:
"""更新执行状态
Args:
execution_id: 执行 ID
status: 状态
output_data: 输出数据
error_message: 错误信息
error_node_id: 出错节点 ID
Returns:
执行记录
Raises:
BusinessException: 执行记录不存在时抛出
"""
execution = self.get_execution(execution_id)
if not execution:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"执行记录不存在: execution_id={execution_id}"
)
execution.status = status
if output_data is not None:
execution.output_data = output_data
if error_message is not None:
execution.error_message = error_message
if error_node_id is not None:
execution.error_node_id = error_node_id
# 如果是完成状态,计算耗时
if status in ["completed", "failed", "cancelled", "timeout"]:
if not execution.completed_at:
execution.completed_at = datetime.datetime.now()
elapsed = (execution.completed_at - execution.started_at).total_seconds()
execution.elapsed_time = elapsed
self.db.commit()
self.db.refresh(execution)
logger.info(f"更新执行状态: execution_id={execution_id}, status={status}")
return execution
def get_execution_statistics(self, app_id: uuid.UUID) -> dict[str, Any]:
"""获取执行统计信息
Args:
app_id: 应用 ID
Returns:
统计信息
"""
total = self.execution_repo.count_by_app_id(app_id)
completed = self.execution_repo.count_by_status(app_id, "completed")
failed = self.execution_repo.count_by_status(app_id, "failed")
running = self.execution_repo.count_by_status(app_id, "running")
return {
"total": total,
"completed": completed,
"failed": failed,
"running": running,
"success_rate": completed / total if total > 0 else 0
}
# ==================== 工作流执行 ====================
async def run(
self,
app_id: uuid.UUID,
payload: DraftRunRequest,
config: WorkflowConfig
):
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
if not config:
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables, "conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by_uuid,
conversation_id=conversation_id_uuid,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id="",
user_id=payload.user_id
)
# 更新执行结果
if result.get("status") == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def run_workflow(
self,
app_id: uuid.UUID,
input_data: dict[str, Any],
triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None,
stream: bool = False
):
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by,
conversation_id=conversation_id,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
app = self.db.query(App).filter(App.id == app_id).first()
if not app:
raise BusinessException(
error_code=BizCode.RESOURCE_NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
if stream:
# 流式执行
return self._run_workflow_stream(
workflow_config_dict,
input_data,
execution.execution_id,
str(app.workspace_id),
str(triggered_by)
)
else:
# 非流式执行
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(app.workspace_id),
user_id=str(triggered_by)
)
# 更新执行结果
if result.get("status") == "completed":
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {})
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
error_code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""运行工作流(流式,内部方法)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件
"""
from app.core.workflow.executor import execute_workflow_stream
try:
output_data = {}
async for event in execute_workflow_stream(
workflow_config=workflow_config,
input_data=input_data,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
):
# 转发事件
yield event
# 收集输出数据
if event.get("type") == "node_complete":
node_data = event.get("data", {})
node_outputs = node_data.get("node_outputs", {})
output_data.update(node_outputs)
# 处理完成事件
if event.get("type") == "workflow_complete":
self.update_execution_status(
execution_id,
"completed",
output_data=output_data
)
# 处理错误事件
if event.get("type") == "workflow_error":
self.update_execution_status(
execution_id,
"failed",
error_message=event.get("error")
)
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution_id,
"failed",
error_message=str(e)
)
yield {
"type": "workflow_error",
"execution_id": execution_id,
"error": str(e)
}
# ==================== 依赖注入函数 ====================
def get_workflow_service(
db: Annotated[Session, Depends(get_db)]
) -> WorkflowService:
"""获取工作流服务(依赖注入)"""
return WorkflowService(db)

View File

@@ -0,0 +1,219 @@
# 智能客服工作流模板
id: customer_service_v1
name: 智能客服工作流
description: 智能客服场景,包含意图识别、知识库查询和回复生成
category: customer_service
version: "1.0.0"
author: RedBear Memory Team
tags:
- 客服
- 意图识别
- 知识库
- 多步骤
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 200
- id: intent_recognition
type: llm
name: 意图识别
config:
prompt: |
分析用户的问题,识别意图类型。
用户问题:{{ var.user_message }}
请从以下类型中选择一个:
- product_inquiry: 产品咨询
- technical_support: 技术支持
- complaint: 投诉建议
- other: 其他
只返回类型名称,不要其他内容。
model: gpt-3.5-turbo
temperature: 0.3
max_tokens: 50
position:
x: 300
y: 200
- id: intent_router
type: condition
name: 意图路由
position:
x: 500
y: 200
- id: product_handler
type: llm
name: 产品咨询处理
config:
prompt: |
用户咨询产品相关问题。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请提供专业、友好的产品咨询回复。
model: gpt-3.5-turbo
temperature: 0.7
max_tokens: 500
position:
x: 700
y: 100
- id: support_handler
type: llm
name: 技术支持处理
config:
prompt: |
用户需要技术支持。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请提供详细的技术支持方案。
model: gpt-3.5-turbo
temperature: 0.5
max_tokens: 800
position:
x: 700
y: 200
- id: complaint_handler
type: llm
name: 投诉处理
config:
prompt: |
用户提出投诉或建议。
问题:{{ var.user_message }}
意图:{{ node.intent_recognition.output }}
请以同理心回应,并提供解决方案。
model: gpt-3.5-turbo
temperature: 0.8
max_tokens: 600
position:
x: 700
y: 300
- id: general_handler
type: llm
name: 通用处理
config:
prompt: |
用户的问题类型:其他
问题:{{ var.user_message }}
请提供友好的回复。
model: gpt-3.5-turbo
temperature: 0.7
max_tokens: 400
position:
x: 700
y: 400
- id: end
type: end
name: 结束
position:
x: 900
y: 200
edges:
- source: start
target: intent_recognition
label: 开始分析
- source: intent_recognition
target: intent_router
label: 识别完成
- source: intent_router
target: product_handler
condition: "'product_inquiry' in node['intent_recognition']['output']"
label: 产品咨询
- source: intent_router
target: support_handler
condition: "'technical_support' in node['intent_recognition']['output']"
label: 技术支持
- source: intent_router
target: complaint_handler
condition: "'complaint' in node['intent_recognition']['output']"
label: 投诉建议
- source: intent_router
target: general_handler
condition: "True" # 默认路径
label: 其他
- source: product_handler
target: end
label: 完成
- source: support_handler
target: end
label: 完成
- source: complaint_handler
target: end
label: 完成
- source: general_handler
target: end
label: 完成
# 变量定义
variables:
- name: user_message
type: string
required: true
description: 用户的消息
default: ""
- name: user_name
type: string
required: false
description: 用户姓名(可选)
default: "客户"
# 执行配置
execution_config:
max_execution_time: 120
max_iterations: 10
# 触发器
triggers: []
# 使用示例
examples:
- name: 产品咨询
description: 用户咨询产品功能
input:
user_message: "你们的产品支持多语言吗?"
user_name: "张三"
expected_output: "产品功能介绍"
- name: 技术支持
description: 用户遇到技术问题
input:
user_message: "我无法登录系统,一直显示密码错误"
user_name: "李四"
expected_output: "技术支持方案"
- name: 投诉处理
description: 用户提出投诉
input:
user_message: "你们的服务态度太差了,我要投诉"
user_name: "王五"
expected_output: "同理心回应和解决方案"

View File

@@ -0,0 +1,131 @@
# 数据处理工作流模板
id: data_processing_v1
name: 数据处理工作流
description: 数据提取、转换和分析的完整流程
category: data_processing
version: "1.0.0"
author: RedBear Memory Team
tags:
- 数据处理
- ETL
- 分析
- Transform
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 200
- id: extract_data
type: transform
name: 数据提取
config:
expression: |
{
"raw_text": var['input_text'],
"length": len(var['input_text']),
"timestamp": sys['execution_id']
}
position:
x: 300
y: 200
- id: analyze_data
type: llm
name: 数据分析
config:
prompt: |
请分析以下数据:
原始文本:{{ node.extract_data.raw_text }}
文本长度:{{ node.extract_data.length }}
请提供:
1. 主题分类
2. 情感分析
3. 关键信息提取
以 JSON 格式返回结果。
model: gpt-3.5-turbo
temperature: 0.3
max_tokens: 500
position:
x: 500
y: 200
- id: transform_result
type: transform
name: 结果转换
config:
expression: |
{
"original_length": node['extract_data']['length'],
"analysis": node['analyze_data']['output'],
"processed_at": sys['execution_id'],
"status": "completed"
}
position:
x: 700
y: 200
- id: end
type: end
name: 结束
position:
x: 900
y: 200
edges:
- source: start
target: extract_data
label: 开始提取
- source: extract_data
target: analyze_data
label: 开始分析
- source: analyze_data
target: transform_result
label: 转换结果
- source: transform_result
target: end
label: 完成
# 变量定义
variables:
- name: input_text
type: string
required: true
description: 待处理的文本数据
default: ""
# 执行配置
execution_config:
max_execution_time: 180
max_iterations: 5
# 触发器
triggers: []
# 使用示例
examples:
- name: 文本分析
description: 分析一段文本
input:
input_text: "今天天气真好,心情也很愉快。我们公司推出了新产品,市场反响热烈。"
expected_output:
original_length: 35
analysis: "主题:天气和产品,情感:积极"
status: "completed"
- name: 长文本处理
description: 处理较长的文本
input:
input_text: "这是一段很长的文本..."
expected_output:
status: "completed"

View File

@@ -0,0 +1,99 @@
# 多步骤问答工作流
# 演示节点输出参数的使用
id: multi_step_qa_v1
name: 多步骤问答工作流
description: 先分析问题,再生成答案,展示节点间的数据传递
category: advanced
version: "1.0.0"
author: RedBear Memory Team
tags:
- 问答
- 多步骤
- LLM
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 100
- id: analyze_question
type: llm
name: 分析问题
description: 分析用户问题的类型和意图
config:
model_id: gpt-3.5-turbo
temperature: 0.3
max_tokens: 500
messages:
- role: system
content: |
你是一个问题分析专家。请分析用户的问题,提取以下信息:
1. 问题类型(事实性、观点性、操作性等)
2. 问题领域(科技、历史、文化等)
3. 关键词
- role: user
content: "{{ sys.message }}"
position:
x: 300
y: 100
- id: generate_answer
type: llm
name: 生成答案
description: 根据问题分析结果生成详细答案
config:
model_id: gpt-3.5-turbo
temperature: 0.7
max_tokens: 1000
messages:
- role: system
content: |
你是一个专业的AI助手。根据问题分析结果生成准确、详细的答案。
问题分析结果:
{{ analyze_question.output }}
- role: user
content: "{{ sys.message }}"
position:
x: 500
y: 100
- id: end
type: end
name: 结束
config:
output: "{{ generate_answer.output }}"
position:
x: 700
y: 100
edges:
- source: start
target: analyze_question
label: 开始分析
- source: analyze_question
target: generate_answer
label: 生成答案
- source: generate_answer
target: end
label: 完成
# 变量定义
variables:
- name: user_question
type: string
required: true
description: 用户的问题
default: ""
# 执行配置
execution_config:
max_execution_time: 120
max_iterations: 1

View File

@@ -0,0 +1,100 @@
# 简单问答工作流模板
id: simple_qa_v1
name: 简单问答工作流
description: 最基础的问答工作流,适合快速开始
category: basic
version: "1.0.0"
author: RedBear Memory Team
tags:
- 问答
- 基础
- LLM
# 工作流配置
nodes:
- id: start
type: start
name: 开始
position:
x: 100
y: 100
- id: llm_qa
type: llm
name: LLM 问答
config:
# 使用 LangChain 标准的消息格式
messages:
- role: system
content: |
你是一个专业、友好且乐于助人的 AI 助手。
你的职责:
- 准确理解用户的问题并提供有价值的回答
- 保持回答的专业性和准确性
- 如果不确定答案,诚实地告知用户
- 使用清晰、易懂的语言进行交流
回答风格:
- 简洁明了,直击要点
- 必要时提供详细解释和示例
- 使用友好、礼貌的语气
- 适当使用格式化(如列表、段落)提高可读性
- role: user
content: "{{ sys.message }}"
model_id: gpt-3.5-turbo
temperature: 0.7
max_tokens: 1000
position:
x: 300
y: 100
- id: end
type: end
name: 结束
config:
output: "{{ llm_qa.output }}"
position:
x: 500
y: 100
edges:
- source: start
target: llm_qa
label: 开始处理
- source: llm_qa
target: end
label: 完成
# 变量定义
variables:
- name: user_question
type: string
required: true
description: 用户的问题
default: ""
# 执行配置
execution_config:
max_execution_time: 60
max_iterations: 1
# 触发器(可选)
triggers: []
# 使用示例
examples:
- name: 基础问答
description: 询问一个简单的问题
input:
user_question: "什么是人工智能?"
expected_output: "关于人工智能的解释"
- name: 技术咨询
description: 询问技术问题
input:
user_question: "如何学习 Python 编程?"
expected_output: "Python 学习建议"

View File

@@ -0,0 +1,27 @@
"""
Server-Sent Events (SSE) Utility Functions
Provides shared utilities for formatting and handling SSE messages.
"""
import json
from typing import Dict, Any
def format_sse_message(event_type: str, data: Dict[str, Any]) -> str:
"""
Format a message in Server-Sent Events (SSE) format.
Args:
event_type: Type of event (stage name, result, error, done)
data: Event data dictionary to be serialized as JSON
Returns:
SSE formatted string: "event: <type>\\ndata: <json>\\n\\n"
Example:
>>> format_sse_message("loading", {"message": "Loading..."})
'event: loading\\ndata: {"message": "Loading..."}\\n\\n'
"""
json_data = json.dumps(data, ensure_ascii=False)
return f"event: {event_type}\ndata: {json_data}\n\n"

View File

@@ -27,12 +27,21 @@ import 'dayjs/locale/en'
import 'dayjs/locale/zh-cn'
import 'dayjs/plugin/timezone'
import 'dayjs/plugin/utc'
import { cookieUtils } from './utils/request';
function App() {
const { t } = useTranslation();
const { locale, language, timeZone } = useI18n()
useEffect(() => {
const authToken = cookieUtils.get('authToken')
if (!authToken && !window.location.hash.includes('#/login')) {
window.location.href = `/#/login`;
}
}, [])
useEffect(() => {
document.title = t('memoryBear')

View File

@@ -200,7 +200,7 @@ export const deleteFile = async (id: string) => {
// 获取文档列表
export const getDocumentList = async (query: PathQuery) => {
const response = await request.get(`${apiPrefix}/documents/${query.kb_id}/${query.parent_id}/documents`, query);
const response = await request.get(`${apiPrefix}/documents/${query.kb_id}/documents`, query);
return response as KnowledgeBaseDocumentData[];
};
// 文档详情
@@ -213,6 +213,11 @@ export const createDocument = async (data: KnowledgeBaseDocumentData) => {
const response = await request.post(`${apiPrefix}/documents/document`, data);
return response as KnowledgeBaseDocumentData;
};
// 自定义文档上传并创建
export const createDocumentAndUpload = async ( data: any, params: PathQuery) => {
const response = await request.post(`${apiPrefix}/files/customtext`, data, { params } );
return response as any;
};
// 更新文档
export const updateDocument = async (id: string, data: KnowledgeBaseDocumentData) => {
const response = await request.put(`${apiPrefix}/documents/${id}`, data);
@@ -223,9 +228,9 @@ export const deleteDocument = async (id: string) => {
const response = await request.delete(`${apiPrefix}/documents/${id}`);
return response;
};
// 文档解析
export const parseDocument = async (id: string) => {
const response = await request.post(`${apiPrefix}/documents/${id}/chunks`);
// 文档解析 / 分块
export const parseDocument = async (id: string, data: any) => {
const response = await request.post(`${apiPrefix}/documents/${id}/chunks`, data);
return response as any;
};
// 文档分块预览

View File

@@ -3,6 +3,7 @@ import { Layout, Dropdown, Space, Breadcrumb } from 'antd';
import type { MenuProps, BreadcrumbProps } from 'antd';
import { UserOutlined, LogoutOutlined, SettingOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next';
import { useLocation } from 'react-router-dom';
import { useUser } from '@/store/user';
import { useMenu } from '@/store/menu';
import styles from './index.module.css'
@@ -12,12 +13,35 @@ const { Header } = Layout;
const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => {
const { t } = useTranslation();
const location = useLocation();
const settingModalRef = useRef<SettingModalRef>(null)
const userInfoModalRef = useRef<UserInfoModalRef>(null)
const { user, logout } = useUser();
const { allBreadcrumbs } = useMenu();
const breadcrumbs = allBreadcrumbs[source] || [];
// 根据当前路由动态选择面包屑源
const getBreadcrumbSource = () => {
const pathname = location.pathname;
// 知识库列表页面使用默认的 space 面包屑
if (pathname === '/knowledge-base') {
return 'space';
}
// 知识库详情相关页面使用独立的面包屑
if (pathname.includes('/knowledge-base/') && pathname !== '/knowledge-base') {
return 'space-detail';
}
// 其他页面使用传入的 source
return source;
};
const breadcrumbSource = getBreadcrumbSource();
const breadcrumbs = allBreadcrumbs[breadcrumbSource] || [];
// 处理退出登录
const handleLogout = () => {

View File

@@ -0,0 +1,248 @@
import { useCallback } from 'react';
import { useNavigate } from 'react-router-dom';
import { useMenu } from '@/store/menu';
import type { MenuItem } from '@/store/menu';
export interface BreadcrumbItem {
id: string;
name: string;
type?: 'knowledgeBase' | 'folder' | 'document';
}
export interface BreadcrumbPath {
knowledgeBaseFolderPath: BreadcrumbItem[]; // 知识库文件夹路径
knowledgeBase?: BreadcrumbItem; // 知识库信息
documentFolderPath: BreadcrumbItem[]; // 文档文件夹路径
document?: BreadcrumbItem; // 文档信息
}
export interface BreadcrumbOptions {
onKnowledgeBaseMenuClick?: () => void;
onKnowledgeBaseFolderClick?: (folderId: string, folderPath: BreadcrumbItem[]) => void;
// 新增:区分面包屑类型
breadcrumbType?: 'list' | 'detail';
}
export const useBreadcrumbManager = (options?: BreadcrumbOptions) => {
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
const navigate = useNavigate();
const updateBreadcrumbs = useCallback((breadcrumbPath: BreadcrumbPath) => {
const breadcrumbType = options?.breadcrumbType || 'list';
// 获取基础面包屑,对于详情页面,使用列表页面的基础面包屑作为起点
const baseBreadcrumbs = breadcrumbType === 'list'
? (allBreadcrumbs['space'] || [])
: (allBreadcrumbs['space'] || []); // 详情页面也从 space 获取基础面包屑
// 只保留知识库菜单项之前的面包屑
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
: baseBreadcrumbs;
// 给"知识库管理"添加点击事件
const breadcrumbsWithClick = filteredBaseBreadcrumbs.map((item) => {
if (item.path === '/knowledge-base') {
return {
...item,
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
if (options?.onKnowledgeBaseMenuClick) {
// 如果提供了回调函数,执行回调
options.onKnowledgeBaseMenuClick();
} else if (breadcrumbType === 'detail') {
// 知识库详情页面:没有回调函数时,返回到知识库列表页面
navigate('/knowledge-base', {
state: {
resetToRoot: true,
}
});
}
return false;
},
};
}
return item;
});
let customBreadcrumbs: MenuItem[] = [...breadcrumbsWithClick];
if (breadcrumbType === 'list') {
// 知识库列表页面:只显示知识库文件夹路径
customBreadcrumbs = [
...breadcrumbsWithClick,
...breadcrumbPath.knowledgeBaseFolderPath.map((folder, index) => ({
id: 0,
parent: 0,
code: null,
label: folder.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
// 如果有回调函数,直接调用回调函数来更新状态
if (options?.onKnowledgeBaseFolderClick) {
options.onKnowledgeBaseFolderClick(folder.id, breadcrumbPath.knowledgeBaseFolderPath.slice(0, index + 1));
} else {
// 否则使用导航(兜底逻辑)
navigate('/knowledge-base', {
state: {
navigateToFolder: folder.id,
folderPath: breadcrumbPath.knowledgeBaseFolderPath.slice(0, index + 1)
}
});
}
return false;
},
})),
];
} else {
// 知识库详情页面:显示知识库名称 + 文档文件夹路径 + 文档名称
customBreadcrumbs = [
...breadcrumbsWithClick,
// 添加知识库名称
...(breadcrumbPath.knowledgeBase ? [{
id: 0,
parent: 0,
code: null,
label: breadcrumbPath.knowledgeBase.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
// 返回到知识库详情页的根目录
const navigationState = {
fromKnowledgeBaseList: true,
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
resetToRoot: true, // 添加重置到根目录的标志
refresh: true, // 添加刷新标志
timestamp: Date.now(), // 添加时间戳确保状态变化
};
navigate(`/knowledge-base/${breadcrumbPath.knowledgeBase!.id}/private`, {
state: navigationState,
replace: true // 使用 replace 避免历史记录堆积
});
return false;
},
}] : []),
// 添加文档文件夹路径
...breadcrumbPath.documentFolderPath.map((folder, index) => ({
id: 0,
parent: 0,
code: null,
label: folder.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
// 返回到知识库详情页的对应文件夹
const navigationState = {
fromKnowledgeBaseList: true,
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
navigateToDocumentFolder: folder.id,
documentFolderPath: breadcrumbPath.documentFolderPath.slice(0, index + 1),
refresh: true, // 添加刷新标志
timestamp: Date.now(), // 添加时间戳确保状态变化
};
navigate(`/knowledge-base/${breadcrumbPath.knowledgeBase!.id}/private`, {
state: navigationState,
replace: true // 使用 replace 避免历史记录堆积
});
return false;
},
})),
// 添加文档名称(如果存在)
...(breadcrumbPath.document ? [{
id: 0,
parent: 0,
code: null,
label: breadcrumbPath.document.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
// 文档名称不可点击
}] : []),
];
}
// 根据面包屑类型使用不同的键,实现独立的面包屑路径
const breadcrumbKey = breadcrumbType === 'list' ? 'space' : 'space-detail';
setCustomBreadcrumbs(customBreadcrumbs, breadcrumbKey);
}, [setCustomBreadcrumbs, navigate, options?.breadcrumbType, options?.onKnowledgeBaseMenuClick, options?.onKnowledgeBaseFolderClick]);
return {
updateBreadcrumbs,
};
};

View File

@@ -11,8 +11,10 @@ export const checkAuthStatus = (): boolean => {
// 递归检查路由是否存在于菜单数据中
export const checkRoutePermission = (menus: MenuItem[], currentPath: string): boolean => {
// 首页默认有权限
if (currentPath === '/' || currentPath.includes('knowledge-detail')) return true;
// 首页和知识库相关页面默认有权限
if (currentPath === '/' || currentPath.includes('knowledge-detail') || currentPath.includes('knowledge-base')) {
return true;
}
for (const menu of menus) {
// 检查当前菜单的path是否匹配
@@ -26,6 +28,7 @@ export const checkRoutePermission = (menus: MenuItem[], currentPath: string): bo
}
}
}
return false;
};
@@ -52,7 +55,7 @@ export const useRouteGuard = (source: 'space' | 'manage') => {
const hasPermission = checkRoutePermission(menus, location.pathname);
if (!hasPermission) {
// 无权限访问该路由,重定向到无权限页面
// navigate('/not-found', { replace: true });
// navigate('/no-permission', { replace: true });
}
}
}, [navigate, location.pathname, location.search, location.hash, menus]);

View File

@@ -479,11 +479,18 @@ export const en = {
noDataSets: 'No datasets yet, click the button below or drag files to create.',
createEmptyDataSet: '+ Empty Dataset',
createImageDataSet: '+ Image Dataset',
createContent: 'Create Content',
title: 'Title',
content: 'Content',
pleaseEnterTitle: 'Please enter title',
pleaseEnterContent: 'Please enter content',
// createImageDataSet: '+ Image Dataset',
dragFilesHere: 'Drag files here to upload',
createImport: 'Create/Import',
textDataSet: 'Text Dataset',
imageDataSet: 'Image Dataset',
blankDataset: 'Blank Dataset',
customTextDataset: 'Custom Text Dataset',
text: 'Text',
search: 'Search',
image: 'Image',

View File

@@ -110,6 +110,11 @@ export const zh = {
noDataSets: '暂无数据集,点击下方按钮或拖拽文件创建。',
createEmptyDataSet: '+ 空白数据集',
createImageDataSet: '+ 图片数据集',
createContent: '创建内容',
title: '标题',
content: '内容',
pleaseEnterTitle: '请输入标题',
pleaseEnterContent: '请输入内容',
dragFilesHere: '拖拽文件到此处上传',
downloadOriginal: '下载原始内容',
createImport: '新建/导入',
@@ -117,6 +122,7 @@ export const zh = {
imageDataSet: '图片数据集',
blankDataset: '空白数据集',
emptyDataSet: '空白数据集',
customTextDataset: '自定义文本数据集',
text: '文本',
search: '搜索',
image: '图片',

View File

@@ -32,7 +32,7 @@ interface MenuState {
allBreadcrumbs: Record<'space' | 'manage' | string, MenuItem[]>;
loadMenus: (source: 'space' | 'manage') => void;
updateBreadcrumbs: (keyPath: string[], source: 'space' | 'manage') => void;
setCustomBreadcrumbs: (breadcrumbs: MenuItem[], source: 'space' | 'manage') => void;
setCustomBreadcrumbs: (breadcrumbs: MenuItem[], source: string) => void;
}
const initBreadcrumbs = localStorage.getItem('breadcrumbs') || '[]'

View File

@@ -8,15 +8,14 @@ import type { UploadFileResponse,KnowledgeBaseDocumentData } from '@/views/Knowl
import type { ColumnsType } from 'antd/es/table';
import UploadFiles from '@/components/Upload/UploadFiles';
import type { UploadRequestOption } from 'rc-upload/lib/interface';
import { uploadFile, getDocumentList, previewDocumentChunk, parseDocument, updateDocument, deleteDocument } from '@/api/knowledgeBase';
import { uploadFile, getDocumentList, parseDocument, updateDocument, deleteDocument } from '@/api/knowledgeBase';
import exitIcon from '@/assets/images/knowledgeBase/exit.png';
import { NoData } from '../components/noData';
import noDataIcon from '@/assets/images/knowledgeBase/noData.png';
import SliderInput from '@/components/SliderInput';
import DelimiterSelector from '../components/DelimiterSelector';
const { confirm } = Modal
const { TextArea } = Input;
import styles from '../index.module.css';
const style: React.CSSProperties = {
display: 'flex',
gap: 16,
@@ -71,12 +70,11 @@ const CreateDataset = () => {
const initialFileIds = locationState.fileIds ?? (locationState.fileId ? [locationState.fileId] : []);
const [current, setCurrent] = useState<number>(stepIndexMap[initialStepKey]);
const tableRef = useRef<TableRef>(null);
const [data, setData] = useState<KnowledgeBaseDocumentData[]>([]);
const [chunkData, setChunkData] = useState<any[]>([]);
const [total, setTotal] = useState<number>(0);
const [rechunkFileIds, setRechunkFileIds] = useState<string[]>(initialFileIds);
const [curSelectedFileId, setCurSelectedFileId] = useState<number>(-1);
const [previewLoading, setPreviewLoading] = useState<boolean>(false);
const [pollingLoading, setPollingLoading] = useState<boolean>(false);
const pollingTimerRef = useRef<ReturnType<typeof setInterval> | null>(null);
const [delimiter, setDelimiter] = useState<string | undefined>(undefined);
@@ -121,6 +119,7 @@ const CreateDataset = () => {
layout_recognize:'DeepDOC',
delimiter: delimiter,
chunk_token_num: blockSize,
auto_question: processingMethod === 'directBlock' ? 0 : 1,
}
}
updateDocument(id, params)
@@ -145,7 +144,7 @@ const CreateDataset = () => {
});
return;
}
debugger
// 显示确认弹框
confirm({
@@ -168,7 +167,7 @@ const CreateDataset = () => {
const startProcessing = (autoReturnToList: boolean) => {
// 触发文档解析
rechunkFileIds.map((id) => {
parseDocument(id);
parseDocument(id, {});
});
// 开启 loading
@@ -276,21 +275,7 @@ const CreateDataset = () => {
onError?.(error as Error);
});
};
// 点击文件 预览分块
const handlePreview = async(item: KnowledgeBaseDocumentData, index: number) => {
setCurSelectedFileId(index);
setPreviewLoading(true);
try{
const res = await previewDocumentChunk(knowledgeBaseId ?? '', item.id ?? '');
setChunkData(res.items || []);
setTotal(res.page.total || 0);
console.log('res', res);
}catch(error) {
console.log('error', error);
} finally {
setPreviewLoading(false);
}
}
// 轮询检查文档处理状态
// autoReturn: 是否在所有文档完成时自动返回列表页
@@ -346,6 +331,8 @@ const CreateDataset = () => {
state: {
refresh: true,
timestamp: Date.now(), // 添加时间戳确保每次都是新的 state
// 保持返回到原来的文档文件夹位置
navigateToDocumentFolder: parentId !== knowledgeBaseId ? parentId : undefined,
},
});
} else {
@@ -565,8 +552,8 @@ const CreateDataset = () => {
{rechunkFileIds.length > 0 ? (
<Table
ref={tableRef}
apiUrl={`/documents/${knowledgeBaseId}/${parentId}/documents`}
apiParams={{
apiUrl={`/documents/${knowledgeBaseId}/documents`}
apiParams={{
document_ids: rechunkFileIds.join(','),
}}
columns={columns}

View File

@@ -4,11 +4,12 @@
* @Author: yujiangping
* @Date: 2025-11-15 16:13:47
* @LastEditors: yujiangping
* @LastEditTime: 2025-11-29 19:46:46
* @LastEditTime: 2025-12-12 20:02:05
*/
import { useEffect, useState, useRef, type FC } from 'react';
import { useNavigate, useParams, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
import { Button, Spin, message, Switch } from 'antd';
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase';
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
@@ -25,7 +26,18 @@ const DocumentDetails: FC = () => {
const navigate = useNavigate();
const { knowledgeBaseId } = useParams<{ knowledgeBaseId: string }>();
const location = useLocation();
const { documentId, parentId: locationParentId } = location.state as { documentId: string; parentId?: string };
const { updateBreadcrumbs } = useBreadcrumbManager({
breadcrumbType: 'detail'
});
const {
documentId,
parentId: locationParentId,
breadcrumbPath
} = location.state as {
documentId: string;
parentId?: string;
breadcrumbPath?: BreadcrumbPath;
};
const [loading, setLoading] = useState(false);
const [document, setDocument] = useState<KnowledgeBaseDocumentData | null>(null);
const [chunkList, setChunkList] = useState<RecallTestData[]>([]);
@@ -44,6 +56,13 @@ const DocumentDetails: FC = () => {
}
}, [documentId]);
// 更新面包屑
useEffect(() => {
if (breadcrumbPath) {
updateBreadcrumbs(breadcrumbPath);
}
}, [breadcrumbPath, updateBreadcrumbs]);
// 当文档加载完成且 progress === 1 时,加载分块列表
useEffect(() => {
if (document && document.progress === 1 && !isManualRefreshRef.current) {
@@ -179,7 +198,18 @@ const DocumentDetails: FC = () => {
};
const handleBack = () => {
if (knowledgeBaseId) {
if (knowledgeBaseId && breadcrumbPath) {
// 返回到知识库详情页,并传递面包屑信息以恢复状态
const navigationState = {
fromKnowledgeBaseList: true,
knowledgeBaseFolderPath: breadcrumbPath.knowledgeBaseFolderPath,
navigateToDocumentFolder: locationParentId,
documentFolderPath: breadcrumbPath.documentFolderPath,
timestamp: Date.now(), // 添加时间戳确保状态变化
};
navigate(`/knowledge-base/${knowledgeBaseId}/private`, { state: navigationState });
} else if (knowledgeBaseId) {
// 降级处理:直接跳转到知识库详情页
navigate(`/knowledge-base/${knowledgeBaseId}/private`);
}
};

View File

@@ -1,5 +1,5 @@
import { useEffect, useState, useRef, type FC } from 'react';
import { useEffect, useState, useRef, useCallback, type FC } from 'react';
import { useNavigate, useParams, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import { Switch, Button, Dropdown, Space, Modal, message } from 'antd';
@@ -12,26 +12,29 @@ import { MoreOutlined } from '@ant-design/icons';
import folderIcon from '@/assets/images/knowledgeBase/folder.png';
import textIcon from '@/assets/images/knowledgeBase/text.png';
import editIcon from '@/assets/images/knowledgeBase/edit.png';
import blankIcon from '@/assets/images/knowledgeBase/blankDocument.png';
import { getKnowledgeBaseDetail, deleteDocument, downloadFile, updateKnowledgeBase } from '@/api/knowledgeBase';
import type {
CreateModalRef,
KnowledgeBaseListItem,
RecallTestDrawerRef,
CreateFolderModalRef,
CreateImageModalRef,
ShareModalRef,
CreateDatasetModalRef,FolderFormData,
KnowledgeBaseDocumentData
import {
type CreateModalRef,
type KnowledgeBaseListItem,
type RecallTestDrawerRef,
type CreateFolderModalRef,
type CreateSetModalRef,
type ShareModalRef,
type CreateDatasetModalRef,type FolderFormData,
type KnowledgeBaseDocumentData,
} from '@/views/KnowledgeBase/types';
import RecallTestDrawer from '../components/RecallTestDrawer';
import CreateFolderModal from '../components/CreateFolderModal';
import CreateContentModal from '../components/CreateContentModal';
import CreateModal from '../components/CreateModal';
import ShareModal from '../components/ShareModal';
import CreateDatasetModal from '../components/CreateDatasetModal';
import CreateImageDataset from '../components/CreateImageDataset';
import FolderTree, { type TreeNodeData } from '../components/FolderTree';
import { formatDateTime } from '@/utils/format';
import { useMenu } from '@/store/menu';
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
import './Private.css'
const { confirm } = Modal
// 树节点数据类型
@@ -48,7 +51,8 @@ const Private: FC = () => {
const [tableApi, setTableApi] = useState<string | undefined>(undefined);
const recallTestDrawerRef = useRef<RecallTestDrawerRef>(null);
const createFolderModalRef = useRef<CreateFolderModalRef>(null);
const createImageDataset = useRef<CreateImageModalRef>(null)
const createImageDataset = useRef<CreateSetModalRef>(null)
const createContentModalRef = useRef<CreateSetModalRef>(null);
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseListItem | null>(null);
const [folder, setFolder] = useState<FolderFormData | null>({
kb_id:knowledgeBaseId ?? '',
@@ -56,47 +60,47 @@ const Private: FC = () => {
});
const [query, setQuery] = useState<Record<string, unknown>>({
orderby: 'created_at',
desc: true,
desc: true
});
const modalRef = useRef<CreateModalRef>(null)
const shareModalRef = useRef<ShareModalRef>(null);
const datasetModalRef = useRef<CreateDatasetModalRef>(null);
const [folderTreeRefreshKey, setFolderTreeRefreshKey] = useState(0);
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
const [folderPath, setFolderPath] = useState<Array<{ id: string; name: string }>>([]);
const [autoExpandPath, setAutoExpandPath] = useState<Array<{ id: string; name: string }>>([]);
const { updateBreadcrumbs } = useBreadcrumbManager({
breadcrumbType: 'detail',
// 不提供 onKnowledgeBaseMenuClick让它使用默认的导航行为返回列表页面
onKnowledgeBaseFolderClick: useCallback((folderId: string, folderPath: Array<{ id: string; name: string }>) => {
// 点击文件夹面包屑时,导航到对应文件夹
setParentId(folderId);
setFolderPath(folderPath);
setSelectedKeys([folderId]);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: folderId
});
// 确保query对象发生变化触发表格刷新
setQuery({
orderby: 'created_at',
desc: true,
parent_id: folderId,
_timestamp: Date.now()
});
// 确保API URL正确设置
setTableApi(`/documents/${knowledgeBaseId}/documents`);
// 手动触发表格刷新,确保数据更新
setTimeout(() => {
tableRef.current?.loadData();
}, 100);
}, [knowledgeBaseId])
});
const [folderPath, setFolderPath] = useState<BreadcrumbItem[]>([]);
const [selectedKeys, setSelectedKeys] = useState<React.Key[]>([]);
useEffect(() => {
if (knowledgeBaseId) {
let url = `/documents/${knowledgeBaseId}/${parentId}/documents`;
setTableApi(url);
fetchKnowledgeBaseDetail(knowledgeBaseId);
}
}, [knowledgeBaseId]);
// 更新面包屑
useEffect(() => {
if (knowledgeBase) {
updateBreadcrumbs();
}
}, [knowledgeBase, folderPath]);
// 监听 tableApi 变化,自动刷新表格数据
useEffect(() => {
if (tableApi) {
tableRef.current?.loadData();
}
}, [tableApi]);
// 监听 location state 变化,如果有 refresh 标志则刷新列表
useEffect(() => {
const state = location.state as { refresh?: boolean; timestamp?: number } | null;
if (state?.refresh) {
tableRef.current?.loadData();
// 清除 state避免重复刷新
navigate(location.pathname, { replace: true, state: {} });
}
}, [location.state]);
const [knowledgeBaseFolderPath, setKnowledgeBaseFolderPath] = useState<BreadcrumbItem[]>([]);
const fetchKnowledgeBaseDetail = async (id: string) => {
setLoading(true);
try {
@@ -109,110 +113,160 @@ const Private: FC = () => {
}
};
// 更新面包屑,包含知识库名称和文件夹路径
const updateBreadcrumbs = () => {
if (!knowledgeBase) return;
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
// 只保留知识库菜单项之前的面包屑
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
: baseBreadcrumbs;
const customBreadcrumbs = [
...filteredBaseBreadcrumbs,
{
id: 0,
parent: 0,
code: null,
label: knowledgeBase.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
// 阻止默认行为和事件冒泡
e?.preventDefault();
e?.stopPropagation();
// 点击知识库名称,回到根目录
setParentId(knowledgeBaseId);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: knowledgeBaseId ?? ''
});
setTableApi(`/documents/${knowledgeBaseId}/${knowledgeBaseId}/documents`);
setFolderPath([]);
setSelectedKeys([knowledgeBaseId ?? '']);
return false;
},
},
...folderPath.map((folder, index) => ({
id: 0,
parent: 0,
code: null,
label: folder.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
// 阻止默认行为和事件冒泡
e?.preventDefault();
e?.stopPropagation();
// 点击文件夹,回到该文件夹层级
setParentId(folder.id);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: folder.id
});
setTableApi(`/documents/${knowledgeBaseId}/${folder.id}/documents`);
// 更新文件夹路径,只保留到当前点击的文件夹
setFolderPath(folderPath.slice(0, index + 1));
setSelectedKeys([folder.id]);
return false;
},
})),
];
useEffect(() => {
if (knowledgeBaseId) {
let url = `/documents/${knowledgeBaseId}/documents`;
setTableApi(url);
fetchKnowledgeBaseDetail(knowledgeBaseId);
}
}, [knowledgeBaseId]);
setCustomBreadcrumbs(customBreadcrumbs, 'space');
};
// 更新面包屑
useEffect(() => {
if (knowledgeBase) {
updateBreadcrumbs({
knowledgeBaseFolderPath,
knowledgeBase: {
id: knowledgeBase.id,
name: knowledgeBase.name,
type: 'knowledgeBase'
},
documentFolderPath: folderPath,
});
}
}, [knowledgeBase, knowledgeBaseFolderPath, folderPath, updateBreadcrumbs]);
// 监听 tableApi 变化,自动刷新表格数据
useEffect(() => {
if (tableApi) {
tableRef.current?.loadData();
}
}, [tableApi]);
// 监听 query 变化,确保表格数据更新
useEffect(() => {
if (tableApi && query._timestamp) {
// 当 query 中有 _timestamp 时,说明是通过面包屑或其他方式触发的更新
tableRef.current?.loadData();
}
}, [query._timestamp, tableApi]);
// 监听 location state 变化
useEffect(() => {
const state = location.state as {
refresh?: boolean;
timestamp?: number;
fromKnowledgeBaseList?: boolean;
knowledgeBaseFolderPath?: BreadcrumbItem[];
parentId?: string;
navigateToDocumentFolder?: string;
documentFolderPath?: BreadcrumbItem[];
resetToRoot?: boolean;
} | null;
if (state?.refresh) {
tableRef.current?.loadData();
// 清除 state避免重复刷新
navigate(location.pathname, { replace: true, state: {} });
}
// 如果是从知识库列表页跳转过来的,设置知识库文件夹路径
if (state?.fromKnowledgeBaseList && state?.knowledgeBaseFolderPath) {
setKnowledgeBaseFolderPath(state.knowledgeBaseFolderPath);
}
// 如果需要重置到根目录(回到初始状态)
if (state?.resetToRoot) {
// 重置所有状态到初始状态,和页面初始化保持一致
setParentId(knowledgeBaseId);
setFolderPath([]);
setSelectedKeys([]);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: knowledgeBaseId ?? ''
});
setQuery({
orderby: 'created_at',
desc: true,
_timestamp: Date.now() // 添加时间戳确保query对象发生变化触发API调用
});
// 重新设置API URL
const rootUrl = `/documents/${knowledgeBaseId}/documents`;
setTableApi(rootUrl);
// 清除自动展开路径
setAutoExpandPath([]);
// 刷新文件夹树(简单的刷新,不需要复杂的重置逻辑)
setFolderTreeRefreshKey((prev) => prev + 1);
// 清除 state避免重复处理
navigate(location.pathname, { replace: true, state: {} });
}
// 如果是从文档详情页返回,恢复文档文件夹路径
if (state?.navigateToDocumentFolder && state?.documentFolderPath) {
setFolderPath(state.documentFolderPath);
setParentId(state.navigateToDocumentFolder);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: state.navigateToDocumentFolder
});
setQuery(prevQuery => ({
...prevQuery,
parent_id: state.navigateToDocumentFolder,
_timestamp: Date.now()
}));
setTableApi(`/documents/${knowledgeBaseId}/documents`);
setSelectedKeys([state.navigateToDocumentFolder]);
// 设置自动展开路径让FolderTree自动展开到对应位置
setAutoExpandPath(state.documentFolderPath);
// 手动触发表格刷新
setTimeout(() => {
tableRef.current?.loadData();
}, 100);
// 清除自动展开路径避免重复触发延迟清除确保FolderTree处理完成
setTimeout(() => {
setAutoExpandPath([]);
}, 2000);
}
}, [location.state, knowledgeBaseId, navigate, location.pathname]);
// 处理树节点选择
const onSelect = (keys: React.Key[]) => {
if (!keys.length) return;
if (!keys.length) {
// 如果没有选中任何节点,回到根目录(初始状态)
setParentId(knowledgeBaseId);
setFolder({
kb_id: knowledgeBaseId ?? '',
parent_id: knowledgeBaseId ?? ''
});
setQuery({
orderby: 'created_at',
desc: true,
_timestamp: Date.now() // 添加时间戳确保query对象发生变化
});
setSelectedKeys([]);
return;
}
if (!folder) return;
const f = {
...folder,
parent_id: String(keys[0]),
}
let url = `/documents/${knowledgeBaseId}/${String(keys[0])}/documents`;
setQuery({
...query,
parent_id: String(keys[0]),
_timestamp: Date.now() // 添加时间戳确保query对象发生变化
})
let url = `/documents/${knowledgeBaseId}/documents`;
setTableApi(url);
setParentId(String(keys[0]))
setFolder(f)
@@ -253,6 +307,15 @@ const Private: FC = () => {
datasetModalRef?.current?.handleOpen(knowledgeBase?.id,folder?.parent_id ?? knowledgeBase?.id ?? '');
},
},
{
key: '8',
icon: <img src={blankIcon} alt="Custome Text" style={{ width: 16, height: 16 }} />,
label: t('knowledgeBase.customTextDataset'),
onClick: () => {
createContentModalRef?.current?.handleOpen(knowledgeBase?.id ?? '', folder?.parent_id ?? knowledgeBase?.id ?? '');
// handleCreate('folder'); // 传入 type: 'folder'
},
},
// 暂时未实现
// {
// key: '3',
@@ -413,6 +476,21 @@ const Private: FC = () => {
state: {
documentId: document.id,
parentId: parentId ?? knowledgeBaseId,
// 传递面包屑信息
breadcrumbPath: {
knowledgeBaseFolderPath,
knowledgeBase: {
id: knowledgeBase?.id || knowledgeBaseId,
name: knowledgeBase?.name || '',
type: 'knowledgeBase'
},
documentFolderPath: folderPath,
document: {
id: document.id,
name: document.file_name || '',
type: 'document'
}
}
},
});
}
@@ -486,7 +564,9 @@ const Private: FC = () => {
}
const refreshDirectoryTree = async () => {
// 先刷新知识库详情,确保数据是最新的
await fetchKnowledgeBaseDetail(knowledgeBase.id);
if (knowledgeBase?.id) {
await fetchKnowledgeBaseDetail(knowledgeBase.id);
}
// 添加短暂延迟,确保后端数据已经完全更新
await new Promise(resolve => setTimeout(resolve, 300));
// 然后刷新文件夹树
@@ -501,6 +581,7 @@ const Private: FC = () => {
}
const handleRootTreeLoad = (nodes: TreeNodeData[] | null) => {
if (!nodes || nodes.length === 0) {
// 如果没有节点设置folder为null这会隐藏FolderTree
setFolder(null);
} else {
// 如果有节点且 folder 为 null重新设置 folder
@@ -524,6 +605,7 @@ const Private: FC = () => {
}
const handleRefreshTable = () => {
debugger
// 刷新表格数据
tableRef.current?.loadData();
}
@@ -545,6 +627,7 @@ const Private: FC = () => {
onRootLoad={handleRootTreeLoad}
onFolderPathChange={handleFolderPathChange}
selectedKeys={selectedKeys}
autoExpandPath={autoExpandPath}
/>
</div>
)}
@@ -601,6 +684,10 @@ const Private: FC = () => {
ref={createFolderModalRef}
refreshTable={refreshDirectoryTree}
/>
<CreateContentModal
ref={createContentModalRef}
refreshTable={handleRefreshTable}
/>
<CreateModal
ref={modalRef}
refreshTable={handleRefreshTable}

View File

@@ -1,5 +1,5 @@
import { useEffect, useState, useRef, type FC } from 'react';
import { useParams } from 'react-router-dom';
import { useParams, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import type { KnowledgeBaseListItem, RecallTestDrawerRef } from '@/views/KnowledgeBase/types';
import RecallTest from '../components/RecallTest';
@@ -15,17 +15,22 @@ import kbModelIcon from '@/assets/images/knowledgeBase/kb-model.png';
import kbHistoryIcon from '@/assets/images/knowledgeBase/kb-history.png';
import { getKnowledgeBaseDetail } from '@/api/knowledgeBase';
import { formatDateTime } from '@/utils/format';
import { useMenu } from '@/store/menu';
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
const Share: FC = () => {
const { t } = useTranslation();
const params = useParams<{ knowledgeBaseId: string }>();
const location = useLocation();
const knowledgeBaseId = params.knowledgeBaseId;
const [loading, setLoading] = useState(false);
const [knowledgeBase, setKnowledgeBase] = useState<KnowledgeBaseListItem | null>(null);
const recallTestRef = useRef<RecallTestDrawerRef>(null);
const [infoItems, setInfoItems] = useState<InfoItem[]>([]);
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
const [knowledgeBaseFolderPath, setKnowledgeBaseFolderPath] = useState<BreadcrumbItem[]>([]);
const { updateBreadcrumbs } = useBreadcrumbManager({
breadcrumbType: 'detail'
});
useEffect(() => {
console.log('Share.tsx - useParams result:', params);
console.log('Share.tsx - knowledgeBaseId:', knowledgeBaseId);
@@ -46,9 +51,30 @@ const Share: FC = () => {
// 更新面包屑
useEffect(() => {
if (knowledgeBase) {
updateBreadcrumbs();
updateBreadcrumbs({
knowledgeBaseFolderPath,
knowledgeBase: {
id: knowledgeBase.id,
name: knowledgeBase.name,
type: 'knowledgeBase'
},
documentFolderPath: [],
});
}
}, [knowledgeBase]);
}, [knowledgeBase, knowledgeBaseFolderPath, updateBreadcrumbs]);
// 监听 location state 变化
useEffect(() => {
const state = location.state as {
fromKnowledgeBaseList?: boolean;
knowledgeBaseFolderPath?: BreadcrumbItem[];
} | null;
// 如果是从知识库列表页跳转过来的,设置知识库文件夹路径
if (state?.fromKnowledgeBaseList && state?.knowledgeBaseFolderPath) {
setKnowledgeBaseFolderPath(state.knowledgeBaseFolderPath);
}
}, [location.state]);
const formatInfoItems = (data: KnowledgeBaseListItem): InfoItem[] => {
const items: InfoItem[] = [
{
@@ -112,46 +138,7 @@ const Share: FC = () => {
});
};
// 更新面包屑,包含知识库名称
const updateBreadcrumbs = () => {
if (!knowledgeBase) return;
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
// 只保留知识库菜单项之前的面包屑
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
: baseBreadcrumbs;
const customBreadcrumbs = [
...filteredBaseBreadcrumbs,
{
id: 0,
parent: 0,
code: null,
label: knowledgeBase.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
},
];
setCustomBreadcrumbs(customBreadcrumbs, 'space');
};
// const handleBack = () => {
// navigate('/knowledge-base');

View File

@@ -0,0 +1,117 @@
import { forwardRef, useImperativeHandle, useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { Form, Input } from 'antd';
import { useTranslation } from 'react-i18next';
import RbModal from '@/components/RbModal';
import { createDocumentAndUpload } from '@/api/knowledgeBase'
import type { CreateSetModalRef,CreateSetMoealRefProps } from '../types'
interface ContentFormData {
title: string;
content: string;
}
const CreateContentModal = forwardRef<CreateSetModalRef, CreateSetMoealRefProps>(
({ refreshTable }, ref) => {
const { t } = useTranslation();
const navigate = useNavigate();
const [visible, setVisible] = useState(false);
const [form] = Form.useForm<ContentFormData>();
const [loading, setLoading] = useState(false);
const [kbId, setKbId] = useState<string>('');
const [parentId, setParentId] = useState<string>('');
const handleClose = () => {
form.resetFields();
setLoading(false);
setVisible(false);
setKbId('');
setParentId('');
};
const handleOpen = (kb_id: string, parent_id: string) => {
setKbId(kb_id);
setParentId(parent_id);
form.resetFields();
setVisible(true);
};
const handleSave = async () => {
try {
const values = await form.validateFields();
setLoading(true);
// TODO: 这里需要调用相应的API来保存内容
const params = {
// ...values,
kb_id: kbId,
parent_id: parentId,
};
const response = await createDocumentAndUpload(values, params)
if(response){
handleChunking(response.kb_id,parentId,response.id)
}
handleClose();
} catch (err) {
console.error('创建内容失败:', err);
} finally {
setLoading(false);
}
};
const handleChunking = (kb_id: string, parent_id: string, file_id: string) => {
if (!kb_id) return;
const targetFileId = file_id
navigate(`/knowledge-base/${kb_id}/create-dataset`, {
state: {
source: 'local',
knowledgeBaseId: kb_id,
parentId: parent_id ?? kb_id,
startStep: 'parameterSettings',
fileId: targetFileId,
},
});
}
useImperativeHandle(ref, () => ({
handleOpen,
}));
return (
<RbModal
title={t('knowledgeBase.createContent')}
open={visible}
onCancel={handleClose}
okText={t('common.create')}
onOk={handleSave}
confirmLoading={loading}
width={600}
>
<Form form={form} layout="vertical">
<Form.Item
name="title"
label={t('knowledgeBase.title')}
rules={[{ required: true, message: t('knowledgeBase.pleaseEnterTitle') }]}
>
<Input placeholder={t('knowledgeBase.pleaseEnterTitle')} />
</Form.Item>
<Form.Item
name="content"
label={t('knowledgeBase.content')}
rules={[{ required: true, message: t('knowledgeBase.pleaseEnterContent') }]}
>
<Input.TextArea
placeholder={t('knowledgeBase.pleaseEnterContent')}
rows={8}
showCount
maxLength={5000}
/>
</Form.Item>
</Form>
</RbModal>
);
}
);
export default CreateContentModal;

View File

@@ -0,0 +1,34 @@
import { useRef } from 'react';
import { Button } from 'antd';
import CreateContentModal from './CreateContentModal';
import type { CreateContentModalRef } from '../types';
// 使用示例组件
const CreateContentModalExample = () => {
const createContentModalRef = useRef<CreateContentModalRef>(null);
const handleOpenModal = () => {
// 打开弹窗传入知识库ID和父级ID
createContentModalRef.current?.handleOpen('kb_123', 'parent_456');
};
const handleRefreshTable = () => {
console.log('刷新表格数据');
// 这里可以添加刷新表格的逻辑
};
return (
<div>
<Button type="primary" onClick={handleOpenModal}>
</Button>
<CreateContentModal
ref={createContentModalRef}
refreshTable={handleRefreshTable}
/>
</div>
);
};
export default CreateContentModalExample;

View File

@@ -2,7 +2,7 @@ import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
import { Form, Input } from 'antd';
import { useTranslation } from 'react-i18next';
import type { UploadFile } from 'antd';
import type { CreateImageModalRef, CreateImageMoealRefProps,UploadFileResponse } from '@/views/KnowledgeBase/types';
import type { CreateSetModalRef, CreateSetMoealRefProps, UploadFileResponse } from '@/views/KnowledgeBase/types';
import type { UploadRequestOption } from 'rc-upload/lib/interface';
import RbModal from '@/components/RbModal';
import UploadFiles from '@/components/Upload/UploadFiles';
@@ -13,7 +13,7 @@ interface ImageDatasetFormData {
images: UploadFile[];
}
const CreateImageDataset = forwardRef<CreateImageModalRef, CreateImageMoealRefProps>(
const CreateImageDataset = forwardRef<CreateSetModalRef, CreateSetMoealRefProps>(
({ refreshTable }, ref) => {
const { t } = useTranslation();
const [visible, setVisible] = useState(false);

View File

@@ -60,6 +60,8 @@ interface FolderTreeProps {
onRootLoad?: (nodes: TreeNodeData[] | null) => void;
onFolderPathChange?: (path: Array<{ id: string; name: string }>) => void;
selectedKeys?: React.Key[];
// 新增:自动展开到指定路径
autoExpandPath?: Array<{ id: string; name: string }>;
}
const renderIcon = (icon?: string) => {
@@ -275,8 +277,11 @@ const FolderTree: FC<FolderTreeProps> = ({
onRootLoad,
onFolderPathChange,
selectedKeys,
autoExpandPath,
}) => {
const [treeData, setTreeData] = useState<TreeNodeData[]>([]);
const [expandedKeys, setExpandedKeys] = useState<React.Key[]>([]);
const [autoExpandInProgress, setAutoExpandInProgress] = useState(false);
// 更新树节点数据的辅助函数
const updateTreeData = (nodes: TreeNodeData[], key: Key, children: TreeNodeData[]): TreeNodeData[] => {
@@ -370,6 +375,109 @@ const FolderTree: FC<FolderTreeProps> = ({
return null;
};
// 查找节点的辅助函数
const findNodeInTree = (nodes: TreeNodeData[], key: string): TreeNodeData | null => {
for (const node of nodes) {
if (String(node.key) === key) {
return node;
}
if (node.children) {
const found = findNodeInTree(node.children, key);
if (found) return found;
}
}
return null;
};
// 渐进式自动展开到指定路径
useEffect(() => {
if (!autoExpandPath || autoExpandPath.length === 0 || autoExpandInProgress || treeData.length === 0) {
return;
}
const expandToPath = async () => {
setAutoExpandInProgress(true);
try {
const keysToExpand: React.Key[] = [];
let currentTreeData = treeData;
// 逐级展开,从第一级开始(跳过根节点,因为根节点已经加载)
for (let i = 0; i < autoExpandPath.length - 1; i++) {
const nodeKey = autoExpandPath[i].id;
keysToExpand.push(nodeKey);
// 查找当前节点
const targetNode = findNodeInTree(currentTreeData, nodeKey);
if (targetNode && targetNode.children === undefined) {
// 如果子节点未加载,先加载
try {
console.log(`自动展开:加载节点 ${nodeKey} 的子节点`);
const children = await buildTreeNodes(knowledgeBaseId, nodeKey);
// 更新树数据
setTreeData((prevData) => {
const newData = updateTreeData(prevData, nodeKey, children);
currentTreeData = newData; // 更新当前引用
return newData;
});
// 等待状态更新完成
await new Promise(resolve => setTimeout(resolve, 150));
} catch (error) {
console.error(`自动展开时加载节点 ${nodeKey} 失败:`, error);
// 加载失败时停止展开
break;
}
}
}
// 设置展开的节点
setExpandedKeys(keysToExpand);
// 选中最后一个节点(目标文件夹)
const targetKey = autoExpandPath[autoExpandPath.length - 1]?.id;
if (targetKey) {
console.log(`自动展开:选中目标节点 ${targetKey}`);
// 延迟选中,确保展开动画完成
setTimeout(() => {
if (onSelect) {
onSelect([targetKey], {
selected: true,
selectedNodes: [],
node: {} as any,
event: 'select',
nativeEvent: new MouseEvent('click')
});
}
}, 200);
}
} catch (error) {
console.error('自动展开路径失败:', error);
} finally {
// 延迟重置标志,确保展开过程完全完成
setTimeout(() => {
setAutoExpandInProgress(false);
}, 500);
}
};
// 延迟执行,确保树数据已经加载完成
const timer = setTimeout(expandToPath, 300);
return () => clearTimeout(timer);
}, [autoExpandPath, treeData.length, knowledgeBaseId, onSelect, autoExpandInProgress]);
// 处理展开事件
const handleExpand: TreeProps['onExpand'] = (expandedKeys, info) => {
setExpandedKeys(expandedKeys);
if (onExpand) {
onExpand(expandedKeys, info);
}
};
// 处理选择事件,计算并传递路径
const handleSelect: TreeProps['onSelect'] = (selectedKeys, info) => {
if (selectedKeys.length > 0) {
@@ -391,11 +499,13 @@ const FolderTree: FC<FolderTreeProps> = ({
return (
<DirectoryTree
key={refreshKey} // 添加key确保refreshKey变化时重新渲染整个组件
multiple={multiple}
className={className}
style={style}
onSelect={handleSelect}
onExpand={onExpand}
onExpand={handleExpand}
expandedKeys={expandedKeys}
loadData={onLoadData}
treeData={treeNodes}
selectedKeys={selectedKeys}

View File

@@ -1,8 +1,8 @@
import { useEffect, useState, useRef, useMemo, type FC } from 'react';
import { useEffect, useState, useRef, useMemo, useCallback, type FC } from 'react';
import { Row, Col, Button, Dropdown, Modal, message, Tooltip } from 'antd'
import type { MenuProps } from 'antd';
import { EllipsisOutlined } from '@ant-design/icons';
import { useNavigate } from 'react-router-dom';
import { useNavigate, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import clsx from 'clsx';
@@ -18,7 +18,8 @@ import Empty from '@/components/Empty'
import { getKnowledgeBaseList, getModelList, getModelTypeList, deleteKnowledgeBase, getKnowledgeBaseTypeList } from '@/api/knowledgeBase'
const { confirm } = Modal;
import InfiniteScroll from 'react-infinite-scroll-component';
import { useMenu } from '@/store/menu';
import { useBreadcrumbManager, type BreadcrumbItem } from '@/hooks/useBreadcrumbManager';
type ModelMenuInfo = {
menu: NonNullable<MenuProps['items']>;
@@ -28,6 +29,7 @@ type ModelMenuInfo = {
const KnowledgeBaseManagement: FC = () => {
const { t } = useTranslation();
const navigate = useNavigate();
const location = useLocation();
const [loading, setLoading] = useState(false);
const [data, setData] = useState<KnowledgeBaseListItem[]>([])
const [page, setPage] = useState(1)
@@ -42,10 +44,29 @@ const KnowledgeBaseManagement: FC = () => {
const modelListCache = useRef<Record<string, string>>({});
const modalRef = useRef<CreateModalRef>(null)
const [messageApi, contextHolder] = message.useMessage();
const processedStateRef = useRef<any>(null);
// 使用 menu store 管理面包屑
const { allBreadcrumbs, setCustomBreadcrumbs } = useMenu();
const [folderPath, setFolderPath] = useState<Array<{ id: string; name: string }>>([]);
// 使用面包屑管理 Hook
const { updateBreadcrumbs } = useBreadcrumbManager({
breadcrumbType: 'list',
onKnowledgeBaseMenuClick: useCallback(() => {
// 返回根目录
setFolderPath([]);
setQuery((prev) => ({
...prev,
parent_id: undefined,
}));
}, []),
onKnowledgeBaseFolderClick: useCallback((folderId: string, folderPath: Array<{ id: string; name: string }>) => {
// 直接更新文件夹路径和查询状态
setFolderPath(folderPath);
setQuery((prev) => ({
...prev,
parent_id: folderId,
}));
}, [])
});
const [folderPath, setFolderPath] = useState<BreadcrumbItem[]>([]);
// 生成下拉菜单项(根据当前 item
@@ -134,7 +155,7 @@ const KnowledgeBaseManagement: FC = () => {
handleCreate(type);
},
}));
}, [knowledgeBaseTypes, t, folderPath, query]);
}, [knowledgeBaseTypes, t]);
const typeToFieldKey = (type: string) => {
const normalized = (type || '').toLowerCase();
switch (normalized) {
@@ -371,90 +392,72 @@ const KnowledgeBaseManagement: FC = () => {
}));
return;
}
// 根据权限类型跳转到不同的详情页
if (knowledgeBase.permission_id === 'Private' || knowledgeBase.permission_id === 'private') {
navigate(`/knowledge-base/${knowledgeBase.id}/private`)
// 跳转时传递当前的文件夹路径信息
const navigationState = {
fromKnowledgeBaseList: true,
knowledgeBaseFolderPath: folderPath,
parentId: query.parent_id,
timestamp: Date.now(), // 添加时间戳确保每次跳转状态都不同
};
const targetPath = knowledgeBase.permission_id === 'Private' || knowledgeBase.permission_id === 'private'
? `/knowledge-base/${knowledgeBase.id}/private`
: `/knowledge-base/${knowledgeBase.id}/share`;
// 检查是否是相同路径跳转
const currentPath = location.pathname;
if (currentPath === targetPath) {
// 如果是相同路径使用replace并强制刷新状态
navigate(targetPath, {
state: navigationState,
replace: true
});
} else {
navigate(`/knowledge-base/${knowledgeBase.id}/share`)
// 不同路径,正常跳转
navigate(targetPath, { state: navigationState });
}
}
// 更新面包屑的函数
const updateBreadcrumbs = () => {
const baseBreadcrumbs = allBreadcrumbs['space'] || [];
// 只保留知识库菜单项之前的面包屑
const knowledgeBaseMenuIndex = baseBreadcrumbs.findIndex(item => item.path === '/knowledge-base');
const filteredBaseBreadcrumbs = knowledgeBaseMenuIndex >= 0
? baseBreadcrumbs.slice(0, knowledgeBaseMenuIndex + 1)
: baseBreadcrumbs;
// 给"知识库管理"添加点击事件,返回根目录
const breadcrumbsWithClick = filteredBaseBreadcrumbs.map((item) => {
if (item.path === '/knowledge-base') {
return {
...item,
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
// 返回根目录
setFolderPath([]);
setQuery((prev) => ({
...prev,
parent_id: undefined,
}));
return false;
},
};
}
return item;
});
const customBreadcrumbs = [
...breadcrumbsWithClick,
...folderPath.map((folder, index) => ({
id: 0,
parent: 0,
code: null,
label: folder.name,
i18nKey: null,
path: null,
enable: true,
display: true,
level: 0,
sort: 0,
icon: null,
iconActive: null,
menuDesc: null,
deleted: null,
updateTime: 0,
new_: null,
keepAlive: false,
master: null,
disposable: false,
appSystem: null,
subs: [],
onClick: (e?: React.MouseEvent) => {
e?.preventDefault();
e?.stopPropagation();
// 点击文件夹,回到该文件夹层级
const newFolderPath = folderPath.slice(0, index + 1);
setFolderPath(newFolderPath);
setQuery((prev) => ({
...prev,
parent_id: folder.id,
}));
return false;
},
})),
];
setCustomBreadcrumbs(customBreadcrumbs, 'space');
};
// 更新面包屑
useEffect(() => {
updateBreadcrumbs();
}, [folderPath]);
updateBreadcrumbs({
knowledgeBaseFolderPath: folderPath,
documentFolderPath: [],
});
}, [folderPath, updateBreadcrumbs]);
// 处理从详情页返回的导航
useEffect(() => {
const state = location.state as {
navigateToFolder?: string;
folderPath?: Array<{ id: string; name: string }>;
resetToRoot?: boolean;
} | null;
// 避免重复处理相同的状态
if (state && state !== processedStateRef.current) {
processedStateRef.current = state;
if (state.resetToRoot) {
// 重置到根目录
setFolderPath([]);
setQuery((prev) => ({
...prev,
parent_id: undefined,
}));
} else if (state?.navigateToFolder && state?.folderPath) {
// 恢复文件夹路径和查询状态
setFolderPath(state.folderPath);
setQuery((prev) => ({
...prev,
parent_id: state.navigateToFolder,
}));
}
// 不清除 state避免干扰后续导航
// 使用 processedStateRef 来避免重复处理相同的 state
}
}, [location.state, navigate]);
useEffect(() => {
fetchModelTypes();
@@ -465,7 +468,7 @@ const KnowledgeBaseManagement: FC = () => {
if (modelTypes.length) {
fetchData(1, false);
}
}, [modelTypes, query])
}, [modelTypes, query.parent_id, query.keywords, query.orderby, query.desc])
return (
<>

View File

@@ -146,11 +146,19 @@ export interface CreateFolderModalRefProps{
refreshTable?: () => void;
}
//建图片数据集
export interface CreateImageModalRef{
handleOpen: (kb_id:string,parent_id:string) => void;
//建图片数据集 / 创建自定义文本数据集
export interface CreateSetModalRef{
handleOpen: (kb_id:string, parent_id:string) => void;
}
export interface CreateImageMoealRefProps{
export interface CreateSetMoealRefProps{
refreshTable?: () => void;
}
// 创建内容
export interface CreateContentModalRef {
handleOpen: (kb_id: string, parent_id: string) => void;
}
export interface CreateContentModalRefProps {
refreshTable?: () => void;
}