diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 3e7db8cb..002547f6 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -7,6 +7,10 @@ from celery import Celery from app.core.config import settings +# macOS fork() safety - must be set before any Celery initialization +if platform.system() == 'Darwin': + os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') + # 创建 Celery 应用实例 # broker: 任务队列(使用 Redis DB 0) # backend: 结果存储(使用 Redis DB 10) @@ -64,6 +68,11 @@ celery_app.conf.update( 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) + 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index c4a2f984..5831586c 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -41,9 +41,9 @@ from . import ( upload_controller, user_controller, user_memory_controllers, - workflow_controller, workspace_controller, ontology_controller, + skill_controller ) # 创建管理端 API 路由器 @@ -76,7 +76,6 @@ manager_router.include_router(release_share_controller.router) manager_router.include_router(public_share_controller.router) # 公开路由(无需认证) manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(multi_agent_controller.router) -manager_router.include_router(workflow_controller.router) manager_router.include_router(emotion_controller.router) manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) @@ -90,5 +89,6 @@ manager_router.include_router(memory_perceptual_controller.router) manager_router.include_router(memory_working_controller.router) manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) +manager_router.include_router(skill_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 71e6e7ca..cf40e99a 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -22,6 +22,7 @@ from app.services import app_service, workspace_service from app.services.agent_config_helper import enrich_agent_config from app.services.app_service import AppService from app.services.workflow_service import WorkflowService, get_workflow_service +from app.services.app_statistics_service import AppStatisticsService router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -904,8 +905,6 @@ def get_app_statistics( - total_tokens: 总token消耗 """ workspace_id = current_user.current_workspace_id - - from app.services.app_statistics_service import AppStatisticsService stats_service = AppStatisticsService(db) result = stats_service.get_app_statistics( @@ -916,3 +915,36 @@ def get_app_statistics( ) return success(data=result) + + +@router.get("/workspace/api-statistics", summary="工作空间API调用统计") +@cur_workspace_access_guard() +def get_workspace_api_statistics( + start_date: int, + end_date: int, + db: Session = Depends(get_db), + current_user=Depends(get_current_user), +): + """获取工作空间API调用统计 + + Args: + start_date: 开始时间戳(毫秒) + end_date: 结束时间戳(毫秒) + + Returns: + 每日统计数据列表,每项包含: + - date: 日期 + - total_calls: 当日总调用次数 + - app_calls: 当日应用调用次数 + - service_calls: 当日服务调用次数 + """ + workspace_id = current_user.current_workspace_id + stats_service = AppStatisticsService(db) + + result = stats_service.get_workspace_api_statistics( + workspace_id=workspace_id, + start_date=start_date, + end_date=end_date + ) + + return success(data=result) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 94e3118c..f36aa6c5 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -116,14 +116,6 @@ def _get_ontology_service( detail=f"找不到指定的LLM模型: {llm_id}" ) - # 检查是否为组合模型 - if hasattr(model_config, 'is_composite') and model_config.is_composite: - logger.error(f"Model {llm_id} is a composite model, which is not supported for ontology extraction") - raise HTTPException( - status_code=400, - detail="本体提取不支持使用组合模型,请选择单个模型" - ) - # 验证模型配置了API密钥 if not model_config.api_keys: logger.error(f"Model {llm_id} has no API key configuration") diff --git a/api/app/controllers/prompt_optimizer_controller.py b/api/app/controllers/prompt_optimizer_controller.py index 61195deb..80f14cd3 100644 --- a/api/app/controllers/prompt_optimizer_controller.py +++ b/api/app/controllers/prompt_optimizer_controller.py @@ -120,7 +120,8 @@ async def get_prompt_opt( session_id=session_id, user_id=current_user.id, current_prompt=data.current_prompt, - user_require=data.message + user_require=data.message, + skill=data.skill ): # chunk 是 prompt 的增量内容 yield f"event:message\ndata: {json.dumps(chunk)}\n\n" diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 536dffd9..9435fc9b 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -587,7 +587,8 @@ async def chat( user_rag_memory_id=user_rag_memory_id, app_id=release.app_id, workspace_id=workspace_id, - release_id=release.id + release_id=release.id, + public=True ): event_type = event.get("event", "message") event_data = event.get("data", {}) diff --git a/api/app/controllers/skill_controller.py b/api/app/controllers/skill_controller.py new file mode 100644 index 00000000..2308307b --- /dev/null +++ b/api/app/controllers/skill_controller.py @@ -0,0 +1,90 @@ +"""Skill Controller - 技能市场管理""" +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session +from typing import Optional +import uuid + +from app.db import get_db +from app.dependencies import get_current_user, cur_workspace_access_guard +from app.models import User +from app.schemas import skill_schema +from app.schemas.response_schema import PageData, PageMeta +from app.services.skill_service import SkillService +from app.core.response_utils import success + +router = APIRouter(prefix="/skills", tags=["Skills"]) + + +@router.post("", summary="创建技能") +@cur_workspace_access_guard() +def create_skill( + data: skill_schema.SkillCreate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """创建技能 - 可以关联现有工具(内置、MCP、自定义)""" + tenant_id = current_user.tenant_id + skill = SkillService.create_skill(db, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功") + + +@router.get("", summary="技能列表") +@cur_workspace_access_guard() +def list_skills( + search: Optional[str] = Query(None, description="搜索关键词"), + is_active: Optional[bool] = Query(None, description="是否激活"), + is_public: Optional[bool] = Query(None, description="是否公开"), + page: int = Query(1, ge=1, description="页码"), + pagesize: int = Query(10, ge=1, le=100, description="每页数量"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """技能市场列表 - 包含本工作空间和公开的技能""" + tenant_id = current_user.tenant_id + skills, total = SkillService.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + items = [skill_schema.Skill.model_validate(s) for s in skills] + meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) + return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功") + + +@router.get("/{skill_id}", summary="获取技能详情") +@cur_workspace_access_guard() +def get_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """获取技能详情""" + tenant_id = current_user.tenant_id + skill = SkillService.get_skill(db, skill_id, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功") + + +@router.put("/{skill_id}", summary="更新技能") +@cur_workspace_access_guard() +def update_skill( + skill_id: uuid.UUID, + data: skill_schema.SkillUpdate, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """更新技能""" + tenant_id = current_user.tenant_id + skill = SkillService.update_skill(db, skill_id, data, tenant_id) + return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功") + + +@router.delete("/{skill_id}", summary="删除技能") +@cur_workspace_access_guard() +def delete_skill( + skill_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """删除技能""" + tenant_id = current_user.tenant_id + SkillService.delete_skill(db, skill_id, tenant_id) + return success(msg="技能删除成功") diff --git a/api/app/controllers/workflow_controller.py b/api/app/controllers/workflow_controller.py deleted file mode 100644 index 8a15f717..00000000 --- a/api/app/controllers/workflow_controller.py +++ /dev/null @@ -1,610 +0,0 @@ -""" -工作流 API 控制器 -""" - -import logging -import uuid -from typing import Annotated - -from fastapi import APIRouter, Depends, Path, Query -from sqlalchemy.orm import Session - -from app.db import get_db -from app.dependencies import get_current_user, cur_workspace_access_guard - -from app.models.user_model import User -from app.models.app_model import App -from app.services.workflow_service import WorkflowService, get_workflow_service -from app.schemas.workflow_schema import ( - WorkflowConfigCreate, - WorkflowConfigUpdate, - WorkflowConfig, - WorkflowValidationResponse, - WorkflowExecution, - WorkflowNodeExecution, - WorkflowExecutionRequest, - WorkflowExecutionResponse -) -from app.core.response_utils import success, fail -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/apps", tags=["workflow"]) - - -# ==================== 工作流配置管理 ==================== - -@router.post("/{app_id}/workflow") -@cur_workspace_access_guard() -async def create_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - config: WorkflowConfigCreate, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] -): - """创建工作流配置 - - 创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。 - """ - try: - # 验证应用是否存在且属于当前工作空间 - app = db.query(App).filter( - App.id == app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="应用不存在或无权访问" - ) - - # 验证应用类型 - if app.type != "workflow": - return fail( - code=BizCode.INVALID_PARAMETER, - msg=f"应用类型必须为 workflow,当前为 {app.type}" - ) - - # 创建工作流配置 - workflow_config = service.create_workflow_config( - app_id=app_id, - nodes=[node.model_dump() for node in config.nodes], - edges=[edge.model_dump() for edge in config.edges], - variables=[var.model_dump() for var in config.variables], - execution_config=config.execution_config.model_dump(), - triggers=[trigger.model_dump() for trigger in config.triggers], - validate=True # 进行基础验证 - ) - - return success( - data=WorkflowConfig.model_validate(workflow_config), - msg="工作流配置创建成功" - ) - - except BusinessException as e: - logger.warning(f"创建工作流配置失败: {e.message}") - return fail(code=e.error_code, msg=e.message) - except Exception as e: - logger.error(f"创建工作流配置异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"创建工作流配置失败: {str(e)}" - ) - - -# -# @router.get("/{app_id}/workflow") -# async def get_workflow_config( -# app_id: Annotated[uuid.UUID, Path(description="应用 ID")], -# db: Annotated[Session, Depends(get_db)], -# current_user: Annotated[User, Depends(get_current_user)] -# -# ): -# """获取工作流配置 -# -# 获取应用的工作流配置详情。 -# """ -# try: -# # 验证应用是否存在且属于当前工作空间 -# app = db.query(App).filter( -# App.id == app_id, -# App.workspace_id == current_user.current_workspace_id, -# App.is_active == True -# ).first() -# -# if not app: -# return fail( -# code=BizCode.NOT_FOUND, -# msg="应用不存在或无权访问" -# ) -# -# # 获取工作流配置 -# service = WorkflowService(db) -# workflow_config = service.get_workflow_config(app_id) -# -# if not workflow_config: -# return fail( -# code=BizCode.NOT_FOUND, -# msg="工作流配置不存在" -# ) -# -# return success( -# data=WorkflowConfig.model_validate(workflow_config) -# ) -# -# except Exception as e: -# logger.error(f"获取工作流配置异常: {e}", exc_info=True) -# return fail( -# code=BizCode.INTERNAL_ERROR, -# msg=f"获取工作流配置失败: {str(e)}" -# ) - - -# @router.put("/{app_id}/workflow") -# async def update_workflow_config( -# app_id: Annotated[uuid.UUID, Path(description="应用 ID")], -# config: WorkflowConfigUpdate, -# db: Annotated[Session, Depends(get_db)], -# current_user: Annotated[User, Depends(get_current_user)], -# service: Annotated[WorkflowService, Depends(get_workflow_service)] -# ): -# """更新工作流配置 - -# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。 -# """ -# try: -# # 验证应用是否存在且属于当前工作空间 -# app = db.query(App).filter( -# App.id == app_id, -# App.workspace_id == current_user.current_workspace_id, -# App.is_active == True -# ).first() - -# if not app: -# return fail( -# code=BizCode.NOT_FOUND, -# msg="应用不存在或无权访问" -# ) - -# # 更新工作流配置 -# workflow_config = service.update_workflow_config( -# app_id=app_id, -# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None, -# edges=[edge.model_dump() for edge in config.edges] if config.edges else None, -# variables=[var.model_dump() for var in config.variables] if config.variables else None, -# execution_config=config.execution_config.model_dump() if config.execution_config else None, -# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None, -# validate=True -# ) - -# return success( -# data=WorkflowConfig.model_validate(workflow_config), -# msg="工作流配置更新成功" -# ) - -# except BusinessException as e: -# logger.warning(f"更新工作流配置失败: {e.message}") -# return fail(code=e.error_code, msg=e.message) -# except Exception as e: -# logger.error(f"更新工作流配置异常: {e}", exc_info=True) -# return fail( -# code=BizCode.INTERNAL_ERROR, -# msg=f"更新工作流配置失败: {str(e)}" -# ) - - -@router.delete("/{app_id}/workflow") -async def delete_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] -): - """删除工作流配置 - - 删除应用的工作流配置。 - """ - try: - # 验证应用是否存在且属于当前工作空间 - app = db.query(App).filter( - App.id == app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="应用不存在或无权访问" - ) - - # 删除工作流配置 - deleted = service.delete_workflow_config(app_id) - - if not deleted: - return fail( - code=BizCode.NOT_FOUND, - msg="工作流配置不存在" - ) - - return success(msg="工作流配置删除成功") - - except Exception as e: - logger.error(f"删除工作流配置异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"删除工作流配置失败: {str(e)}" - ) - - -@router.post("/{app_id}/workflow/validate") -async def validate_workflow_config( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - for_publish: Annotated[bool, Query(description="是否为发布验证")] = False -): - """验证工作流配置 - - 验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。 - """ - try: - # 验证应用是否存在且属于当前工作空间 - app = db.query(App).filter( - App.id == app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="应用不存在或无权访问" - ) - - # 验证工作流配置 - - if for_publish: - is_valid, errors = service.validate_workflow_config_for_publish(app_id) - else: - workflow_config = service.get_workflow_config(app_id) - if not workflow_config: - return fail( - code=BizCode.NOT_FOUND, - msg="工作流配置不存在" - ) - - from app.core.workflow.validator import validate_workflow_config as validate_config - config_dict = { - "nodes": workflow_config.nodes, - "edges": workflow_config.edges, - "variables": workflow_config.variables, - "execution_config": workflow_config.execution_config, - "triggers": workflow_config.triggers - } - is_valid, errors = validate_config(config_dict, for_publish=False) - - return success( - data=WorkflowValidationResponse( - is_valid=is_valid, - errors=errors, - warnings=[] - ) - ) - - except BusinessException as e: - logger.warning(f"验证工作流配置失败: {e.message}") - return fail(code=e.error_code, msg=e.message) - except Exception as e: - logger.error(f"验证工作流配置异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"验证工作流配置失败: {str(e)}" - ) - - -# ==================== 工作流执行管理 ==================== - -@router.get("/{app_id}/workflow/executions") -async def get_workflow_executions( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)], - limit: Annotated[int, Query(ge=1, le=100)] = 50, - offset: Annotated[int, Query(ge=0)] = 0 -): - """获取工作流执行记录列表 - - 获取应用的工作流执行历史记录。 - """ - try: - # 验证应用是否存在且属于当前工作空间 - app = db.query(App).filter( - App.id == app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="应用不存在或无权访问" - ) - - # 获取执行记录 - executions = service.get_executions_by_app(app_id, limit, offset) - - # 获取统计信息 - statistics = service.get_execution_statistics(app_id) - - return success( - data={ - "executions": [WorkflowExecution.model_validate(e) for e in executions], - "statistics": statistics, - "pagination": { - "limit": limit, - "offset": offset, - "total": statistics["total"] - } - } - ) - - except Exception as e: - logger.error(f"获取工作流执行记录异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"获取工作流执行记录失败: {str(e)}" - ) - - -@router.get("/workflow/executions/{execution_id}") -async def get_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] -): - """获取工作流执行详情 - - 获取单个工作流执行的详细信息,包括所有节点的执行记录。 - """ - try: - # 获取执行记录 - execution = service.get_execution(execution_id) - - if not execution: - return fail( - code=BizCode.NOT_FOUND, - msg="执行记录不存在" - ) - - # 验证应用是否属于当前工作空间 - app = db.query(App).filter( - App.id == execution.app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="无权访问该执行记录" - ) - - # 获取节点执行记录 - node_executions = service.node_execution_repo.get_by_execution_id(execution.id) - - return success( - data={ - "execution": WorkflowExecution.model_validate(execution), - "node_executions": [ - WorkflowNodeExecution.model_validate(ne) for ne in node_executions - ] - } - ) - - except Exception as e: - logger.error(f"获取工作流执行详情异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"获取工作流执行详情失败: {str(e)}" - ) - - -# ==================== 工作流执行 ==================== -@router.post("/{app_id}/workflow/run") -async def run_workflow( - app_id: Annotated[uuid.UUID, Path(description="应用 ID")], - request: WorkflowExecutionRequest, - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] -): - """执行工作流 - - 执行工作流并返回结果。支持流式和非流式两种模式。 - - **非流式模式**:等待工作流执行完成后返回完整结果。 - - **流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。 - """ - try: - # 验证应用是否存在且属于当前工作空间 - app = db.query(App).filter( - App.id == app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="应用不存在或无权访问" - ) - - # 验证应用类型 - if app.type != "workflow": - return fail( - code=BizCode.INVALID_PARAMETER, - msg=f"应用类型必须为 workflow,当前为 {app.type}" - ) - - # 准备输入数据 - input_data = { - "message": request.message or "", - "variables": request.variables - } - - # 执行工作流 - - if request.stream: - # 流式执行 - from fastapi.responses import StreamingResponse - import json - - async def event_generator(): - """生成 SSE 事件 - - SSE 格式: - event: - data: - - 支持的事件类型: - - workflow_start: 工作流开始 - - workflow_end: 工作流结束 - - node_start: 节点开始执行 - - node_end: 节点执行完成 - - node_chunk: 中间节点的流式输出 - - message: 最终消息的流式输出(End 节点及其相邻节点) - """ - try: - async for event in await service.run_workflow( - app_id=app_id, - input_data=input_data, - triggered_by=current_user.id, - conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, - stream=True - ): - # 提取事件类型和数据 - event_type = event.get("event", "message") - event_data = event.get("data", {}) - - # 转换为标准 SSE 格式(字符串) - # event: - # data: - sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n" - yield sse_message - - except Exception as e: - logger.error(f"流式执行异常: {e}", exc_info=True) - # 发送错误事件 - sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n" - yield sse_error - - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no" # 禁用 nginx 缓冲 - } - ) - else: - # 非流式执行 - result = await service.run_workflow( - app_id=app_id, - input_data=input_data, - triggered_by=current_user.id, - conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None, - stream=False - ) - - return success( - data=WorkflowExecutionResponse( - execution_id=result["execution_id"], - status=result["status"], - output=result.get("output"), - output_data=result.get("output_data"), - error_message=result.get("error_message"), - elapsed_time=result.get("elapsed_time"), - token_usage=result.get("token_usage") - ), - msg="工作流执行完成" - ) - - except BusinessException as e: - logger.warning(f"执行工作流失败: {e.message}") - return fail(code=e.error_code, msg=e.message) - except Exception as e: - logger.error(f"执行工作流异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"执行工作流失败: {str(e)}" - ) - - -@router.post("/workflow/executions/{execution_id}/cancel") -async def cancel_workflow_execution( - execution_id: Annotated[str, Path(description="执行 ID")], - db: Annotated[Session, Depends(get_db)], - current_user: Annotated[User, Depends(get_current_user)], - service: Annotated[WorkflowService, Depends(get_workflow_service)] -): - """取消工作流执行 - - 取消正在运行的工作流执行。 - - **注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。 - """ - try: - # 获取执行记录 - execution = service.get_execution(execution_id) - - if not execution: - return fail( - code=BizCode.NOT_FOUND, - msg="执行记录不存在" - ) - - # 验证应用是否属于当前工作空间 - app = db.query(App).filter( - App.id == execution.app_id, - App.workspace_id == current_user.current_workspace_id, - App.is_active.is_(True) - ).first() - - if not app: - return fail( - code=BizCode.NOT_FOUND, - msg="无权访问该执行记录" - ) - - # 检查执行状态 - if execution.status not in ["pending", "running"]: - return fail( - code=BizCode.INVALID_PARAMETER, - msg=f"无法取消状态为 {execution.status} 的执行" - ) - - # 更新状态为 cancelled - service.update_execution_status(execution_id, "cancelled") - - return success(msg="工作流执行已取消") - - except BusinessException as e: - logger.warning(f"取消工作流执行失败: {e.message}") - return fail(code=e.code, msg=e.message) - except Exception as e: - logger.error(f"取消工作流执行异常: {e}", exc_info=True) - return fail( - code=BizCode.INTERNAL_ERROR, - msg=f"取消工作流执行失败: {str(e)}" - ) diff --git a/api/app/core/agent/agent_middleware.py b/api/app/core/agent/agent_middleware.py new file mode 100644 index 00000000..735423c9 --- /dev/null +++ b/api/app/core/agent/agent_middleware.py @@ -0,0 +1,162 @@ +"""Agent Middleware - 动态技能过滤""" +import uuid +from typing import List, Dict, Any, Optional +from langchain_core.runnables import RunnablePassthrough + +from app.services.skill_service import SkillService +from app.repositories.skill_repository import SkillRepository + + +class AgentMiddleware: + """Agent 中间件 - 用于动态过滤和加载技能""" + + def __init__(self, skills: Optional[dict] = None): + """ + 初始化中间件 + + Args: + skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]} + """ + self.skills = skills or {} + self.enabled = self.skills.get('enabled', False) + self.all_skills = self.skills.get('all_skills', False) + self.skill_ids = self.skills.get('skill_ids', []) + + @staticmethod + def filter_tools( + tools: List, + message: str = "", + skill_configs: Dict[str, Any] = None, + tool_to_skill_map: Dict[str, str] = None + ) -> tuple[List, List[str]]: + """ + 根据消息内容和技能配置动态过滤工具 + + Args: + tools: 所有可用工具列表 + message: 用户消息(可用于智能过滤) + skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}} + tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id} + + Returns: + (过滤后的工具列表, 激活的技能ID列表) + """ + if not tools: + return [], [] + + # 如果没有技能配置,返回所有工具 + if not skill_configs: + return tools, [] + + # 基于关键词匹配激活技能 + activated_skill_ids = [] + message_lower = message.lower() + + for skill_id, config in skill_configs.items(): + if not config.get('enabled', True): + continue + + keywords = config.get('keywords', []) + # 如果没有关键词限制,或消息包含关键词,则激活该技能 + if not keywords or any(kw.lower() in message_lower for kw in keywords): + activated_skill_ids.append(skill_id) + + # 如果没有工具映射关系,返回所有工具 + if not tool_to_skill_map: + return tools, activated_skill_ids + + # 根据激活的技能过滤工具 + filtered_tools = [] + for tool in tools: + tool_name = getattr(tool, 'name', str(id(tool))) + # 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留 + if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids: + filtered_tools.append(tool) + + return filtered_tools, activated_skill_ids + + def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]: + """ + 加载技能关联的工具 + + Args: + db: 数据库会话 + tenant_id: 租户id + base_tools: 基础工具列表 + + Returns: + (工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id}) + """ + + tools_dict = {} + tool_to_skill_map = {} # 工具名称到技能ID的映射 + + if base_tools: + for tool in base_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + tools_dict[tool_name] = tool + # base_tools 不属于任何 skill,不加入映射 + + skill_configs = {} + skill_ids_to_load = [] + + # 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能 + if self.enabled and self.all_skills: + skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000) + skill_ids_to_load = [str(skill.id) for skill in skills] + elif self.enabled and self.skill_ids: + skill_ids_to_load = self.skill_ids + + if skill_ids_to_load: + for skill_id in skill_ids_to_load: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) + if skill and skill.is_active: + # 保存技能配置(包含prompt) + config = skill.config or {} + config['prompt'] = skill.prompt + config['name'] = skill.name + skill_configs[skill_id] = config + except Exception: + continue + + # 加载技能工具并获取映射关系 + skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id) + + # 只添加不冲突的 skill_tools + for tool in skill_tools: + tool_name = getattr(tool, 'name', str(id(tool))) + if tool_name not in tools_dict: + tools_dict[tool_name] = tool + # 复制映射关系 + if tool_name in skill_tool_map: + tool_to_skill_map[tool_name] = skill_tool_map[tool_name] + + return list(tools_dict.values()), skill_configs, tool_to_skill_map + + @staticmethod + def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str: + """ + 根据激活的技能ID获取对应的提示词 + + Args: + activated_skill_ids: 被激活的技能ID列表 + skill_configs: 技能配置字典 + + Returns: + 合并后的提示词 + """ + prompts = [] + for skill_id in activated_skill_ids: + config = skill_configs.get(skill_id, {}) + prompt = config.get('prompt') + name = config.get('name', 'Skill') + if prompt: + prompts.append(f"# {name}\n{prompt}") + + return "\n\n".join(prompts) if prompts else "" + + @staticmethod + def create_runnable(): + """创建可运行的中间件""" + return RunnablePassthrough() diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index b2dc8416..40cf068e 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -289,10 +289,9 @@ class LangChainAgent: return content_parts - return messages - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): db = next(get_db()) + #TODO: 魔法数字 scope=6 try: @@ -302,6 +301,12 @@ class LangChainAgent: from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) + + # Handle case where no session exists in Redis (returns False) + if not result or result is False: + logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") + return + if type=="chunk" or type=="aggregate": data = await format_parsing(result, "dict") chunk_data = data[:scope] @@ -309,7 +314,14 @@ class LangChainAgent: repo.upsert(end_user_id, chunk_data) logger.info(f'写入短长期:') else: + # TODO: This branch handles type="time" strategy, currently unused. + # Will be activated when time-based long-term storage is implemented. + # TODO: 魔法数字 - extract 5 to a constant long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + # Handle case where no session exists in Redis (returns False or empty) + if not long_time_data or long_time_data is False: + logger.debug(f"No recent sessions in Redis for user {end_user_id}") + return long_messages = await messages_parse(long_time_data) repo.upsert(end_user_id, long_messages) logger.info(f'写入短长期:') @@ -509,13 +521,13 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: long_term_messages=await agent_chat_messages(message_chat,content) - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - '''长期''' - if actual_config_id: - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") - else: - logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}") + # Batched long-term memory storage (Redis buffer + Neo4j when window full) + await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, "model": self.model_name, @@ -698,13 +710,14 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. long_term_messages = await agent_chat_messages(message_chat, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - if actual_config_id: - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") - else: - logger.warning(f"Skipping term_memory_save: no memory config available for end_user {end_user_id}") + # Batched long-term memory storage (Redis buffer + Neo4j when window full) + await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/config.py b/api/app/core/config.py index 0de957c7..bf721af9 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -215,6 +215,9 @@ class Settings: # official environment system version SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1") + # model square loading + LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" + # workflow config WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index d6fbbb38..e9de02b6 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config): for node_name, node_data in update_event.items(): if 'save_neo4j' == node_name: massages = node_data + # TODO:删除 massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') print(contents) @@ -60,6 +61,7 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope + redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] @@ -91,6 +93,9 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) + # Handle case where no session exists in Redis (returns False or empty) + if not long_time_data or long_time_data is False: + return format_messages = await chat_data_format(long_time_data) if format_messages!=[]: await write_messages(end_user_id, format_messages, memory_config) @@ -108,8 +113,9 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - history = await format_parsing(result) - if not result: + + # Handle case where no session exists in Redis (returns False or empty) + if not result or result is False: history = [] else: history = await format_parsing(result) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index d0e8a45d..9b858f47 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,24 +1,21 @@ - import asyncio -import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse -from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) if sys.platform.startswith("win"): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + @asynccontextmanager async def make_write_graph(): """ @@ -39,29 +36,59 @@ async def make_write_graph(): graph = workflow.compile() yield graph -async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment - from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format - from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) - # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 - service_name="MemoryAgentService" + + +async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', + end_user_id: str = '', scope: int = 6): + """Dispatch long-term memory storage to Celery background tasks. + + Args: + long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' + langchain_messages: List of messages to store + memory_config: Memory configuration ID (string) + end_user_id: End user identifier + scope: Window size for 'chunk' strategy (default: 6) + """ + from app.tasks import ( + long_term_storage_window_task, + # TODO: Uncomment when implemented + # long_term_storage_time_task, + # long_term_storage_aggregate_task, ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': + from app.core.logging_config import get_logger - """方案三:聚合判断""" - await aggregate_judgment(end_user_id, langchain_messages, memory_config) + logger = get_logger(__name__) + # Convert config to string if needed + config_id = str(memory_config) if memory_config else '' + + if long_term_type == 'chunk': + # Strategy 1: Window-based batching (6 rounds of dialogue) + logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") + long_term_storage_window_task.delay( + end_user_id=end_user_id, + langchain_messages=langchain_messages, + config_id=config_id, + scope=scope + ) + # TODO: Uncomment when time-based strategy is fully implemented + # elif long_term_type == 'time': + # # Strategy 2: Time-based retrieval + # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") + # long_term_storage_time_task.delay( + # end_user_id=end_user_id, + # config_id=config_id, + # time_window=5 + # ) + # TODO: Uncomment when aggregate strategy is fully implemented + # elif long_term_type == 'aggregate': + # # Strategy 3: Aggregate judgment (deduplication) + # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") + # long_term_storage_aggregate_task.delay( + # end_user_id=end_user_id, + # langchain_messages=langchain_messages, + # config_id=config_id + # ) # async def main(): # """主函数 - 运行工作流""" diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index e135d980..76a28156 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -174,4 +174,4 @@ async def write( f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") logger.info("=== Pipeline Complete ===") - logger.info(f"Total execution time: {total_time:.2f} seconds") + logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index 453aaa13..e5b91d1c 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -1,5 +1,4 @@ provider: bedrock -enabled: false models: - name: ai21 type: llm diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index bcdb467e..df538e72 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -1,5 +1,4 @@ provider: dashscope -enabled: false models: - name: deepseek-r1-distill-qwen-14b type: llm diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index 6469656c..a14d3268 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -1,11 +1,11 @@ """模型配置加载器 - 用于将预定义模型批量导入到数据库""" -import os from pathlib import Path from typing import Callable import yaml from sqlalchemy.orm import Session + from app.models.models_model import ModelBase, ModelProvider @@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]: with open(config_file, 'r', encoding='utf-8') as f: data = yaml.safe_load(f) - - # 检查是否需要加载(默认为 true) - if not data.get('enabled', True): - return [] - return data.get('models', []) -def _disable_yaml_config(provider: ModelProvider) -> None: - """将YAML文件的enabled标志设置为false""" - config_dir = Path(__file__).parent - config_file = config_dir / f"{provider.value}_models.yaml" - - if not config_file.exists(): - return - - with open(config_file, 'r', encoding='utf-8') as f: - data = yaml.safe_load(f) - - data['enabled'] = False - - with open(config_file, 'w', encoding='utf-8') as f: - yaml.dump(data, f, allow_unicode=True, sort_keys=False) - - def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict: """ 加载模型配置到数据库 @@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") - - # provider_success = 0 + for model_data in models: try: # 检查模型是否已存在 @@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"更新成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 else: # 创建新模型 model = ModelBase(**model_data) @@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) if not silent: print(f"添加成功: {model_data['name']}") result["success"] += 1 - # provider_success += 1 except Exception as e: db.rollback() if not silent: print(f"添加失败: {model_data['name']} - {str(e)}") result["failed"] += 1 - - # 如果该供应商的模型全部加载成功,将enabled设置为false - # if provider_success == len(models): - _disable_yaml_config(provider) return result diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 5a416264..68c63ee2 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -1,5 +1,4 @@ provider: openai -enabled: false models: - name: chatgpt-4o-latest type: llm diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index b7abf659..537058a0 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -11,19 +11,20 @@ from typing import Any from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_expression from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig from app.core.workflow.nodes import WorkflowState -from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) class WorkflowExecutor: - """工作流执行器 + """Workflow Executor. - 负责将工作流配置转换为 LangGraph 并执行。 + Converts workflow configuration into a LangGraph and executes it, + supporting both synchronous and streaming execution modes. """ def __init__( @@ -31,15 +32,29 @@ class WorkflowExecutor: workflow_config: dict[str, Any], execution_id: str, workspace_id: str, - user_id: str + user_id: str, ): - """初始化执行器 + """Initialize Workflow Executor. + + Converts a workflow configuration into an executor instance that can + run the workflow in both streaming and non-streaming modes. Args: - workflow_config: 工作流配置 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration dictionary. + execution_id (str): Unique identifier for this workflow execution. + workspace_id (str): Workspace or project ID. + user_id (str): User ID executing the workflow. + + Attributes: + self.nodes (list): List of node definitions from workflow_config. + self.edges (list): List of edge definitions from workflow_config. + self.execution_config (dict): Optional execution parameters from workflow_config. + self.start_node_id (str | None): ID of the Start node, set after graph build. + self.end_outputs (dict[str, StreamOutputConfig]): End node output configs. + self.activate_end (str | None): Currently active End node ID for streaming outputs. + self.variable_pool (VariablePool | None): Variable pool instance. + self.graph (CompiledStateGraph | None): Compiled workflow graph. + self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing. """ self.workflow_config = workflow_config self.execution_id = execution_id @@ -52,73 +67,108 @@ class WorkflowExecutor: self.start_node_id = None self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None + self.variable_pool: VariablePool | None = None + self.graph: CompiledStateGraph | None = None self.checkpoint_config = RunnableConfig( configurable={ "thread_id": uuid.uuid4(), } ) - def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: - """准备初始状态(注入系统变量和会话变量) + async def __init_variable_pool(self, input_data: dict[str, Any]): + """Initialize the variable pool with system, conversation, and input variables. - 变量命名空间: - - sys.xxx - 系统变量(execution_id, workspace_id, user_id, message, input_variables 等) - - conv.xxx - 会话变量(跨多轮对话保持) - - node_id.xxx - 节点输出(执行时动态生成) + This method populates the VariablePool instance with: + - Conversation-level variables (`conv` namespace) from workflow config or provided values. + - System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables. Args: - input_data: 输入数据 - - Returns: - 初始化的工作流状态 + input_data (dict): Input data for workflow execution, may contain: + - "message": user message (str) + - "file": list of user-uploaded files + - "conv": existing conversation variables (dict) + - "variables": custom variables for the Start node (dict) + - "conversation_id": conversation identifier """ user_message = input_data.get("message") or "" - conversation_messages = input_data.get("conv_messages") or [] + user_files = input_data.get("files") or [] - # 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value) config_variables_list = self.workflow_config.get("variables") or [] - conversation_vars = {} + conv_vars = input_data.get("conv", {}) + + # Initialize conversation variables (conv namespace) for var_def in config_variables_list: - if isinstance(var_def, dict): - var_name = var_def.get("name") - var_default = var_def.get("default") - if var_name: - if var_default: - conversation_vars[var_name] = var_default - else: - var_type = var_def.get("type") - match var_type: - case VariableType.STRING: - conversation_vars[var_name] = "" - case VariableType.NUMBER: - conversation_vars[var_name] = 0 - case VariableType.OBJECT: - conversation_vars[var_name] = {} - case VariableType.BOOLEAN: - conversation_vars[var_name] = False - case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: - conversation_vars[var_name] = [] - input_variables = input_data.get("variables") or {} # Start 节点的自定义变量 - conversation_vars = conversation_vars | input_data.get("conv", {}) - # 构建分层的变量结构 - variables = { - "sys": { - "message": user_message, # 用户消息 - "conversation_id": input_data.get("conversation_id"), # 会话 ID - "execution_id": self.execution_id, # 执行 ID - "workspace_id": self.workspace_id, # 工作空间 ID - "user_id": self.user_id, # 用户 ID - "input_variables": input_variables, # 自定义输入变量(给 Start 节点使用) - }, - "conv": conversation_vars # 会话级变量(跨多轮对话保持) + var_name = var_def.get("name") + var_default = conv_vars.get(var_name, var_def.get("default")) + var_type = var_def.get("type") + if var_name: + if var_default: + var_value = var_default + else: + var_value = DEFAULT_VALUE(var_type) + await self.variable_pool.new( + namespace="conv", + key=var_name, + value=var_value, + var_type=var_type, + mut=True + ) + + # Initialize system variables (sys namespace) + input_variables = input_data.get("variables") or {} + sys_vars = { + "message": (user_message, VariableType.STRING), + "conversation_id": (input_data.get("conversation_id"), VariableType.STRING), + "execution_id": (self.execution_id, VariableType.STRING), + "workspace_id": (self.workspace_id, VariableType.STRING), + "user_id": (self.user_id, VariableType.STRING), + "input_variables": (input_variables, VariableType.OBJECT), + "files": (user_files, VariableType.ARRAY_FILE) } + for key, var_def in sys_vars.items(): + value = var_def[0] + var_type = var_def[1] + await self.variable_pool.new( + namespace='sys', + key=key, + value=value, + var_type=var_type, + mut=False + ) + + def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState: + """Generate the initial workflow state for execution. + + This method prepares the runtime state dictionary with system variables, + conversation variables, node outputs, loop tracking, and activation flags. + + Args: + input_data (dict): The input payload for workflow execution. + Expected keys: + - "conv_messages" (list, optional): Historical conversation messages + to include in the workflow state. + + Returns: + WorkflowState: A dictionary representing the initialized workflow state + with the following keys: + - "messages": List of conversation messages + - "node_outputs": Empty dict to store outputs of executed nodes + - "execution_id": Current workflow execution ID + - "workspace_id": Current workspace ID + - "user_id": ID of the user triggering execution + - "error": None initially, will store error message if a node fails + - "error_node": None initially, will store ID of node that caused error + - "cycle_nodes": List of node IDs that are of type LOOP or ITERATION + - "looping": Integer flag indicating loop execution state (0 = not looping) + - "activate": Dict mapping node IDs to activation status; initially + only the start node is active + """ + conversation_messages = input_data.get("conv_messages") or [] return { "messages": conversation_messages, - "variables": variables, "node_outputs": {}, - "runtime_vars": {}, # 运行时节点变量(简化版,供快速访问) "execution_id": self.execution_id, "workspace_id": self.workspace_id, "user_id": self.user_id, @@ -136,18 +186,47 @@ class WorkflowExecutor: } def _build_final_output(self, result, elapsed_time, final_output): + """Construct the final standardized output of the workflow execution. + + This method aggregates node outputs, token usage, conversation and system + variables, messages, and other metadata into a consistent dictionary + structure suitable for returning from workflow execution. + + Args: + result (dict): The runtime state returned by the workflow graph execution. + Expected keys include: + - "node_outputs" (dict): Outputs of executed nodes. + - "messages" (list): Conversation messages exchanged during execution. + - "error" (str, optional): Error message if any node failed. + elapsed_time (float): Total execution time in seconds. + final_output (Any): The aggregated or final output content of the workflow + (e.g., combined messages from all End nodes). + + Returns: + dict: A dictionary containing the final workflow execution result with keys: + - "status": Execution status ("completed") + - "output": Aggregated final output content + - "variables": Namespace dictionary with: + - "conv": Conversation variables + - "sys": System variables + - "node_outputs": Outputs from all executed nodes + - "messages": Conversation messages exchanged + - "conversation_id": ID of the current conversation + - "elapsed_time": Total execution time in seconds + - "token_usage": Aggregated token usage across nodes (if available) + - "error": Error message if any occurred during execution + """ node_outputs = result.get("node_outputs", {}) token_usage = self._aggregate_token_usage(node_outputs) - conversation_id = None - for node_id, node_output in node_outputs.items(): - if node_output.get("node_type") == "start": - conversation_id = node_output.get("output", {}).get("conversation_id") - break + conversation_id = self.variable_pool.get_value("sys.conversation_id") return { "status": "completed", "output": final_output, - "variables": result.get("variables", {}), + "variables": { + "conv": self.variable_pool.get_all_conversation_vars(), + "sys": self.variable_pool.get_all_system_vars() + }, "node_outputs": node_outputs, "messages": result.get("messages", []), "conversation_id": conversation_id, @@ -163,7 +242,7 @@ class WorkflowExecutor: Iterates over all End nodes in `self.end_outputs` and calls `update_activate` on each, which may: - Activate variable segments that depend on the completed node/scope. - - Activate the entire End node output if all control conditions are met. + - Activate the entire End node output if any control conditions are met. If any End node becomes active and `self.activate_end` is not yet set, this node will be marked as the currently active End node. @@ -197,18 +276,11 @@ class WorkflowExecutor: """ for node_id in data.keys(): if activate.get(node_id): - node_output_status = ( - data[node_id] - .get('runtime_vars', {}) - .get(node_id) - .get("output") - ) + node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False) self._update_scope_activate(node_id, status=node_output_status) async def _emit_active_chunks( self, - node_outputs: dict, - variables: dict, force=False ): """ @@ -231,8 +303,6 @@ class WorkflowExecutor: and reset `activate_end` to None. Args: - node_outputs (dict): Current runtime node outputs, used for variable evaluation. - variables (dict): Current runtime variables, used for variable evaluation. force (bool, default=False): If True, process segments even if `activate=False`. Yields: @@ -260,14 +330,9 @@ class WorkflowExecutor: else: # Variable segment: evaluate and transform try: - chunk = evaluate_expression( - current_segment.literal, - variables=variables, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) + chunk = self.variable_pool.get_literal(current_segment.literal) final_chunk += chunk - except ValueError: + except KeyError: # Log failed evaluation but continue streaming logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}") @@ -287,63 +352,339 @@ class WorkflowExecutor: self.end_outputs.pop(self.activate_end) self.activate_end = None - @staticmethod - def _trans_output_string(content): - if isinstance(content, str): - return content - elif isinstance(content, list): - return "\n".join(content) - else: - return str(content) + async def _handle_updates_event(self, data): + """ + Handle workflow state update events ("updates") and stream active End node outputs. + + Steps: + 1. Retrieve the current graph state. + 2. Extract node activation information from the state. + 3. Update the activation status of all End nodes. + 4. While there is an active End node: + - Call _emit_active_chunks() to yield all currently active output segments. + - After all segments are processed, update activate_end if there are remaining End nodes. + 5. Log a debug message indicating state update received. + + Args: + data (dict): The latest node state updates. + + Yields: + dict: Streamed output event, each chunk in the format: + {"event": "message", "data": {"chunk": ...}} + """ + # Get the latest workflow state + state = self.graph.get_state(config=self.checkpoint_config).values + activate = state.get("activate", {}) + + # Update End node activation based on the new state + self._update_stream_output_status(activate, data) + wait = False + while self.activate_end and not wait: + async for msg_event in self._emit_active_chunks(): + yield msg_event + + if self.activate_end: + wait = True + else: + self._update_stream_output_status(activate, data) + + logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} " + f"- execution_id: {self.execution_id}") + + async def _handle_node_chunk_event(self, data): + """ + Handle streaming chunk events from individual nodes ("node_chunk"). + + This method processes output segments for the currently active End node. + If the segment depends on the provided node_id: + - If the node has finished execution (`done=True`), advance the cursor. + - If all segments are processed, deactivate the End node. + - Otherwise, yield the current chunk as a streaming message. + + Args: + data (dict): Node chunk event data, expected keys: + - "node_id": ID of the node producing this chunk + - "chunk": Chunk of output text + - "done": Boolean indicating whether the node finished producing output + + Yields: + dict: Streaming message event in the format: + {"event": "message", "data": {"chunk": ...}} + """ + node_id = data.get("node_id") + if self.activate_end: + end_info = self.end_outputs.get(self.activate_end) + if not end_info or end_info.cursor >= len(end_info.outputs): + return + current_output = end_info.outputs[end_info.cursor] + if current_output.is_variable and current_output.depends_on_scope(node_id): + if data.get("done"): + end_info.cursor += 1 + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None + else: + yield { + "event": "message", + "data": { + "chunk": data.get("chunk") + } + } + + async def _handle_node_error_event(self, data): + """ + Handle node error events ("node_error") during workflow execution. + + This method streams an error event for a node that has failed. The event + contains the node ID, status, input data, elapsed time, and error message. + + Args: + data (dict): Node error event data, expected keys: + - "node_id": ID of the node that failed + - "input_data": The input data that caused the error + - "elapsed_time": Execution time before the error occurred + - "error": Error message or exception string + + Yields: + dict: Node error event in the format: + { + "event": "node_error", + "data": { + "node_id": str, + "status": "failed", + "input": ..., + "elapsed_time": float, + "output": None, + "error": str + } + } + """ + node_id = data.get("node_id") + yield { + "event": "node_error", + "data": { + "node_id": node_id, + "status": "failed", + "input": data.get("input_data"), + "elapsed_time": data.get("elapsed_time"), + "output": None, + "error": data.get("error") + } + } + + async def _handle_debug_event(self, data, input_data): + """ + Handle debug events ("debug") related to node execution status. + + This method streams debug events for nodes, including when a node starts + execution ("node_start") and when it completes execution ("node_end"). + It filters out nodes with names starting with "nop" as no-operation nodes. + + Args: + data (dict): Debug event data, expected keys: + - "type": Event type ("task" for start, "task_result" for completion) + - "payload": Node-related information, including: + - "name": Node name / ID + - "input": Node input data (for "task" type) + - "result": Node execution result (for "task_result" type) + - "timestamp": ISO timestamp string of the event + input_data (dict): Original workflow input data (used to get conversation_id) + + Yields: + dict: Node debug event in one of the following formats: + 1. Node start: + { + "event": "node_start", + "data": { + "node_id": str, + "conversation_id": str, + "execution_id": str, + "timestamp": int (ms) + } + } + 2. Node end: + { + "event": "node_end", + "data": { + "node_id": str, + "conversation_id": str, + "execution_id": str, + "timestamp": int (ms), + "input": dict, + "output": Any, + "elapsed_time": float + } + } + """ + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + # Skip no-operation nodes + if node_name and node_name.startswith("nop"): + return + + if event_type == "task": + # Node starts execution + inputv = payload.get("input", {}) + if not inputv.get("activate", {}).get(node_name): + return + conversation_id = input_data.get("conversation_id") + logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}") + + yield { + "event": "node_start", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": self.execution_id, + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), + } + } + elif event_type == "task_result": + # Node execution completed + result = payload.get("result", {}) + if not result.get("activate", {}).get(node_name): + return + + conversation_id = input_data.get("conversation_id") + logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}") + + yield { + "event": "node_end", + "data": { + "node_id": node_name, + "conversation_id": conversation_id, + "execution_id": self.execution_id, + "timestamp": int(datetime.datetime.fromisoformat( + data.get("timestamp") + ).timestamp() * 1000), + "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), + "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), + "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), + "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") + } + } + + async def _flush_remaining_chunk(self): + """ + Flush and yield all remaining output segments from active End nodes. + + This method ensures that any remaining chunks of output, which may not have + been emitted during normal streaming due to activation conditions, are fully + processed. It is typically called at the end of a workflow to guarantee + that all output is delivered. + + Behavior: + 1. Filter `end_outputs` to only keep End nodes that are still active. + 2. While there is an active End node (`self.activate_end`): + - Call `_emit_active_chunks(force=True)` to emit all segments regardless + of their activation state. + - If the current End node finishes, move to the next active End node + if any remain. + + Yields: + dict: Streamed output events in the format: + {"event": "message", "data": {"chunk": ...}} + """ + # Keep only active End nodes + self.end_outputs = { + node_id: node_info + for node_id, node_info in self.end_outputs.items() + if node_info.activate + } + + if self.end_outputs or self.activate_end: + while self.activate_end: + # Force emit all remaining chunks of the active End node + async for msg_event in self._emit_active_chunks(force=True): + yield msg_event + + # Move to next active End node if current one is done + if not self.activate_end and self.end_outputs: + self.activate_end = list(self.end_outputs.keys())[0] def build_graph(self, stream=False) -> CompiledStateGraph: - """构建 LangGraph + """ + Build the workflow graph using LangGraph. + + This method initializes a GraphBuilder with the workflow configuration, + builds the compiled state graph, and sets up the executor's key attributes: + - `start_node_id`: the ID of the start node in the workflow + - `end_outputs`: mapping of End nodes and their output configurations + - `variable_pool`: pool containing workflow variables + - `graph`: the compiled state graph ready for execution + + Args: + stream (bool, optional): Whether to enable streaming mode. Defaults to False. Returns: - 编译后的状态图 + CompiledStateGraph: The compiled and ready-to-run state graph. """ - logger.info(f"开始构建工作流图: execution_id={self.execution_id}") + logger.info(f"Starting workflow graph build: execution_id={self.execution_id}") builder = GraphBuilder( self.workflow_config, stream=stream, ) self.start_node_id = builder.start_node_id self.end_outputs = builder.end_node_map - graph = builder.build() - logger.info(f"工作流图构建完成: execution_id={self.execution_id}") + self.variable_pool = builder.variable_pool + self.graph = builder.build() + logger.info(f"Workflow graph build completed: execution_id={self.execution_id}") - return graph + return self.graph async def execute( self, input_data: dict[str, Any] ) -> dict[str, Any]: - """执行工作流(非流式) + """ + Execute the workflow in non-streaming (batch) mode. + + Steps: + 1. Build the workflow graph. + 2. Initialize the variable pool and inject system variables. + 3. Prepare the initial workflow state. + 4. Invoke the compiled graph and collect outputs. + 5. Aggregate outputs, messages, and token usage. Args: - input_data: 输入数据,包含 message 和 variables + input_data (dict): Input data including 'message' and 'variables'. Returns: - 执行结果,包含 status, output, node_outputs, elapsed_time, token_usage + dict: Execution result containing: + - status: "completed" or "failed" + - output: aggregated output string from all End nodes + - variables: current conversation and system variables + - node_outputs: all node outputs + - messages: list of messages including user and assistant content + - elapsed_time: workflow execution time in seconds + - token_usage: aggregated token usage if available + - error: error message if any """ - logger.info(f"开始执行工作流: execution_id={self.execution_id}") + logger.info(f"Starting workflow execution: execution_id={self.execution_id}") - # 记录开始时间 start_time = datetime.datetime.now() - # 1. 构建图 + # Build the workflow graph graph = self.build_graph() - # 2. 初始化状态(自动注入系统变量) + # Initialize the variable pool with input data + await self.__init_variable_pool(input_data) initial_state = self._prepare_initial_state(input_data) - # 3. 执行工作流 + # Execute the workflow try: - result = await graph.ainvoke(initial_state, config=self.checkpoint_config) + + # Aggregate output from all End nodes full_content = '' for end_id in self.end_outputs.keys(): - full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '') + full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) + + # Append messages for user and assistant result["messages"].extend( [ { @@ -356,20 +697,19 @@ class WorkflowExecutor: } ] ) - # 计算耗时 + # Calculate elapsed time end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") + logger.info(f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s") return self._build_final_output(result, elapsed_time, full_content) except Exception as e: - # 计算耗时(即使失败也记录) end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True) return { "status": "failed", "error": str(e), @@ -383,48 +723,52 @@ class WorkflowExecutor: self, input_data: dict[str, Any] ): - """执行工作流(流式) + """ + Execute the workflow in streaming mode. - 使用多个 stream_mode 来获取: - 1. "updates" - 节点的 state 更新和流式 chunk - 2. "debug" - 节点执行的详细信息(开始/完成时间) - 3. "custom" - 自定义流式数据(chunks) + Supports multiple streaming modes: + 1. "updates" - Node state updates and streaming chunks. + 2. "debug" - Detailed node execution info (start/end). + 3. "custom" - Custom streaming chunks from nodes. Args: - input_data: 输入数据 + input_data (dict): Input data including 'message', 'variables', etc. Yields: - 流式事件,格式: - { - "event": "workflow_start" | "workflow_end" | "node_start" | "node_end" | "node_chunk" | "message", - "data": {...} - } + dict: Streaming events in the format: + { + "event": "workflow_start" | "workflow_end" | "node_start" | + "node_end" | "node_chunk" | "message", + "data": {...} + } """ - logger.info(f"开始执行工作流(流式): execution_id={self.execution_id}") + logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}") - # 记录开始时间 start_time = datetime.datetime.now() - # 发送 workflow_start 事件 yield { "event": "workflow_start", "data": { "execution_id": self.execution_id, "workspace_id": self.workspace_id, + "conversation_id": input_data.get("conversation_id"), "timestamp": int(start_time.timestamp() * 1000) } } - # 1. 构建图 + # Build the workflow graph in streaming mode graph = self.build_graph(stream=True) - # 2. 初始化状态(自动注入系统变量) + # Initialize the variable pool and system variables + await self.__init_variable_pool(input_data) initial_state = self._prepare_initial_state(input_data) - # 3. Execute workflow + + try: - chunk_count = 0 full_content = '' self._update_scope_activate("sys") + + # Execute the workflow with streaming async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode @@ -442,153 +786,37 @@ class WorkflowExecutor: if mode == "custom": # Handle custom streaming events (chunks from nodes via stream writer) - chunk_count += 1 event_type = data.get("type", "node_chunk") # "message" or "node_chunk" if event_type == "node_chunk": - node_id = data.get("node_id") - if self.activate_end: - end_info = self.end_outputs.get(self.activate_end) - if not end_info or end_info.cursor >= len(end_info.outputs): - continue - current_output = end_info.outputs[end_info.cursor] - if current_output.is_variable and current_output.depends_on_scope(node_id): - if data.get("done"): - end_info.cursor += 1 - if end_info.cursor >= len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None - else: - full_content += data.get("chunk") - yield { - "event": "message", - "data": { - "chunk": data.get("chunk") - } - } - logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}" - f"- execution_id: {self.execution_id}") - - elif event_type == "node_error": - yield { - "event": event_type, # "message" or "node_chunk" - "data": { - "node_id": data.get("node_id"), - "status": "failed", - "input": data.get("input_data"), - "elapsed_time": data.get("elapsed_time"), - "output": None, - "error": data.get("error") - } - } - - elif mode == "debug": - # Handle debug information (node execution status) - event_type = data.get("type") - payload = data.get("payload", {}) - node_name = payload.get("name") - - if node_name and node_name.startswith("nop"): - continue - - if event_type == "task": - # Node starts execution - inputv = payload.get("input", {}) - if not inputv.get("activate", {}).get(node_name): - continue - conversation_id = input_data.get("conversation_id") - logger.info(f"[NODE-START] Node starts execution: {node_name} " - f"- execution_id: {self.execution_id}") - yield { - "event": "node_start", - "data": { - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": self.execution_id, - "timestamp": int(datetime.datetime.fromisoformat( - data.get("timestamp") - ).timestamp() * 1000), - } - } - elif event_type == "task_result": - # Node execution completed - result = payload.get("result", {}) - if not result.get("activate", {}).get(node_name): - continue - - conversation_id = input_data.get("conversation_id") - logger.info(f"[NODE-END] Node execution completed: {node_name} " - f"- execution_id: {self.execution_id}") - - yield { - "event": "node_end", - "data": { - "node_id": node_name, - "conversation_id": conversation_id, - "execution_id": self.execution_id, - "timestamp": int(datetime.datetime.fromisoformat( - data.get("timestamp") - ).timestamp() * 1000), - "input": result.get("node_outputs", {}).get(node_name, {}).get("input"), - "output": result.get("node_outputs", {}).get(node_name, {}).get("output"), - "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), - } - } - - elif mode == "updates": - # Handle state updates - store final state - state = graph.get_state(config=self.checkpoint_config).values - node_outputs = state.get("runtime_vars", {}) - variables = state.get("variables", {}) - activate = state.get("activate", {}) - for _, node_data in data.items(): - node_outputs |= node_data.get("runtime_vars", {}) - variables |= node_data.get("variables", {}) - - self._update_stream_output_status(activate, data) - wait = False - while self.activate_end and not wait: - async for msg_event in self._emit_active_chunks( - node_outputs=node_outputs, - variables=variables - ): - full_content += msg_event["data"]['chunk'] + async for msg_event in self._handle_node_chunk_event(data): + full_content += data.get("chunk") yield msg_event - if self.activate_end: - wait = True - else: - self._update_stream_output_status(activate, data) + elif event_type == "node_error": + async for error_event in self._handle_node_error_event(data): + yield error_event + elif mode == "debug": + async for debug_event in self._handle_debug_event(data, input_data): + yield debug_event + + elif mode == "updates": logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " f"- execution_id: {self.execution_id}") - - result = graph.get_state(self.checkpoint_config).values - node_outputs = result.get("runtime_vars", {}) - variables = result.get("variables", {}) - self.end_outputs = { - node_id: node_info - for node_id, node_info in self.end_outputs.items() - if node_info.activate - } - - if self.end_outputs or self.activate_end: - while self.activate_end: - async for msg_event in self._emit_active_chunks( - node_outputs=node_outputs, - variables=variables, - force=True - ): + async for msg_event in self._handle_updates_event(data): full_content += msg_event["data"]['chunk'] yield msg_event - if not self.activate_end and self.end_outputs: - self.activate_end = list(self.end_outputs.keys())[0] + # Flush any remaining chunks + async for msg_event in self._flush_remaining_chunk(): + full_content += msg_event["data"]['chunk'] + yield msg_event - # 计算耗时 + result = graph.get_state(self.checkpoint_config).values end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - result = graph.get_state(self.checkpoint_config).values - logger.info(result) + + # Append messages for user and assistant result["messages"].extend( [ { @@ -603,23 +831,20 @@ class WorkflowExecutor: ) logger.info( f"Workflow execution completed (streaming), " - f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" + f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}" ) - # 发送 workflow_end 事件 yield { "event": "workflow_end", "data": self._build_final_output(result, elapsed_time, full_content) } except Exception as e: - # 计算耗时(即使失败也记录) end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() - logger.error(f"工作流执行失败: execution_id={self.execution_id}, error={e}", exc_info=True) + logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True) - # 发送 workflow_end 事件(失败) yield { "event": "workflow_end", "data": { @@ -633,14 +858,20 @@ class WorkflowExecutor: @staticmethod def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None: - """聚合所有节点的 token 使用情况 + """ + Aggregate token usage statistics across all nodes. Args: - node_outputs: 所有节点的输出 + node_outputs (dict): A dictionary of all node outputs. Returns: - 聚合的 token 使用情况 {"prompt_tokens": x, "completion_tokens": y, "total_tokens": z} - 如果没有 token 使用信息,返回 None + dict | None: Aggregated token usage in the format: + { + "prompt_tokens": int, + "completion_tokens": int, + "total_tokens": int + } + Returns None if no token usage information is available. """ total_prompt_tokens = 0 total_completion_tokens = 0 @@ -673,17 +904,18 @@ async def execute_workflow( workspace_id: str, user_id: str ) -> dict[str, Any]: - """执行工作流(便捷函数) + """ + Execute a workflow (convenience function, non-streaming). Args: - workflow_config: 工作流配置 - input_data: 输入数据 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration. + input_data (dict): Input data for the workflow. + execution_id (str): Execution ID. + workspace_id (str): Workspace ID. + user_id (str): User ID. Returns: - 执行结果 + dict: Workflow execution result. """ executor = WorkflowExecutor( workflow_config=workflow_config, @@ -701,17 +933,18 @@ async def execute_workflow_stream( workspace_id: str, user_id: str ): - """执行工作流(流式,便捷函数) + """ + Execute a workflow in streaming mode (convenience function). Args: - workflow_config: 工作流配置 - input_data: 输入数据 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID + workflow_config (dict): The workflow configuration. + input_data (dict): Input data for the workflow. + execution_id (str): Execution ID. + workspace_id (str): Workspace ID. + user_id (str): User ID. Yields: - 流式事件 + dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. """ executor = WorkflowExecutor( workflow_config=workflow_config, diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/expression_evaluator.py index 1a8b101e..26f0c41c 100644 --- a/api/app/core/workflow/expression_evaluator.py +++ b/api/app/core/workflow/expression_evaluator.py @@ -1,9 +1,3 @@ -""" -安全的表达式求值器 - -使用 simpleeval 库提供安全的表达式评估,避免代码注入攻击。 -""" - import logging import re from typing import Any @@ -14,160 +8,119 @@ logger = logging.getLogger(__name__) class ExpressionEvaluator: - """安全的表达式求值器""" + """Safe expression evaluator for workflow variables and node outputs.""" - # 保留的命名空间 + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @staticmethod def evaluate( expression: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> Any: - """安全地评估表达式 - - Args: - expression: 表达式字符串,如 "{{var.score}} > 0.8" - variables: 用户定义的变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - - Returns: - 表达式求值结果 - - Raises: - ValueError: 表达式无效或求值失败 - - Examples: - >>> evaluator = ExpressionEvaluator() - >>> evaluator.evaluate( - ... "var.score > 0.8", - ... {"score": 0.9}, - ... {}, - ... {} - ... ) - True - - >>> evaluator.evaluate( - ... "node.intent.output == '售前咨询'", - ... {}, - ... {"intent": {"output": "售前咨询"}}, - ... {} - ... ) - True """ - # 移除 Jinja2 模板语法的花括号(如果存在) + Safely evaluate an expression using workflow variables. + + Args: + expression (str): The expression string, e.g., "var.score > 0.8" + conv_vars (dict): Conversation-level variables + node_outputs (dict): Outputs from workflow nodes + system_vars (dict, optional): System variables + + Returns: + Any: Result of the evaluated expression + + Raises: + ValueError: If the expression is invalid or evaluation fails + """ + # Remove Jinja2-style brackets if present expression = expression.strip() - # "{{system.message}} == {{ user.messge }}" -> "system.message == user.message" pattern = r"\{\{\s*(.*?)\s*\}\}" expression = re.sub(pattern, r"\1", expression).strip() - # 构建命名空间上下文 + # Build context for evaluation context = { - "var": variables, # 用户变量 - "node": node_outputs, # 节点输出 - "sys": system_vars or {}, # 系统变量 + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - - # 为了向后兼容,也支持直接访问(但会在日志中警告) - context.update(variables) + + context.update(conv_vars) context["nodes"] = node_outputs context.update(node_outputs) try: - # simpleeval 只支持安全的操作: - # - 算术运算: +, -, *, /, //, %, ** - # - 比较运算: ==, !=, <, <=, >, >= - # - 逻辑运算: and, or, not - # - 成员运算: in, not in - # - 属性访问: obj.attr - # - 字典/列表访问: obj["key"], obj[0] - # 不支持:函数调用、导入、赋值等危险操作 + # simpleeval supports safe operations: + # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result except NameNotDefined as e: - logger.error(f"表达式中引用了未定义的变量: {expression}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") + logger.error(f"Undefined variable in expression: {expression}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except InvalidExpression as e: - logger.error(f"表达式语法无效: {expression}, 错误: {e}") - raise ValueError(f"表达式语法无效: {e}") + logger.error(f"Invalid expression syntax: {expression}, error: {e}") + raise ValueError(f"Invalid expression syntax: {e}") except SyntaxError as e: - logger.error(f"表达式语法错误: {expression}, 错误: {e}") - raise ValueError(f"表达式语法错误: {e}") + logger.error(f"Syntax error in expression: {expression}, error: {e}") + raise ValueError(f"Syntax error: {e}") except Exception as e: - logger.error(f"表达式求值异常: {expression}, 错误: {e}") - raise ValueError(f"表达式求值失败: {e}") + logger.error(f"Expression evaluation failed: {expression}, error: {e}") + raise ValueError(f"Expression evaluation failed: {e}") @staticmethod def evaluate_bool( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> bool: - """评估布尔表达式(用于条件判断) - + """ + Evaluate a boolean expression (for conditions). + Args: - expression: 布尔表达式 - variables: 用户变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + expression (str): Boolean expression + conv_var (dict): Conversation variables + node_outputs (dict): Node outputs + system_vars (dict, optional): System variables + Returns: - 布尔值结果 - - Examples: - >>> ExpressionEvaluator.evaluate_bool( - ... "var.count >= 10 and var.status == 'active'", - ... {"count": 15, "status": "active"}, - ... {}, - ... {} - ... ) - True + bool: Boolean result """ result = ExpressionEvaluator.evaluate( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) return bool(result) @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: - """验证变量名是否合法 - + """ + Validate variable names for legality. + Args: - variables: 变量定义列表 - + variables (list[dict]): List of variable definitions + Returns: - 错误列表,如果为空则验证通过 - - Examples: - >>> ExpressionEvaluator.validate_variable_names([ - ... {"name": "user_input"}, - ... {"name": "var"} # 保留字 - ... ]) - ["变量名 'var' 是保留的命名空间,请使用其他名称"] + list[str]: List of error messages. Empty if all names are valid. """ errors = [] for var in variables: var_name = var.get("name", "") - - # 检查是否为保留命名空间 + if var_name in ExpressionEvaluator.RESERVED_NAMESPACES: errors.append( - f"变量名 '{var_name}' 是保留的命名空间,请使用其他名称" + f"Variable name '{var_name}' is a reserved namespace, please use another name" ) - - # 检查是否为有效的 Python 标识符 + if not var_name.isidentifier(): errors.append( - f"变量名 '{var_name}' 不是有效的标识符" + f"Variable name '{var_name}' is not a valid Python identifier" ) return errors @@ -176,23 +129,23 @@ class ExpressionEvaluator: # 便捷函数 def evaluate_expression( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + system_vars: dict[str, Any] ) -> Any: - """评估表达式(便捷函数)""" + """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) def evaluate_condition( expression: str, - variables: dict[str, Any], + conv_var: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> bool: - """评估条件表达式(便捷函数)""" + """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( - expression, variables, node_outputs, system_vars + expression, conv_var, node_outputs, system_vars ) diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index b1d43e08..46a594d7 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -14,9 +14,14 @@ from pydantic import BaseModel, Field from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) +SCOPE_PATTERN = re.compile( + r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}" +) + class OutputContent(BaseModel): """ @@ -53,6 +58,12 @@ class OutputContent(BaseModel): ) ) + _SCOPE: str | None = None + + def get_scope(self) -> str: + self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0] + return self._SCOPE + def depends_on_scope(self, scope: str) -> bool: """ Check if this segment depends on a given scope. @@ -63,8 +74,9 @@ class OutputContent(BaseModel): Returns: bool: True if this segment references the given scope. """ - pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}" - return bool(re.search(pattern, self.literal)) + if self._SCOPE: + return self._SCOPE == scope + return self.get_scope() == scope class StreamOutputConfig(BaseModel): @@ -167,6 +179,7 @@ class GraphBuilder: workflow_config: dict[str, Any], stream: bool = False, subgraph: bool = False, + variable_pool: VariablePool | None = None ): self.workflow_config = workflow_config @@ -180,6 +193,10 @@ class GraphBuilder: self._find_upstream_branch_node = lru_cache( maxsize=len(self.nodes) * 2 )(self._find_upstream_branch_node) + if variable_pool: + self.variable_pool = variable_pool + else: + self.variable_pool = VariablePool() self.graph = StateGraph(WorkflowState) self.add_nodes() @@ -452,9 +469,9 @@ class GraphBuilder: if self.stream: # Stream mode: create an async generator function # LangGraph collects all yielded values; the last yielded dictionary is merged into the state - def make_stream_func(inst): + def make_stream_func(inst, variable_pool=self.variable_pool): async def node_func(state: WorkflowState): - async for item in inst.run_stream(state): + async for item in inst.run_stream(state, variable_pool): yield item return node_func @@ -462,9 +479,9 @@ class GraphBuilder: self.graph.add_node(node_id, make_stream_func(node_instance)) else: # Non-stream mode: create an async function - def make_func(inst): + def make_func(inst, variable_pool=self.variable_pool): async def node_func(state: WorkflowState): - return await inst.run(state) + return await inst.run(state, variable_pool) return node_func @@ -567,27 +584,28 @@ class GraphBuilder: for target in branch_info["target"]: waiting_edges[target].append(branch_info["node"]["name"]) - def router_fn(state: WorkflowState) -> list[Send]: + def router_fn(state: WorkflowState, variable_pool: VariablePool = self.variable_pool) -> list[Send]: branch_activate = [] new_state = state.copy() new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate - + node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False) for label, branch in unique_branch.items(): - if evaluate_condition( + if node_output and evaluate_condition( branch["condition"], - state.get("variables", {}), - state.get("runtime_vars", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } + {}, + {src: node_output}, + {} ): logger.debug(f"Conditional routing {src}: selected branch {label}") new_state["activate"][branch["node"]["name"]] = True + branch_activate.append( + Send( + branch['node']['name'], + new_state + ) + ) continue new_state["activate"][branch["node"]["name"]] = False - for label, branch in unique_branch.items(): branch_activate.append( Send( branch['node']['name'], diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py index 926f86e4..1f2eb15b 100644 --- a/api/app/core/workflow/nodes/__init__.py +++ b/api/app/core/workflow/nodes/__init__.py @@ -15,7 +15,6 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.tool import ToolNode @@ -25,7 +24,6 @@ __all__ = [ "WorkflowState", "LLMNode", "AgentNode", - "TransformNode", "IfElseNode", "StartNode", "EndNode", diff --git a/api/app/core/workflow/nodes/agent/config.py b/api/app/core/workflow/nodes/agent/config.py index 413ce606..4d428a4b 100644 --- a/api/app/core/workflow/nodes/agent/config.py +++ b/api/app/core/workflow/nodes/agent/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class AgentNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index e4525d88..0818749c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -2,6 +2,7 @@ Agent 节点实现 调用已发布的 Agent 应用。 +# TODO """ import logging @@ -9,6 +10,8 @@ from typing import Any from langchain_core.messages import AIMessage from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.services.draft_run_service import DraftRunService from app.models import AppRelease from app.db import get_db @@ -30,19 +33,22 @@ class AgentNode(BaseNode): } } """ - - def _prepare_agent(self, state: WorkflowState) -> tuple[DraftRunService, AppRelease, str]: + + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: """准备 Agent(公共逻辑) Args: - state: 工作流状态 + variable_pool: 变量池 Returns: (draft_service, release, message): 服务实例、发布配置、消息 """ # 1. 渲染消息 message_template = self.config.get("message", "") - message = self._render_template(message_template, state) + message = self._render_template(message_template, variable_pool) # 2. 获取 Agent 配置 agent_id = self.config.get("agent_id") @@ -61,16 +67,17 @@ class AgentNode(BaseNode): return draft_service, release, message - async def execute(self, state: WorkflowState) -> dict[str, Any]: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """非流式执行 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 状态更新字典 """ - draft_service, release, message = self._prepare_agent(state) + draft_service, release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") @@ -79,9 +86,9 @@ class AgentNode(BaseNode): agent_config=release.config, model_config=None, message=message, - workspace_id=state.get("workspace_id"), + workspace_id=variable_pool.get_value("sys.workspace_id"), user_id=state.get("user_id"), - variables=state.get("variables", {}) + variables=variable_pool.get_all_conversation_vars() ) response = result.get("response", "") @@ -99,16 +106,17 @@ class AgentNode(BaseNode): } } - async def execute_stream(self, state: WorkflowState): + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): """流式执行 Args: state: 工作流状态 + variable_pool: 变量池 Yields: 流式事件字典 """ - draft_service, release, message = self._prepare_agent(state) + draft_service, release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") @@ -120,9 +128,9 @@ class AgentNode(BaseNode): agent_config=release.config, model_config=None, message=message, - workspace_id=state.get("workspace_id"), + workspace_id=variable_pool.get_value("sys.workspace_id"), user_id=state.get("user_id"), - variables=state.get("variables", {}) + variables=variable_pool.get_all_conversation_vars() ): # 提取内容 content = chunk.get("content", "") diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 6f2583b4..e1bb6e9d 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -6,6 +6,7 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import AssignmentOperator from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -17,13 +18,17 @@ class AssignerNode(BaseNode): self.variable_updater = True self.typed_config: AssignerNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the assignment operation defined by this node. Args: state: The current workflow state, including conversation variables, node outputs, and system variables. + variable_pool: variable pool Returns: None or the result of the assignment operation. @@ -31,60 +36,57 @@ class AssignerNode(BaseNode): # Initialize a variable pool for accessing conversation, node, and system variables self.typed_config = AssignerNodeConfig(**self.config) logger.info(f"节点 {self.node_id} 开始执行") - pool = VariablePool(state) + pattern = r"\{\{\s*(.*?)\s*\}\}" + for assignment in self.typed_config.assignments: # Get the target variable selector (e.g., "conv.test") variable_selector = assignment.variable_selector - if isinstance(variable_selector, str): - # Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"] - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", variable_selector).strip() - variable_selector = expression.split('.') + namespace = re.sub(pattern, r"\1", variable_selector).split('.')[0] # Only conversation variables ('conv') are allowed - if variable_selector[0] != 'conv' and variable_selector[0] not in state["cycle_nodes"]: - raise ValueError("Only conversation or cycle variables can be assigned.") + if namespace != 'conv' and namespace not in state["cycle_nodes"]: + raise ValueError(f"Only conversation or cycle variables can be assigned. - {variable_selector}") # Get the value or expression to assign value = assignment.value logger.debug(f"left:{variable_selector}, right: {value}") - pattern = r"\{\{\s*(.*?)\s*\}\}" + if isinstance(value, str): expression = re.match(pattern, value) if expression: expression = expression.group(1) expression = re.sub(pattern, r"\1", expression).strip() - value = self.get_variable(expression, state) + value = self.get_variable(expression, variable_pool, default=value, strict=False) # Select the appropriate assignment operator instance based on the target variable type operator: AssignmentOperatorInstance = AssignmentOperatorResolver.resolve_by_value( - pool.get(variable_selector) + variable_pool.get_value(variable_selector) )( - pool, variable_selector, value + variable_pool, variable_selector, value ) # Execute the configured assignment operation match assignment.operation: case AssignmentOperator.COVER: - operator.assign() + await operator.assign() case AssignmentOperator.ASSIGN: - operator.assign() + await operator.assign() case AssignmentOperator.CLEAR: - operator.clear() + await operator.clear() case AssignmentOperator.ADD: - operator.add() + await operator.add() case AssignmentOperator.SUBTRACT: - operator.subtract() + await operator.subtract() case AssignmentOperator.MULTIPLY: - operator.multiply() + await operator.multiply() case AssignmentOperator.DIVIDE: - operator.divide() + await operator.divide() case AssignmentOperator.APPEND: - operator.append() + await operator.append() case AssignmentOperator.REMOVE_FIRST: - operator.remove_first() + await operator.remove_first() case AssignmentOperator.REMOVE_LAST: - operator.remove_last() + await operator.remove_last() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") logger.info(f"Node {self.node_id}: execution completed") diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index a6b33928..973e120d 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -3,79 +3,13 @@ 定义所有节点配置的通用字段和数据结构。 """ -from enum import StrEnum -from typing import Any +from pydantic import BaseModel, Field -from pydantic import BaseModel, Field, ConfigDict +from app.core.workflow.variable.base_variable import VariableType VARIABLE_PATTERN = r"\{\{\s*(.*?)\s*\}\}" -class VariableType(StrEnum): - """变量类型枚举""" - - STRING = "string" - NUMBER = "number" - BOOLEAN = "boolean" - OBJECT = "object" - - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_OBJECT = "array[object]" - - -class TypedVariable(BaseModel): - """ - TODO: 强类型限制 - Strongly typed variable that validates value on assignment. - """ - - value: Any = Field(..., description="Variable value") - type: VariableType = Field(..., description="Declared type of the variable") - - model_config = ConfigDict( - validate_assignment=True - ) - - def __setattr__(self, name, value): - if name == "value": - self._validate_value(value) - if name == "type": - raise RuntimeError("Cannot modify variable type at runtime") - super().__setattr__(name, value) - - def _validate_value(self, v: Any): - t = self.type - match t: - case VariableType.STRING: - if not isinstance(v, str): - raise TypeError("Variable value does not match type STRING") - case VariableType.BOOLEAN: - if not isinstance(v, bool): - raise TypeError("Variable value does not match type BOOLEAN") - case VariableType.NUMBER: - if not isinstance(v, (int, float)): - raise TypeError("Variable value does not match type NUMBER") - case VariableType.OBJECT: - if not isinstance(v, dict): - raise TypeError("Variable value does not match type OBJECT") - case VariableType.ARRAY_STRING: - if not isinstance(v, list) or not all(isinstance(i, str) for i in v): - raise TypeError("Variable value does not match type ARRAY_STRING") - case VariableType.ARRAY_NUMBER: - if not isinstance(v, list) or not all(isinstance(i, (int, float)) for i in v): - raise TypeError("Variable value does not match type ARRAY_NUMBER") - case VariableType.ARRAY_BOOLEAN: - if not isinstance(v, list) or not all(isinstance(i, bool) for i in v): - raise TypeError("Variable value does not match type ARRAY_BOOLEAN") - case VariableType.ARRAY_OBJECT: - if not isinstance(v, list) or not all(isinstance(i, dict) for i in v): - raise TypeError("Variable value does not match type ARRAY_OBJECT") - case _: - raise TypeError(f"Unknown variable type: {t}") - - class VariableDefinition(BaseModel): """变量定义 diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 4dcdf2bb..107567e1 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,12 +1,7 @@ -""" -工作流节点基类 - -定义节点的基本接口和通用功能。 -""" - import asyncio import logging from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, AsyncGenerator from langgraph.config import get_stream_writer @@ -14,7 +9,9 @@ from typing_extensions import TypedDict, Annotated from app.core.config import settings from app.core.workflow.nodes.enums import BRANCH_NODES +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool +from app.services.multimodal_service import PROVIDER_STRATEGIES logger = logging.getLogger(__name__) @@ -42,22 +39,10 @@ class WorkflowState(TypedDict): cycle_nodes: list looping: Annotated[int, merge_looping_state] - # Input variables (passed from configured variables) - # Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx) - variables: Annotated[dict[str, Any], lambda x, y: { - **x, - **{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v - for k, v in y.items()} - }] - # Node outputs (stores execution results of each node for variable references) # Uses a custom merge function to combine new node outputs into the existing dictionary node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # Runtime node variables (simplified version, stores business data for fast access between nodes) - # Format: {node_id: business_result} - runtime_vars: Annotated[dict[str, Any], lambda x, y: {**x, **y}] - # Execution context execution_id: str workspace_id: str @@ -72,17 +57,17 @@ class WorkflowState(TypedDict): class BaseNode(ABC): - """节点基类 - - 所有节点类型都应该继承此基类,实现 execute 方法。 + """Base class for workflow nodes. + + All node types should inherit from this class and implement the `execute` method. """ def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): - """初始化节点 - + """Initialize the node. + Args: - node_config: 节点配置 - workflow_config: 工作流配置 + node_config: Configuration of the node. + workflow_config: Configuration of the workflow. """ self.node_config = node_config self.workflow_config = workflow_config @@ -94,7 +79,27 @@ class BaseNode(ABC): self.config = node_config.get("config") or {} self.error_handling = node_config.get("error_handling") or {} - self.variable_updater = False + self.variable_change_able = False + + @cached_property + def output_types(self) -> dict[str, VariableType]: + """Returns the output variable types of the node. + + This property is cached to avoid recomputation. + """ + return self._output_types() + + @abstractmethod + def _output_types(self) -> dict[str, VariableType]: + """Defines output variable types for the node. + + Subclasses must override this method to declare the variables + produced by the node and their corresponding types. + + Returns: + A mapping from output variable names to ``VariableType``. + """ + return {} def check_activate(self, state: WorkflowState): """Check if the current node is activated in the workflow state. @@ -136,92 +141,84 @@ class BaseNode(ABC): } @abstractmethod - async def execute(self, state: WorkflowState) -> Any: - """执行节点业务逻辑(非流式) - - 节点只需要返回业务结果,不需要关心输出格式、时间统计等。 - BaseNode 会自动包装成标准格式。 - + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: + """Executes the node business logic (non-streaming). + + The node implementation should only return the business result. + It does not need to handle output formatting, timing, or statistics. + The ``BaseNode`` will automatically wrap the result into a standard + response format. + Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 业务结果(任意类型) - - Examples: - >>> # LLM 节点 - >>> "这是 AI 的回复" - - >>> # Transform 节点 - >>> {"processed_data": [...]} - - >>> # Start/End 节点 - >>> {"message": "开始", "conversation_id": "xxx"} + The business result produced by the node. The return value can be + of any type. """ pass - async def execute_stream(self, state: WorkflowState): - """执行节点业务逻辑(流式) - - 子类可以重写此方法以支持流式输出。 - 默认实现:执行非流式方法并一次性返回。 - - 节点需要: - 1. yield 中间结果(如文本片段) - 2. 最后 yield 一个特殊的完成标记:{"__final__": True, "result": final_result} - - Args: - state: 工作流状态 - - Yields: - 业务数据(chunk)或完成标记 - - Examples: - # 流式 LLM 节点 - full_response = "" - async for chunk in llm.astream(prompt): - full_response += chunk - yield chunk # yield 文本片段 + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + """Executes the node business logic in streaming mode. - # 最后 yield 完成标记 - yield {"__final__": True, "result": AIMessage(content=full_response)} + Subclasses may override this method to support streaming output. + The default implementation executes the non-streaming method and + yields a single final result. + + For streaming execution, a node implementation should: + 1. Yield intermediate results (e.g. text chunks). + 2. Yield a final completion marker in the following format: + ``{"__final__": True, "result": final_result}``. + + Args: + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + + Yields: + Business data chunks or a final completion marker. """ - result = await self.execute(state) - # 默认实现:直接 yield 完成标记 + result = await self.execute(state, variable_pool) + # Default implementation: yield a single final completion marker. yield {"__final__": True, "result": result} def supports_streaming(self) -> bool: - """节点是否支持流式输出 - + """Returns whether the node supports streaming output. + + A node is considered to support streaming if its class overrides + the ``execute_stream`` method. If the default implementation from + ``BaseNode`` is used, streaming is not supported. + Returns: - 是否支持流式输出 + True if the node supports streaming output, False otherwise. """ - # 检查子类是否重写了 execute_stream 方法 + # Check whether the subclass overrides the execute_stream method. return self.__class__.execute_stream is not BaseNode.execute_stream - def get_timeout(self) -> int: - """获取超时时间(秒) - + @staticmethod + def get_timeout() -> int: + """Returns the execution timeout in seconds. + Returns: - 超时时间 + The timeout duration, in seconds. """ return settings.WORKFLOW_NODE_TIMEOUT - # return self.error_handling.get("timeout", 60) - async def run(self, state: WorkflowState) -> dict[str, Any]: - """执行节点(带错误处理和输出包装,非流式) - - 这个方法由 Executor 调用,负责: - 1. 时间统计 - 2. 调用节点的 execute() 方法 - 3. 将业务结果包装成标准输出格式 - 4. 错误处理 - + async def run(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + """Runs the node with error handling and output wrapping (non-streaming). + + This method is invoked by the Executor and is responsible for: + 1. Execution time measurement. + 2. Invoking the node's ``execute()`` method. + 3. Wrapping the business result into a standardized output format. + 4. Handling execution errors. + Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 标准化的状态更新字典 + A standardized state update dictionary. """ if not self.check_activate(state): return self.trans_activate(state) @@ -233,70 +230,78 @@ class BaseNode(ABC): timeout = self.get_timeout() try: - # 调用节点的业务逻辑 + # Invoke the node business logic. business_result = await asyncio.wait_for( - self.execute(state), + self.execute(state, variable_pool), timeout=timeout ) elapsed_time = time.time() - start_time - # 提取处理后的输出(调用子类的 _extract_output) + # Extract processed outputs using subclass-defined logic. extracted_output = self._extract_output(business_result) - # 包装成标准输出格式 - wrapped_output = self._wrap_output(business_result, elapsed_time, state) + # Wrap the business result into the standard output format. + wrapped_output = self._wrap_output(business_result, elapsed_time, state, variable_pool) - # 将提取后的输出存储到运行时变量中(供后续节点快速访问) - # 如果提取后的输出是字典,拆包存储;否则存储为 output 字段 - if isinstance(extracted_output, dict): - runtime_var = extracted_output - else: - runtime_var = {"output": extracted_output} + # Store extracted outputs as runtime variables for downstream nodes. + if extracted_output is not None: + runtime_vars = extracted_output + if not isinstance(extracted_output, dict): + runtime_vars = {"output": extracted_output} + for k, v in runtime_vars.items(): + await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able) - # 返回包装后的输出和运行时变量 + # Return the wrapped output along with activation state updates. return { **wrapped_output, - "messages": state["messages"], - "runtime_vars": { - self.node_id: runtime_var - }, "looping": state["looping"] } | self.trans_activate(state) except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时({timeout}秒)") - return self._wrap_error(f"节点执行超时({timeout}秒)", elapsed_time, state) + logger.error( + f"Node {self.node_id} execution timed out ({timeout} seconds)." + ) + return self._wrap_error( + f"Node execution timed out ({timeout} seconds).", + elapsed_time, + state, + variable_pool, + ) except Exception as e: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - return self._wrap_error(str(e), elapsed_time, state) + logger.error( + f"Node {self.node_id} execution failed: {e}", + exc_info=True, + ) + return self._wrap_error(str(e), elapsed_time, state, variable_pool) + + async def run_stream( + self, state: WorkflowState, + variable_pool: VariablePool + ) -> AsyncGenerator[dict[str, Any], Any]: + """Executes the node with error handling and output wrapping (streaming). - async def run_stream(self, state: WorkflowState) -> AsyncGenerator[dict[str, Any], Any]: - """Execute node with error handling and output wrapping (streaming) - This method is called by the Executor and is responsible for: - 1. Time tracking - 2. Calling the node's execute_stream() method - 3. Using LangGraph's stream writer to send chunks - 4. Updating streaming buffer in state for downstream nodes - 5. Wrapping business data into standard output format - 6. Error handling - - Special handling for End nodes: - - End nodes don't send chunks via writer (prefix and LLM content already sent) - - End nodes only yield suffix for final result assembly - + 1. Tracking execution time. + 2. Calling the node's ``execute_stream()`` method. + 3. Sending streaming chunks via LangGraph's stream writer. + 4. Updating activation-related state for downstream nodes. + 5. Wrapping business data into a standardized output format. + 6. Handling execution errors. + Args: - state: Workflow state - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Yields: - State updates with streaming buffer and final result + Incremental state updates, including activation state changes and + the final wrapped result. """ if not self.check_activate(state): yield self.trans_activate(state) - logger.info(f"jump node: {self.node_id}") + logger.debug(f"jump node: {self.node_id}") return import time @@ -317,7 +322,7 @@ class BaseNode(ABC): # Stream chunks in real-time loop_start = asyncio.get_event_loop().time() - async for item in self.execute_stream(state): + async for item in self.execute_stream(state, variable_pool): # Check timeout if asyncio.get_event_loop().time() - loop_start > timeout: raise TimeoutError() @@ -332,7 +337,7 @@ class BaseNode(ABC): chunks.append(content) # Send chunks for all nodes (including End nodes for suffix) - logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...") + logger.debug(f"Node {self.node_id} sent chunk #{chunk_count}: {content[:50]}...") # 1. Send via stream writer (for real-time client updates) writer({ @@ -344,27 +349,26 @@ class BaseNode(ABC): elapsed_time = time.time() - start_time - logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}") + logger.info(f"Node {self.node_id} streaming execution finished, " + f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) # Wrap final result - final_output = self._wrap_output(final_result, elapsed_time, state) + final_output = self._wrap_output(final_result, elapsed_time, state, variable_pool) # Store extracted output in runtime variables (for quick access by subsequent nodes) - if isinstance(extracted_output, dict): - runtime_var = extracted_output - else: - runtime_var = {"output": extracted_output} + if extracted_output is not None: + runtime_vars = extracted_output + if not isinstance(extracted_output, dict): + runtime_vars = {"output": extracted_output} + for k, v in runtime_vars.items(): + await variable_pool.new(self.node_id, k, v, self.output_types[k], mut=self.variable_change_able) # Build complete state update (including node_outputs, runtime_vars, and final streaming buffer) state_update = { **final_output, - "messages": state["messages"], - "runtime_vars": { - self.node_id: runtime_var - }, "looping": state["looping"] } @@ -374,41 +378,49 @@ class BaseNode(ABC): except TimeoutError: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行超时 ({timeout}s)") - error_output = self._wrap_error(f"节点执行超时 ({timeout}s)", elapsed_time, state) + logger.error(f"Node {self.node_id} execution timed out ({timeout}s)") + error_output = self._wrap_error( + f"Node execution timed out ({timeout}s)", + elapsed_time, + state, + variable_pool + ) yield error_output except Exception as e: elapsed_time = time.time() - start_time - logger.error(f"节点 {self.node_id} 执行失败: {e}", exc_info=True) - error_output = self._wrap_error(str(e), elapsed_time, state) + logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True) + error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool) yield error_output def _wrap_output( self, business_result: Any, elapsed_time: float, - state: WorkflowState + state: WorkflowState, + variable_pool: VariablePool ) -> dict[str, Any]: - """将业务结果包装成标准输出格式 - - Args: - business_result: 节点返回的业务结果 - elapsed_time: 执行耗时 - state: 工作流状态 - - Returns: - 标准化的状态更新字典 - """ - # 提取输入数据(用于记录) - input_data = self._extract_input(state) + """Wraps the business result into a standardized node output format. - # 提取 token 使用情况(如果有) + Args: + business_result: The result returned by the node's business logic. + elapsed_time: Time elapsed during node execution (in seconds). + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + + Returns: + A dictionary representing the standardized state update for this node, + including node outputs, input, output, elapsed time, token usage, and status. + """ + # Extract input data (for logging or audit purposes) + input_data = self._extract_input(state, variable_pool) + + # Extract token usage information (if applicable) token_usage = self._extract_token_usage(business_result) - # 提取实际输出(去除元数据) + # Extract actual output (strip any metadata) output = self._extract_output(business_result) - # 构建标准节点输出 + # Construct standardized node output node_output = { "node_id": self.node_id, "node_type": self.node_type, @@ -423,8 +435,6 @@ class BaseNode(ABC): final_output = { "node_outputs": {self.node_id: node_output}, } - if self.variable_updater: - final_output = final_output | {"variables": state["variables"]} return final_output @@ -432,25 +442,33 @@ class BaseNode(ABC): self, error_message: str, elapsed_time: float, - state: WorkflowState + state: WorkflowState, + variable_pool: VariablePool ) -> dict[str, Any]: - """将错误包装成标准输出格式 - + """Wraps an error into a standardized node output format. + + This method handles both cases: + - If an error edge is defined, the workflow can continue to the error handling node. + - If no error edge exists, the workflow is stopped by raising an exception. + Args: - error_message: 错误信息 - elapsed_time: 执行耗时 - state: 工作流状态 - + error_message: The error message describing the failure. + elapsed_time: Time elapsed during node execution (in seconds). + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 标准化的状态更新字典 + A dictionary representing the standardized state update for this node + when an error edge exists. If no error edge exists, this method + raises an exception to stop the workflow. """ - # 查找错误边 + # Check if the node has an error edge defined error_edge = self._find_error_edge() - # 提取输入数据 - input_data = self._extract_input(state) + # Extract input data (for logging or audit purposes) + input_data = self._extract_input(state, variable_pool) - # 构建错误输出 + # Construct the standardized node output for the error node_output = { "node_id": self.node_id, "node_type": self.node_type, @@ -464,9 +482,9 @@ class BaseNode(ABC): } if error_edge: - # 有错误边:记录错误并继续 + # If an error edge exists, log a warning and continue to error node logger.warning( - f"节点 {self.node_id} 执行失败,跳转到错误处理节点: {error_edge['target']}" + f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" ) return { "node_outputs": { @@ -476,198 +494,188 @@ class BaseNode(ABC): "error_node": self.node_id } else: + # If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer() writer({ "type": "node_error", **node_output }) - # 无错误边:抛出异常停止工作流 - logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}") - raise Exception(f"节点 {self.node_id} 执行失败: {error_message}") + logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") + raise Exception(f"Node {self.node_id} execution failed: {error_message}") + + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: + """Extracts the input data for this node (used for logging or audit). + + Subclasses may override this method to customize what input data + should be recorded. - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: - """提取节点输入数据(用于记录) - - 子类可以重写此方法来自定义输入记录。 - Args: - state: 工作流状态 - + state: The current workflow state. + variable_pool: The variable pool used for reading and writing variables. + Returns: - 输入数据字典 + A dictionary containing the node's input data. """ - # 默认返回配置 + # Default implementation returns the node configuration return {"config": self.config} def _extract_output(self, business_result: Any) -> Any: - """从业务结果中提取实际输出 - - 子类可以重写此方法来自定义输出提取。 - + """Extracts the actual output from the business result. + + Subclasses may override this method to customize how the node's + output is extracted. + Args: - business_result: 业务结果 - + business_result: The result returned by the node's business logic. + Returns: - 实际输出 + The actual output extracted from the business result. """ - # 默认直接返回业务结果 + # Default implementation returns the business result directly return business_result def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None: - """从业务结果中提取 token 使用情况 - - 子类可以重写此方法来提取 token 信息。 - + """Extracts token usage information from the business result. + + Subclasses may override this method to extract token usage statistics + (e.g., for LLM nodes). + Args: - business_result: 业务结果 - + business_result: The result returned by the node's business logic. + Returns: - token 使用情况或 None + A dictionary mapping token types to counts, or None if not applicable. """ - # 默认返回 None + # Default implementation returns None return None def _find_error_edge(self) -> dict[str, Any] | None: - """查找错误边 - + """Finds the error edge for this node, if any. + + An error edge is used to redirect workflow execution when this node + fails. + Returns: - 错误边配置或 None + A dictionary representing the error edge configuration if it exists, + or None if no error edge is defined. """ for edge in self.workflow_config.get("edges", []): if edge.get("source") == self.node_id and edge.get("type") == "error": return edge return None - def _render_template(self, template: str, state: WorkflowState | None, strict: bool = True) -> str: - """渲染模板 - - 支持的变量命名空间: - - sys.xxx: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) - - conv.xxx: 会话变量(跨多轮对话保持) - - node_id.xxx: 节点输出 - + @staticmethod + def _render_template(template: str, variable_pool: VariablePool, strict: bool = True) -> str: + """Renders a template string using the provided variable pool. + + Supported variable namespaces: + - sys.xxx: System variables (e.g., message, execution_id, workspace_id, + user_id, conversation_id) + - conv.xxx: Conversation variables (persist across multiple turns) + - node_id.xxx: Node outputs + Args: - template: 模板字符串 - state: 工作流状态 - + template: The template string to render. + variable_pool: The variable pool containing system, conversation, and + node variables. + strict: If True, missing variables will raise an error; if False, + missing variables are ignored. + Returns: - 渲染后的字符串 + The rendered string with all variables substituted. """ from app.core.workflow.template_renderer import render_template - # 处理 state 为 None 的情况 - if state is None: - state = {} - - # 使用变量池获取变量 - pool = VariablePool(state) - - # 构建完整的 variables 结构 - variables = { - "sys": pool.get_all_system_vars(), - "conv": pool.get_all_conversation_vars() - } - return render_template( template=template, - variables=variables, - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(), + node_outputs=variable_pool.get_all_node_outputs(), + system_vars=variable_pool.get_all_system_vars(), strict=strict ) - def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool: - """评估条件表达式 - - 支持的变量命名空间: - - sys.xxx: 系统变量 - - conv.xxx: 会话变量 - - node_id.xxx: 节点输出 - + @staticmethod + def _evaluate_condition(expression: str, variable_pool: VariablePool) -> bool: + """Evaluates a conditional expression using the provided variable pool. + + Supported variable namespaces: + - sys.xxx: System variables + - conv.xxx: Conversation variables + - node_id.xxx: Node outputs + Args: - expression: 条件表达式 - state: 工作流状态 - + expression: The conditional expression to evaluate. + variable_pool: The variable pool containing system, conversation, and + node variables. + Returns: - 布尔值结果 + The boolean result of evaluating the expression. """ from app.core.workflow.expression_evaluator import evaluate_condition - # 处理 state 为 None 的情况 - if state is None: - state = {} - - # 使用变量池获取变量 - pool = VariablePool(state) - - # 构建完整的 variables 结构(包含 sys 和 conv) - variables = { - "sys": pool.get_all_system_vars(), - "conv": pool.get_all_conversation_vars() - } - return evaluate_condition( expression=expression, - variables=variables, - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars() + conv_var=variable_pool.get_all_conversation_vars(), + node_outputs=variable_pool.get_all_node_outputs(), + system_vars=variable_pool.get_all_system_vars() ) - def get_variable_pool(self, state: WorkflowState) -> VariablePool: - """获取变量池实例 - - VariablePool 是轻量级包装器,只持有 state 的引用,创建成本极低。 - - Args: - state: 工作流状态 - - Returns: - VariablePool 实例 - - Examples: - >>> pool = self.get_variable_pool(state) - >>> message = pool.get("sys.message") - >>> llm_output = pool.get("llm_qa.output") - """ - return VariablePool(state) - + @staticmethod def get_variable( - self, - selector: list[str] | str, - state: WorkflowState, - default: Any = None + selector: str, + variable_pool: VariablePool, + default: Any = None, + strict: bool = True ) -> Any: - """获取变量值(便捷方法) - - Args: - selector: 变量选择器 - state: 工作流状态 - default: 默认值 - - Returns: - 变量值 - - Examples: - >>> message = self.get_variable("sys.message", state) - >>> output = self.get_variable(["llm_qa", "output"], state) - >>> custom = self.get_variable("var.custom", state, default="默认值") - """ - pool = VariablePool(state) - return pool.get(selector, default=default) + """Retrieves a variable value from the variable pool (convenience method). - def has_variable(self, selector: list[str] | str, state: WorkflowState) -> bool: - """检查变量是否存在(便捷方法) - Args: - selector: 变量选择器 - state: 工作流状态 - + selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx). + variable_pool: The variable pool from which to fetch the value. + default: The default value to return if the variable does not exist. + strict: If True, raise an error when the variable is missing; if False, return the default. + Returns: - 变量是否存在 - - Examples: - >>> if self.has_variable("llm_qa.output", state): - ... output = self.get_variable("llm_qa.output", state) + The value of the selected variable, or the default if not found and strict is False. """ - pool = VariablePool(state) - return pool.has(selector) + return variable_pool.get_value(selector, default, strict=strict) + + @staticmethod + def has_variable(selector: str, variable_pool: VariablePool) -> bool: + """Checks whether a variable exists in the variable pool (convenience method). + + Args: + selector: The variable selector (can be namespaced, e.g., sys.xxx, conv.xxx, node_id.xxx). + variable_pool: The variable pool to check. + + Returns: + True if the variable exists in the pool, False otherwise. + """ + return variable_pool.has(selector) + + @staticmethod + async def process_message(provider, content, enable_file=False) -> dict | str | None: + if isinstance(content, str): + if enable_file: + return {"text": content} + return content + elif isinstance(content, dict): + trans_tool = PROVIDER_STRATEGIES[provider]() + result = await trans_tool.format_image(content["url"]) + return result + raise TypeError('Unexpect input value type') + + @staticmethod + def process_model_output(content) -> str: + result = "" + if isinstance(content, list): + for msg in content: + if isinstance(msg, dict): + result += msg.get("text") + elif isinstance(msg, str): + result += msg + elif isinstance(content, dict): + result = content.get("text") + elif isinstance(content, str): + return content + return result diff --git a/api/app/core/workflow/nodes/breaker/node.py b/api/app/core/workflow/nodes/breaker/node.py index f00015d1..8b772d6a 100644 --- a/api/app/core/workflow/nodes/breaker/node.py +++ b/api/app/core/workflow/nodes/breaker/node.py @@ -2,6 +2,8 @@ import logging from typing import Any from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,15 +16,19 @@ class BreakNode(BaseNode): to False, signaling the outer loop runtime to terminate further iterations. """ - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the break node. Args: state: Current workflow state, including loop control flags. + variable_pool: Pool of variables for the workflow. Effects: - - Sets 'looping' in the state to False to stop the loop. + - Sets 'looping' in the state too False to stop the loop. - Logs the action for debugging purposes. Returns: diff --git a/api/app/core/workflow/nodes/code/config.py b/api/app/core/workflow/nodes/code/config.py index 8af13f12..e17e841f 100644 --- a/api/app/core/workflow/nodes/code/config.py +++ b/api/app/core/workflow/nodes/code/config.py @@ -1,7 +1,8 @@ from typing import Literal from pydantic import Field, BaseModel -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType class InputVariable(BaseModel): @@ -44,7 +45,7 @@ class CodeNodeConfig(BaseNodeConfig): description="code content" ) - language: Literal['python3', 'nodejs'] = Field( + language: Literal['python3', 'javascript'] = Field( ..., description="language" ) diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index 892708f2..f6176edf 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -2,15 +2,17 @@ import base64 import json import logging import re +import urllib.parse from string import Template from textwrap import dedent from typing import Any - +import urllib.parse import httpx from app.core.workflow.nodes import BaseNode, WorkflowState -from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.code.config import CodeNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -52,6 +54,12 @@ class CodeNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: CodeNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + output_dict = {} + for output in self.typed_config.output_variables: + output_dict[output.name] = output.type + return output_dict + def extract_result(self, content: str): match = re.search(r'<>(.*?)<>', content, re.DOTALL) if match: @@ -92,15 +100,16 @@ class CodeNode(BaseNode): else: raise RuntimeError("The output of main must be a dictionary") - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = CodeNodeConfig(**self.config) input_variable_dict = {} for input_variable in self.typed_config.input_variables: - input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, state) + input_variable_dict[input_variable.name] = self.get_variable(input_variable.variable, variable_pool) code = base64.b64decode( self.typed_config.code ).decode("utf-8") + code = urllib.parse.unquote(code, encoding='utf-8') input_variable_dict = base64.b64encode( json.dumps(input_variable_dict).encode("utf-8") @@ -110,7 +119,7 @@ class CodeNode(BaseNode): code=code, inputs_variable=input_variable_dict, ) - elif self.typed_config.language == 'nodejs': + elif self.typed_config.language == 'javascript': final_script = NODEJS_SCRIPT_TEMPLATE.substitute( code=code, inputs_variable=input_variable_dict, diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index d73754f6..e4e418fe 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -8,7 +8,6 @@ from app.core.workflow.nodes.assigner.config import AssignerNodeConfig from app.core.workflow.nodes.base_config import ( BaseNodeConfig, VariableDefinition, - VariableType, ) from app.core.workflow.nodes.code.config import CodeNodeConfig from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig @@ -23,21 +22,18 @@ from app.core.workflow.nodes.parameter_extractor.config import ParameterExtracto from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.tool.config import ToolNodeConfig -from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig __all__ = [ # 基础类 "BaseNodeConfig", "VariableDefinition", - "VariableType", # 节点配置 "StartNodeConfig", "EndNodeConfig", "LLMNodeConfig", "MessageConfig", "AgentNodeConfig", - "TransformNodeConfig", "IfElseNodeConfig", "KnowledgeRetrievalNodeConfig", "AssignerNodeConfig", diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index 445ddd9a..52aca1d9 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -2,7 +2,8 @@ from typing import Any from pydantic import Field, BaseModel, field_validator -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType @@ -127,4 +128,9 @@ class IterationNodeConfig(BaseNodeConfig): description="Output of the loop iteration" ) + output_type: VariableType = Field( + default=None, + description="Data type of the loop iteration output" + ) + diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index cd63d233..762da847 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -7,6 +7,7 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.cycle_graph import IterationNodeConfig +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -28,6 +29,8 @@ class IterationRuntime: node_id: str, config: dict[str, Any], state: WorkflowState, + variable_pool: VariablePool, + child_variable_pool: VariablePool, ): """ Initialize the iteration runtime. @@ -44,11 +47,13 @@ class IterationRuntime: self.node_id = node_id self.typed_config = IterationNodeConfig(**config) self.looping = True + self.variable_pool = variable_pool + self.child_variable_pool = child_variable_pool self.output_value = None self.result: list = [] - def _init_iteration_state(self, item, idx): + async def _init_iteration_state(self, item, idx): """ Initialize a per-iteration copy of the workflow state. @@ -62,10 +67,9 @@ class IterationRuntime: loopstate = WorkflowState( **self.state ) - loopstate["runtime_vars"][self.node_id] = { - "item": item, - "index": idx, - } + self.child_variable_pool.copy(self.variable_pool) + await self.child_variable_pool.new(self.node_id, "item", item, VariableType.type_map(item), mut=True) + await self.child_variable_pool.new(self.node_id, "index", item, VariableType.type_map(item), mut=True) loopstate["node_outputs"][self.node_id] = { "item": item, "index": idx, @@ -74,6 +78,11 @@ class IterationRuntime: loopstate["activate"][self.start_id] = True return loopstate + def merge_conv_vars(self): + self.variable_pool.get_all_conversation_vars().update( + self.child_variable_pool.get_all_conversation_vars() + ) + async def run_task(self, item, idx): """ Execute a single iteration asynchronously. @@ -82,8 +91,8 @@ class IterationRuntime: item: The input element for this iteration. idx: The index of this iteration. """ - result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) - output = VariablePool(result).get(self.output_value) + result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) + output = self.child_variable_pool.get_value(self.output_value) if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) else: @@ -125,7 +134,7 @@ class IterationRuntime: input_expression = re.sub(pattern, r"\1", self.typed_config.input).strip() self.output_value = re.sub(pattern, r"\1", self.typed_config.output).strip() - array_obj = VariablePool(self.state).get(input_expression) + array_obj = self.variable_pool.get_value(input_expression) if not isinstance(array_obj, list): raise RuntimeError("Cannot iterate over a non-list variable") child_state = [] @@ -137,14 +146,16 @@ class IterationRuntime: logger.info(f"Iteration node {self.node_id}: running, concurrency {len(tasks)}") idx += self.typed_config.parallel_count child_state.extend(await asyncio.gather(*tasks)) + self.merge_conv_vars() else: # Execute iterations sequentially while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] - result = await self.graph.ainvoke(self._init_iteration_state(item, idx)) + result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) child_state.append(result) - output = VariablePool(result).get(self.output_value) + output = self.child_variable_pool.get_value(self.output_value) + self.merge_conv_vars() if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) else: diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 6a15891f..7204a642 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -31,6 +31,8 @@ class LoopRuntime: node_id: str, config: dict[str, Any], state: WorkflowState, + variable_pool: VariablePool, + child_variable_pool: VariablePool ): """ Initialize the loop runtime executor. @@ -40,6 +42,8 @@ class LoopRuntime: node_id: The unique identifier of the loop node in the workflow. config: Raw configuration dictionary for the loop node. state: The current workflow state before entering the loop. + variable_pool: A VariablePool instance for accessing and modifying workflow variables. + child_variable_pool: A VariablePool instance for managing child node outputs. """ self.start_id = start_id self.graph = graph @@ -47,8 +51,10 @@ class LoopRuntime: self.node_id = node_id self.typed_config = LoopNodeConfig(**config) self.looping = True + self.variable_pool = variable_pool + self.child_variable_pool = child_variable_pool - def _init_loop_state(self): + async def _init_loop_state(self): """ Initialize workflow state for loop execution. @@ -62,33 +68,35 @@ class LoopRuntime: Returns: WorkflowState: A prepared workflow state used for loop execution. """ - pool = VariablePool(self.state) # 循环变量 - self.state["runtime_vars"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) - if variable.input_type == ValueInputType.VARIABLE - else TypeTransformer.transform(variable.value, variable.type) - for variable in self.typed_config.cycle_vars - } - self.state["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - variables=pool.get_all_conversation_vars(), - node_outputs=pool.get_all_node_outputs(), - system_vars=pool.get_all_system_vars(), - ) - if variable.input_type == ValueInputType.VARIABLE - else TypeTransformer.transform(variable.value, variable.type) - for variable in self.typed_config.cycle_vars - } + self.child_variable_pool.copy(self.variable_pool) + + for variable in self.typed_config.cycle_vars: + if variable.input_type == ValueInputType.VARIABLE: + value = evaluate_expression( + expression=variable.value, + conv_var=self.variable_pool.get_all_conversation_vars(), + node_outputs=self.variable_pool.get_all_node_outputs(), + system_vars=self.variable_pool.get_all_system_vars(), + ) + else: + value = TypeTransformer.transform(variable.value, variable.type) + await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) loopstate = WorkflowState( **self.state ) + loopstate["node_outputs"][self.node_id] = { + variable.name: evaluate_expression( + expression=variable.value, + conv_var=self.variable_pool.get_all_conversation_vars(), + node_outputs=self.variable_pool.get_all_node_outputs(), + system_vars=self.variable_pool.get_all_system_vars(), + ) + if variable.input_type == ValueInputType.VARIABLE + else TypeTransformer.transform(variable.value, variable.type) + for variable in self.typed_config.cycle_vars + } + loopstate["looping"] = 1 loopstate["activate"][self.start_id] = True return loopstate @@ -134,7 +142,12 @@ class LoopRuntime: case _: raise ValueError(f"Invalid condition: {operator}") - def evaluate_conditional(self, state) -> bool: + def merge_conv_vars(self): + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables.get("conv", {}) + ) + + def evaluate_conditional(self) -> bool: """ Evaluate the loop continuation condition at runtime. @@ -143,18 +156,15 @@ class LoopRuntime: - Evaluates each comparison expression immediately - Combines results using the configured logical operator (AND / OR) - Args: - state: The current workflow state during loop execution. - Returns: bool: True if the loop should continue, False otherwise. """ conditions = [] for expression in self.typed_config.condition.expressions: - left_value = VariablePool(state).get(expression.left) + left_value = self.child_variable_pool.get_value(expression.left) evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( - VariablePool(state), + self.child_variable_pool, expression.left, expression.right, expression.input_type @@ -177,16 +187,18 @@ class LoopRuntime: Returns: dict[str, Any]: The final runtime variables of this loop node. """ - loopstate = self._init_loop_state() + loopstate = await self._init_loop_state() loop_time = self.typed_config.max_loop child_state = [] - while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0: + while not self.evaluate_conditional() and self.looping and loop_time > 0: logger.info(f"loop node {self.node_id}: running") result = await self.graph.ainvoke(loopstate) child_state.append(result) + + self.merge_conv_vars() if result["looping"] == 2: self.looping = False loop_time -= 1 logger.info(f"loop node {self.node_id}: execution completed") - return loopstate["runtime_vars"][self.node_id] | {"__child_state": child_state} + return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 82782658..6908cb73 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -6,9 +6,12 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime from app.core.workflow.nodes.enums import NodeType +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -35,9 +38,41 @@ class CycleGraphNode(BaseNode): self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None + self.child_variable_pool: VariablePool | None = None self.build_graph() self.iteration_flag = True + def _output_types(self) -> dict[str, VariableType]: + outputs = {"__child_state": VariableType.ARRAY_OBJECT} + if self.node_type == NodeType.LOOP: + # Loop node outputs the final state of the loop + config = LoopNodeConfig(**self.config) + for var_def in config.cycle_vars: + outputs[var_def.name] = var_def.type + return outputs + elif self.node_type == NodeType.ITERATION: + # Iteration node outputs the processed collection + config = IterationNodeConfig(**self.config) + if not config.output_type: + outputs['output'] = VariableType.ANY + return outputs + if config.output_type in [ + VariableType.ARRAY_FILE, + VariableType.ARRAY_STRING, + VariableType.NUMBER, + VariableType.ARRAY_OBJECT, + VariableType.BOOLEAN + ]: + if config.flatten: + outputs['output'] = config.output_type + else: + outputs['output'] = VariableType.ARRAY_STRING + else: + outputs['output'] = VariableType(f"array[{config.output_type}]") + return outputs + else: + raise KeyError(f"Valid Cycle Node Type - {self.node_type}") + def pure_cycle_graph(self) -> tuple[list, list]: """ Extract cycle-scoped nodes and internal edges from the workflow configuration. @@ -103,17 +138,20 @@ class CycleGraphNode(BaseNode): """ from app.core.workflow.graph_builder import GraphBuilder self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { "nodes": self.cycle_nodes, "edges": self.cycle_edges, }, - subgraph=True + subgraph=True, + variable_pool=self.child_variable_pool ) self.start_node_id = builder.start_node_id self.graph = builder.build() + self.child_variable_pool = builder.variable_pool - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the cycle node at runtime. @@ -123,6 +161,7 @@ class CycleGraphNode(BaseNode): Args: state: The current workflow state when entering the cycle node. + variable_pool: Variable Pool Returns: Any: The runtime result produced by the loop or iteration executor. @@ -137,6 +176,8 @@ class CycleGraphNode(BaseNode): node_id=self.node_id, config=self.config, state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool, ).run() if self.node_type == NodeType.ITERATION: return await IterationRuntime( @@ -145,5 +186,7 @@ class CycleGraphNode(BaseNode): node_id=self.node_id, config=self.config, state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool ).run() raise RuntimeError("Unknown cycle node type") diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index 50e84a36..f534dfb5 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class EndNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 3a5153a9..a13a8153 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -7,6 +7,8 @@ End 节点实现 import logging from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -17,12 +19,18 @@ class EndNode(BaseNode): 工作流的结束节点,根据配置的模板输出最终结果。 支持实时流式输出:如果模板引用了上游节点的输出,会实时监听其流式缓冲区。 """ + def _output_types(self) -> dict[str, VariableType]: + """声明此节点的输出类型""" + return { + "output": VariableType.STRING + } - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> str: """执行 end 节点业务逻辑 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 最终输出字符串 @@ -34,7 +42,7 @@ class EndNode(BaseNode): # 如果配置了输出模板,使用模板渲染;否则使用默认输出 if output_template: - output = self._render_template(output_template, state, strict=False) + output = self._render_template(output_template, variable_pool, strict=False) else: output = "" diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index aaf49a11..6ad1c6a8 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -9,7 +9,6 @@ class NodeType(StrEnum): KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" IF_ELSE = "if-else" CODE = "code" - TRANSFORM = "transform" QUESTION_CLASSIFIER = "question-classifier" HTTP_REQUEST = "http-request" TOOL = "tool" diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 141cba79..64fdfcb9 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -10,6 +10,8 @@ from httpx import AsyncClient, Response, Timeout from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__file__) @@ -34,6 +36,14 @@ class HttpRequestNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: HttpRequestNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "body": VariableType.STRING, + "status_code": VariableType.NUMBER, + "headers": VariableType.OBJECT, + "output": VariableType.STRING + } + def _build_timeout(self) -> Timeout: """ Build httpx Timeout configuration. @@ -50,7 +60,7 @@ class HttpRequestNode(BaseNode): ) return timeout - def _build_auth(self, state: WorkflowState) -> dict[str, str]: + def _build_auth(self, variable_pool: VariablePool) -> dict[str, str]: """ Build authentication-related HTTP headers. @@ -58,12 +68,12 @@ class HttpRequestNode(BaseNode): the current workflow runtime state. Args: - state: Current workflow runtime state. + variable_pool: Variable Pool Returns: A dictionary of HTTP headers used for authentication. """ - api_key = self._render_template(self.typed_config.auth.api_key, state) + api_key = self._render_template(self.typed_config.auth.api_key, variable_pool) match self.typed_config.auth.auth_type: case HttpAuthType.NONE: return {} @@ -82,7 +92,7 @@ class HttpRequestNode(BaseNode): case _: raise RuntimeError(f"Auth type not supported: {self.typed_config.auth.auth_type}") - def _build_header(self, state: WorkflowState) -> dict[str, str]: + def _build_header(self, variable_pool: VariablePool) -> dict[str, str]: """ Build HTTP request headers. @@ -90,10 +100,10 @@ class HttpRequestNode(BaseNode): """ headers = {} for key, value in self.typed_config.headers.items(): - headers[self._render_template(key, state)] = self._render_template(value, state) + headers[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return headers - def _build_params(self, state: WorkflowState) -> dict[str, str]: + def _build_params(self, variable_pool: VariablePool) -> dict[str, str]: """ Build URL query parameters. @@ -101,10 +111,10 @@ class HttpRequestNode(BaseNode): """ params = {} for key, value in self.typed_config.params.items(): - params[self._render_template(key, state)] = self._render_template(value, state) + params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return params - def _build_content(self, state) -> dict[str, Any]: + def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: """ Build HTTP request body arguments for httpx request methods. @@ -120,13 +130,13 @@ class HttpRequestNode(BaseNode): return {} case HttpContentType.JSON: content["json"] = json.loads(self._render_template( - self.typed_config.body.data, state + self.typed_config.body.data, variable_pool )) case HttpContentType.FROM_DATA: data = {} for item in self.typed_config.body.data: if item.type == "text": - data[self._render_template(item.key, state)] = self._render_template(item.value, state) + data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) elif item.type == "file": # TODO: File support (Feature) pass @@ -136,11 +146,11 @@ class HttpRequestNode(BaseNode): pass case HttpContentType.WWW_FORM: content["data"] = json.loads(self._render_template( - json.dumps(self.typed_config.body.data), state + json.dumps(self.typed_config.body.data), variable_pool )) case HttpContentType.RAW: - content["content"] = self._render_template(self.typed_config.body.data, state) + content["content"] = self._render_template(self.typed_config.body.data, variable_pool) case _: raise RuntimeError(f"Content type not supported: {self.typed_config.body.content_type}") return content @@ -165,7 +175,7 @@ class HttpRequestNode(BaseNode): case _: raise RuntimeError(f"HttpRequest method not supported: {self.typed_config.method}") - async def execute(self, state: WorkflowState) -> dict | str: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict | str: """ Execute the HTTP request node. @@ -176,6 +186,7 @@ class HttpRequestNode(BaseNode): Args: state: Current workflow runtime state. + variable_pool: Variable Pool Returns: - dict: Serialized HttpRequestNodeOutput on success @@ -185,8 +196,8 @@ class HttpRequestNode(BaseNode): async with httpx.AsyncClient( verify=self.typed_config.verify_ssl, timeout=self._build_timeout(), - headers=self._build_header(state) | self._build_auth(state), - params=self._build_params(state), + headers=self._build_header(variable_pool) | self._build_auth(variable_pool), + params=self._build_params(variable_pool), follow_redirects=True ) as client: retries = self.typed_config.retry.max_attempts @@ -194,8 +205,8 @@ class HttpRequestNode(BaseNode): try: request_func = self._get_client_method(client) resp = await request_func( - url=self._render_template(self.typed_config.url, state), - **self._build_content(state) + url=self._render_template(self.typed_config.url, variable_pool), + **self._build_content(variable_pool) ) resp.raise_for_status() logger.info(f"Node {self.node_id}: HTTP request succeeded") diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index cf5a1499..3c6d0e36 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -6,6 +6,8 @@ from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -15,6 +17,11 @@ class IfElseNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: IfElseNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.STRING + } + @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: match operator: @@ -45,7 +52,7 @@ class IfElseNode(BaseNode): case _: raise ValueError(f"Invalid condition: {operator}") - def evaluate_conditional_edge_expressions(self, state) -> list[bool]: + def evaluate_conditional_edge_expressions(self, variable_pool: VariablePool) -> list[bool]: """ Build conditional edge expressions for the If-Else node. @@ -72,11 +79,11 @@ class IfElseNode(BaseNode): pattern = r"\{\{\s*(.*?)\s*\}\}" left_string = re.sub(pattern, r"\1", expression.left).strip() try: - left_value = self.get_variable(left_string, state) + left_value = self.get_variable(left_string, variable_pool) except KeyError: left_value = None evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( - self.get_variable_pool(state), + variable_pool, expression.left, expression.right, expression.input_type @@ -95,7 +102,7 @@ class IfElseNode(BaseNode): return conditions - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the conditional branching logic of the node. @@ -105,13 +112,13 @@ class IfElseNode(BaseNode): Args: state (WorkflowState): The current workflow state, containing variables, messages, node outputs, etc. + variable_pool: Variable Pool Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ self.typed_config = IfElseNodeConfig(**self.config) - expressions = self.evaluate_conditional_edge_expressions(state) - # TODO: 变量类型及文本类型解析 + expressions = self.evaluate_conditional_edge_expressions(variable_pool) for i in range(len(expressions)): if expressions[i]: logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 822f1918..240b003b 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -5,6 +5,8 @@ from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig from app.core.workflow.template_renderer import TemplateRenderer +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,7 +16,12 @@ class JinjaRenderNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: JinjaRenderNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.STRING + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the node: render the Jinja2 template with mapped variables. @@ -24,6 +31,7 @@ class JinjaRenderNode(BaseNode): Args: state (WorkflowState): Current workflow state containing variables, node outputs, and runtime variables. + variable_pool: Variable Pool Returns: dict[str, Any]: Node output dictionary containing the rendered result @@ -40,7 +48,7 @@ class JinjaRenderNode(BaseNode): context = {} for variable in self.typed_config.mapping: try: - context[variable.name] = self.get_variable(variable.value, state) + context[variable.name] = self.get_variable(variable.value, variable_pool) except Exception: logger.info(f"variable not found, var: {variable.value}") continue diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 997135f3..1e146721 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -8,6 +8,8 @@ from app.core.models import RedBearRerank, RedBearModelConfig from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import knowledge_model, knowledgeshare_model, ModelType from app.repositories import knowledge_repository, knowledgeshare_repository @@ -22,6 +24,11 @@ class KnowledgeRetrievalNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: KnowledgeRetrievalNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + return { + "output": VariableType.ARRAY_STRING + } + @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): """ @@ -149,7 +156,7 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the knowledge retrieval workflow node. @@ -163,6 +170,7 @@ class KnowledgeRetrievalNode(BaseNode): Args: state (WorkflowState): Current workflow execution state. + variable_pool: Variable Pool Returns: Any: List of retrieved knowledge chunks (dict format). @@ -171,7 +179,7 @@ class KnowledgeRetrievalNode(BaseNode): RuntimeError: If no valid knowledge base is found or access is denied. """ self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) - query = self._render_template(self.typed_config.query, state) + query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py index 265724f3..1229450f 100644 --- a/api/app/core/workflow/nodes/llm/config.py +++ b/api/app/core/workflow/nodes/llm/config.py @@ -4,7 +4,8 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class MessageConfig(BaseModel): @@ -70,6 +71,16 @@ class LLMNodeConfig(BaseNodeConfig): description="对话上下文窗口" ) + vision: bool = Field( + default=False, + description="是否启用视觉模型" + ) + + vision_input: str = Field( + default=None, + description="视觉输入" + ) + # 简单模式 prompt: str | None = Field( default=None, diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index f315b238..4393e1ed 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -15,6 +15,8 @@ from app.core.exceptions import BusinessException from app.core.models import RedBearLLM, RedBearModelConfig from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.llm.config import LLMNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_context from app.models import ModelType from app.services.model_service import ModelConfigService @@ -66,59 +68,34 @@ class LLMNode(BaseNode): - ai/assistant: AI 消息(AIMessage) """ + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: LLMNodeConfig | None = None - def _render_context(self, message, state): - context = f"{self._render_template(self.typed_config.context, state)}" + def _render_context(self, message: str, variable_pool: VariablePool): + context = f"{self._render_template(self.typed_config.context, variable_pool)}" return re.sub(r"{{context}}", context, message) - def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]: + async def _prepare_llm( + self, + state: WorkflowState, + variable_pool: VariablePool, + stream: bool = False + ) -> RedBearLLM: """准备 LLM 实例(公共逻辑) Args: - state: 工作流状态 + variable_pool: 变量池 Returns: (llm, messages_or_prompt): LLM 实例和消息列表或 prompt 字符串 """ - - # 1. 处理消息格式(优先使用 messages) self.typed_config = LLMNodeConfig(**self.config) - messages_config = self.typed_config.messages - if messages_config: - # 使用 LangChain 消息格式 - messages = [] - for msg_config in messages_config: - role = msg_config.role.lower() - content_template = msg_config.content - content_template = self._render_context(content_template, state) - content = self._render_template(content_template, state) - - # 根据角色创建对应的消息对象 - if role == "system": - messages.append({"role": "system", "content": content}) - elif role in ["user", "human"]: - messages.append({"role": "user", "content": content}) - elif role in ["ai", "assistant"]: - messages.append({"role": "assistant", "content": content}) - else: - logger.warning(f"未知的消息角色: {role},默认使用 user") - messages.append({"role": "user", "content": content}) - - if self.typed_config.memory.enable: - # if self.typed_config.memory.enable_window: - messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] - prompt_or_messages = messages - else: - # 使用简单的 prompt 格式(向后兼容) - prompt_template = self.config.get("prompt", "") - prompt_or_messages = self._render_template(prompt_template, state) - - # 2. 获取模型配置 - model_id = self.config.get("model_id") + model_id = self.typed_config.model_id if not model_id: raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置") @@ -157,27 +134,82 @@ class LLMNode(BaseNode): logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}") - return llm, prompt_or_messages + messages_config = self.typed_config.messages - async def execute(self, state: WorkflowState) -> AIMessage: + if messages_config: + # 使用 LangChain 消息格式 + messages = [] + for msg_config in messages_config: + role = msg_config.role.lower() + content_template = msg_config.content + content_template = self._render_context(content_template, variable_pool) + content = self._render_template(content_template, variable_pool) + + # 根据角色创建对应的消息对象 + if role == "system": + messages.append({ + "role": "system", + "content": content + }) + elif role in ["user", "human"]: + messages.append({ + "role": "user", + "content": content + }) + elif role in ["ai", "assistant"]: + messages.append({ + "role": "assistant", + "content": content + }) + else: + logger.warning(f"未知的消息角色: {role},默认使用 user") + messages.append({ + "role": "user", + "content": content + }) + + if self.typed_config.vision_input and self.typed_config.vision: + file_content = [] + files = variable_pool.get_value(self.typed_config.vision_input) + for file in files: + content = await self.process_message(provider, file, self.typed_config.vision) + if content: + file_content.append(content) + if messages and messages[-1]["role"] == 'user': + messages[-1]['content'] = [messages[-1]["content"]] + file_content + else: + messages.append({"role": "user", "content": file_content}) + + if self.typed_config.memory.enable: + messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] + self.messages = messages + else: + # 使用简单的 prompt 格式(向后兼容) + prompt_template = self.config.get("prompt", "") + self.messages = self._render_template(prompt_template, variable_pool) + + return llm + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage: """非流式执行 LLM 调用 Args: state: 工作流状态 + variable_pool: 变量池 Returns: LLM 响应消息 """ # self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, True) + llm = await self._prepare_llm(state, variable_pool, False) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") # 调用 LLM(支持字符串或消息列表) - response = await llm.ainvoke(prompt_or_messages) + response = await llm.ainvoke(self.messages) # 提取内容 if hasattr(response, 'content'): - content = response.content + content = self.process_model_output(response.content) else: content = str(response) @@ -186,16 +218,15 @@ class LLMNode(BaseNode): # 返回 AIMessage(包含响应元数据) return response if isinstance(response, AIMessage) else AIMessage(content=content) - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录)""" - _, prompt_or_messages = self._prepare_llm(state) return { - "prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None, + "prompt": self.messages if isinstance(self.messages, str) else None, "messages": [ {"role": msg.get("role"), "content": msg.get("content", "")} - for msg in prompt_or_messages - ] if isinstance(prompt_or_messages, list) else None, + for msg in self.messages + ] if isinstance(self.messages, list) else None, "config": { "model_id": self.config.get("model_id"), "temperature": self.config.get("temperature"), @@ -215,24 +246,25 @@ class LLMNode(BaseNode): usage = business_result.response_metadata.get('token_usage') if usage: return { - "prompt_tokens": usage.get('prompt_tokens', 0), - "completion_tokens": usage.get('completion_tokens', 0), + "prompt_tokens": usage.get('input_tokens', 0), + "completion_tokens": usage.get('output_tokens', 0), "total_tokens": usage.get('total_tokens', 0) } return None - async def execute_stream(self, state: WorkflowState): + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): """流式执行 LLM 调用 Args: state: 工作流状态 + variable_pool: 变量池 Yields: 文本片段(chunk)或完成标记 """ self.typed_config = LLMNodeConfig(**self.config) - llm, prompt_or_messages = self._prepare_llm(state, True) + llm = await self._prepare_llm(state, variable_pool, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)") logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}") @@ -243,10 +275,10 @@ class LLMNode(BaseNode): # 调用 LLM(流式,支持字符串或消息列表) last_meta_data = {} - async for chunk in llm.astream(prompt_or_messages, stream_usage=True): + async for chunk in llm.astream(self.messages, stream_usage=True): # 提取内容 if hasattr(chunk, 'content'): - content = chunk.content + content = self.process_model_output(chunk.content) else: content = str(chunk) if hasattr(chunk, 'response_metadata'): diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 13860bec..ddbe4b99 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -3,6 +3,8 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.services.memory_agent_service import MemoryAgentService from app.tasks import write_message_task @@ -13,17 +15,23 @@ class MemoryReadNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: MemoryReadNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return { + "answer": VariableType.STRING, + "intermediate_outputs": VariableType.ARRAY_OBJECT + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", variable_pool) if not end_user_id: raise RuntimeError("End user id is required") return await MemoryAgentService().read_memory( end_user_id=end_user_id, - message=self._render_template(self.typed_config.message, state), + message=self._render_template(self.typed_config.message, variable_pool), config_id=self.typed_config.config_id, search_switch=self.typed_config.search_switch, history=[], @@ -38,16 +46,19 @@ class MemoryWriteNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: MemoryWriteNodeConfig | None = None - async def execute(self, state: WorkflowState) -> Any: + def _output_types(self) -> dict[str, VariableType]: + return {"output": VariableType.STRING} + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: self.typed_config = MemoryWriteNodeConfig(**self.config) - end_user_id = self.get_variable("sys.user_id", state) + end_user_id = self.get_variable("sys.user_id", variable_pool) if not end_user_id: raise RuntimeError("End user id is required") write_message_task.delay( end_user_id, - self._render_template(self.typed_config.message, state), + self._render_template(self.typed_config.message, variable_pool), str(self.typed_config.config_id), "neo4j", "" diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index fb2fe00f..00120ca0 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -22,7 +22,6 @@ from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.memory import MemoryReadNode, MemoryWriteNode from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode -from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.breaker import BreakNode @@ -37,7 +36,6 @@ WorkflowNode = Union[ LLMNode, IfElseNode, AgentNode, - TransformNode, AssignerNode, HttpRequestNode, KnowledgeRetrievalNode, @@ -67,7 +65,6 @@ class NodeFactory: NodeType.END: EndNode, NodeType.LLM: LLMNode, NodeType.AGENT: AgentNode, - NodeType.TRANSFORM: TransformNode, NodeType.IF_ELSE: IfElseNode, NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, NodeType.ASSIGNER: AssignerNode, diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index ad38284a..251d6a79 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,9 +1,9 @@ import json import re from abc import ABC -from typing import Union, Type, NoReturn +from typing import Union, Type, NoReturn, Any -from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.nodes.enums import ValueInputType from app.core.workflow.variable_pool import VariablePool @@ -69,7 +69,7 @@ class TypeTransformer: class OperatorBase(ABC): - def __init__(self, pool: VariablePool, left_selector, right): + def __init__(self, pool: VariablePool, left_selector: str, right: Any): self.pool = pool self.left_selector = left_selector self.right = right @@ -77,7 +77,7 @@ class OperatorBase(ABC): self.type_limit: type[str, int, dict, list] = None def check(self, no_right=False): - left = self.pool.get(self.left_selector) + left = self.pool.get_value(self.left_selector) if not isinstance(left, self.type_limit): raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") @@ -92,13 +92,13 @@ class StringOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = str - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, '') + await self.pool.set(self.left_selector, '') class NumberOperator(OperatorBase): @@ -106,33 +106,33 @@ class NumberOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = (float, int) - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, 0) + await self.pool.set(self.left_selector, 0) - def add(self) -> None: + async def add(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin + self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin + self.right) - def subtract(self) -> None: + async def subtract(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin - self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin - self.right) - def multiply(self) -> None: + async def multiply(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin * self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin * self.right) - def divide(self) -> None: + async def divide(self) -> None: self.check() - origin = self.pool.get(self.left_selector) - self.pool.set(self.left_selector, origin / self.right) + origin = self.pool.get_value(self.left_selector) + await self.pool.set(self.left_selector, origin / self.right) class BooleanOperator(OperatorBase): @@ -140,13 +140,13 @@ class BooleanOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = bool - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, False) + await self.pool.set(self.left_selector, False) class ArrayOperator(OperatorBase): @@ -154,38 +154,37 @@ class ArrayOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = list - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, list()) + await self.pool.set(self.left_selector, list()) - def append(self) -> None: + async def append(self) -> None: self.check(no_right=True) - # TODO:require type limit in list - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.append(self.right) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def extend(self) -> None: + async def extend(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.extend(self.right) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def remove_last(self) -> None: + async def remove_last(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.pop() - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) - def remove_first(self) -> None: + async def remove_first(self) -> None: self.check(no_right=True) - origin = self.pool.get(self.left_selector) + origin = self.pool.get_value(self.left_selector) origin.pop(0) - self.pool.set(self.left_selector, origin) + await self.pool.set(self.left_selector, origin) class ObjectOperator(OperatorBase): @@ -193,13 +192,13 @@ class ObjectOperator(OperatorBase): super().__init__(pool, left_selector, right) self.type_limit = dict - def assign(self) -> None: + async def assign(self) -> None: self.check() - self.pool.set(self.left_selector, self.right) + await self.pool.set(self.left_selector, self.right) - def clear(self) -> None: + async def clear(self) -> None: self.check(no_right=True) - self.pool.set(self.left_selector, dict()) + await self.pool.set(self.left_selector, dict()) class AssignmentOperatorResolver: @@ -245,7 +244,7 @@ class ConditionBase(ABC): self.right_selector = right_selector self.input_type = input_type - self.left_value = self.pool.get(self.left_selector) + self.left_value = self.pool.get_value(self.left_selector) self.right_value = self.resolve_right_literal_value() self.type_limit = getattr(self, "type_limit", None) @@ -254,7 +253,7 @@ class ConditionBase(ABC): if self.input_type == ValueInputType.VARIABLE: pattern = r"\{\{\s*(.*?)\s*\}\}" right_expression = re.sub(pattern, r"\1", self.right_selector).strip() - return self.pool.get(right_expression) + return self.pool.get_value(right_expression) elif self.input_type == ValueInputType.CONSTANT: return self.right_selector raise RuntimeError("Unsupported variable type") diff --git a/api/app/core/workflow/nodes/parameter_extractor/config.py b/api/app/core/workflow/nodes/parameter_extractor/config.py index cfbd9c14..a0b9c032 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/config.py +++ b/api/app/core/workflow/nodes/parameter_extractor/config.py @@ -1,7 +1,7 @@ import uuid +from enum import StrEnum from pydantic import Field, BaseModel -from enum import StrEnum from app.core.workflow.nodes.base_config import BaseNodeConfig diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index ec58d96c..475c54fe 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -12,6 +12,8 @@ from app.core.models import RedBearLLM, RedBearModelConfig from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService @@ -24,6 +26,12 @@ class ParameterExtractorNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: ParameterExtractorNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + outputs = {} + for param in self.typed_config.params: + outputs[param.name] = param.type + return outputs + @staticmethod def _get_prompt(): """ @@ -120,7 +128,7 @@ class ParameterExtractorNode(BaseNode): field_type[param.name] = f'{param.type}, required:{str(param.required)}' return field_type - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Main execution function for this node. @@ -138,6 +146,7 @@ class ParameterExtractorNode(BaseNode): Args: state (WorkflowState): Current state of the workflow, used for template rendering. + variable_pool (VariablePool): Used for accessing and setting variables during execution. Returns: dict[str, Any]: Dictionary containing extracted parameters under the "output" key. @@ -153,7 +162,7 @@ class ParameterExtractorNode(BaseNode): rendered_user_prompt = user_prompt_teplate.render( field_descriptions=str(self._get_field_desc()), field_type=str(self._get_field_type()), - text_input=self._render_template(self.typed_config.text, state) + text_input=self._render_template(self.typed_config.text, variable_pool) ) messages = [ @@ -162,7 +171,7 @@ class ParameterExtractorNode(BaseNode): ] if self.typed_config.prompt: messages.extend([ - ("user", self._render_template(self.typed_config.prompt, state)), + ("user", self._render_template(self.typed_config.prompt, variable_pool)), ("user", rendered_user_prompt), ]) else: diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 6df410cb..d7496f12 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -6,6 +6,8 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie from app.core.models import RedBearLLM, RedBearModelConfig from app.core.exceptions import BusinessException from app.core.error_codes import BizCode +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.db import get_db_read from app.models import ModelType from app.services.model_service import ModelConfigService @@ -24,6 +26,12 @@ class QuestionClassifierNode(BaseNode): self.typed_config: QuestionClassifierNodeConfig | None = None self.category_to_case_map = {} + def _output_types(self) -> dict[str, VariableType]: + return { + "class_name": VariableType.STRING, + "output": VariableType.STRING + } + def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" with get_db_read() as db: @@ -65,7 +73,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> dict: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict: """执行问题分类""" self.typed_config = QuestionClassifierNodeConfig(**self.config) self.category_to_case_map = self._build_category_case_map() @@ -102,7 +110,7 @@ class QuestionClassifierNode(BaseNode): categories=", ".join(category_names), supplement_prompt=supplement_prompt ), - state + variable_pool ) messages = [ diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py index 1544f89f..98390bf7 100644 --- a/api/app/core/workflow/nodes/start/config.py +++ b/api/app/core/workflow/nodes/start/config.py @@ -2,7 +2,8 @@ from pydantic import Field -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType +from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition +from app.core.workflow.variable.base_variable import VariableType class StartNodeConfig(BaseNodeConfig): diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index 69560422..db66bc65 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -7,9 +7,10 @@ Start 节点实现 import logging from typing import Any -from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.start.config import StartNodeConfig +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -36,14 +37,25 @@ class StartNode(BaseNode): # 解析并验证配置 self.typed_config: StartNodeConfig | None = None + self.output_var_types = {} - async def execute(self, state: WorkflowState) -> dict[str, Any]: + def _output_types(self) -> dict[str, VariableType]: + return self.output_var_types | { + "message": VariableType.STRING, + "execution_id": VariableType.STRING, + "conversation_id": VariableType.STRING, + "workspace_id": VariableType.STRING, + "user_id": VariableType.STRING, + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """执行 start 节点业务逻辑 Start 节点输出系统变量、会话变量和自定义变量。 Args: state: 工作流状态 + variable_pool: 变量池 Returns: 包含系统参数、会话变量和自定义变量的字典 @@ -51,19 +63,16 @@ class StartNode(BaseNode): self.typed_config = StartNodeConfig(**self.config) logger.info(f"节点 {self.node_id} (Start) 开始执行") - # 创建变量池实例(在方法内复用) - pool = self.get_variable_pool(state) - # 处理自定义变量(传入 pool 避免重复创建) - custom_vars = self._process_custom_variables(pool) + custom_vars = self._process_custom_variables(variable_pool) # 返回业务数据(包含自定义变量) result = { - "message": pool.get("sys.message"), - "execution_id": pool.get("sys.execution_id"), - "conversation_id": pool.get("sys.conversation_id"), - "workspace_id": pool.get("sys.workspace_id"), - "user_id": pool.get("sys.user_id"), + "message": variable_pool.get_value("sys.message"), + "execution_id": variable_pool.get_value("sys.execution_id"), + "conversation_id": variable_pool.get_value("sys.conversation_id"), + "workspace_id": variable_pool.get_value("sys.workspace_id"), + "user_id": variable_pool.get_value("sys.user_id"), **custom_vars # 自定义变量作为节点输出的一部分 } @@ -74,7 +83,7 @@ class StartNode(BaseNode): return result - def _process_custom_variables(self, pool) -> dict[str, Any]: + def _process_custom_variables(self, pool: VariablePool) -> dict[str, Any]: """处理自定义变量 从输入数据中提取自定义变量,应用默认值和验证。 @@ -89,13 +98,14 @@ class StartNode(BaseNode): ValueError: 缺少必需变量 """ # 获取输入数据中的自定义变量 - input_variables = pool.get("sys.input_variables", default={}) + input_variables = pool.get_value("sys.input_variables", default={}, strict=False) processed = {} # 遍历配置的变量定义 for var_def in self.typed_config.variables: var_name = var_def.name + var_type = var_def.type # 检查变量是否存在 if var_name in input_variables: @@ -116,21 +126,12 @@ class StartNode(BaseNode): f"变量 '{var_name}' 使用默认值: {var_def.default}" ) else: - match var_def.type: - case VariableType.STRING: - processed[var_name] = "" - case VariableType.NUMBER: - processed[var_name] = 0 - case VariableType.OBJECT: - processed[var_name] = {} - case VariableType.BOOLEAN: - processed[var_name] = False - case VariableType.ARRAY_NUMBER | VariableType.ARRAY_OBJECT | VariableType.ARRAY_BOOLEAN | VariableType.ARRAY_STRING: - processed[var_name] = [] + processed[var_name] = DEFAULT_VALUE(var_type) + self.output_var_types[var_name] = var_type return processed - def _extract_input(self, state: WorkflowState) -> dict[str, Any]: + def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """提取输入数据(用于记录) Args: @@ -139,11 +140,9 @@ class StartNode(BaseNode): Returns: 输入数据字典 """ - pool = self.get_variable_pool(state) - return { - "execution_id": pool.get("sys.execution_id"), - "conversation_id": pool.get("sys.conversation_id"), - "message": pool.get("sys.message"), - "conversation_vars": pool.get_all_conversation_vars() + "execution_id": variable_pool.get_value("sys.execution_id"), + "conversation_id": variable_pool.get_value("sys.conversation_id"), + "message": variable_pool.get_value("sys.message"), + "conversation_vars": variable_pool.get_all_conversation_vars() } diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index aba96303..adc55d87 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -6,6 +6,8 @@ from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.tool.config import ToolNodeConfig +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable_pool import VariablePool from app.services.tool_service import ToolService from app.db import get_db_read @@ -21,13 +23,20 @@ class ToolNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: ToolNodeConfig | None = None - async def execute(self, state: WorkflowState) -> dict[str, Any]: + def _output_types(self) -> dict[str, VariableType]: + return { + "data": VariableType.STRING, + "error_code": VariableType.STRING, + "execution_time": VariableType.NUMBER + } + + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """执行工具""" self.typed_config = ToolNodeConfig(**self.config) # 获取租户ID和用户ID - tenant_id = self.get_variable("sys.tenant_id", state) - user_id = self.get_variable("sys.user_id", state) - workspace_id = self.get_variable("sys.workspace_id", state) + tenant_id = self.get_variable("sys.tenant_id", variable_pool, strict=False) + user_id = self.get_variable("sys.user_id", variable_pool) + workspace_id = self.get_variable("sys.workspace_id", variable_pool) # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: @@ -48,7 +57,7 @@ class ToolNode(BaseNode): for param_name, param_template in self.typed_config.tool_parameters.items(): if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): try: - rendered_value = self._render_template(param_template, state) + rendered_value = self._render_template(param_template, variable_pool) except Exception as e: raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e else: diff --git a/api/app/core/workflow/nodes/transform/__init__.py b/api/app/core/workflow/nodes/transform/__init__.py deleted file mode 100644 index 384b818c..00000000 --- a/api/app/core/workflow/nodes/transform/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Transform 节点""" - -from app.core.workflow.nodes.transform.node import TransformNode -from app.core.workflow.nodes.transform.config import TransformNodeConfig - -__all__ = ["TransformNode", "TransformNodeConfig"] diff --git a/api/app/core/workflow/nodes/transform/config.py b/api/app/core/workflow/nodes/transform/config.py deleted file mode 100644 index 47d2a6ac..00000000 --- a/api/app/core/workflow/nodes/transform/config.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Transform 节点配置""" - -from typing import Literal - -from pydantic import Field - -from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType - - -class TransformNodeConfig(BaseNodeConfig): - """Transform 节点配置 - - 用于数据转换和处理。 - """ - - transform_type: Literal["template", "code", "json"] = Field( - default="template", - description="转换类型:template(模板), code(代码), json(JSON处理)" - ) - - # 模板模式 - template: str | None = Field( - default=None, - description="转换模板,支持变量引用" - ) - - # 代码模式 - code: str | None = Field( - default=None, - description="Python 代码,用于数据转换" - ) - - # JSON 模式 - json_path: str | None = Field( - default=None, - description="JSON 路径表达式" - ) - - # 输入变量 - inputs: dict[str, str] | None = Field( - default=None, - description="输入变量映射,key 为变量名,value 为变量选择器" - ) - - # 输出变量 - output_key: str = Field( - default="result", - description="输出变量的键名" - ) - - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="result", - type=VariableType.STRING, - description="转换后的结果" - ) - ], - description="输出变量定义(根据 output_key 动态生成)" - ) - - class Config: - json_schema_extra = { - "examples": [ - { - "transform_type": "template", - "template": "用户问题:{{ sys.message }}\n回答:{{ llm_qa.output }}", - "output_key": "formatted_result" - }, - { - "transform_type": "code", - "code": "result = input_text.upper()", - "inputs": { - "input_text": "{{ sys.message }}" - }, - "output_key": "uppercase_text" - } - ] - } diff --git a/api/app/core/workflow/nodes/transform/node.py b/api/app/core/workflow/nodes/transform/node.py deleted file mode 100644 index 4211c510..00000000 --- a/api/app/core/workflow/nodes/transform/node.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Transform 节点实现 - -数据转换节点,用于处理和转换数据。 -""" - -import logging -from typing import Any - -from app.core.workflow.nodes.base_node import BaseNode, WorkflowState - -logger = logging.getLogger(__name__) - - -class TransformNode(BaseNode): - """数据转换节点 - - 配置示例: - { - "type": "transform", - "config": { - "mapping": { - "output_field": "{{node.previous.output}}", - "processed": "{{var.input | upper}}" - } - } - } - """ - - async def execute(self, state: WorkflowState) -> dict[str, Any]: - """执行数据转换 - - Args: - state: 工作流状态 - - Returns: - 状态更新字典 - """ - logger.info(f"节点 {self.node_id} 开始执行数据转换") - - # 获取映射配置 - mapping = self.config.get("mapping", {}) - - # 执行数据转换 - transformed_data = {} - for target_key, source_template in mapping.items(): - # 渲染模板获取值 - value = self._render_template(str(source_template), state) - transformed_data[target_key] = value - - logger.info(f"节点 {self.node_id} 数据转换完成,输出字段: {list(transformed_data.keys())}") - - return { - "node_outputs": { - self.node_id: { - "output": transformed_data, - "status": "completed" - } - } - } diff --git a/api/app/core/workflow/nodes/variable_aggregator/config.py b/api/app/core/workflow/nodes/variable_aggregator/config.py index ac1419a4..7fe63be1 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/config.py +++ b/api/app/core/workflow/nodes/variable_aggregator/config.py @@ -1,6 +1,7 @@ from pydantic import Field, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig +from app.core.workflow.variable.base_variable import VariableType class VariableAggregatorNodeConfig(BaseNodeConfig): @@ -14,6 +15,11 @@ class VariableAggregatorNodeConfig(BaseNodeConfig): description="需要被聚合的变量" ) + group_type: dict[str, VariableType] = Field( + default=None, + description="每个分组的变量类型" + ) + @field_validator("group_variables") @classmethod def group_variables_validator(cls, v, info): diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index 5bff8e33..56ab4cfb 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -5,6 +5,8 @@ from typing import Any from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE +from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,6 +16,17 @@ class VariableAggregatorNode(BaseNode): super().__init__(node_config, workflow_config) self.typed_config: VariableAggregatorNodeConfig | None = None + def _output_types(self) -> dict[str, VariableType]: + config = VariableAggregatorNodeConfig(**self.config) + output = {} + if not config.group_type: + for group_name in config.group_variables.keys(): + output[group_name] = VariableType.ANY + return output + for var_type in config.group_type: + output[var_type] = config.group_type[var_type] + return output + @staticmethod def _get_express(variable_string: str) -> Any: """ @@ -29,7 +42,7 @@ class VariableAggregatorNode(BaseNode): expression = re.sub(pattern, r"\1", variable_string).strip() return expression - async def execute(self, state: WorkflowState) -> Any: + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the variable aggregation logic. @@ -45,7 +58,7 @@ class VariableAggregatorNode(BaseNode): for variable in self.typed_config.group_variables: var_express = self._get_express(variable) try: - value = self.get_variable(var_express, state) + value = self.get_variable(var_express, variable_pool) except Exception as e: logger.warning(f"Failed to get variable '{var_express}': {e}") continue @@ -55,7 +68,9 @@ class VariableAggregatorNode(BaseNode): return value logger.info("No variable found in non-group mode; returning empty string.") - return "" + if not self.typed_config.group_type: + return "" + return DEFAULT_VALUE(self.typed_config.group_type["output"]) # -------------------------- # Group mode @@ -65,7 +80,7 @@ class VariableAggregatorNode(BaseNode): for variable in variables: var_express = self._get_express(variable) try: - value = self.get_variable(var_express, state) + value = self.get_variable(var_express, variable_pool) except Exception as e: logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}") continue @@ -74,7 +89,10 @@ class VariableAggregatorNode(BaseNode): result[group_name] = value break else: - result[group_name] = "" + if not self.typed_config.group_type: + result[group_name] = "" + else: + result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name]) logger.info(f"No variable found for group '{group_name}'; set empty string.") logger.info(f"Node: {self.node_id} variable aggregation result: {result}") return result diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/template_renderer.py index c2d7f255..9e2a28e8 100644 --- a/api/app/core/workflow/template_renderer.py +++ b/api/app/core/workflow/template_renderer.py @@ -43,7 +43,7 @@ class TemplateRenderer: def render( self, template: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], system_vars: dict[str, Any] | None = None ) -> str: @@ -51,7 +51,7 @@ class TemplateRenderer: Args: template: 模板字符串 - variables: 用户定义的变量 + conv_vars: 会话变量 node_outputs: 节点输出结果 system_vars: 系统变量 @@ -80,20 +80,11 @@ class TemplateRenderer: '分析结果: 正面情绪' """ # 构建命名空间上下文 - # variables 的结构:{"sys": {...}, "conv": {...}} - sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {} - conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {} - if self.strict: - context = defaultdict(dict) - context["conv"] = conv_vars - context["node"] = node_outputs - context["sys"] = {**(system_vars or {}), **sys_vars} - else: - context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源) - } + context = { + "conv": conv_vars, # 会话变量:{{conv.user_name}} + "node": node_outputs, # 节点输出:{{node.node_1.output}} + "sys": system_vars, # 系统变量:{{sys.execution_id}} + } # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} # 将所有节点输出添加到顶层上下文 @@ -157,9 +148,9 @@ _default_renderer = TemplateRenderer(strict=True) def render_template( template: str, - variables: dict[str, Any], + conv_vars: dict[str, Any], node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None, + system_vars: dict[str, Any], strict: bool = True ) -> str: """渲染模板(便捷函数) @@ -167,7 +158,7 @@ def render_template( Args: strict: 严格模式 template: 模板字符串 - variables: 用户变量 + conv_vars: 会话变量 node_outputs: 节点输出 system_vars: 系统变量 @@ -184,7 +175,7 @@ def render_template( '请分析: 这是一段文本' """ renderer = TemplateRenderer(strict=strict) - return renderer.render(template, variables, node_outputs, system_vars) + return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 6daf415d..96fc35ad 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -5,10 +5,13 @@ """ import logging -from typing import Any, Union +from typing import Any, Union, TYPE_CHECKING from app.core.workflow.nodes.enums import NodeType +if TYPE_CHECKING: + from app.schemas.workflow_schema import WorkflowConfig + logger = logging.getLogger(__name__) @@ -64,7 +67,7 @@ class WorkflowValidator: return cycle_nodes, cycle_edges @classmethod - def get_subgraph(cls, workflow_config: Union[dict[str, Any], Any]) -> list: + def get_subgraph(cls, workflow_config: Union[dict[str, Any], "WorkflowConfig"]) -> list: if not isinstance(workflow_config, dict): workflow_config = { "nodes": workflow_config.nodes, @@ -331,7 +334,7 @@ class WorkflowValidator: def validate_workflow_config( - workflow_config: dict[str, Any], + workflow_config: Union[dict[str, Any], 'WorkflowConfig'], for_publish: bool = False ) -> tuple[bool, list[str]]: """验证工作流配置(便捷函数) diff --git a/web/src/views/MemoryConversation/types.ts b/api/app/core/workflow/variable/__init__.py similarity index 100% rename from web/src/views/MemoryConversation/types.ts rename to api/app/core/workflow/variable/__init__.py diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py new file mode 100644 index 00000000..6a2e84d2 --- /dev/null +++ b/api/app/core/workflow/variable/base_variable.py @@ -0,0 +1,170 @@ +from enum import StrEnum +from abc import abstractmethod, ABC +from typing import Any + +from pydantic import BaseModel + +from app.schemas import FileType + + +class VariableType(StrEnum): + """Enumeration of supported variable types in the workflow.""" + + STRING = "string" + NUMBER = "number" + BOOLEAN = "boolean" + OBJECT = "object" + FILE = "file" + + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" + + NESTED_ARRAY = "array_nest" + + ANY = 'any' + + @classmethod + def type_map(cls, var: Any) -> "VariableType": + """Maps a Python value to a corresponding VariableType. + + Args: + var: The Python value to map. + + Returns: + The VariableType corresponding to the input value. + + Raises: + TypeError: If the type of the input value is not supported. + """ + var_type = type(var) + if isinstance(var_type, str): + return cls.STRING + elif isinstance(var_type, (int, float)): + return cls.NUMBER + elif isinstance(var_type, bool): + return cls.BOOLEAN + elif isinstance(var_type, FileObject) or (isinstance(var, dict) and var.get('__file')): + return cls.FILE + elif isinstance(var_type, dict): + return cls.OBJECT + elif isinstance(var_type, list): + if len(var) == 0: + return cls.ARRAY_STRING + else: + child_type = type(var[0]) + if child_type == str: + return cls.ARRAY_STRING + elif child_type == int or child_type == float: + return cls.ARRAY_NUMBER + elif child_type == bool: + return cls.ARRAY_BOOLEAN + elif child_type == dict: + return cls.ARRAY_OBJECT + elif child_type == list: + return cls.NESTED_ARRAY + else: + raise TypeError(f"Unsupported array child type - {child_type}") + raise TypeError(f"Unsupported type - {var_type}") + + +def DEFAULT_VALUE(var_type: VariableType) -> Any: + """Returns the default value for a given VariableType. + + Args: + var_type: The variable type for which to get the default value. + + Returns: + The default Python value corresponding to the VariableType. + + Raises: + TypeError: If the VariableType is invalid. + """ + match var_type: + case VariableType.STRING: + return "" + case VariableType.NUMBER: + return 0 + case VariableType.BOOLEAN: + return False + case VariableType.OBJECT: + return {} + case VariableType.FILE: + return None + case VariableType.ARRAY_STRING: + return [] + case VariableType.ARRAY_NUMBER: + return [] + case VariableType.ARRAY_BOOLEAN: + return [] + case VariableType.ARRAY_OBJECT: + return [] + case VariableType.ARRAY_FILE: + return [] + case _: + raise TypeError(f"Invalid type - {type}") + + +class FileObject(BaseModel): + type: FileType + url: str + __file: bool + + +class BaseVariable(ABC): + """Abstract base class for all workflow variables. + + Subclasses must implement validation and serialization methods. + """ + type = None + + def __init__(self, value: Any): + """Initializes a variable instance. + + Args: + value: The initial value for the variable. + + Attributes: + self.value: The validated value stored in the variable. + self.literal: A string representation of the variable. + """ + self.value = self.valid_value(value) + self.literal = self.to_literal() + + @abstractmethod + def valid_value(self, value) -> Any: + """Validates or converts a value to the correct type for the variable. + + Args: + value: The value to validate. + + Returns: + The validated or converted value. + + Raises: + TypeError: If the value is invalid. + """ + pass + + @abstractmethod + def to_literal(self) -> str: + """Converts the variable value to a string literal representation. + + Returns: + A string representing the variable's value. + """ + pass + + def get_value(self) -> Any: + """Returns the current value of the variable.""" + return self.value + + def set(self, value): + """Sets the variable to a new value after validation. + + Args: + value: The new value to assign to the variable. + """ + self.value = self.valid_value(value) diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py new file mode 100644 index 00000000..7a39835c --- /dev/null +++ b/api/app/core/workflow/variable/variable_objects.py @@ -0,0 +1,174 @@ +from typing import Any, TypeVar, Type, Generic + +from deprecated import deprecated + +from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType + +T = TypeVar("T", bound=BaseVariable) + + +class StringVariable(BaseVariable): + type = 'str' + + def valid_value(self, value) -> str: + if not isinstance(value, str): + raise TypeError(f"Value must be a string - {type(value)}:{value}") + return value + + def to_literal(self) -> str: + return self.value + + +class NumberVariable(BaseVariable): + type = 'number' + + def valid_value(self, value) -> int | float: + if not isinstance(value, (int, float)): + raise TypeError(f"Value must be a number - {type(value)}:{value}") + return value + + def to_literal(self) -> str: + return str(self.value) + + +class BooleanVariable(BaseVariable): + type = 'boolean' + + def valid_value(self, value) -> bool: + if not isinstance(value, bool): + raise TypeError(f"Value must be a boolean - {type(value)}:{value}") + return value + + def to_literal(self) -> str: + return str(self.value).lower() + + +class DictVariable(BaseVariable): + type = 'object' + + def valid_value(self, value) -> dict: + if not isinstance(value, dict): + raise TypeError(f"Value must be a dict - {type(value)}:{value}") + return value + + def to_literal(self) -> str: + return str(self.value) + + +class FileVariable(BaseVariable): + type = 'file' + + def valid_value(self, value) -> FileObject: + + if isinstance(value, dict): + if not value.get("__file"): + raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") + return FileObject( + **{ + "type": str(value.get('type')), + "url": value.get('url'), + "__file": True + } + ) + if isinstance(value, FileObject): + return value + raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") + + def to_literal(self) -> str: + return f'{"!"if self.value.type == FileType.IMAGE else ""}[file]({self.value.url})' + + def get_value(self) -> Any: + return self.value.model_dump() + + +class ArrayObject(BaseVariable, Generic[T]): + type = 'array' + + def __init__(self, child_type: Type[T], value: list[Any]): + if not issubclass(child_type, BaseVariable): + raise TypeError("child_type must be a subclass of BaseVariable") + self.child_type = child_type + super().__init__(value) + + def valid_value(self, value: list[Any]) -> list[T]: + if not isinstance(value, list): + raise TypeError(f"Value must be a list - {type(value)}:{value}") + final_value = [] + for v in value: + try: + final_value.append(self.child_type(v)) + except: + raise TypeError(f"All elements must be of type {self.child_type.type}") + return final_value + + def to_literal(self) -> str: + return "\n".join([v.to_literal() for v in self.value]) + + def get_value(self) -> Any: + return [v.get_value() for v in self.value] + + +class NestedArrayObject(BaseVariable): + type = 'array_nest' + + def valid_value(self, value: list[T]) -> list[T]: + if not isinstance(value, list): + raise TypeError(f"Value must be a list - {type(value)}:{value}") + final_value = [] + for v in value: + if not isinstance(v, ArrayObject): + raise TypeError("All elements must be of type list") + final_value.append(v) + return final_value + + def to_literal(self) -> str: + return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) + + def get_value(self) -> Any: + return [[item.get_value() for item in row] for row in self.value] + + +@deprecated( + reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", + category=RuntimeWarning +) +class AnyObject(BaseVariable): + type = 'any' + + def valid_value(self, value: Any) -> Any: + return value + + def to_literal(self) -> str: + return str(self.value) + + +def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: + """简化 ArrayObject 创建,不需要重复写类型""" + + return ArrayObject(child_type, value) + + +def create_variable_instance(var_type: VariableType, value: Any) -> T: + match var_type: + case VariableType.STRING: + return StringVariable(value) + case VariableType.NUMBER: + return NumberVariable(value) + case VariableType.BOOLEAN: + return BooleanVariable(value) + case VariableType.OBJECT: + return DictVariable(value) + case VariableType.ARRAY_STRING: + return make_array(StringVariable, value) + case VariableType.ARRAY_NUMBER: + return make_array(NumberVariable, value) + case VariableType.ARRAY_BOOLEAN: + return make_array(BooleanVariable, value) + case VariableType.ARRAY_OBJECT: + return make_array(DictVariable, value) + case VariableType.ARRAY_FILE: + return make_array(FileVariable, value) + case VariableType.ANY: + return AnyObject(value) + case _: + raise TypeError(f"Invalid type - {var_type}") diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/variable_pool.py index 7d4b0609..96495ce8 100644 --- a/api/app/core/workflow/variable_pool.py +++ b/api/app/core/workflow/variable_pool.py @@ -11,10 +11,15 @@ import logging import re -from typing import Any, TYPE_CHECKING +from asyncio import Lock +from collections import defaultdict +from copy import deepcopy +from typing import Any, Generic -if TYPE_CHECKING: - from app.core.workflow.nodes import WorkflowState +from pydantic import BaseModel + +from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import T, create_variable_instance logger = logging.getLogger(__name__) @@ -23,11 +28,6 @@ class VariableSelector: """变量选择器 用于引用变量的路径表示。 - - Examples: - >>> selector = VariableSelector(["sys", "message"]) - >>> selector = VariableSelector(["node_A", "output"]) - >>> selector = VariableSelector.from_string("sys.message") """ def __init__(self, path: list[str]): @@ -52,10 +52,6 @@ class VariableSelector: Returns: VariableSelector 实例 - - Examples: - >>> selector = VariableSelector.from_string("sys.message") - >>> selector = VariableSelector.from_string("llm_qa.output") """ path = selector_str.split(".") return cls(path) @@ -67,160 +63,212 @@ class VariableSelector: return f"VariableSelector({self.path})" +class VariableStruct(BaseModel, Generic[T]): + """A typed variable struct. + + Represents a runtime variable with an associated logical type and + a concrete value object. + + This class bridges the static type system (via generics) and the + runtime type system (via ``VariableType``). + + Attributes: + type: + Logical variable type descriptor used for runtime validation, + serialization, and workflow type checking. + instance: + The concrete variable object. The actual Python type is + represented by the generic parameter ``T`` (e.g. StringVariable, + NumberVariable, ArrayObject[StringVariable]). + mut: + Whether the variable is mutable. + """ + type: VariableType + instance: T + mut: bool + + model_config = { + "arbitrary_types_allowed": True + } + + class VariablePool: - """变量池 - - 管理工作流执行过程中的所有变量。 - - 变量命名空间: - - sys.*: 系统变量(message, execution_id, workspace_id, user_id, conversation_id) - - conv.*: 会话变量(跨多轮对话保持的变量) - - .*: 节点输出 - - Examples: - >>> pool = VariablePool(state) - >>> pool.get(["sys", "message"]) - "用户的问题" - >>> pool.get(["llm_qa", "output"]) - "AI 的回答" - >>> pool.set(["conv", "user_name"], "张三") + """Variable pool. + + Manages all variables during workflow execution, including storage, + namespacing, and concurrency control. + + Variable namespace conventions: + - ``sys.*``: + System variables (e.g. message, execution_id, workspace_id, + user_id, conversation_id). + - ``conv.*``: + Conversation-level variables that persist across multiple turns. + - ``.*``: + Variables produced by workflow nodes. """ - def __init__(self, state: "WorkflowState"): - """初始化变量池 - - Args: - state: 工作流状态(LangGraph State) - """ - self.state = state + def __init__(self): + """Initialize the variable pool. + + Attributes: + self.locks: + A per-key lock table used for fine-grained concurrency control. + + self.variables: + Storage for all variables managed by the pool. + """ + self.locks = defaultdict(Lock) + self.variables: dict[str, dict[str, VariableStruct[Any]]] = {} + + @staticmethod + def transform_selector(selector): + pattern = r"\{\{\s*(.*?)\s*\}\}" + variable_literal = re.sub(pattern, r"\1", selector).strip() + selector = VariableSelector.from_string(variable_literal).path + if len(selector) != 2: + raise ValueError(f"Selector not valid - {selector}") + return selector + + def _get_variable_struct( + self, + selector: str + ) -> VariableStruct[T] | None: + """Retrieve a variable struct from the variable pool. - def get(self, selector: list[str] | str, default: Any = None) -> Any: - """获取变量值 - Args: - selector: 变量选择器,可以是列表或字符串 - default: 默认值(变量不存在时返回) - + selector: + Variable selector, either: + - A string variable literal (e.g. "{{ sys.message }}") + Returns: - 变量值 - - Examples: - >>> pool.get(["sys", "message"]) - >>> pool.get("sys.message") - >>> pool.get(["llm_qa", "output"]) - >>> pool.get("llm_qa.output") - - Raises: - KeyError: 变量不存在且未提供默认值 + The variable's struct if it exists; otherwise returns None. """ - # 转换为 VariableSelector - if isinstance(selector, str): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() - selector = VariableSelector.from_string(variable_literal).path - - if not selector or len(selector) < 1: - raise ValueError("变量选择器不能为空") + selector = self.transform_selector(selector) namespace = selector[0] + variable_name = selector[1] - try: - # 系统变量 - if namespace == "sys": - key = selector[1] if len(selector) > 1 else None - if not key: - return self.state.get("variables", {}).get("sys", {}) - return self.state.get("variables", {}).get("sys", {}).get(key, default) + namespace_variables = self.variables.get(namespace) + if namespace_variables is None: + return None - # 会话变量 - elif namespace == "conv": - key = selector[1] if len(selector) > 1 else None - if not key: - return self.state.get("variables", {}).get("conv", {}) - return self.state.get("variables", {}).get("conv", {}).get(key, default) + var_instance = namespace_variables.get(variable_name) + if var_instance is None: + return None + return var_instance - # 节点输出(从 runtime_vars 读取) - else: - node_id = namespace - runtime_vars = self.state.get("runtime_vars", {}) + def get_value( + self, + selector: str, + default: Any = None, + strict: bool = True, + ) -> Any: + """Retrieve a variable value from the variable pool. - if node_id not in runtime_vars: - if default is not None: - return default - raise KeyError(f"节点 '{node_id}' 的输出不存在") + Args: + selector: + Variable selector, either: + - A list of path components (e.g. ["sys", "message"]) + - A string variable literal (e.g. "{{ sys.message }}") + default: + The value to return if the variable does not exist. + strict: + If True, raises KeyError when the variable does not exist. - node_var = runtime_vars[node_id] + Returns: + The variable's value if it exists; otherwise returns `default`. - # 如果只有节点 ID,返回整个变量 - if len(selector) == 1: - return node_var + Raises: + KeyError: If strict is True and the variable does not exist. + """ + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default - # 获取特定字段 - # 支持嵌套访问,如 node_id.field.subfield - result = node_var - for k in selector[1:]: - if isinstance(result, dict): - result = result.get(k) - if result is None: - if default is not None: - return default - raise KeyError(f"字段 '{'.'.join(selector)}' 不存在") - else: - if default is not None: - return default - raise KeyError(f"无法访问 '{'.'.join(selector)}'") + return variable_struct.instance.get_value() - return result + def get_literal( + self, + selector: str, + default: Any = None, + strict: bool = True, + ) -> Any: + """Retrieve a variable value from the variable pool. - except KeyError: - if default is not None: - return default - raise + Args: + selector: + Variable selector, either: + - A list of path components (e.g. ["sys", "message"]) + - A string variable literal (e.g. "{{ sys.message }}") + default: + The value to return if the variable does not exist. + strict: + If True, raises KeyError when the variable does not exist. - def set(self, selector: list[str] | str, value: Any): + Returns: + The variable's value if it exists; otherwise returns `default`. + + Raises: + KeyError: If strict is True and the variable does not exist. + """ + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default + + return variable_struct.instance.to_literal() + + async def set( + self, + selector: str, + value: Any + ): """设置变量值 Args: selector: 变量选择器 value: 变量值 - - Examples: - >>> pool.set(["conv", "user_name"], "张三") - >>> pool.set("conv.user_name", "张三") - + Note: - 只能设置会话变量 (conv.*) - 系统变量和节点输出是只读的 """ - # 转换为 VariableSelector - if isinstance(selector, str): - selector = VariableSelector.from_string(selector).path + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + raise KeyError(f"Variable {selector} is not defined") + if not variable_struct.mut: + raise KeyError(f"{selector} cannot be modified") + async with self.locks[selector]: + variable_struct.instance.set(value) - if not selector or len(selector) < 2: - raise ValueError("变量选择器必须包含命名空间和键名") + async def new( + self, + namespace: str, + key: str, + value: Any, + var_type: VariableType, + mut: bool + ): + if self.has(f"{namespace}.{key}"): + try: + await self.set(f"{namespace}.{key}", value) + except KeyError: + pass + instance = create_variable_instance(var_type, value) + variable_struct = VariableStruct(type=var_type, instance=instance, mut=mut) + namespace_variable = self.variables.get(namespace) + if namespace_variable is None: + self.variables[namespace] = { + key: variable_struct + } + else: + self.variables[namespace][key] = variable_struct - namespace = selector[0] - - if namespace != "conv" and namespace not in self.state["cycle_nodes"]: - raise ValueError("Only conversation or cycle variables can be assigned.") - - key = selector[1] - - # 确保 variables 结构存在 - if "variables" not in self.state: - self.state["variables"] = {"sys": {}, "conv": {}} - if namespace == "conv": - if "conv" not in self.state["variables"]: - self.state["variables"]["conv"] = {} - - # 设置值 - self.state["variables"]["conv"][key] = value - elif namespace in self.state["cycle_nodes"]: - self.state["runtime_vars"][namespace][key] = value - - logger.debug(f"设置变量: {'.'.join(selector)} = {value}") - - def has(self, selector: list[str] | str) -> bool: + def has(self, selector: str) -> bool: """检查变量是否存在 Args: @@ -228,18 +276,8 @@ class VariablePool: Returns: 变量是否存在 - - Examples: - >>> pool.has(["sys", "message"]) - True - >>> pool.has("llm_qa.output") - False """ - try: - self.get(selector) - return True - except KeyError: - return False + return self._get_variable_struct(selector) is not None def get_all_system_vars(self) -> dict[str, Any]: """获取所有系统变量 @@ -247,7 +285,8 @@ class VariablePool: Returns: 系统变量字典 """ - return self.state.get("variables", {}).get("sys", {}) + sys_namespace = self.variables.get("sys", {}) + return {k: v.instance.get_value() for k, v in sys_namespace.items()} def get_all_conversation_vars(self) -> dict[str, Any]: """获取所有会话变量 @@ -255,7 +294,8 @@ class VariablePool: Returns: 会话变量字典 """ - return self.state.get("variables", {}).get("conv", {}) + conv_namespace = self.variables.get("conv", {}) + return {k: v.instance.get_value() for k, v in conv_namespace.items()} def get_all_node_outputs(self) -> dict[str, Any]: """获取所有节点输出(运行时变量) @@ -263,18 +303,37 @@ class VariablePool: Returns: 节点输出字典,键为节点 ID """ - return self.state.get("runtime_vars", {}) + runtime_vars = { + namespace: { + k: v.instance.get_value() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + return runtime_vars - def get_node_output(self, node_id: str) -> dict[str, Any] | None: + def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) Args: node_id: 节点 ID + defalut: 默认值 + strict: 是否严格模式 Returns: 节点输出或 None """ - return self.state.get("runtime_vars", {}).get(node_id) + node_namespace = self.variables.get(node_id) + if node_namespace: + return {k: v.instance.get_value() for k, v in node_namespace.items()} + if strict: + raise KeyError(f"node {node_id} output not exist") + else: + return defalut + + def copy(self, pool: 'VariablePool'): + self.variables = deepcopy(pool.variables) def to_dict(self) -> dict[str, Any]: """导出为字典 diff --git a/api/app/main.py b/api/app/main.py index 7e16d2c0..af5ed796 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -50,13 +50,16 @@ async def lifespan(app: FastAPI): logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") # 加载预定义模型 - logger.info("开始加载预定义模型...") - try: - with get_db_context() as db: - result = load_models(db, silent=True) - logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") - except Exception as e: - logger.warning(f"加载预定义模型时出错: {str(e)}") + if settings.LOAD_MODEL: + logger.info("开始加载预定义模型...") + try: + with get_db_context() as db: + result = load_models(db, silent=True) + logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") + except Exception as e: + logger.warning(f"加载预定义模型时出错: {str(e)}") + else: + logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("应用程序启动完成") yield @@ -77,10 +80,14 @@ default_origins = [ ] allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o}) +# 如果 CORS_ORIGINS 包含 "*",则允许所有来源 +if "*" in settings.CORS_ORIGINS: + allowed_origins = ["*"] + app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, - allow_credentials=True, + allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials allow_methods=["*"], allow_headers=["*"], ) diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 984212de..daf03841 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -28,6 +28,7 @@ from .tool_model import ( ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus ) from .memory_perceptual_model import MemoryPerceptualModel +from .skill_model import Skill from .ontology_scene import OntologyScene from .ontology_class import OntologyClass from .ontology_scene import OntologyScene @@ -84,5 +85,6 @@ __all__ = [ "ExecutionStatus", "MemoryPerceptualModel", "ModelBase", - "LoadBalanceStrategy" + "LoadBalanceStrategy", + "Skill" ] diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 96752c8e..cc2e0686 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -29,7 +29,8 @@ class AgentConfig(Base): knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") memory = Column(JSON, nullable=True, comment="记忆配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置") - tools = Column(JSON, default=dict, nullable=True, comment="工具配置") + tools = Column(JSON, default=list, nullable=True, comment="工具配置") + skills = Column(JSON, default=dict, nullable=True, comment="技能配置") # 多 Agent 相关字段 agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") diff --git a/api/app/models/skill_model.py b/api/app/models/skill_model.py new file mode 100644 index 00000000..97fdeb03 --- /dev/null +++ b/api/app/models/skill_model.py @@ -0,0 +1,37 @@ +"""Skill 模型定义""" +import datetime +import uuid +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID, JSON + +from app.db import Base + + +class Skill(Base): + """技能模型 - 可以关联工具(内置、MCP、自定义)""" + __tablename__ = "skills" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True) + name = Column(String, nullable=False, comment="技能名称") + description = Column(Text, comment="技能描述") + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True, comment="租户ID") + + # 关联的工具 + tools = Column(JSON, default=list, comment="关联的工具列表") + + # 技能配置 + config = Column(JSON, default=dict, comment="技能配置") + + # 专属提示词 + prompt = Column(Text, comment="技能专属提示词") + + # 状态 + is_active = Column(Boolean, default=True, nullable=False, comment="是否激活") + is_public = Column(Boolean, default=False, nullable=False, comment="是否公开到市场") + + # 时间戳 + created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 568c262f..22972669 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -235,6 +235,8 @@ class MemoryConfigRepository: llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, + reflection_model_id=params.reflection_model_id, + emotion_model_id=params.emotion_model_id, ) db.add(db_config) db.flush() # 获取自增ID但不提交事务 diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 3d66964a..f323b30c 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -583,7 +583,7 @@ class ModelApiKeyRepository: db_api_key.usage_count = str(current_count + 1) db_api_key.last_used_at = func.now() - db.commit() + db.flush() db_logger.debug(f"API Key使用统计更新成功: api_key_id={api_key_id}") return True diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 1575315f..fc32ca9a 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -207,4 +207,4 @@ async def save_dialog_and_statements_to_neo4j( except Exception as e: print(f"Neo4j integration error: {e}") print("Continuing without database storage...") - return False + return False \ No newline at end of file diff --git a/api/app/repositories/skill_repository.py b/api/app/repositories/skill_repository.py new file mode 100644 index 00000000..6eeb7e08 --- /dev/null +++ b/api/app/repositories/skill_repository.py @@ -0,0 +1,111 @@ +"""Skill Repository""" +from typing import List, Optional, Tuple, Any +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_ +import uuid + +from app.models.skill_model import Skill +from app.schemas.skill_schema import SkillCreate, SkillUpdate + + +class SkillRepository: + """Skill 数据访问层""" + + @staticmethod + def create(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + skill = Skill( + **data.model_dump(), + tenant_id=tenant_id + ) + db.add(skill) + db.flush() + return skill + + @staticmethod + def get_by_id(db: Session, skill_id: uuid.UUID, tenant_id: Optional[uuid.UUID] = None) -> Optional[Skill]: + """根据ID获取技能""" + query = db.query(Skill).filter(Skill.id == skill_id) + if tenant_id: + query = query.filter( + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ) + return query.first() + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: Optional[str] = None, + is_active: Optional[bool] = None, + is_public: Optional[bool] = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + filters = [ + or_( + Skill.tenant_id == tenant_id, + Skill.is_public == True + ) + ] + + if search: + filters.append( + or_( + Skill.name.ilike(f"%{search}%"), + # Skill.description.ilike(f"%{search}%") + ) + ) + + if is_active is not None: + filters.append(Skill.is_active == is_active) + + if is_public is not None: + filters.append(Skill.is_public == is_public) + + query = db.query(Skill).filter(and_(*filters)) + total = query.count() + + skills = query.order_by(Skill.created_at.desc()).offset( + (page - 1) * pagesize + ).limit(pagesize).all() + + return skills, total + + @staticmethod + def update(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Optional[Skill]: + """更新技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return None + + update_data = data.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(skill, key, value) + + db.flush() + return skill + + @staticmethod + def delete(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + """删除技能""" + skill = db.query(Skill).filter( + Skill.id == skill_id, + Skill.tenant_id == tenant_id + ).first() + + if not skill: + return False + + # db.delete(skill) + skill.is_active = False + db.flush() + return True diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 26d9b246..2f94b69d 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,14 +1,14 @@ import datetime import uuid from typing import Optional, Any, List, Dict, Union -from enum import Enum +from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator # ---------- Multimodal File Support ---------- -class FileType(str, Enum): +class FileType(StrEnum): """文件类型枚举""" IMAGE = "image" DOCUMENT = "document" @@ -82,6 +82,12 @@ class ToolConfig(BaseModel): tool_id: Optional[str] = Field(default=None, description="工具ID") operation: Optional[str] = Field(default=None, description="工具特定配置") +class SkillConfig(BaseModel): + """技能配置""" + enabled: bool = Field(default=True, description="是否启用该技能") + skill_ids: Optional[list[str]] = Field(default=list, description="技能ID列表") + all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能") + class ToolOldConfig(BaseModel): """工具配置""" @@ -92,7 +98,7 @@ class ToolOldConfig(BaseModel): class MemoryConfig(BaseModel): """记忆配置""" enabled: bool = Field(default=True, description="是否启用对话历史记忆") - memory_content: Optional[str] = Field(default=None, description="选择记忆的内容类型") + memory_config_id: Optional[str] = Field(default=None, description="选择记忆的内容类型") max_history: int = Field(default=10, ge=0, le=100, description="最大保留的历史对话轮数") @@ -156,6 +162,9 @@ class AgentConfigCreate(BaseModel): description="Agent 可用的工具列表" ) + # 技能配置 + skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") + class AppCreate(BaseModel): name: str @@ -207,6 +216,9 @@ class AgentConfigUpdate(BaseModel): # 工具配置 tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表") + + # 技能配置 + skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表") # ---------- Output Schemas ---------- @@ -266,6 +278,8 @@ class AgentConfig(BaseModel): # 工具配置 tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] + skills: Optional[SkillConfig] = {} + is_active: bool created_at: datetime.datetime updated_at: datetime.datetime diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 5e22d70f..11cacda0 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -236,6 +236,8 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") + reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") + emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) diff --git a/api/app/schemas/prompt_optimizer_schema.py b/api/app/schemas/prompt_optimizer_schema.py index 08a11317..96a46742 100644 --- a/api/app/schemas/prompt_optimizer_schema.py +++ b/api/app/schemas/prompt_optimizer_schema.py @@ -21,6 +21,11 @@ class PromptOptMessage(BaseModel): description="currently optimized prompt" ) + skill: bool = Field( + default=False, + description="Enable variable output" + ) + class PromptSaveRequest(BaseModel): session_id: UUID = Field( diff --git a/api/app/schemas/skill_schema.py b/api/app/schemas/skill_schema.py new file mode 100644 index 00000000..f002308e --- /dev/null +++ b/api/app/schemas/skill_schema.py @@ -0,0 +1,64 @@ +"""Skill Schema 定义""" +from typing import Optional, List, Dict, Any, Union +from pydantic import BaseModel, Field, field_serializer +import uuid +from datetime import datetime + + +class SkillBase(BaseModel): + """Skill 基础 Schema""" + name: str = Field(..., description="技能名称") + description: Optional[str] = Field(None, description="技能描述") + tools: List[Dict[str, str]] = Field(default_factory=list, description="工具对象列表: [{\"tool_id\": \"xxx\", \"operation\": \"yyy\"}]") + config: Dict[str, Any] = Field(default_factory=dict, description="技能配置") + prompt: Optional[str] = Field(None, description="技能专属提示词") + is_active: bool = Field(True, description="是否激活") + is_public: bool = Field(False, description="是否公开到市场") + + +class SkillCreate(SkillBase): + """创建 Skill""" + pass + + +class SkillUpdate(BaseModel): + """更新 Skill""" + name: Optional[str] = None + description: Optional[str] = None + tools: Optional[List[Dict[str, str]]] = None + config: Optional[Dict[str, Any]] = None + prompt: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + + +class Skill(BaseModel): + """Skill 响应 Schema""" + id: uuid.UUID + tenant_id: uuid.UUID + name: str + description: Optional[str] = None + tools: Union[List[Dict[str, Any]], List[Dict[str, str]]] = Field(default_factory=list, description="工具列表,可以是简单格式或包含工具详情") + config: Dict[str, Any] = Field(default_factory=dict) + prompt: Optional[str] = None + is_active: bool = True + is_public: bool = False + created_at: datetime + updated_at: datetime + + @field_serializer('created_at', 'updated_at') + def serialize_datetime_to_timestamp(self, value: datetime) -> int: + """(毫秒级)时间戳""" + return int(value.timestamp() * 1000) + + class Config: + from_attributes = True + + +class SkillQuery(BaseModel): + """Skill 查询参数""" + search: Optional[str] = None + is_active: Optional[bool] = None + is_public: Optional[bool] = None + page: int = Field(1, ge=1) + pagesize: int = Field(10, ge=1, le=100) diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 094aade8..fbc75f4c 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -9,7 +9,7 @@ from app.schemas.app_schema import ( VariableDefinition, ToolConfig, AgentConfigCreate, - AgentConfigUpdate, ToolOldConfig, + AgentConfigUpdate, ToolOldConfig, SkillConfig, ) @@ -48,6 +48,9 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: result["tools"] = [tool.model_dump() for tool in config.tools] + + if hasattr(config, "skills") and config.skills: + result["skills"] = config.skills.model_dump() return result @@ -58,6 +61,7 @@ class AgentConfigConverter: memory: Optional[Dict[str, Any]], variables: Optional[list], tools: Optional[Union[list, Dict[str, Any]]], + skills: Optional[dict] ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -68,6 +72,7 @@ class AgentConfigConverter: memory: 记忆配置 variables: 变量配置 tools: 工具配置 + skills: 技能列表 Returns: 包含 Pydantic 对象的字典 @@ -78,6 +83,7 @@ class AgentConfigConverter: "memory": MemoryConfig(enabled=True), "variables": [], "tools": [], + "skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[]) } # 1. 解析模型参数配置 @@ -117,5 +123,10 @@ class AgentConfigConverter: name: ToolOldConfig(**tool_data) for name, tool_data in tools.items() } + + if skills: + result["skills"] = SkillConfig(**skills) + else: + result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[]) return result diff --git a/api/app/services/agent_config_helper.py b/api/app/services/agent_config_helper.py index ae195913..08d28424 100644 --- a/api/app/services/agent_config_helper.py +++ b/api/app/services/agent_config_helper.py @@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: memory=agent_cfg.memory, variables=agent_cfg.variables, tools=agent_cfg.tools, + skills=agent_cfg.skills ) # 将解析后的字段添加到对象上(用于序列化) @@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig: agent_cfg.memory = parsed["memory"] agent_cfg.variables = parsed["variables"] agent_cfg.tools = parsed["tools"] + agent_cfg.skills = parsed["skills"] return agent_cfg diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 1d9ab4a8..5e989150 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List from fastapi import Depends from sqlalchemy.orm import Session +from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -63,7 +64,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) + api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -79,21 +80,55 @@ class AppChatService: # 获取工具服务 tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, workspace_id)) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) + elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search: + if web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 加载技能关联的工具 + if hasattr(config, 'skills') and config.skills: + skills = config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval @@ -113,22 +148,6 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - # 获取模型参数 model_parameters = config.model_parameters @@ -192,6 +211,8 @@ class AppChatService: } ) + ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) + elapsed_time = time.time() - start_time return { @@ -230,7 +251,7 @@ class AppChatService: # 获取模型配置ID model_config_id = config.default_model_config_id - api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id) + api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) # 处理系统提示词(支持变量替换) system_prompt = config.system_prompt if variables: @@ -246,20 +267,54 @@ class AppChatService: # 获取工具服务 tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, workspace_id)) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) + elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search: + if web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 加载技能关联的工具 + if hasattr(config, 'skills') and config.skills: + skills = config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval @@ -279,22 +334,6 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - # 获取模型参数 model_parameters = config.model_parameters @@ -374,6 +413,8 @@ class AppChatService: } ) + ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) + # 发送结束事件 end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" @@ -618,6 +659,7 @@ class AppChatService: memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, + public=False ) -> AsyncGenerator[dict, None]: """聊天(流式)""" @@ -634,7 +676,8 @@ class AppChatService: payload=payload, config=config, workspace_id=workspace_id, - release_id=release_id + release_id=release_id, + public=public ): yield event diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 71d4d0b7..42c4fe4f 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -313,6 +313,7 @@ class AppService: memory=storage_data.get("memory"), variables=storage_data.get("variables", []), tools=storage_data.get("tools", []), + skills=storage_data.get("skills", {}), is_active=True, created_at=now, updated_at=now, @@ -916,6 +917,7 @@ class AppService: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: agent_cfg.tools = storage_data.get("tools", []) + agent_cfg.skills = storage_data.get("skills", {}) agent_cfg.updated_at = now @@ -1003,11 +1005,12 @@ class AppService: }, memory={ "enabled": True, - "memory_content": None, + "memory_config_id": None, "max_history": 10 }, variables=[], tools=[], + skills=[], is_active=True, created_at=now, updated_at=now, @@ -1403,6 +1406,7 @@ class AppService: "memory": agent_cfg.memory, "variables": agent_cfg.variables or [], "tools": agent_cfg.tools or [], + "skills": agent_cfg.skills or {}, } # config = AgentConfigConverter.from_storage_format(agent_cfg) default_model_config_id = agent_cfg.default_model_config_id diff --git a/api/app/services/app_statistics_service.py b/api/app/services/app_statistics_service.py index c164924a..9eefd343 100644 --- a/api/app/services/app_statistics_service.py +++ b/api/app/services/app_statistics_service.py @@ -1,15 +1,13 @@ """应用统计服务""" from datetime import datetime, timedelta -from typing import Dict, Any, List +from typing import Dict, Any import uuid from sqlalchemy import func, and_, cast, Date from sqlalchemy.orm import Session from app.models.conversation_model import Conversation, Message from app.models.end_user_model import EndUser -from app.models.api_key_model import ApiKey, ApiKeyLog -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode +from app.models.api_key_model import ApiKey, ApiKeyLog, ApiKeyType class AppStatisticsService: @@ -146,7 +144,6 @@ class AppStatisticsService: end_dt: datetime ) -> Dict[str, Any]: """获取Token消耗统计(从Message的meta_data中提取)""" - from sqlalchemy import text # 查询所有相关消息的token使用情况 # meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100} @@ -187,7 +184,80 @@ class AppStatisticsService: daily_tokens[date_str] = 0 daily_tokens[date_str] += int(tokens) - daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] - total = sum(row["tokens"] for row in daily_data) + daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0] + total = sum(row["count"] for row in daily_data) return {"daily": daily_data, "total": total} + + def get_workspace_api_statistics( + self, + workspace_id: uuid.UUID, + start_date: int, + end_date: int + ) -> list[Any]: + """获取工作空间API调用统计 + + Args: + workspace_id: 工作空间ID + start_date: 开始时间戳(毫秒) + end_date: 结束时间戳(毫秒) + + Returns: + 每日统计数据列表 + """ + # 将毫秒时间戳转换为 datetime + start_time = datetime.fromtimestamp(start_date / 1000) + end_time = datetime.fromtimestamp(end_date / 1000) + + # 应用类型(agent, multi_agent, workflow) + app_types = [ApiKeyType.AGENT, ApiKeyType.CLUSTER, ApiKeyType.WORKFLOW] + + # 每日应用类型调用次数 + daily_app_calls = self.db.query( + cast(ApiKeyLog.created_at, Date).label('date'), + func.count(ApiKeyLog.id).label('count') + ).join( + ApiKey, ApiKeyLog.api_key_id == ApiKey.id + ).filter( + and_( + ApiKey.workspace_id == workspace_id, + ApiKey.type.in_(app_types), + ApiKeyLog.created_at >= start_time, + ApiKeyLog.created_at <= end_time + ) + ).group_by(cast(ApiKeyLog.created_at, Date)).all() + + # 每日服务类型调用次数 + daily_service_calls = self.db.query( + cast(ApiKeyLog.created_at, Date).label('date'), + func.count(ApiKeyLog.id).label('count') + ).join( + ApiKey, ApiKeyLog.api_key_id == ApiKey.id + ).filter( + and_( + ApiKey.workspace_id == workspace_id, + ApiKey.type == ApiKeyType.SERVICE, + ApiKeyLog.created_at >= start_time, + ApiKeyLog.created_at <= end_time + ) + ).group_by(cast(ApiKeyLog.created_at, Date)).all() + + # 构建每日数据 + app_calls_dict = {str(row.date): row.count for row in daily_app_calls} + service_calls_dict = {str(row.date): row.count for row in daily_service_calls} + + # 合并所有日期 + all_dates = sorted(set(app_calls_dict.keys()) | set(service_calls_dict.keys())) + + daily_data = [] + for date in all_dates: + app_count = app_calls_dict.get(date, 0) + service_count = service_calls_dict.get(date, 0) + daily_data.append({ + "date": date, + "total_calls": app_count + service_count, + "app_calls": app_count, + "service_calls": service_count + }) + + return daily_data diff --git a/api/app/services/collaborative_orchestrator.py b/api/app/services/collaborative_orchestrator.py index f01b7e01..00a731de 100644 --- a/api/app/services/collaborative_orchestrator.py +++ b/api/app/services/collaborative_orchestrator.py @@ -24,6 +24,7 @@ from app.core.error_codes import BizCode from app.core.models import RedBearLLM from app.core.models.base import RedBearModelConfig from app.models import ModelType +from app.services.model_service import ModelApiKeyService logger = get_business_logger() @@ -357,6 +358,8 @@ class CollaborativeOrchestrator: "usage": response.get("usage", {"total_tokens": 0}), "is_final_answer": True } + + ModelApiKeyService.record_api_key_usage(self.db, agent_config.get("api_key_id")) # 检查是否有工具调用(handoff) tool_calls = response.get("tool_calls", []) @@ -427,7 +430,7 @@ class CollaborativeOrchestrator: ) # 获取 API Key - api_key_config = ModelApiKeyService.get_a_api_key(self.db, model_config_id) + api_key_config = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_config: raise BusinessException( f"Agent 模型没有可用的 API Key: {agent_id}", @@ -442,7 +445,8 @@ class CollaborativeOrchestrator: "provider": api_key_config.provider, "api_key": api_key_config.api_key, "api_base": api_key_config.api_base, - "model_parameters": config_data.get("model_parameters", {}) + "model_parameters": config_data.get("model_parameters", {}), + "api_key_id": api_key_config.id } except ValueError: diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index edad0123..31662769 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,6 +10,11 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -24,12 +29,11 @@ from app.services import task_service from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger +from app.services.model_service import ModelApiKeyService from app.services.tool_service import ToolService from app.services.multimodal_service import MultimodalService -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session +from app.core.agent.agent_middleware import AgentMiddleware + logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -59,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str 长期记忆工具 """ # search_switch = memory_config.get("search_switch", "2") - config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None) logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") @tool(args_schema=LongTermMemoryInput) def long_term_memory(question: str) -> str: @@ -310,6 +315,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -320,9 +326,7 @@ class DraftRunService: print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -345,6 +349,25 @@ class DraftRunService: } ) + # 加载技能关联的工具 + if hasattr(agent_config, 'skills') and agent_config.skills: + skills = agent_config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: kb_config = agent_config.knowledge_retrieval @@ -433,7 +456,8 @@ class DraftRunService: ) memory_config_= agent_config.memory - config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 8. 调用 Agent(支持多模态) result = await agent.chat( @@ -450,6 +474,8 @@ class DraftRunService: elapsed_time = time.time() - start_time + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) + # 9. 保存会话消息 if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): await self._save_conversation_message( @@ -558,6 +584,7 @@ class DraftRunService: tools = [] tool_service = ToolService(self.db) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): @@ -567,9 +594,7 @@ class DraftRunService: # print(f"tool_config:{tool_config}") if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool_instance: if tool_instance.name == "baidu_search_tool" and not web_search: continue @@ -592,6 +617,25 @@ class DraftRunService: } ) + # 加载技能关联的工具 + if hasattr(agent_config, 'skills') and agent_config.skills: + skills = agent_config.skills + skill_enable = skills.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + # 应用动态过滤 + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + active_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + system_prompt = f"{system_prompt}\n\n{active_prompts}" + # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -628,7 +672,6 @@ class DraftRunService: } ) - # 4. 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_config["model_name"], @@ -677,7 +720,8 @@ class DraftRunService: }) memory_config_ = agent_config.memory - config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None) # 9. 流式调用 Agent(支持多模态) full_content = "" @@ -704,6 +748,8 @@ class DraftRunService: elapsed_time = time.time() - start_time + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) + if sub_agent: yield self._format_sse_event("sub_usage", { "total_tokens": total_tokens @@ -770,7 +816,7 @@ class DraftRunService: Raises: BusinessException: 当没有可用的 API Key 时 """ - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) # stmt = ( # select(ModelApiKey).join( # ModelConfig, ModelApiKey.model_configs @@ -784,7 +830,8 @@ class DraftRunService: # ) # # api_key = self.db.scalars(stmt).first() - api_key = api_keys[0] if api_keys else None + # api_key = api_keys[0] if api_keys else None + api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) @@ -793,7 +840,8 @@ class DraftRunService: "model_name": api_key.model_name, "provider": api_key.provider, "api_key": api_key.api_key, - "api_base": api_key.api_base + "api_base": api_key.api_base, + "api_key_id": api_key.id } async def _ensure_conversation( @@ -1051,7 +1099,7 @@ class DraftRunService: except Exception as e: # 对于多 Agent 应用,没有直接的 AgentConfig 是正常的 - logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)}) + logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)}) return {} def _replace_variables( diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index 10e4d646..e490eea4 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -537,7 +537,7 @@ def convert_multi_agent_config_to_handoffs( # 获取该 Agent 的模型配置 if release.default_model_config_id: - model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(db, release.default_model_config_id) if model_api_key: model_config = RedBearModelConfig( model_name=model_api_key.model_name, @@ -551,6 +551,7 @@ def convert_multi_agent_config_to_handoffs( } ) logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}") + ModelApiKeyService.record_api_key_usage(db, model_api_key.id) else: logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}") else: diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index 9e102ac3..e56ad5aa 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -382,6 +382,7 @@ class LLMRouter: from app.core.models import RedBearLLM from app.core.models.base import RedBearModelConfig from app.models import ModelApiKey, ModelType + from app.services.model_service import ModelApiKeyService # 获取 API Key 配置(通过关联关系) # api_key_config = self.db.query(ModelApiKey).join( @@ -389,8 +390,9 @@ class LLMRouter: # ).filter(ModelConfig.id == self.routing_model_config.id, # ModelApiKey.is_active == True # ).first() - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id) - api_key_config = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id) + # api_key_config = api_keys[0] if api_keys else None + api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.routing_model_config.id) if not api_key_config: raise Exception("路由模型没有可用的 API Key") @@ -424,7 +426,6 @@ class LLMRouter: # 调用模型 response = await llm.ainvoke(prompt) - from app.services.model_service import ModelApiKeyService ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) # 提取响应内容 diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 87fdb22c..3cf3ecc3 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -349,7 +349,7 @@ class MasterAgentRouter: from app.models import ModelApiKey, ModelType # 获取 API Key 配置 - api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id) + api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.master_model_config.id) if not api_key_config: raise Exception("Master Agent 模型没有可用的 API Key") @@ -400,6 +400,7 @@ class MasterAgentRouter: # 调用模型 response = await llm.ainvoke(prompt) + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) # 提取响应内容 if hasattr(response, 'content'): diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index fed5109f..f480a1da 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1194,7 +1194,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An workspace_id=app.workspace_id ) - memory_config_id = str(memory_config.config_id) if memory_config else None + memory_obj = config.get('memory', {}) + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None result = { "end_user_id": str(end_user_id), @@ -1284,7 +1286,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) if release: config = release.config or {} memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None if memory_config_id: # 判断是否为UUID格式 if len(str(memory_config_id))>=5: @@ -1330,7 +1333,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 从 config 中提取 memory_config_id config = release.config or {} memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称(使用字符串形式的ID进行查找,兼容新旧格式) memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 06a94060..6fa8b228 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -53,7 +53,10 @@ def get_workspace_end_users( workspace_id: uuid.UUID, current_user: User ) -> List[EndUser]: - """获取工作空间的所有宿主(优化版本:减少数据库查询次数)""" + """获取工作空间的所有宿主(优化版本:减少数据库查询次数) + + 返回结果按 updated_at 从新到旧排序(NULL 值排在最后) + """ business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") try: @@ -68,9 +71,14 @@ def get_workspace_end_users( app_ids = [app.id for app in apps_orm] # 批量查询所有 end_users(一次查询而非循环查询) + # 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性 from app.models.end_user_model import EndUser as EndUserModel + from sqlalchemy import desc, nullslast end_users_orm = db.query(EndUserModel).filter( EndUserModel.app_id.in_(app_ids) + ).order_by( + nullslast(desc(EndUserModel.updated_at)), + desc(EndUserModel.id) ).all() # 转换为 Pydantic 模型(只在需要时转换) diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index e025c1b3..0e542ff0 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -108,13 +108,14 @@ class WorkspaceAppService: app_info["releases"].append(release_info) def _extract_memory_content(self, config: Any) -> str: - """Extract memory_comtent from config""" + """Extract memory_config_id from config (兼容新旧字段名)""" if not config or not isinstance(config, dict): return None memory_obj = config.get('memory') if memory_obj and isinstance(memory_obj, dict): - return memory_obj.get('memory_content') + # 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content + return memory_obj.get('memory_config_id') or memory_obj.get('memory_content') return None diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 82baef9f..b7079e62 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.rerank_id: params.rerank_id = configs.get('rerank') + # reflection_model_id 和 emotion_model_id 默认与 llm_id 一致 + if not params.reflection_model_id: + params.reflection_model_id = params.llm_id + if not params.emotion_model_id: + params.emotion_model_id = params.llm_id + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "end_user_id": config.end_user_id, "config_id_old": config_id_old, "apply_id": config.apply_id, + "scene_id": config.scene_id, "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index dee6cd1d..d382b1b1 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -6,7 +6,7 @@ import math import time import asyncio -from app.models.models_model import ModelConfig, ModelApiKey, ModelType +from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository from app.schemas import model_schema from app.schemas.model_schema import ( @@ -633,19 +633,31 @@ class ModelApiKeyService: @staticmethod def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]: - """获取可用的API Key(按优先级和负载均衡)""" - api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True) + """获取可用的API Key(根据负载均衡策略)""" + model_config = ModelConfigRepository.get_by_id(db, model_config_id) + if not model_config: + return None + + api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None - return min(api_keys, key=lambda x: int(x.usage_count or "0")) + + # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + + # 否则返回第一个 + return api_keys[0] @staticmethod - def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool: + def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool: """记录API Key使用""" - success = ModelApiKeyRepository.update_usage(db, api_key_id) - if success: - db.commit() - return success + if api_key_id: + success = ModelApiKeyRepository.update_usage(db, api_key_id) + if success: + db.commit() + return success + return False @staticmethod def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey: diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index b28bafbf..d1aa46d1 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -14,6 +14,7 @@ from app.services.conversation_state_manager import ConversationStateManager from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger +from app.services.model_service import ModelApiKeyService logger = get_business_logger() @@ -2569,8 +2570,9 @@ class MultiAgentOrchestrator: # ModelConfig.id == default_model_config_id, # ModelApiKey.is_active.is_(True) # ).first() - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) - api_key_config = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) + # api_key_config = api_keys[0] if api_keys else None + api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id) if not api_key_config: logger.warning("Master Agent 没有可用的 API Key,使用简单整合") @@ -2601,6 +2603,8 @@ class MultiAgentOrchestrator: # 调用模型进行整合 response = await llm.ainvoke(merge_prompt) + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取响应内容 if hasattr(response, 'content'): merged_response = response.content @@ -2730,8 +2734,9 @@ class MultiAgentOrchestrator: # ModelConfig.id == default_model_config_id, # ModelApiKey.is_active.is_(True) # ).first() - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) - api_key_config = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id) + # api_key_config = api_keys[0] if api_keys else None + api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id) if not api_key_config: logger.warning("Master Agent 没有可用的 API Key,使用简单整合") @@ -2790,6 +2795,8 @@ class MultiAgentOrchestrator: logger.debug(f"收到流式 chunk #{chunk_count}: {content[:30]}...") yield self._format_sse_event("message", {"content": content}) + ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + logger.info(f"Master Agent 流式整合完成,共 {chunk_count} 个 chunks") except AttributeError as e: diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index a460a7ba..02636c27 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -23,7 +23,7 @@ logger = get_business_logger() class ImageFormatStrategy(Protocol): """图片格式策略接口""" - + async def format_image(self, url: str) -> Dict[str, Any]: """将图片 URL 转换为特定 provider 的格式""" ... @@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol): class DashScopeImageStrategy: """通义千问图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """通义千问格式: {"type": "image", "image": "url"}""" return { @@ -42,7 +42,7 @@ class DashScopeImageStrategy: class BedrockImageStrategy: """Bedrock/Anthropic 图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 @@ -51,17 +51,17 @@ class BedrockImageStrategy: import httpx import base64 from mimetypes import guess_type - + logger.info(f"下载并编码图片: {url}") - + # 下载图片 async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url) response.raise_for_status() - + # 获取图片数据 image_data = response.content - + # 确定 media type content_type = response.headers.get("content-type") if content_type and content_type.startswith("image/"): @@ -69,12 +69,12 @@ class BedrockImageStrategy: else: guessed_type, _ = guess_type(url) media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - + # 转换为 base64 base64_data = base64.b64encode(image_data).decode("utf-8") - + logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - + return { "type": "image", "source": { @@ -87,7 +87,7 @@ class BedrockImageStrategy: class OpenAIImageStrategy: """OpenAI 图片格式策略""" - + async def format_image(self, url: str) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" return { @@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = { class MultimodalService: """多模态文件处理服务""" - + def __init__(self, db: Session, provider: str = "dashscope"): """ 初始化多模态服务 @@ -120,10 +120,10 @@ class MultimodalService: """ self.db = db self.provider = provider.lower() - + async def process_files( - self, - files: Optional[List[FileInput]] + self, + files: Optional[List[FileInput]] ) -> List[Dict[str, Any]]: """ 处理文件列表,返回 LLM 可用的格式 @@ -136,7 +136,7 @@ class MultimodalService: """ if not files: return [] - + result = [] for idx, file in enumerate(files): try: @@ -168,10 +168,10 @@ class MultimodalService: "type": "text", "text": f"[文件处理失败: {str(e)}]" }) - + logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - + async def _process_image(self, file: FileInput) -> Dict[str, Any]: """ 处理图片文件 @@ -184,14 +184,10 @@ class MultimodalService: - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - 通义千问: {"type": "image", "image": "url"} """ - if file.transfer_method == TransferMethod.REMOTE_URL: - url = file.url - else: - # 本地文件,获取访问 URL - url = await self._get_file_url(file.upload_file_id) - + url = await self.get_file_url(file) + logger.debug(f"处理图片: {url}, provider={self.provider}") - + # 根据 provider 返回不同格式 if self.provider in ["bedrock", "anthropic"]: # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 @@ -223,7 +219,7 @@ class MultimodalService: "type": "image", "image": url } - + async def _download_and_encode_image(self, url: str) -> tuple[str, str]: """ 下载图片并转换为 base64 @@ -237,15 +233,15 @@ class MultimodalService: import httpx import base64 from mimetypes import guess_type - + # 下载图片 async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url) response.raise_for_status() - + # 获取图片数据 image_data = response.content - + # 确定 media type content_type = response.headers.get("content-type") if content_type and content_type.startswith("image/"): @@ -254,14 +250,14 @@ class MultimodalService: # 从 URL 推断 guessed_type, _ = guess_type(url) media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg" - + # 转换为 base64 base64_data = base64.b64encode(image_data).decode("utf-8") - + logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}") - + return base64_data, media_type - + async def _process_document(self, file: FileInput) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) @@ -284,14 +280,14 @@ class MultimodalService: generic_file = self.db.query(GenericFile).filter( GenericFile.id == file.upload_file_id ).first() - + file_name = generic_file.file_name if generic_file else "unknown" - + return { "type": "text", "text": f"\n{text}\n" } - + async def _process_audio(self, file: FileInput) -> Dict[str, Any]: """ 处理音频文件 @@ -307,7 +303,7 @@ class MultimodalService: "type": "text", "text": "[音频文件,暂不支持处理]" } - + async def _process_video(self, file: FileInput) -> Dict[str, Any]: """ 处理视频文件 @@ -323,13 +319,13 @@ class MultimodalService: "type": "text", "text": "[视频文件,暂不支持处理]" } - - async def _get_file_url(self, file_id: uuid.UUID) -> str: + + async def get_file_url(self, file: FileInput) -> str: """ 获取文件的访问 URL Args: - file_id: 文件ID + file: File Input Struct Returns: str: 文件访问 URL @@ -337,26 +333,31 @@ class MultimodalService: Raises: BusinessException: 文件不存在 """ - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file_id, - GenericFile.status == "active" - ).first() - - if not generic_file: - raise BusinessException( - f"文件不存在或已删除: {file_id}", - BizCode.NOT_FOUND - ) - - # 如果有 access_url,直接返回 - if generic_file.access_url: - return generic_file.access_url - - # 否则,根据 storage_path 生成 URL - # TODO: 根据实际存储方式生成 URL(本地存储、OSS 等) - # 这里暂时返回一个占位 URL - return f"/api/files/{file_id}/download" - + if file.transfer_method == TransferMethod.REMOTE_URL: + return file.url + else: + # 本地文件,获取访问 URL + file_id = file.upload_file_id + generic_file = self.db.query(GenericFile).filter( + GenericFile.id == file.upload_file_id, + GenericFile.status == "active" + ).first() + + if not generic_file: + raise BusinessException( + f"文件不存在或已删除: {file.upload_file_id}", + BizCode.NOT_FOUND + ) + + # 如果有 access_url,直接返回 + if generic_file.access_url: + return generic_file.access_url + + # 否则,根据 storage_path 生成 URL + # TODO: 根据实际存储方式生成 URL(本地存储、OSS 等) + # 这里暂时返回一个占位 URL + return f"/api/files/{file_id}/download" + async def _extract_document_text(self, file_id: uuid.UUID) -> str: """ 提取文档文本内容 @@ -371,20 +372,20 @@ class MultimodalService: GenericFile.id == file_id, GenericFile.status == "active" ).first() - + if not generic_file: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - + # TODO: 根据文件类型提取文本 # - PDF: 使用 PyPDF2 或 pdfplumber # - Word: 使用 python-docx # - TXT/MD: 直接读取 - + file_ext = generic_file.file_ext.lower() - + if file_ext in ['.txt', '.md', '.markdown']: return await self._read_text_file(generic_file.storage_path) elif file_ext == '.pdf': @@ -393,7 +394,7 @@ class MultimodalService: return await self._extract_word_text(generic_file.storage_path) else: return f"[不支持的文档格式: {file_ext}]" - + async def _read_text_file(self, storage_path: str) -> str: """读取纯文本文件""" try: @@ -402,7 +403,7 @@ class MultimodalService: except Exception as e: logger.error(f"读取文本文件失败: {e}") return f"[文件读取失败: {str(e)}]" - + async def _extract_pdf_text(self, storage_path: str) -> str: """提取 PDF 文本""" try: @@ -412,7 +413,7 @@ class MultimodalService: except Exception as e: logger.error(f"提取 PDF 文本失败: {e}") return f"[PDF 提取失败: {str(e)}]" - + async def _extract_word_text(self, storage_path: str) -> str: """提取 Word 文档文本""" try: diff --git a/api/app/services/prompt/prompt_optimizer_system.jinja2 b/api/app/services/prompt/prompt_optimizer_system.jinja2 index b9060f68..39a4ba68 100644 --- a/api/app/services/prompt/prompt_optimizer_system.jinja2 +++ b/api/app/services/prompt/prompt_optimizer_system.jinja2 @@ -1,4 +1,3 @@ -{% raw %} Role: AI Prompt Optimization Expert Profile @@ -12,11 +11,11 @@ Skills Core Optimization Skills Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt. Structural Reconstruction: Transform vague requirements into clear, block-structured instructions. -Variable Handling: Identify and standardize dynamic variables in prompts. +{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %} Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs. Auxiliary Generation Skills -Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined. +{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %} Language Consistency: Maintain consistency between label language and user input language. Executability Verification: Ensure optimized prompts can be directly used in AI tools. Format Standardization: Strictly adhere to specified output format requirements. @@ -25,30 +24,30 @@ Rules Basic Principles Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements. Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements. -Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints +{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %} Language Rule: All label languages must fully match the user input language. Behavior Guidelines Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity. Readability Guideline: Ensure optimized prompts have good readability and logical flow. -Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. -Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label. +{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed. +Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %} Constraints Output Constraint: Must output in JSON format including the fields "prompt" and "desc". Content Constraint: Must not include any explanations, analyses, or additional comments. Language Constraint: Must use clear and concise language. -Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.). +{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %} Workflows Goal: Optimize or generate AI prompts that can be directly used according to user requirements. Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}. Step 2: Analyze requirements, identify conflicts, and prioritize current requirements. -Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. +{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined. Step 4: Generate a JSON output containing the optimized prompt and its description. +{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %} Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description. Initialization -As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. -{% endraw %} \ No newline at end of file +As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows. \ No newline at end of file diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 2c0b57ac..99edcc0e 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -23,6 +23,7 @@ from app.repositories.prompt_optimizer_repository import ( PromptReleaseRepository ) from app.schemas.prompt_optimizer_schema import OptimizePromptResult +from app.services.model_service import ModelApiKeyService logger = get_business_logger() @@ -128,7 +129,8 @@ class PromptOptimizerService: session_id: uuid.UUID, user_id: uuid.UUID, current_prompt: str, - user_require: str + user_require: str, + skill: bool = False ) -> AsyncGenerator[dict[str, str | Any], Any]: """ Optimize a user-provided prompt using a configured prompt optimizer LLM. @@ -157,6 +159,7 @@ class PromptOptimizerService: user_id (uuid.UUID): Identifier of the user associated with the session. current_prompt (str): Original prompt to optimize. user_require (str): User's requirements or instructions for optimization. + skill(bool): Is skill required Returns: OptimizePromptResult: An object containing: @@ -174,8 +177,9 @@ class PromptOptimizerService: logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}") # Create LLM instance - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id) - api_config: ModelApiKey = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id) + # api_config: ModelApiKey = api_keys[0] if api_keys else None + api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id) llm = RedBearLLM(RedBearModelConfig( model_name=api_config.model_name, provider=api_config.provider, @@ -186,7 +190,7 @@ class PromptOptimizerService: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() - rendered_system_message = Template(opt_system_prompt).render() + rendered_system_message = Template(opt_system_prompt).render(skill=skill) with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f: opt_user_prompt = f.read() @@ -250,6 +254,7 @@ class PromptOptimizerService: optim_result = json_repair.repair_json(buffer, return_objects=True) # prompt = optim_result.get("prompt") desc = optim_result.get("desc") + ModelApiKeyService.record_api_key_usage(self.db, api_config.id) self.create_message( tenant_id=tenant_id, session_id=session_id, diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index a92c2649..c7b81999 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -10,6 +10,7 @@ from app.services.memory_konwledges_server import write_rag from app.models import ReleaseShare, AppRelease, Conversation from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_web_search_tool +from app.services.model_service import ModelApiKeyService from app.services.release_share_service import ReleaseShareService from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.error_codes import BizCode @@ -178,8 +179,9 @@ class SharedChatService: # .limit(1) # ) # api_key_obj = self.db.scalars(stmt).first() - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) - api_key_obj = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + # api_key_obj = api_keys[0] if api_keys else None + api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) @@ -309,6 +311,8 @@ class SharedChatService: elapsed_time = time.time() - start_time + ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) + return { "conversation_id": conversation.id, @@ -349,7 +353,8 @@ class SharedChatService: if variables is None: variables = {} - memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10} + # 兼容新旧字段名:使用 memory_config_id + memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10} try: # 获取发布版本和配置 @@ -383,8 +388,9 @@ class SharedChatService: # .limit(1) # ) # api_key_obj = self.db.scalars(stmt).first() - api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) - api_key_obj = api_keys[0] if api_keys else None + # api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id) + # api_key_obj = api_keys[0] if api_keys else None + api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) @@ -513,7 +519,8 @@ class SharedChatService: } ) - + ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) + # 发送结束事件 end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py new file mode 100644 index 00000000..ea21b2ad --- /dev/null +++ b/api/app/services/skill_service.py @@ -0,0 +1,133 @@ +"""Skill Service""" +import uuid +from typing import List + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.repositories.skill_repository import SkillRepository +from app.schemas.skill_schema import SkillCreate, SkillUpdate +from app.models.skill_model import Skill +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.services.tool_service import ToolService + + +class SkillService: + """Skill 业务逻辑层""" + + @staticmethod + def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill: + """创建技能""" + # 检查同名技能 + existing = db.query(Skill).filter( + Skill.tenant_id == tenant_id, + Skill.name == data.name + ).first() + if existing: + raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME) + + skill = SkillRepository.create(db, data, tenant_id) + db.commit() + db.refresh(skill) + return skill + + @staticmethod + def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill: + """获取技能""" + try: + skill = SkillRepository.get_by_id(db, skill_id, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND) + + # 填充工具详情 + tool_service = ToolService(db) + enriched_tools = [] + for tool_config in skill.tools: + tool_id = tool_config.get("tool_id") + if tool_id: + tool_info = tool_service.get_tool_info(tool_id, tenant_id) + if tool_info: + enriched_tools.append({ + "tool_id": tool_id, + "operation": tool_config.get("operation"), + "tool_info": tool_info + }) + skill.tools = enriched_tools + + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def list_skills( + db: Session, + tenant_id: uuid.UUID, + search: str = None, + is_active: bool = None, + is_public: bool = None, + page: int = 1, + pagesize: int = 10 + ) -> tuple[list[type[Skill]], int]: + """列出技能""" + return SkillRepository.list_skills( + db, tenant_id, search, is_active, is_public, page, pagesize + ) + + @staticmethod + def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill: + """更新技能""" + try: + skill = SkillRepository.update(db, skill_id, data, tenant_id) + if not skill: + raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND) + db.commit() + db.refresh(skill) + return skill + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool: + """删除技能""" + try: + success = SkillRepository.delete(db, skill_id, tenant_id) + if not success: + raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND) + db.commit() + return True + except (BusinessException, SQLAlchemyError) as e: + db.rollback() + raise e + + @staticmethod + def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]: + """加载技能关联的工具 + + Returns: + (tools, tool_to_skill_map) - 工具列表和工具到技能的映射 + """ + tools = [] + tool_to_skill_map = {} # {tool_name: skill_id} + tool_service = ToolService(db) + + for skill_id in skill_ids: + try: + skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id) + if skill and skill.is_active: + # 加载技能关联的工具 + for tool_config in skill.tools: + tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool: + langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 建立工具到技能的映射 + tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool))) + tool_to_skill_map[tool_name] = skill_id + except Exception as e: + print(f"加载技能 {skill_id} 的工具时出错: {e}") + continue + + return tools, tool_to_skill_map diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 2958f4f9..1c7c7304 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -4,9 +4,8 @@ import datetime import logging import uuid -from typing import Any, Annotated, AsyncGenerator, Optional +from typing import Any, Annotated, Optional -from deprecated import deprecated from fastapi import Depends from sqlalchemy.orm import Session @@ -23,6 +22,7 @@ from app.repositories.workflow_repository import ( from app.schemas import DraftRunRequest from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -36,6 +36,7 @@ class WorkflowService: self.execution_repo = WorkflowExecutionRepository(db) self.node_execution_repo = WorkflowNodeExecutionRepository(db) self.conversation_service = ConversationService(db) + self.multimodal_service = MultimodalService(db) # ==================== 配置管理 ==================== @@ -445,24 +446,22 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id} + files = [] + if payload.files: + for file in payload.files: + files.append( + { + "type": file.type, + "url": await self.multimodal_service.get_file_url(file), + "__file": True + } + ) - # 转换 user_id 为 UUID - triggered_by_uuid = None - if payload.user_id: - try: - triggered_by_uuid = uuid.UUID(payload.user_id) - except (ValueError, AttributeError): - logger.warning(f"无效的 user_id 格式: {payload.user_id}") + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, "files": files} # 转换 conversation_id 为 UUID - conversation_id_uuid = None - if payload.conversation_id: - try: - conversation_id_uuid = uuid.UUID(payload.conversation_id) - except (ValueError, AttributeError): - logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") + conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None # 2. 创建执行记录 execution = self.create_execution( @@ -544,10 +543,10 @@ class WorkflowService: return { "execution_id": execution.execution_id, "status": result.get("status"), - "variables": result.get("variables"), - "messages": result.get("messages"), + # "variables": result.get("variables"), + # "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) - "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) + # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), "elapsed_time": result.get("elapsed_time"), @@ -566,6 +565,41 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) + @staticmethod + def _map_public_event(event: dict) -> dict | None: + event_type = event.get("event") + payload = event.get("data") + match event_type: + case "workflow_start": + return { + "event": "start", + "data": { + "conversation_id": payload.get("conversation_id"), + } + } + case "workflow_end": + return { + "event": "end", + "data": { + "elapsed_time": payload.get("elapsed_time"), + "message_length": len(payload.get("output", "")) + } + } + case "node_start" | "node_end" | "node_error": + return None + case _: + return event + + def _emit(self, public: bool, internal_event: dict): + """ + decide + """ + if public: + mapped = self._map_public_event(internal_event) + else: + mapped = internal_event + return mapped + async def run_stream( self, app_id: uuid.UUID, @@ -573,6 +607,7 @@ class WorkflowService: config: WorkflowConfig, workspace_id: uuid.UUID, release_id: Optional[uuid.UUID] = None, + public: bool = False ): """运行工作流(流式) @@ -582,6 +617,7 @@ class WorkflowService: app_id: 应用 ID payload: 请求对象(包含 message, variables, conversation_id 等) config: 存储类型(可选) + public: 是否发布 Yields: SSE 格式的流式事件 @@ -597,24 +633,23 @@ class WorkflowService: code=BizCode.CONFIG_MISSING, message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id} - # 转换 user_id 为 UUID - triggered_by_uuid = None - if payload.user_id: - try: - triggered_by_uuid = uuid.UUID(payload.user_id) - except (ValueError, AttributeError): - logger.warning(f"无效的 user_id 格式: {payload.user_id}") + files = [] + if payload.files: + for file in payload.files: + files.append( + { + "type": file.type, + "url": await self.multimodal_service.get_file_url(file), + "__file": True + } + ) + + input_data = {"message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, "files": files} # 转换 conversation_id 为 UUID - conversation_id_uuid = None - if payload.conversation_id: - try: - conversation_id_uuid = uuid.UUID(payload.conversation_id) - except (ValueError, AttributeError): - logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}") + conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None # 2. 创建执行记录 execution = self.create_execution( @@ -661,7 +696,7 @@ class WorkflowService: input_data=input_data, execution_id=execution.execution_id, workspace_id=str(workspace_id), - user_id=payload.user_id + user_id=payload.user_id, ): if event.get("event") == "workflow_end": @@ -692,7 +727,9 @@ class WorkflowService: ) else: logger.error(f"unexpect workflow run status, status: {status}") - yield event + event = self._emit(public, event) + if event: + yield event except Exception as e: logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) @@ -710,134 +747,6 @@ class WorkflowService: } } - @deprecated(reason="This method is deprecated. " - "Please use WorkflowService.run / run_stream instead.") - async def run_workflow( - self, - app_id: uuid.UUID, - input_data: dict[str, Any], - triggered_by: uuid.UUID, - conversation_id: uuid.UUID | None = None, - stream: bool = False - ) -> AsyncGenerator | dict: - """运行工作流 - - Args: - app_id: 应用 ID - input_data: 输入数据(包含 message 和 variables) - triggered_by: 触发用户 ID - conversation_id: 会话 ID(可选) - stream: 是否流式返回 - - Returns: - 执行结果(非流式)或生成器(流式) - - Raises: - BusinessException: 配置不存在或执行失败时抛出 - """ - # 1. 获取工作流配置 - config = self.get_workflow_config(app_id) - if not config: - raise BusinessException( - code=BizCode.NOT_FOUND, - message=f"工作流配置不存在: app_id={app_id}" - ) - - # 2. 创建执行记录 - execution = self.create_execution( - workflow_config_id=config.id, - app_id=app_id, - trigger_type="manual", - triggered_by=triggered_by, - conversation_id=conversation_id, - input_data=input_data - ) - - # 3. 构建工作流配置字典 - workflow_config_dict = { - "nodes": config.nodes, - "edges": config.edges, - "variables": config.variables, - "execution_config": config.execution_config - } - - # 4. 获取工作空间 ID(从 app 获取) - from app.models import App - app = self.db.query(App).filter( - App.id == app_id, - App.is_active.is_(True) - ).first() - if not app: - raise BusinessException( - code=BizCode.NOT_FOUND, - message=f"应用不存在: app_id={app_id}" - ) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - - try: - # 更新状态为运行中 - self.update_execution_status(execution.execution_id, "running") - - if stream: - # 流式执行 - return self._run_workflow_stream( - workflow_config_dict, - input_data, - execution.execution_id, - str(app.workspace_id), - str(triggered_by) - ) - else: - # 非流式执行 - result = await execute_workflow( - workflow_config=workflow_config_dict, - input_data=input_data, - execution_id=execution.execution_id, - workspace_id=str(app.workspace_id), - user_id=str(triggered_by) - ) - - # 更新执行结果 - if result.get("status") == "completed": - token_usage = result.get("data").get("token_usage", {}) or {} - self.update_execution_status( - execution.execution_id, - "completed", - output_data=result.get("node_outputs", {}), - token_usage=token_usage.get("total_tokens", None) - ) - else: - self.update_execution_status( - execution.execution_id, - "failed", - error_message=result.get("error") - ) - - # 返回增强的响应结构 - return { - "execution_id": execution.execution_id, - "status": result.get("status"), - "output": result.get("output"), # 最终输出(字符串) - "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) - "error_message": result.get("error"), - "elapsed_time": result.get("elapsed_time"), - "token_usage": result.get("token_usage") - } - - except Exception as e: - logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True) - self.update_execution_status( - execution.execution_id, - "failed", - error_message=str(e) - ) - raise BusinessException( - code=BizCode.INTERNAL_ERROR, - message=f"工作流执行失败: {str(e)}" - ) - def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: """清理事件数据,移除不可序列化的对象 @@ -869,72 +778,6 @@ class WorkflowService: return clean_value(event) - @deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.") - async def _run_workflow_stream( - self, - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str): - """运行工作流(流式,内部方法) - - Args: - workflow_config: 工作流配置 - input_data: 输入数据 - execution_id: 执行 ID - workspace_id: 工作空间 ID - user_id: 用户 ID - - Yields: - 流式事件(格式:{"event": "", "data": {...}}) - """ - from app.core.workflow.executor import execute_workflow_stream - - try: - async for event in execute_workflow_stream( - workflow_config=workflow_config, - input_data=input_data, - execution_id=execution_id, - workspace_id=workspace_id, - user_id=user_id - ): - # 直接转发事件(executor 已经返回正确格式) - if event.get("event") == "workflow_end": - token_usage = event.get("data").get("token_usage", {}) or {} - status = event.get("data", {}).get("status") - if status == "completed": - self.update_execution_status( - execution_id, - "completed", - output_data=event.get("data"), - token_usage=token_usage.get("total_tokens", None) - ) - elif status == "failed": - self.update_execution_status( - execution_id, - "failed", - output_data=event.get("data") - ) - else: - logger.error(f"unexpect workflow run status, status: {status}") - yield event - - except Exception as e: - logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True) - self.update_execution_status( - execution_id, - "failed", - error_message=str(e) - ) - yield { - "event": "error", - "data": { - "execution_id": execution_id, - "error": str(e) - } - } - # ==================== 依赖注入函数 ==================== diff --git a/api/app/tasks.py b/api/app/tasks.py index 247cba76..a46a3a7b 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1069,6 +1069,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: + db.rollback() # Rollback failed transaction to allow next query api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), @@ -1207,3 +1208,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di return result finally: loop.close() + + +# ============================================================================= +# Long-term Memory Storage Tasks (Batched Write Strategies) +# ============================================================================= + +@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True) +def long_term_storage_window_task( + self, + end_user_id: str, + langchain_messages: List[Dict[str, Any]], + config_id: str, + scope: int = 6 +) -> Dict[str, Any]: + """Celery task for window-based long-term memory storage. + + Accumulates messages in Redis buffer until window size (scope) is reached, + then writes batched messages to Neo4j. + + Args: + end_user_id: End user identifier + langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] + config_id: Memory configuration ID + scope: Window size (number of messages before triggering write) + + Returns: + Dict containing task status and metadata + """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}") + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + from app.services.memory_config_service import MemoryConfigService + + db = next(get_db()) + try: + # Save to Redis buffer first + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + + # Load memory config + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="LongTermStorageTask" + ) + + # Execute window-based dialogue storage + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + + return {"status": "SUCCESS", "strategy": "window", "scope": scope} + finally: + db.close() + + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + + logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s") + + return { + **result, + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True) + + return { + "status": "FAILURE", + "strategy": "window", + "error": str(e), + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True) +# def long_term_storage_time_task( +# self, +# end_user_id: str, +# config_id: str, +# time_window: int = 5 +# ) -> Dict[str, Any]: +# """Celery task for time-based long-term memory storage. + +# Retrieves recent sessions from Redis within time window and writes to Neo4j. + +# Args: +# end_user_id: End user identifier +# config_id: Memory configuration ID +# time_window: Time window in minutes for retrieving recent sessions + +# Returns: +# Dict containing task status and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute time-based storage +# await memory_long_term_storage(end_user_id, memory_config, time_window) + +# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window} +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "time", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True) +# def long_term_storage_aggregate_task( +# self, +# end_user_id: str, +# langchain_messages: List[Dict[str, Any]], +# config_id: str +# ) -> Dict[str, Any]: +# """Celery task for aggregate-based long-term memory storage. + +# Uses LLM to determine if new messages describe the same event as history. +# Only writes to Neo4j if messages represent new information (not duplicates). + +# Args: +# end_user_id: End user identifier +# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] +# config_id: Memory configuration ID + +# Returns: +# Dict containing task status, is_same_event flag, and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment +# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format +# from app.core.memory.agent.utils.redis_tool import write_store +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Save to Redis buffer first +# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute aggregate judgment +# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config) + +# return { +# "status": "SUCCESS", +# "strategy": "aggregate", +# "is_same_event": result.get("is_same_event", False), +# "wrote_to_neo4j": not result.get("is_same_event", False) +# } +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "aggregate", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index 06549989..afa18417 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -99,7 +99,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig: knowledge_retrieval=config_dict.get("knowledge_retrieval"), memory=config_dict.get("memory"), variables=config_dict.get("variables", []), - tools=config_dict.get("tools", {}), + tools=config_dict.get("tools", []), + skills=config_dict.get("skills", {}) ) return agent_config diff --git a/api/migrations/versions/9b28b66cf8e8_202602041811.py b/api/migrations/versions/9b28b66cf8e8_202602041811.py new file mode 100644 index 00000000..c35b753f --- /dev/null +++ b/api/migrations/versions/9b28b66cf8e8_202602041811.py @@ -0,0 +1,30 @@ +"""202602041811 + +Revision ID: 9b28b66cf8e8 +Revises: e7c7afa249d1 +Create Date: 2026-02-04 18:12:12.454402 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '9b28b66cf8e8' +down_revision: Union[str, None] = 'e7c7afa249d1' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('agent_configs', 'skill_ids', new_column_name='skills', comment='技能配置') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('agent_configs', 'skills', new_column_name='skill_ids', comment='关联的技能ID列表') + # ### end Alembic commands ### diff --git a/api/migrations/versions/e7c7afa249d1_202602041355.py b/api/migrations/versions/e7c7afa249d1_202602041355.py new file mode 100644 index 00000000..0559d5b4 --- /dev/null +++ b/api/migrations/versions/e7c7afa249d1_202602041355.py @@ -0,0 +1,50 @@ +"""202602041355 + +Revision ID: e7c7afa249d1 +Revises: 9def72f79398 +Create Date: 2026-02-04 13:55:22.284981 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'e7c7afa249d1' +down_revision: Union[str, None] = '9def72f79398' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('skills', + sa.Column('id', sa.UUID(), nullable=False), + sa.Column('name', sa.String(), nullable=False, comment='技能名称'), + sa.Column('description', sa.Text(), nullable=True, comment='技能描述'), + sa.Column('tenant_id', sa.UUID(), nullable=False, comment='租户ID'), + sa.Column('tools', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的工具列表'), + sa.Column('config', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='技能配置'), + sa.Column('prompt', sa.Text(), nullable=True, comment='技能专属提示词'), + sa.Column('is_active', sa.Boolean(), nullable=False, comment='是否激活'), + sa.Column('is_public', sa.Boolean(), nullable=False, comment='是否公开到市场'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'), + sa.ForeignKeyConstraint(['tenant_id'], ['tenants.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_skills_id'), 'skills', ['id'], unique=False) + op.create_index(op.f('ix_skills_tenant_id'), 'skills', ['tenant_id'], unique=False) + op.add_column('agent_configs', sa.Column('skill_ids', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='关联的技能ID列表')) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('agent_configs', 'skill_ids') + op.drop_index(op.f('ix_skills_tenant_id'), table_name='skills') + op.drop_index(op.f('ix_skills_id'), table_name='skills') + op.drop_table('skills') + # ### end Alembic commands ### diff --git a/sandbox/app/controllers/sandbox_controller.py b/sandbox/app/controllers/sandbox_controller.py index c5cce40c..f9bc3fc0 100644 --- a/sandbox/app/controllers/sandbox_controller.py +++ b/sandbox/app/controllers/sandbox_controller.py @@ -33,7 +33,7 @@ async def run_code(request: RunCodeRequest): """Execute code in sandbox""" if request.language == "python3": return await run_python_code(request.code, request.preload, request.options) - elif request.language == "nodejs": + elif request.language == "javascript": return await run_nodejs_code(request.code, request.preload, request.options) else: return error_response(-400, "unsupported language") diff --git a/web/src/App.tsx b/web/src/App.tsx index 032338a3..b3d0708c 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -37,7 +37,7 @@ function App() { const { checkJump } = useUser(); useEffect(() => { const authToken = cookieUtils.get('authToken') - if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/')) { + if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump')) { window.location.href = `/#/login`; } else { checkJump() diff --git a/web/src/api/skill.ts b/web/src/api/skill.ts new file mode 100644 index 00000000..47e77d86 --- /dev/null +++ b/web/src/api/skill.ts @@ -0,0 +1,30 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-05 10:28:44 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-05 10:28:44 + */ +import { request } from '@/utils/request' +import type { SkillFormData } from '@/views/Skills/types' + +// Get skill list +export const getSkillListUrl = '/skills' +export const getSkillList = (data?: any) => { + return request.get(getSkillListUrl, data) +} +// Get skill details +export const getSkillDetail = (skill_id: string, data?: any) => { + return request.get(`/skills/${skill_id}`, data) +} +// Create skill +export const createSkill = (data: SkillFormData) => { + return request.post('/skills', data) +} +// Update skill +export const updateSkill = (skill_id: string, data: SkillFormData) => { + return request.put(`/skills/${skill_id}`, data) +} +// Delete skill +export const deleteSkill = (skill_id: string) => { + return request.delete(`/skills/${skill_id}`) +} \ No newline at end of file diff --git a/web/src/assets/images/menu/skills.svg b/web/src/assets/images/menu/skills.svg new file mode 100644 index 00000000..ac121d1e --- /dev/null +++ b/web/src/assets/images/menu/skills.svg @@ -0,0 +1,14 @@ + + + 技能点 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menu/skills_active.svg b/web/src/assets/images/menu/skills_active.svg new file mode 100644 index 00000000..789b5586 --- /dev/null +++ b/web/src/assets/images/menu/skills_active.svg @@ -0,0 +1,14 @@ + + + 技能点备份 + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/Layout/BasicAuthLayout.tsx b/web/src/components/Layout/BasicAuthLayout.tsx new file mode 100644 index 00000000..a73f6c69 --- /dev/null +++ b/web/src/components/Layout/BasicAuthLayout.tsx @@ -0,0 +1,45 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-02 15:12:42 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 14:06:28 + */ +/** + * BasicLayout Component + * + * A minimal layout wrapper that provides: + * - User information initialization + * - Storage type initialization + * - Simple container for child routes without navigation UI + * + * Used for pages that don't require sidebar/header (e.g., login, public pages). + * + * @component + */ + +import { Outlet } from 'react-router-dom'; +import { useEffect, type FC } from 'react'; + +import { useUser } from '@/store/user'; + +/** + * Basic layout component for pages without navigation UI. + * Fetches user info and storage type on mount, then renders child routes. + */ +const BasicLayout: FC = () => { + const { getUserInfo } = useUser(); + + // Fetch user information and storage type on component mount + useEffect(() => { + getUserInfo(); + }, [getUserInfo]); + + return ( +
+ {/* Render child routes without additional UI */} + +
+ ) +}; + +export default BasicLayout; \ No newline at end of file diff --git a/web/src/components/PageScrollList/index.tsx b/web/src/components/PageScrollList/index.tsx index 49173a68..a877a9c7 100644 --- a/web/src/components/PageScrollList/index.tsx +++ b/web/src/components/PageScrollList/index.tsx @@ -142,7 +142,7 @@ const PageScrollList = forwardRef(>({ dataLength={data.length} next={loadMoreData} hasMore={hasMore} - loader={needLoading ? : undefined} + loader={loading && needLoading ? : false} // endMessage={It is all, nothing more 🤐} scrollableTarget="scrollableDiv" className='rb:h-full!' diff --git a/web/src/components/RbCard/Card.tsx b/web/src/components/RbCard/Card.tsx index 85b569df..896dc201 100644 --- a/web/src/components/RbCard/Card.tsx +++ b/web/src/components/RbCard/Card.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 15:21:14 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 15:21:14 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 13:49:05 */ /** * RbCard Component @@ -18,7 +18,7 @@ */ import { type FC, type ReactNode } from 'react' -import { Card, Tooltip } from 'antd'; +import { Card, Tooltip, Flex } from 'antd'; import clsx from 'clsx'; /** Props interface for RbCard component */ @@ -51,6 +51,7 @@ interface RbCardProps { className?: string; /** Click handler */ onClick?: () => void; + variant?: 'borderL'; } /** Custom card component with flexible styling and header options */ @@ -68,6 +69,7 @@ const RbCard: FC = ({ bgColor = '#FBFDFF', height = 'auto', className, + variant, ...props }) => { /** Calculate body padding based on header type and avatar presence */ @@ -82,7 +84,45 @@ const RbCard: FC = ({ : (headerType === 'border' && !avatarUrl && !avatar) || headerType === 'borderBL' ? 'rb:p-[16px_16px_20px_16px]!' : '' - + + if (variant === 'borderL') { + return ( +
+ + +
+ {typeof title === 'function' ? title() : title ? +
+ {avatarUrl + ? + : avatar ? avatar : null + } +
+
{title}
+ {subTitle &&
{subTitle}
} +
+
: null + } +
+ {subTitle &&
{subTitle}
} +
+ {extra} +
+
+ {children} +
+
+ ) + } return ( = ({ }, headerClassName, ), - body: bodyClassNames ? bodyClassNames : children ? bodyClassName : 'rb:p-[0]!', + body: bodyClassNames ? bodyClassNames : children ? bodyClassName : 'rb:p-0!', }} style={{ background: bgColor, diff --git a/web/src/components/SiderMenu/index.tsx b/web/src/components/SiderMenu/index.tsx index 82ea8c6e..21202aa0 100644 --- a/web/src/components/SiderMenu/index.tsx +++ b/web/src/components/SiderMenu/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 15:25:31 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 15:25:31 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 13:49:16 */ /** * SiderMenu Component @@ -67,6 +67,8 @@ import ontologyIcon from '@/assets/images/menu/ontology.svg' import ontologyActiveIcon from '@/assets/images/menu/ontology_active.svg' import promptIcon from '@/assets/images/menu/prompt.svg' import promptActiveIcon from '@/assets/images/menu/prompt_active.svg' +import skillsIcon from '@/assets/images/menu/skills.svg' +import skillsActiveIcon from '@/assets/images/menu/skills_active.svg' /** Icon path mapping table for menu items (normal and active states) */ const iconPathMap: Record = { @@ -102,6 +104,8 @@ const iconPathMap: Record = { 'ontologyActive': ontologyActiveIcon, 'prompt': promptIcon, 'promptActive': promptActiveIcon, + 'skills': skillsIcon, + 'skillsActive': skillsActiveIcon, }; const { Sider } = Layout; diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index d3c788df..fe0fbc37 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -115,6 +115,7 @@ export const en = { spaceConfig: 'Space Configuration', ontology: 'Ontology Engineering', prompt: 'Prompt Engineering', + skills: 'Skill Library', }, dashboard: { total_models: 'Available Models', @@ -1248,6 +1249,22 @@ export const en = { daily_new_users: 'Daily New Users', daily_api_calls: 'Daily API Calls', daily_tokens: 'Token Consumption', + + skill: 'Skill Configuration', + skillTitle: 'Configure Agent skills and matching modes', + skillHelp: 'Help Center', + addSkill: 'Add Skill', + dynamicBindingSkill: 'Dynamic Optional Skills', + dynamicBindingSkill_subTitle: 'Skill pool that Agent can automatically match based on tasks', + dynamicBindingSkill_empty: 'No dynamic skills configured yet, click the button above to add or enable "Allow access to all skills"', + chooseSkill: 'Choose Skill', + allSkill: 'Allow access to all skills', + allSkillIntro: 'Access to all skills enabled, Agent will automatically match optimal skills based on tasks', + executeProcessPreview: 'Execution Process Preview', + receiveTask: 'Receive Task', + analyTask: 'Analyze Task Intent', + dynamicMatchSkill: 'Dynamic Match Skill', + executeTask: 'Execute Task', }, userMemory: { userMemory: 'User Memory', @@ -1882,6 +1899,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re enable_window: 'Memory Window', inner: 'Built-in', messagesPlaceholder: 'Write prompts here, type "{" to insert variables, type "insert" to insert', + vision: 'Vision', }, start: { variables: 'Input Fields', @@ -2154,6 +2172,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re resilience: 'Resilience', suggestions: 'Personalized Suggestions', suggestionLoading: 'Your personalized suggestions are being generated', + item: 'item', }, reflectionEngine: { reflectionEngineConfig: 'Reflection Engine Configuration', @@ -2493,5 +2512,28 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re initialInput: 'Original Input', saveTitle: 'Title', }, + skills: { + searchPlaceholder: 'Search skills', + create: 'Add Skill', + mainfest: 'Define Encapsulation Container', + name: 'Skill Name', + description: 'Brief Description', + descriptionPlaceholder: 'Describe the intent and purpose of the skill...', + keywords: 'Keywords', + promptConfiguration: 'Inject Experience Logic', + aiPrompt: 'AI Experience Refinement', + prompt_type: 'System Instructions / Expert Knowledge', + promptPlaceholder: 'Enter system instructions or expert knowledge...', + save: 'Save', + AIPromptAssistant: 'AI Experience Refinement', + model: 'Model', + promptChatEmpty: 'No conversation content available', + you: 'You', + ai: 'AI Assistant', + promptChatPlaceholder: 'Describe your requirements...', + conversationOptimizationPrompt: 'Refined Content', + apply: 'Apply', + tools: 'Tools', + }, }, }; diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 4e1e52e4..7fc8b652 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -114,6 +114,7 @@ export const zh = { spaceConfig: '空间配置', ontology: '本体工程', prompt: '提示词工程', + skills: '技能库', }, knowledgeBase: { home: '首页', @@ -667,6 +668,22 @@ export const zh = { daily_new_users: '新增用户数', daily_api_calls: '调用次数', daily_tokens: 'Token消耗', + + skill: '技能配置', + skillTitle: '配置 Agent 可使用的技能及匹配模式', + skillHelp: '帮助中心', + addSkill: '添加技能', + dynamicBindingSkill: '动态可选技能', + dynamicBindingSkill_subTitle: 'Agent 可根据任务自动匹配的技能池', + dynamicBindingSkill_empty: '暂未配置动态技能,点击上方按钮添加或开启"允许访问所有技能"', + chooseSkill: '选择技能', + allSkill: '允许访问所有技能', + allSkillIntro: '已开启访问所有技能,Agent 将根据任务自动匹配最优技能', + executeProcessPreview: '执行流程预览', + receiveTask: '收到任务', + analyTask: '分析任务意图', + dynamicMatchSkill: '动态匹配技能', + executeTask: '执行任务', }, role: { roleManagement: '角色管理', @@ -1970,6 +1987,7 @@ export const zh = { enable_window: '记忆窗口', inner: '内置', messagesPlaceholder: '在此处编写提示,输入“{”插入变量,输入“insert”插入', + vision: '视觉', }, start: { variables: '输入字段', @@ -2243,6 +2261,7 @@ export const zh = { resilience: '恢复力', suggestions: '个性化建议', suggestionLoading: '您的个性化建议正在生成中', + item: '个', }, reflectionEngine: { reflectionEngineConfig: '反思引擎配置', @@ -2582,5 +2601,28 @@ export const zh = { initialInput: '原始输入', saveTitle: '标题', }, + skills: { + searchPlaceholder: '搜索技能', + create: '添加技能', + mainfest: '定义封装容器', + name: '技能名称', + description: '简要描述', + descriptionPlaceholder: '描述技能的意图和用途...', + keywords: '关键词', + promptConfiguration: '注入经验逻辑', + aiPrompt: 'AI 经验提炼', + prompt_type: '系统指令 / 专家知识', + promptPlaceholder: '输入系统指令或专家知识...', + save: '保存', + AIPromptAssistant: 'AI 经验提炼', + model: '模型', + promptChatEmpty: '目前没有对话内容', + you: '你', + ai: 'AI 助手', + promptChatPlaceholder: '描述你的需求...', + conversationOptimizationPrompt: '提炼内容', + apply: '应用', + tools: '工具', + }, }, } \ No newline at end of file diff --git a/web/src/routes/index.tsx b/web/src/routes/index.tsx index 09479f59..42e0106a 100644 --- a/web/src/routes/index.tsx +++ b/web/src/routes/index.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 16:33:11 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 16:33:11 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 18:11:34 */ /** * Route Configuration @@ -23,7 +23,6 @@ import { createHashRouter, createRoutesFromElements, Route } from 'react-router- /** Import route configuration JSON */ import routesConfig from './routes.json'; - /** Recursively collect all element names from routes */ function collectElements(routes: RouteConfig[]): Set { const elements = new Set(); @@ -52,6 +51,7 @@ const componentMap: Record>> = BasicLayout: lazy(() => import('@/components/Layout/BasicLayout')), LoginLayout: lazy(() => import('@/components/Layout/LoginLayout')), NoAuthLayout: lazy(() => import('@/components/Layout/NoAuthLayout')), + BasicAuthLayout: lazy(() => import('@/components/Layout/BasicAuthLayout')), /** View components */ Index: lazy(() => import('@/views/Index')), Home: lazy(() => import('@/views/Home')), @@ -88,6 +88,9 @@ const componentMap: Record>> = Ontology: lazy(() => import('@/views/Ontology')), OntologyDetail: lazy(() => import('@/views/Ontology/pages/Detail')), Prompt: lazy(() => import('@/views/Prompt')), + Skills: lazy(() => import('@/views/Skills')), + SkillConfig: lazy(() => import('@/views/Skills/pages/SkillConfig')), + Jump: lazy(() => import('@/views/JumpPage')), Login: lazy(() => import('@/views/Login')), InviteRegister: lazy(() => import('@/views/InviteRegister')), NoPermission: lazy(() => import('@/views/NoPermission')), diff --git a/web/src/routes/routes.json b/web/src/routes/routes.json index b02ebddf..ea137bd4 100644 --- a/web/src/routes/routes.json +++ b/web/src/routes/routes.json @@ -10,6 +10,7 @@ { "path": "/pricing", "element": "Pricing" }, { "path": "/order-pay", "element": "OrderPayment" }, { "path": "/orders", "element": "OrderHistory" }, + { "path": "/skills", "element": "Skills" }, { "path": "/no-permission", "element": "NoPermission" } ] }, @@ -50,10 +51,18 @@ { "path": "/ontology/:id", "element": "OntologyDetail" } ] }, + { + "element": "BasicAuthLayout", + "children": [ + { "path": "/skills/add", "element": "SkillConfig" }, + { "path": "/skills/config/:id", "element": "SkillConfig" } + ] + }, { "element": "NoAuthLayout", "children": [ - { "path": "/conversation/:token", "element": "Conversation" } + { "path": "/conversation/:token", "element": "Conversation" }, + { "path": "/jump", "element": "Jump" } ] }, { diff --git a/web/src/store/menu.json b/web/src/store/menu.json index d264e061..4f53ab50 100644 --- a/web/src/store/menu.json +++ b/web/src/store/menu.json @@ -52,6 +52,21 @@ "sort": 0, "subs": [] }, + { + "id": 8, + "parent": 0, + "code": "skills", + "label": "技能库", + "i18nKey": "menu.skills", + "path": "/skills", + "enable": true, + "display": true, + "level": 1, + "sort": 0, + "icon": null, + "iconActive": null, + "subs": null + }, { "id": 6, "parent": 0, diff --git a/web/src/store/user.ts b/web/src/store/user.ts index 505cb768..c9231d9c 100644 --- a/web/src/store/user.ts +++ b/web/src/store/user.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 16:33:54 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 16:33:54 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 18:30:10 */ /** * User Store @@ -59,7 +59,8 @@ export interface UserState { export const whitePage = [ '/conversation', '/login', - '/invite-register' + '/invite-register', + 'jump' ] /** User store */ diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 53670dab..bbbe9cd9 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -180,7 +180,4 @@ body { .x6-node foreignObject > body { min-height: 100%; max-height: 100%; -} -#scrollableDiv .infinite-scroll-component__outerdiv { - height: 100%; } \ No newline at end of file diff --git a/web/src/views/ApiKeyManagement/components/ApiKeyDetailModal.tsx b/web/src/views/ApiKeyManagement/components/ApiKeyDetailModal.tsx index 2899a306..f9e1df51 100644 --- a/web/src/views/ApiKeyManagement/components/ApiKeyDetailModal.tsx +++ b/web/src/views/ApiKeyManagement/components/ApiKeyDetailModal.tsx @@ -1,7 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 15:52:44 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 10:00:02 + */ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Switch, Button, Tooltip } from 'antd'; import clsx from 'clsx'; import { useTranslation } from 'react-i18next'; + import type { ApiKey, ApiKeyModalRef } from '../types'; import RbModal from '@/components/RbModal' import { getApiKey } from '@/api/apiKey'; @@ -9,16 +16,29 @@ import { formatDateTime } from '@/utils/format' import Tag from '@/components/Tag' import { maskApiKeys } from '@/utils/apiKeyReplacer'; +/** + * Modal component for viewing API key details + * Displays read-only information about an API key + */ const ApiKeyDetailModal = forwardRef void }>(({ handleCopy }, ref) => { + // Hooks const { t } = useTranslation(); + + // State const [visible, setVisible] = useState(false); const [data, setData] = useState({} as ApiKey) - // 封装取消方法,添加关闭弹窗逻辑 + /** + * Close the modal + */ const handleClose = () => { setVisible(false); }; + /** + * Open modal and fetch API key details + * @param apiKey - API key item to view + */ const handleOpen = (apiKey?: ApiKey) => { if (apiKey?.id) { getApiKey(apiKey.id) @@ -29,7 +49,9 @@ const ApiKeyDetailModal = forwardRef ({ handleOpen, handleClose @@ -84,7 +106,6 @@ const ApiKeyDetailModal = forwardRef - {/* 高级设置 */} {data.expires_at && <>
{t('apiKey.advancedSettings')}
diff --git a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx index f0bf4e11..9395df43 100644 --- a/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx +++ b/web/src/views/ApiKeyManagement/components/ApiKeyModal.tsx @@ -1,28 +1,48 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 15:52:47 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 10:00:01 + */ import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input, Switch, App, DatePicker } from 'antd'; import { useTranslation } from 'react-i18next'; +import dayjs from 'dayjs' + import type { ApiKey, ApiKeyModalRef } from '../types'; import RbModal from '@/components/RbModal' -import dayjs from 'dayjs' import { createApiKey, updateApiKey } from '@/api/apiKey'; const FormItem = Form.Item; +/** + * Props for ApiKeyModal component + */ interface CreateModalProps { + /** Callback to refresh parent list after save */ refresh: () => void; } +/** + * Modal component for creating or editing API keys + * Handles API key configuration including permissions and expiration + */ const ApiKeyModal = forwardRef(({ refresh, }, ref) => { + // Hooks const { t } = useTranslation(); const { message } = App.useApp(); - const [visible, setVisible] = useState(false); const [form] = Form.useForm(); + + // State + const [visible, setVisible] = useState(false); const [loading, setLoading] = useState(false); const [editVo, setEditVo] = useState(null); - // 封装取消方法,添加关闭弹窗逻辑 + /** + * Close modal and reset form state + */ const handleClose = () => { setVisible(false); form.resetFields(); @@ -30,10 +50,14 @@ const ApiKeyModal = forwardRef(({ setEditVo(null); }; + /** + * Open modal for creating or editing + * @param apiKey - Optional API key data for edit mode + */ const handleOpen = (apiKey?: ApiKey) => { if (apiKey?.id) { const { scopes = [], expires_at } = apiKey - // 编辑模式,填充表单 + // Edit mode - populate form with existing data form.setFieldsValue({ name: apiKey.name, description: apiKey.description, @@ -46,7 +70,10 @@ const ApiKeyModal = forwardRef(({ setVisible(true); }; - // 封装保存方法,添加提交逻辑 + /** + * Validate and submit form data + * Creates new API key or updates existing one + */ const handleSave = async () => { form.validateFields() .then((values) => { @@ -59,7 +86,7 @@ const ApiKeyModal = forwardRef(({ if (rag) { scopes.push('rag') } - // 准备新的/更新的API Key数据 + // Prepare new/updated API key data const apiKeyData = { ...rest, scopes, @@ -78,7 +105,9 @@ const ApiKeyModal = forwardRef(({ }) } - // 暴露给父组件的方法 + /** + * Expose methods to parent component via ref + */ useImperativeHandle(ref, () => ({ handleOpen, handleClose @@ -133,7 +162,6 @@ const ApiKeyModal = forwardRef(({ - {/* 高级设置 */}
{t('apiKey.advancedSettings')}
{ + // Hooks const { t } = useTranslation(); const { modal, message } = App.useApp(); + + // Refs const apiKeyModalRef = useRef(null); const apiKeyDetailModalRef = useRef(null) const scrollListRef = useRef(null) + /** + * Refresh the API key list + */ const refresh = () => { scrollListRef.current?.refresh(); } + /** + * Open modal to create or edit API key + * @param item - Optional API key item for edit mode + */ const handleEdit = (item?: ApiKey) => { apiKeyModalRef.current?.handleOpen(item); } + + /** + * Open modal to view API key details + * @param item - API key item to view + */ const handleView = (item: ApiKey) => { apiKeyDetailModalRef.current?.handleOpen(item); } + + /** + * Delete API key with confirmation + * @param item - API key item to delete + */ const handleDelete = (item: ApiKey) => { modal.confirm({ title: t('common.confirmDeleteDesc', { name: item.name }), @@ -46,6 +77,10 @@ const ApiKeyManagement: React.FC = () => { } }) } + /** + * Copy content to clipboard + * @param content - Content to copy + */ const handleCopy = (content: string) => { copy(content) message.success(t('common.copySuccess')) diff --git a/web/src/views/ApiKeyManagement/types.ts b/web/src/views/ApiKeyManagement/types.ts index 2df67193..4ea1de0d 100644 --- a/web/src/views/ApiKeyManagement/types.ts +++ b/web/src/views/ApiKeyManagement/types.ts @@ -1,39 +1,76 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 15:52:53 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 15:52:53 + */ import type { Dayjs } from 'dayjs' import { maskApiKeys } from '@/utils/apiKeyReplacer' +/** + * API Key data structure + */ export interface ApiKey { + /** Unique identifier */ id: string; + /** API key name */ name: string; + /** Optional description */ description?: string; + /** API key type */ type: 'agent' | 'multi_agent' | 'workflow' | 'service'; - scopes?: string[]; // 'memory' | 'rag' | 'app' + /** Permission scopes: 'memory' | 'rag' | 'app' */ + scopes?: string[]; + /** The actual API key string */ api_key: string; + /** Whether the key is active */ is_active: boolean; + /** Whether the key has expired */ is_expired: boolean; + /** Creation timestamp */ created_at: number; + /** Expiration timestamp or Dayjs object */ expires_at?: number | Dayjs; + /** Memory engine permission flag */ memory?: boolean; + /** RAG/Knowledge base permission flag */ rag?: boolean; - + /** Last update timestamp */ updated_at: string; + /** Queries per second limit */ qps_limit?: number; + /** Daily request limit */ daily_request_limit?: number; + /** Rate limit */ rate_limit?: number; + /** Total number of requests made */ total_requests: number; + /** Quota used */ quota_used: number; + /** Quota limit */ quota_limit: number; } +/** + * Ref methods exposed by API Key modal components + */ export interface ApiKeyModalRef { + /** + * Open the modal + * @param apiKey - Optional API key data for edit mode + */ handleOpen: (apiKey?: ApiKey) => void; + /** Close the modal */ handleClose: () => void; } /** - * 获取掩码后的API密钥 + * Get masked API key for display + * @param apiKey - The API key to mask + * @returns Masked API key string */ export const getMaskedApiKey = (apiKey: string): string => { return maskApiKeys(apiKey) diff --git a/web/src/views/ApiParameters/index.tsx b/web/src/views/ApiParameters/index.tsx deleted file mode 100644 index b5ad6b4a..00000000 --- a/web/src/views/ApiParameters/index.tsx +++ /dev/null @@ -1,75 +0,0 @@ -import { useTranslation } from 'react-i18next'; -import { type FC, useEffect, useState } from 'react'; -import { Row, Col, Skeleton } from 'antd' -import CodeBlock from '@/components/Markdown/CodeBlock'; -import { getMemoryApi } from '@/api/memory'; -import RbCard from '@/components/RbCard/Card'; -import type { - Data, - Section -} from './types'; -import Empty from '@/components/Empty' - - -const ApiParameters: FC = () => { - const { t } = useTranslation(); - const [loading, setLoading] = useState(false) - // const [data, setData] = useState(null) - const [apiData, setApiData] = useState([]) - - useEffect(() => { - getApiData() - }, []) - const getApiData = () => { - setLoading(true) - getMemoryApi().then((res) => { - const resp = res as Data || {} - // setData(resp) - setApiData(resp.sections || []) - }) - .finally(() => setLoading(false)) - } - - return ( -
-

{t('api.pageTitle')}

-

{t('api.pageSubTitle')}

- - {loading - ? - : apiData.length === 0 - ? - : - {apiData.map((api, index) => ( - - - <> -
- {api.method} - {api.path} -
- {api.desc &&<> -
{t('api.desc')}
-
{api.desc}
- } - - {typeof api.input === 'string' && api.input !== '无' && <> -
{t('api.input')}
- - } - {typeof api.output === 'string' && api.output !== '无' && <> -
{t('api.output')}
- - } - -
- - ))} -
- } -
- ); -}; -export default ApiParameters; \ No newline at end of file diff --git a/web/src/views/ApiParameters/types.ts b/web/src/views/ApiParameters/types.ts deleted file mode 100644 index 56d516be..00000000 --- a/web/src/views/ApiParameters/types.ts +++ /dev/null @@ -1,22 +0,0 @@ -export interface Section { - name: string; - path: string; - method: string; - input: string; - output: string; - desc: string; -} -export interface Data { - title: string; - meta: { - search_switch: { - value: string; - desc: string; - }[]; - status_code: { - code: string; - desc: string; - }[]; - } - sections: Section[] -} \ No newline at end of file diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 0e9e8b44..0bfd4ba7 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -1,8 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:29:21 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 20:16:45 + */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom'; import { Row, Col, Space, Form, Input, Switch, Button, App, Spin } from 'antd' + import Chat from './components/Chat' import RbCard from '@/components/RbCard/Card' import Card from './components/Card' @@ -32,7 +39,14 @@ import aiPrompt from '@/assets/images/application/aiPrompt.png' import AiPromptModal from './components/AiPromptModal' import ToolList from './components/ToolList/ToolList' import ChatVariableConfigModal from './components/ChatVariableConfigModal'; +import SkillList from './components/Skill' +import type { Skill } from '@/views/Skills/types' +/** + * Description wrapper component + * @param desc - Description text + * @param className - Additional CSS classes + */ const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => { return (
@@ -40,6 +54,12 @@ const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className})
) } +/** + * Label wrapper component + * @param title - Label title + * @param className - Additional CSS classes + * @param children - Child elements + */ const LabelWrapper: FC<{title: string, className?: string; children?: ReactNode}> = ({title, className, children}) => { return (
@@ -48,6 +68,13 @@ const LabelWrapper: FC<{title: string, className?: string; children?: ReactNode}
) } +/** + * Switch wrapper component with label and description + * @param title - Switch title + * @param desc - Optional description + * @param name - Form field name + * @param needTransition - Whether to translate text + */ const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[]; needTransition?: boolean; }> = ({ title, desc, name, needTransition = true }) => { const { t } = useTranslation(); return ( @@ -65,6 +92,13 @@ const SwitchWrapper: FC<{ title: string, desc?: string, name: string | string[]; ) } +/** + * Select wrapper component with label and description + * @param title - Select title + * @param desc - Description text + * @param name - Form field name + * @param url - API URL for options + */ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], url: string }> = ({ title, desc, name, url }) => { const { t } = useTranslation(); return ( @@ -88,6 +122,10 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], ) } +/** + * Agent configuration component + * Manages single agent configuration including prompts, knowledge, memory, variables, and tools + */ const Agent = forwardRef((_props, ref) => { const { t } = useTranslation() const { id } = useParams(); @@ -103,7 +141,7 @@ const Agent = forwardRef((_props, ref) => { const [isSave, setIsSave] = useState(false) const initialized = useRef(false) - // 初始化完成标记 + // Initialization flag useEffect(() => { if (data) { initialized.current = true @@ -121,10 +159,15 @@ const Agent = forwardRef((_props, ref) => { getData() }, []) + /** + * Fetch agent configuration data + */ const getData = () => { setLoading(true) getApplicationConfig(id as string).then(res => { const response = res as Config + const { skills } = response + let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] let allTools = Array.isArray(response.tools) ? response.tools : [] const memoryContent = response.memory?.memory_content const parsedMemoryContent = memoryContent === null || memoryContent === '' @@ -136,6 +179,10 @@ const Agent = forwardRef((_props, ref) => { memory: { ...response.memory, memory_content: parsedMemoryContent + }, + skills: { + ...skills, + skill_ids: allSkills } }) setData({ @@ -147,6 +194,11 @@ const Agent = forwardRef((_props, ref) => { }) } + /** + * Refresh configuration after model changes + * @param vo - Model configuration + * @param type - Source type (model or chat) + */ const refresh = (vo: ModelConfig, type: Source) => { if (type === 'model') { const { default_model_config_id, ...rest } = vo @@ -188,20 +240,30 @@ const Agent = forwardRef((_props, ref) => { } } + /** + * Open model configuration modal + */ const handleModelConfig = () => { modelConfigModalRef.current?.handleOpen('model') } + /** + * Clear all debugging chat sessions + */ const handleClearDebugging = () => { setChatList([]) } - // 保存Agent配置 + /** + * Save agent configuration + * @param flag - Whether to show success message + * @returns Promise that resolves when save is complete + */ const handleSave = (flag = true) => { if (!isSave || !data) return Promise.resolve() - const { memory, knowledge_retrieval, tools, ...rest } = values + const { memory, knowledge_retrieval, tools, skills, ...rest } = values const { knowledge_bases = [], ...knowledgeRest } = knowledge_retrieval || {} const { memory_content } = memory || {} - // 从原数据中获取memory的其他必要属性 + // Get other necessary properties of memory from original data const originalMemory = data.memory || ({} as MemoryConfig) const params: Config = { @@ -224,7 +286,11 @@ const Agent = forwardRef((_props, ref) => { tool_id: vo.tool_id, operation: vo.operation, enabled: vo.enabled - })) + })), + skills: { + ...skills, + skill_ids: (skills?.skill_ids as Skill[])?.map(vo => vo.id) + } } return new Promise((resolve, reject) => { @@ -240,6 +306,9 @@ const Agent = forwardRef((_props, ref) => { }) }) } + /** + * Fetch available models list + */ const getModels = () => { getModelList({ type: 'llm,chat', pagesize: 100, page: 1, is_active: true }) .then(res => { @@ -247,6 +316,9 @@ const Agent = forwardRef((_props, ref) => { setModelList(response.items) }) } + /** + * Add new model for debugging + */ const handleAddModel = () => { modelConfigModalRef.current?.handleOpen('chat') } @@ -268,9 +340,16 @@ const Agent = forwardRef((_props, ref) => { })) const aiPromptModalRef = useRef(null) + /** + * Open AI prompt generation modal + */ const handlePrompt = () => { aiPromptModalRef.current?.handleOpen() } + /** + * Update prompt and extract variables + * @param value - New prompt value + */ const updatePrompt = (value: string) => { form.setFieldValue('system_prompt', value) const variables = value.match(/\{\{([^}]+)\}\}/g)?.map(match => match.slice(2, -2)) || [] @@ -285,15 +364,26 @@ const Agent = forwardRef((_props, ref) => { updateVariableList(newVariableList) } + /** + * Update variable list + * @param list - New variable list + */ const updateVariableList = (list: Variable[]) => { form.setFieldValue('variables', [...list]) setChatVariables([...list]) } const chatVariableConfigModalRef = useRef(null) const [chatVariables, setChatVariables] = useState([]) + /** + * Open chat variable configuration modal + */ const handleOpenVariableConfig = () => { chatVariableConfigModalRef.current?.handleOpen(chatVariables) } + /** + * Save chat variable configuration + * @param values - Variable values + */ const handleSaveChatVariable = (values: Variable[]) => { setChatVariables(values) } @@ -347,7 +437,7 @@ const Agent = forwardRef((_props, ref) => { - {/* 记忆配置 */} + {/* Memory Configuration */} @@ -360,12 +450,16 @@ const Agent = forwardRef((_props, ref) => { - + - - {/* 工具配置 */} - + + + + + + {/* Tool Configuration */} + diff --git a/web/src/views/ApplicationConfig/Api.tsx b/web/src/views/ApplicationConfig/Api.tsx index ab33ba19..c4b0fefb 100644 --- a/web/src/views/ApplicationConfig/Api.tsx +++ b/web/src/views/ApplicationConfig/Api.tsx @@ -1,3 +1,9 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:29:29 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:29:29 + */ import { type FC, useState, useRef, useEffect } from 'react'; import clsx from 'clsx'; import { useTranslation } from 'react-i18next'; @@ -14,6 +20,11 @@ import Tag from '@/components/Tag' import { getApiKeyList, getApiKeyStats, deleteApiKey } from '@/api/apiKey'; import { maskApiKeys } from '@/utils/apiKeyReplacer' +/** + * API configuration page component + * Manages API endpoints and API keys for the application + * @param application - Current application data + */ const Api: FC<{ application: Application | null }> = ({ application }) => { const { t } = useTranslation(); const activeMethods = ['POST']; @@ -23,6 +34,10 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { const apiKeyConfigModalRef = useRef(null); const [apiKeyList, setApiKeyList] = useState([]) + /** + * Copy content to clipboard + * @param content - Content to copy + */ const handleCopy = (content: string) => { copy(content) message.success(t('common.copySuccess')) @@ -31,6 +46,9 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { useEffect(() => { getApiList() }, []) + /** + * Fetch API key list for the application + */ const getApiList = () => { if (!application) { return @@ -48,6 +66,10 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { getAllStats([...list]) }) } + /** + * Fetch statistics for all API keys + * @param list - List of API keys + */ const getAllStats = (list: ApiKey[]) => { const allList: ApiKey[] = [] list.forEach(async item => { @@ -66,12 +88,23 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { }) } + /** + * Open modal to add new API key + */ const handleAdd = () => { apiKeyModalRef.current?.handleOpen() } + /** + * Open modal to edit API key + * @param vo - API key to edit + */ const handleEdit = (vo: ApiKey) => { apiKeyConfigModalRef.current?.handleOpen(vo) } + /** + * Delete API key with confirmation + * @param vo - API key to delete + */ const handleDelete = (vo: ApiKey) => { modal.confirm({ title: t('common.confirmDeleteDesc', { name: vo.name }), @@ -89,7 +122,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { }) } - // 计算total_requests总数 + // Calculate total requests across all API keys const totalRequests = apiKeyList.reduce((total, item) => total + item.total_requests, 0); return (
@@ -129,7 +162,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { } >
{t('application.apiKeySubTitle')}
- {/* 总览数据 */} + {/* Overview Data */} @@ -138,7 +171,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { - {/* API Key 列表 */} + {/* API Key List */} {apiKeyList.sort((a, b) => b.created_at - a.created_at).map(item => (
diff --git a/web/src/views/ApplicationConfig/Cluster.tsx b/web/src/views/ApplicationConfig/Cluster.tsx index aa4a5d98..2688eaae 100644 --- a/web/src/views/ApplicationConfig/Cluster.tsx +++ b/web/src/views/ApplicationConfig/Cluster.tsx @@ -1,8 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:29:33 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:29:33 + */ import { useEffect, useState, useRef, forwardRef, useImperativeHandle } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom'; -import Card from './components/Card' import { Form, Space, Row, Col, Button, Flex, App, Select } from 'antd' + +import Card from './components/Card' import Tag, { type TagProps } from './components/Tag' import CustomSelect from '@/components/CustomSelect'; import { getMultiAgentConfig, saveMultiAgentConfig, getApplicationList } from '@/api/application'; @@ -26,6 +33,10 @@ import type { Application } from '@/views/ApplicationManagement/types' const tagColors = ['processing', 'warning', 'default'] const MAX_LENGTH = 5; +/** + * Multi-agent cluster configuration component + * Manages multi-agent orchestration, sub-agents, and collaboration modes + */ const Cluster = forwardRef((_props, ref) => { const { t } = useTranslation() const { message } = App.useApp() @@ -41,6 +52,11 @@ const Cluster = forwardRef((_props, ref) => { }, ]) + /** + * Save cluster configuration + * @param flag - Whether to show success message + * @returns Promise that resolves when save is complete + */ const handleSave = (flag = true) => { if (!data) return Promise.resolve() if (!values.default_model_config_id && values.orchestration_mode === 'supervisor') { @@ -80,6 +96,9 @@ const Cluster = forwardRef((_props, ref) => { getData() }, [id]) + /** + * Fetch cluster configuration data + */ const getData = () => { if (!id) { return @@ -113,9 +132,17 @@ const Cluster = forwardRef((_props, ref) => { } }) } + /** + * Open sub-agent modal for add or edit + * @param agent - Optional agent data for edit mode + */ const handleSubAgentModal = (agent?: SubAgentItem) => { subAgentModalRef.current?.handleOpen(agent) } + /** + * Refresh sub-agents list after add or edit + * @param agent - Agent data to add or update + */ const refreshSubAgents = (agent: SubAgentItem) => { const index = subAgents.findIndex(item => item.agent_id === agent.agent_id) const newSubAgents = [...subAgents] @@ -130,6 +157,10 @@ const Cluster = forwardRef((_props, ref) => { setSubAgents(newSubAgents) } } + /** + * Delete sub-agent from list + * @param agent - Agent to delete + */ const handleDeleteSubAgent = (agent: SubAgentItem) => { setSubAgents(prev => prev.filter(item => item.agent_id !== agent.agent_id)) } @@ -138,9 +169,16 @@ const Cluster = forwardRef((_props, ref) => { })) const modelConfigModalRef = useRef(null) + /** + * Open model configuration modal + */ const handleEditModelConfig = () => { modelConfigModalRef.current?.handleOpen('multi_agent', values.model_parameters) } + /** + * Save model configuration + * @param values - Model parameters + */ const handleSaveModelConfig = (values: Config['model_parameters']) => { form.setFieldsValue({ model_parameters: values diff --git a/web/src/views/ApplicationConfig/ReleasePage.tsx b/web/src/views/ApplicationConfig/ReleasePage.tsx index ae550d36..63f6df71 100644 --- a/web/src/views/ApplicationConfig/ReleasePage.tsx +++ b/web/src/views/ApplicationConfig/ReleasePage.tsx @@ -1,7 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:29:41 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:29:41 + */ import { type FC, useState, useEffect, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx'; import { Button, Space, Input, Form, App } from 'antd'; + import Tag, { type TagProps } from './components/Tag' import RbCard from '@/components/RbCard/Card' import { getReleaseList, rollbackRelease } from '@/api/application' @@ -12,12 +19,21 @@ import type { Application } from '@/views/ApplicationManagement/types' import Empty from '@/components/Empty' import { formatDateTime } from '@/utils/format'; import Markdown from '@/components/Markdown' +/** + * Tag color mapping for release versions + */ const tagColors: Record = { current: 'processing', rolledBack: 'warning', history: 'default', } +/** + * Release page component + * Manages application version releases, rollbacks, and version history + * @param data - Application data + * @param refresh - Function to refresh application data + */ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refresh}) => { const { t } = useTranslation(); const { message } = App.useApp() @@ -30,6 +46,9 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres getData() }, [data.id]) + /** + * Fetch release list data + */ const getData = () => { refresh() getReleaseList(data.id).then(res => { @@ -38,6 +57,9 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres setSelectedVersion(response?.[0]) }) } + /** + * Rollback to selected version + */ const handleRollback = () => { if (!selectedVersion) return rollbackRelease(data.id, selectedVersion.version).then(() => { @@ -124,7 +146,7 @@ const ReleasePage: FC<{data: Application; refresh: () => void}> = ({data, refres
- {/* 日志 */} + {/* Logs */} {selectedVersion && ( diff --git a/web/src/views/ApplicationConfig/Statistics.tsx b/web/src/views/ApplicationConfig/Statistics.tsx index 8a76ab06..0c2c4b54 100644 --- a/web/src/views/ApplicationConfig/Statistics.tsx +++ b/web/src/views/ApplicationConfig/Statistics.tsx @@ -1,3 +1,9 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:29:45 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:29:45 + */ import { type FC, useState, useEffect } from 'react'; import { Row, Col, Flex, DatePicker } from 'antd'; import type { Dayjs } from 'dayjs' @@ -10,12 +16,21 @@ import { getAppStatistics } from '@/api/application'; import LineCard from './components/LineCard' import type { StatisticsData, StatisticsItem } from './types' +/** + * Mapping of daily statistics keys to total statistics keys + */ const TotalObj: Record = { daily_conversations: 'total_conversations', daily_new_users: 'total_new_users', daily_api_calls: 'total_api_calls', daily_tokens: 'total_tokens', } + +/** + * Statistics page component + * Displays application usage statistics with charts and date range filtering + * @param application - Application data + */ const Statistics: FC<{ application: Application | null }> = ({ application }) => { const [data, setData] = useState({ daily_conversations: [], @@ -35,6 +50,9 @@ const Statistics: FC<{ application: Application | null }> = ({ application }) => useEffect(() => { getData() }, [application, query]) + /** + * Fetch statistics data + */ const getData = () => { if (!application?.id) { return @@ -49,6 +67,10 @@ const Statistics: FC<{ application: Application | null }> = ({ application }) => setData(res as StatisticsData) }) } + /** + * Handle date range change + * @param date - Selected date range + */ const handleChange = (date: [Dayjs | null, Dayjs | null] | null) => { if (!date || !date[0] || !date[1]) return setQuery({ diff --git a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx index 198460eb..6a4a50b1 100644 --- a/web/src/views/ApplicationConfig/components/AiPromptModal.tsx +++ b/web/src/views/ApplicationConfig/components/AiPromptModal.tsx @@ -1,3 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:26:44 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-05 10:31:12 + */ +/** + * AI Prompt Assistant Modal + * Provides an interactive chat interface to help users optimize their prompts using AI + * Features model selection, chat history, and variable insertion + */ + import { forwardRef, useImperativeHandle, useState, useRef } from 'react'; import { Button, Form, Input, App, Row, Col } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -19,14 +31,25 @@ import AiPromptVariableModal from './AiPromptVariableModal' import { type SSEMessage } from '@/utils/stream' import Editor from './Editor' +/** + * Component props + */ interface AiPromptModalProps { + /** Callback to refresh prompt with optimized value */ refresh: (value: string) => void; - defaultModel: ModelListItem | null; + /** Default model to pre-select */ + defaultModel?: ModelListItem | null; + source?: 'app' | 'skills' } +/** + * AI Prompt Assistant Modal Component + * Helps users create and optimize prompts through AI-powered conversation + */ const AiPromptModal = forwardRef(({ refresh, defaultModel, + source = 'application' }, ref) => { const { t } = useTranslation(); const { message } = App.useApp() @@ -42,7 +65,7 @@ const AiPromptModal = forwardRef(({ const values = Form.useWatch([], form) - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset state */ const handleClose = () => { setVisible(false); setLoading(false) @@ -54,6 +77,7 @@ const AiPromptModal = forwardRef(({ }) }; + /** Open modal and create new prompt session */ const handleOpen = () => { createPromptSessions() .then(res => { @@ -66,14 +90,15 @@ const AiPromptModal = forwardRef(({ setVisible(true); }) }; + /** Send user message and get AI response */ const handleSend = () => { if (!promptSession) return if (!values.model_id) { - message.warning(t('common.selectPlaceholder', { title: t('application.model') })) + message.warning(t('common.selectPlaceholder', { title: t(`${source}.model`) })) return } if (!values.message) { - message.warning(t('application.promptChatPlaceholder')) + message.warning(t(`${source}.promptChatPlaceholder`)) return } const messageContent = values.message @@ -115,33 +140,31 @@ const AiPromptModal = forwardRef(({ break; case 'end': setLoading(false) - // 流结束时同步表单值 + // Sync form value when stream ends form.setFieldsValue({ current_prompt: currentPromptValueRef.current }) break } }) }; - updatePromptMessages(promptSession, values, handleStreamMessage) - // .then(res => { - // const response = res as { prompt: string; desc: string; variables: string[] } - // form.setFieldsValue({ current_prompt: response.prompt }) - // setChatList(prev => { - // return [...prev, { role: 'assistant', content: response.desc }] - // }) - // setVariables(response.variables) - // }) + updatePromptMessages(promptSession, { + ...values, + skill: source === 'skills' + }, handleStreamMessage) .finally(() => { setLoading(false) }) } + /** Copy current prompt to clipboard */ const handleCopy = () => { if (!values.current_prompt || values?.current_prompt?.trim() === '') return copy(values.current_prompt) message.success(t('common.copySuccess')) } + /** Open variable selection modal */ const handleAdd = () => { aiPromptVariableModalRef.current?.handleOpen() } + /** Insert variable into prompt editor */ const handleVariableApply = (value: string) => { if (editorRef.current?.insertText) { editorRef.current.insertText(value) @@ -149,6 +172,7 @@ const AiPromptModal = forwardRef(({ form.setFieldValue('current_prompt', (values.current_prompt || '') + value) } } + /** Apply optimized prompt and close modal */ const handleApply = () => { if (!values.current_prompt) { return @@ -157,7 +181,7 @@ const AiPromptModal = forwardRef(({ handleClose() } - // 暴露给父组件的方法 + /** Expose methods to parent component */ useImperativeHandle(ref, () => ({ handleOpen, })); @@ -165,7 +189,7 @@ const AiPromptModal = forwardRef(({ console.log(values) return ( (({
@@ -192,18 +216,18 @@ const AiPromptModal = forwardRef(({ } + empty={} data={chatList || []} streamLoading={false} labelPosition="top" - labelFormat={(item) => item.role === 'user' ? t('application.you') : t('application.ai')} + labelFormat={(item) => item.role === 'user' ? t(`${source}.you`) : t(`${source}.ai`)} />
@@ -215,12 +239,12 @@ const AiPromptModal = forwardRef(({
- - - - - + + + {source === 'application' && + + } (({
- +
diff --git a/web/src/views/ApplicationConfig/components/AiPromptVariableModal.tsx b/web/src/views/ApplicationConfig/components/AiPromptVariableModal.tsx index 61847d9a..6e8e617a 100644 --- a/web/src/views/ApplicationConfig/components/AiPromptVariableModal.tsx +++ b/web/src/views/ApplicationConfig/components/AiPromptVariableModal.tsx @@ -1,25 +1,42 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:14 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:14 + */ +/** + * AI Prompt Variable Modal + * Allows users to insert variables into AI-generated prompts + * Supports autocomplete with existing variables + */ + import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; -import { Form, Input, App, Select, AutoComplete, type AutoCompleteProps } from 'antd'; +import { Form, AutoComplete, type AutoCompleteProps } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { Application } from '@/views/ApplicationManagement/types' import type { AiPromptVariableModalRef } from '../types' -import { createApiKey } from '@/api/apiKey'; import RbModal from '@/components/RbModal' const FormItem = Form.Item; +/** + * Component props + */ interface AiPromptVariableModalProps { + /** Callback to insert variable into prompt */ refresh: (value: string) => void; + /** List of available variables */ variables: string[]; } +/** + * Variable selection modal for AI prompt assistant + */ const AiPromptVariableModal = forwardRef(({ refresh, variables }, ref) => { const { t } = useTranslation(); - const { message } = App.useApp(); const [visible, setVisible] = useState(false); const [form] = Form.useForm(); const [loading, setLoading] = useState(false) @@ -31,6 +48,7 @@ const AiPromptVariableModal = forwardRef { const filterKeys = variables?.filter(key => key.includes(value)) @@ -47,18 +65,19 @@ const AiPromptVariableModal = forwardRef { setVisible(false); form.resetFields(); setLoading(false) }; + /** Open modal */ const handleOpen = () => { setVisible(true); form.resetFields(); }; - // 封装保存方法,添加提交逻辑 + /** Apply selected variable */ const handleSave = () => { const variableName = form.getFieldValue('variableName') @@ -68,7 +87,7 @@ const AiPromptVariableModal = forwardRef ({ handleOpen, handleClose diff --git a/web/src/views/ApplicationConfig/components/ApiKeyConfigModal.tsx b/web/src/views/ApplicationConfig/components/ApiKeyConfigModal.tsx index 1b4f3f6e..f4751c88 100644 --- a/web/src/views/ApplicationConfig/components/ApiKeyConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/ApiKeyConfigModal.tsx @@ -1,3 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:22 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:22 + */ +/** + * API Key Configuration Modal + * Allows configuring rate limits and daily usage limits for API keys + */ + import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Slider } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -7,10 +18,17 @@ import RbModal from '@/components/RbModal' import { updateApiKey } from '@/api/apiKey'; import type { ApiKey } from '@/views/ApiKeyManagement/types' +/** + * Component props + */ interface ApiKeyConfigModalProps { + /** Callback to refresh API key list */ refresh: () => void; } -const ApiKeyConfigModal = forwardRef(({ + +/** + * Modal for configuring API key limits + */const ApiKeyConfigModal = forwardRef(({ refresh }, ref) => { const { t } = useTranslation(); @@ -20,7 +38,7 @@ const ApiKeyConfigModal = forwardRef([], form) const [editVo, setEditVo] = useState(null) - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset state */ const handleClose = () => { form.resetFields(); setLoading(false) @@ -28,6 +46,7 @@ const ApiKeyConfigModal = forwardRef { setVisible(true); setEditVo(apiKey) @@ -36,7 +55,7 @@ const ApiKeyConfigModal = forwardRef { if (!editVo?.id) return form.validateFields() @@ -52,7 +71,7 @@ const ApiKeyConfigModal = forwardRef ({ handleOpen, handleClose @@ -73,7 +92,7 @@ const ApiKeyConfigModal = forwardRef - {/* QPS 限制(每秒请求数) */} + {/* QPS limit (requests per second) */} <>
{t(`application.qpsLimit`)}({t('application.qpsLimitTip')}) @@ -98,7 +117,7 @@ const ApiKeyConfigModal = forwardRef
- {/* 日调用量限制 */} + {/* Daily usage limit */} <>
{t(`application.dailyUsageLimit`)} diff --git a/web/src/views/ApplicationConfig/components/ApiKeyModal.tsx b/web/src/views/ApplicationConfig/components/ApiKeyModal.tsx index 54740436..b43f0e4a 100644 --- a/web/src/views/ApplicationConfig/components/ApiKeyModal.tsx +++ b/web/src/views/ApplicationConfig/components/ApiKeyModal.tsx @@ -1,3 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:25 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:25 + */ +/** + * API Key Creation Modal + * Allows creating new API keys for application access + */ + import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input, App } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -9,11 +20,19 @@ import RbModal from '@/components/RbModal' const FormItem = Form.Item; +/** + * Component props + */ interface ApiKeyModalProps { + /** Callback to refresh API key list */ refresh: () => void; + /** Application data */ application?: Application | null; } +/** + * Modal for creating new API keys + */ const ApiKeyModal = forwardRef(({ refresh, application @@ -24,18 +43,19 @@ const ApiKeyModal = forwardRef(({ const [form] = Form.useForm(); const [loading, setLoading] = useState(false) - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset form */ const handleClose = () => { setVisible(false); form.resetFields(); setLoading(false) }; + /** Open modal */ const handleOpen = () => { setVisible(true); form.resetFields(); }; - // 封装保存方法,添加提交逻辑 + /** Create new API key */ const handleSave = () => { if (!application) return form.validateFields() @@ -58,7 +78,7 @@ const ApiKeyModal = forwardRef(({ }) } - // 暴露给父组件的方法 + /** Expose methods to parent component */ useImperativeHandle(ref, () => ({ handleOpen, handleClose @@ -78,7 +98,7 @@ const ApiKeyModal = forwardRef(({ layout="vertical" scrollToFirstError={{ behavior: 'instant', block: 'end', focus: true }} > - {/* Key 名称 */} + {/* Key name */} (({ > - {/* 描述 */} + {/* Description */} = ({ title, subTitle, children, extra, + variant }) => { return ( {children} diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index bd826ba1..716f3cc0 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -1,7 +1,20 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:39 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:39 + */ +/** + * Chat debugging component for application testing + * Supports both single agent and multi-agent cluster modes + * Provides real-time streaming responses and conversation history + */ + import { type FC, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' import { Input, Form } from 'antd' + import ChatIcon from '@/assets/images/application/chat.png' import ChatSendIcon from '@/assets/images/application/chatSend.svg' import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png' @@ -12,13 +25,26 @@ import ChatContent from '@/components/Chat/ChatContent' import type { ChatItem } from '@/components/Chat/types' import { type SSEMessage } from '@/utils/stream' +/** + * Component props + */ interface ChatProps { + /** List of chat configurations for comparison */ chatList: ChatData[]; + /** Application configuration data */ data: Config; + /** Update chat list state */ updateChatList: React.Dispatch>; + /** Save configuration before running */ handleSave: (flag?: boolean) => Promise; + /** Source type: multi-agent cluster or single agent */ source?: 'multi_agent' | 'agent'; } + +/** + * Chat debugging component + * Allows testing application with different model configurations side-by-side + */ const Chat: FC = ({ chatList, data, updateChatList, handleSave, source = 'agent' }) => { const { t } = useTranslation(); const [form] = Form.useForm<{ message: string }>() @@ -31,6 +57,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc setIsCluster(source === 'multi_agent') }, [source]) + /** Add user message to all chat lists */ const addUserMessage = (message: string) => { const newUserMessage: ChatItem = { role: 'user', @@ -42,6 +69,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc list: [...(item.list || []), newUserMessage] }))) } + /** Add empty assistant message placeholder */ const addAssistantMessage = () => { const assistantMessage: ChatItem = { role: 'assistant', @@ -65,6 +93,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }))) } } + /** Update assistant message with streaming content */ const updateAssistantMessage = (content?: string, model_config_id?: string, conversation_id?: string) => { if (!content || !model_config_id) return updateChatList(prev => { @@ -92,6 +121,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return prev; }) } + /** Update assistant message when error occurs */ const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { if (message_length > 0 || !model_config_id) return @@ -120,6 +150,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return prev }) } + /** Send message for agent comparison mode */ const handleSend = () => { if (loading) return setLoading(true) @@ -176,6 +207,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } + /** Add assistant message for cluster mode */ const addClusterAssistantMessage = () => { const assistantMessage: ChatItem = { role: 'assistant', @@ -187,6 +219,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc list: [...(item.list || []), assistantMessage] }))) } + /** Update cluster assistant message with content */ const updateClusterAssistantMessage = (content?: string) => { if (!content) return updateChatList(prev => { @@ -209,6 +242,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return [...modelChatList] }) } + /** Update cluster message when error occurs */ const updateClusterErrorAssistantMessage = (message_length: number) => { if (message_length > 0) return @@ -232,6 +266,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return [...modelChatList] }) } + /** Send message for cluster mode */ const handleClusterSend = () => { if (loading) return setLoading(true) @@ -291,6 +326,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } + /** Delete chat configuration from list */ const handleDelete = (index: number) => { updateChatList(chatList.filter((_, voIndex) => voIndex !== index)) } diff --git a/web/src/views/ApplicationConfig/components/ChatVariableConfigModal.tsx b/web/src/views/ApplicationConfig/components/ChatVariableConfigModal.tsx index abf33cb5..a16840cc 100644 --- a/web/src/views/ApplicationConfig/components/ChatVariableConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/ChatVariableConfigModal.tsx @@ -1,3 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:44 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:44 + */ +/** + * Chat Variable Configuration Modal + * Allows users to configure variable values before starting a chat session + */ + import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input, InputNumber } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -6,10 +17,17 @@ import type { ChatVariableConfigModalRef } from '../types' import type { Variable } from './VariableList/types' import RbModal from '@/components/RbModal' +/** + * Component props + */ interface VariableEditModalProps { + /** Callback to update variables */ refresh: (values: Variable[]) => void; } +/** + * Modal for configuring chat variables + */ const ChatVariableConfigModal = forwardRef(({ refresh, }, ref) => { @@ -19,20 +37,21 @@ const ChatVariableConfigModal = forwardRef([]) - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset form */ const handleClose = () => { setVisible(false); form.resetFields(); setLoading(false) }; + /** Open modal with variable list */ const handleOpen = (values: Variable[]) => { console.log('values', values) setVisible(true); form.setFieldsValue({variables: values}) setInitialValues([...values]) }; - // 封装保存方法,添加提交逻辑 + /** Save variable configuration */ const handleSave = () => { form.validateFields().then((values) => { refresh([ @@ -42,7 +61,7 @@ const ChatVariableConfigModal = forwardRef ({ handleOpen, handleClose diff --git a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx index db1e0fa5..374c87e8 100644 --- a/web/src/views/ApplicationConfig/components/ConfigHeader.tsx +++ b/web/src/views/ApplicationConfig/components/ConfigHeader.tsx @@ -1,8 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:52 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:52 + */ import { type FC, useRef } from 'react'; import { useNavigate, useParams } from 'react-router-dom'; import { Layout, Tabs, Dropdown, Button, Flex } from 'antd'; import type { MenuProps } from 'antd'; import { useTranslation } from 'react-i18next'; + import styles from '../index.module.css' import logoutIcon from '@/assets/images/logout.svg' import editIcon from '@/assets/images/edit_hover.svg' @@ -17,21 +24,43 @@ import CopyModal from './CopyModal' const { Header } = Layout; +/** + * Tab keys for application configuration + */ const tabKeys = ['arrangement', 'api', 'release', 'statistics'] + +/** + * Menu icon mapping + */ const menuIcons: Record = { edit: editIcon, copy: copyIcon, export: exportIcon, delete: deleteIcon } + +/** + * Props for ConfigHeader component + */ interface ConfigHeaderProps { + /** Application data */ application?: Application; + /** Active tab key */ activeTab: string; + /** Tab change handler */ handleChangeTab: (key: string) => void; + /** Refresh application data */ refresh: () => void; + /** Workflow component ref */ workflowRef: React.RefObject + /** App component ref (Agent/Cluster/Workflow) */ appRef?: React.RefObject } + +/** + * Configuration header component + * Displays application name, tabs, and action buttons + */ const ConfigHeader: FC = ({ application, activeTab, handleChangeTab, refresh, workflowRef, @@ -42,12 +71,18 @@ const ConfigHeader: FC = ({ const applicationModalRef = useRef(null); const copyModalRef = useRef(null); + /** + * Format tab items for display + */ const formatTabItems = () => { return tabKeys.map(key => ({ key, label: t(`application.${key}`), })) } + /** + * Format dropdown menu items + */ const formatMenuItems = () => { const items = ['edit', 'copy', 'export', 'delete'].map(key => ({ key, @@ -59,6 +94,9 @@ const ConfigHeader: FC = ({ onClick: handleClick } } + /** + * Handle menu item click + */ const handleClick: MenuProps['onClick'] = ({ key }) => { switch (key) { case 'edit': @@ -74,6 +112,9 @@ const ConfigHeader: FC = ({ break; } } + /** + * Delete application with confirmation + */ const handleDelete = () => { if (!id) { return @@ -86,21 +127,36 @@ const ConfigHeader: FC = ({ console.error('Failed to delete application'); }); } + /** + * Navigate to application list + */ const goToApplication = () => { navigate('/application', { replace: true }) } + /** + * Save workflow configuration + */ const save = () => { workflowRef.current?.handleSave() } + /** + * Run workflow + */ const run = () => { workflowRef.current?.handleSave(false) .then(() => { workflowRef.current?.handleRun() }) } + /** + * Clear workflow canvas + */ const clear = () => { workflowRef?.current?.graphRef?.current?.clearCells() } + /** + * Add variable to workflow + */ const addvariable = () => { workflowRef?.current?.addVariable() } diff --git a/web/src/views/ApplicationConfig/components/CopyModal.tsx b/web/src/views/ApplicationConfig/components/CopyModal.tsx index 0b83e65a..7eaf1497 100644 --- a/web/src/views/ApplicationConfig/components/CopyModal.tsx +++ b/web/src/views/ApplicationConfig/components/CopyModal.tsx @@ -1,3 +1,14 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:27:56 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:27:56 + */ +/** + * Copy Application Modal + * Allows users to duplicate an existing application with a new name + */ + import { forwardRef, useImperativeHandle, useState } from 'react'; import { Form, Input } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -10,10 +21,17 @@ import type { Application } from '@/views/ApplicationManagement/types' const FormItem = Form.Item; +/** + * Component props + */ interface CopyModalProps { + /** Application data to copy */ data: Application } +/** + * Modal for copying applications + */ const CopyModal = forwardRef(({ data }, ref) => { @@ -23,17 +41,18 @@ const CopyModal = forwardRef(({ const [form] = Form.useForm(); const [loading, setLoading] = useState(false) - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset form */ const handleClose = () => { setVisible(false); form.resetFields(); setLoading(false) }; + /** Open modal */ const handleOpen = () => { setVisible(true); }; - // 封装保存方法,添加提交逻辑 + /** Copy application with new name */ const handleSave = () => { setVisible(false); setLoading(true) @@ -48,7 +67,7 @@ const CopyModal = forwardRef(({ }) } - // 暴露给父组件的方法 + /** Expose methods to parent component */ useImperativeHandle(ref, () => ({ handleOpen, handleClose @@ -68,7 +87,7 @@ const CopyModal = forwardRef(({ form={form} layout="vertical" > - {/* 应用名 */} + {/* Application name */} void; + /** Append text to the end of content */ appendText: (text: string) => void; + /** Clear all editor content */ clear: () => void; + /** Scroll editor to bottom */ scrollToBottom: () => void; } +/** + * Editor component props + */ interface LexicalEditorProps { + /** Additional CSS class names */ className?: string; + /** Placeholder text when editor is empty */ placeholder?: string; + /** Initial editor value */ value?: string; + /** Callback when content changes */ onChange?: (value: string) => void; + /** Editor height in pixels */ height?: number; + disabled?: boolean; } +/** + * Lexical editor theme configuration + */ const theme = { paragraph: 'editor-paragraph', text: { @@ -33,14 +65,25 @@ const theme = { }, }; +/** + * Editor content component with Lexical context + */ const EditorContent = forwardRef(({ className = '', value, - placeholder = "请输入内容...", + placeholder = "Please enter content...", onChange, + disabled }, ref) => { const [editor] = useLexicalComposerContext(); + /** + * Expose editor methods to parent component + * - insertText: Insert at cursor position + * - appendText: Append to end of content + * - clear: Clear all content + * - scrollToBottom: Scroll to bottom + */ useImperativeHandle(ref, () => ({ insertText: (text: string) => { editor.update(() => { @@ -92,7 +135,11 @@ const EditorContent = forwardRef(({ } placeholder={ @@ -105,15 +152,21 @@ const EditorContent = forwardRef(({ +
); }); +/** + * Main editor wrapper component + * Initializes Lexical composer with configuration + */ const Editor = forwardRef((props, ref) => { const initialConfig = { namespace: 'Editor', theme, nodes: [], + editable: !props.disabled, onError: (error: Error) => { console.error(error); }, diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx new file mode 100644 index 00000000..6c237f01 --- /dev/null +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/EditablePlugin.tsx @@ -0,0 +1,48 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-04 11:20:49 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-04 11:20:49 + */ +import { useEffect } from 'react'; +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; + +/** + * Props for the EditablePlugin component + */ +interface EditablePluginProps { + /** Whether the editor should be disabled (read-only mode) */ + disabled?: boolean; +} + +/** + * EditablePlugin - A Lexical editor plugin that controls the editable state of the editor + * + * This plugin allows you to dynamically toggle between editable and read-only modes. + * When disabled is true, the editor becomes read-only and users cannot modify content. + * When disabled is false or undefined, the editor is fully editable. + * + * @param {EditablePluginProps} props - Component props + * @param {boolean} [props.disabled] - Controls whether the editor is in read-only mode + * @returns {null} This plugin doesn't render any UI elements + * + * @example + * ```tsx + * + * + * + * ``` + */ +export default function EditablePlugin({ disabled }: EditablePluginProps) { + // Get the editor instance from Lexical composer context + const [editor] = useLexicalComposerContext(); + + // Update editor's editable state whenever the disabled prop changes + useEffect(() => { + // Set editor to editable when disabled is false, read-only when disabled is true + editor.setEditable(!disabled); + }, [editor, disabled]); + + // This plugin doesn't render any UI, it only manages editor state + return null; +} diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx index da373023..dc84074a 100644 --- a/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/InitialValuePlugin.tsx @@ -1,20 +1,34 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:24:59 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:24:59 + */ +/** + * Initial Value Plugin + * Sets the initial content of the Lexical editor + * Only updates when the value prop changes + */ + import { type FC, useEffect, useRef } from 'react'; import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -// 设置初始值的插件 +/** + * Plugin to set initial editor value + */ const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => { const [editor] = useLexicalComposerContext(); const lastValueRef = useRef(undefined); useEffect(() => { - // 只有当value真正发生变化时才更新 + // Only update when value actually changes if (lastValueRef.current !== value) { editor.update(() => { const root = $getRoot(); const currentText = root.getTextContent(); - // 如果当前内容和新值相同,则不更新 + // If current content matches new value, don't update if (currentText === (value || '')) { return; } @@ -26,7 +40,7 @@ const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => { paragraph.append(textNode); root.append(paragraph); } else { - // 当value为undefined或空时,创建一个空段落 + // When value is undefined or empty, create an empty paragraph const paragraph = $createParagraphNode(); root.append(paragraph); } diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx index ca75c393..38c068aa 100644 --- a/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/InsertTextPlugin.tsx @@ -1,10 +1,22 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:25:05 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:25:05 + */ +/** + * Insert Text Plugin + * Provides functionality to insert text at the current cursor position + */ + import { forwardRef, useImperativeHandle } from 'react'; import { $getSelection } from 'lexical'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -import type { EditorRef } from '../index' -// 插入文本的插件 -const InsertTextPlugin = forwardRef((_, ref) => { +/** + * Plugin to insert text at cursor position + */ +const InsertTextPlugin = forwardRef<{ insertText: (text: string) => void; }>((_, ref) => { const [editor] = useLexicalComposerContext(); useImperativeHandle(ref, () => ({ diff --git a/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx b/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx index 63d1ffc4..225ba322 100644 --- a/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx +++ b/web/src/views/ApplicationConfig/components/Editor/plugin/LineBreakPlugin.tsx @@ -1,8 +1,22 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:25:09 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:25:09 + */ +/** + * Line Break Plugin + * Handles line breaks and triggers onChange callback when editor content changes + * Converts \n escape sequences to actual line breaks + */ + import { type FC, useEffect } from 'react'; import { $getRoot } from 'lexical'; import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'; -// 处理换行的插件 +/** + * Plugin to handle line breaks and content changes + */ const LineBreakPlugin: FC<{ onChange?: (value: string) => void }> = ({ onChange }) => { const [editor] = useLexicalComposerContext(); @@ -11,7 +25,7 @@ const LineBreakPlugin: FC<{ onChange?: (value: string) => void }> = ({ onChange editorState.read(() => { const root = $getRoot(); const textContent = root.getTextContent(); - // 将\n转换为实际换行 + // Convert \n to actual line breaks const processedContent = textContent.replace(/\\n/g, '\n'); onChange?.(processedContent); }); diff --git a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx index 1e59f26d..297e9faa 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/Knowledge.tsx @@ -1,6 +1,19 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:25:32 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:25:32 + */ +/** + * Knowledge Base Component + * Manages knowledge base associations for the application + * Allows adding, configuring, and removing knowledge bases + */ + import { type FC, useRef, useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Space, Button, List } from 'antd' + import knowledgeEmpty from '@/assets/images/application/knowledgeEmpty.svg' import type { KnowledgeConfigForm, @@ -19,6 +32,11 @@ import Tag from '@/components/Tag' import { getKnowledgeBaseList } from '@/api/knowledgeBase' import Card from '../Card' +/** + * Knowledge base management component + * @param value - Current knowledge configuration + * @param onChange - Callback when configuration changes + */ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfig) => void}> = ({value = {knowledge_bases: []}, onChange}) => { const { t } = useTranslation() const knowledgeModalRef = useRef(null) @@ -32,10 +50,10 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi setEditConfig({ ...(value || {}) }) const knowledge_bases = [...(value.knowledge_bases || [])] - // 检查是否有knowledge_bases缺少name字段 + // Check if knowledge_bases are missing name field const basesWithoutName = knowledge_bases.filter(base => !base.name) if (basesWithoutName.length > 0) { - // 调用接口获取完整的知识库信息 + // Call API to get complete knowledge base information getKnowledgeBaseList().then(res => { const fullBases = knowledge_bases.map(base => { if (!base.name) { @@ -54,12 +72,15 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi } }, [value]) + /** Open global knowledge configuration modal */ const handleKnowledgeConfig = () => { knowledgeGlobalConfigModalRef.current?.handleOpen() } + /** Open knowledge base selection modal */ const handleAddKnowledge = () => { knowledgeModalRef.current?.handleOpen() } + /** Remove knowledge base from list */ const handleDeleteKnowledge = (id: string) => { const list = knowledgeList.filter(item => item.id !== id) setKnowledgeList([...list]) @@ -68,9 +89,11 @@ const Knowledge: FC<{value?: KnowledgeConfig; onChange?: (config: KnowledgeConfi knowledge_bases: [...list], }) } + /** Open knowledge base configuration modal */ const handleEditKnowledge = (item: KnowledgeBase) => { knowledgeConfigModalRef.current?.handleOpen(item) } + /** Update knowledge configuration */ const refresh = (values: KnowledgeBase[] | KnowledgeConfigForm | RerankerConfig, type: 'knowledge' | 'knowledgeConfig' | 'rerankerConfig') => { if (type === 'knowledge') { let list = [...knowledgeList] diff --git a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx index 70b17a11..9adcd168 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx @@ -1,3 +1,15 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-03 16:25:37 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-02-03 16:25:37 + */ +/** + * Knowledge Configuration Modal + * Configures retrieval settings for individual knowledge bases + * Supports different retrieval modes: participle, semantic, and hybrid + */ + import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; import { Form, Select, InputNumber } from 'antd'; import { useTranslation } from 'react-i18next'; @@ -9,11 +21,22 @@ import { formatDateTime } from '@/utils/format'; const FormItem = Form.Item; +/** + * Component props + */ interface KnowledgeConfigModalProps { + /** Callback to update knowledge configuration */ refresh: (values: KnowledgeConfigForm, type: 'knowledgeConfig') => void; } + +/** + * Available retrieval types + */ const retrieveTypes: RetrieveType[] = ['participle', 'semantic', 'hybrid'] +/** + * Modal for configuring knowledge base retrieval settings + */ const KnowledgeConfigModal = forwardRef(({ refresh, }, ref) => { @@ -24,13 +47,14 @@ const KnowledgeConfigModal = forwardRef([], form); - // 封装取消方法,添加关闭弹窗逻辑 + /** Close modal and reset form */ const handleClose = () => { setVisible(false); form.resetFields(); setData(null) }; + /** Open modal with knowledge base data */ const handleOpen = (data: KnowledgeBase) => { form.setFieldsValue({ retrieve_type: data?.config?.retrieve_type || retrieveTypes[0], @@ -44,7 +68,7 @@ const KnowledgeConfigModal = forwardRef { form .validateFields() @@ -57,7 +81,7 @@ const KnowledgeConfigModal = forwardRef ({ handleOpen, handleClose @@ -94,7 +118,7 @@ const KnowledgeConfigModal = forwardRef )}