diff --git a/.gitignore b/.gitignore
index 2fcdbcd6..2fb41537 100644
--- a/.gitignore
+++ b/.gitignore
@@ -37,5 +37,4 @@ tika-server*.jar*
cl100k_base.tiktoken
libssl*.deb
-sandbox/lib/seccomp_python/target
-sandbox/lib/seccomp_nodejs/target
+sandbox/lib/seccomp_redbear/target
diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py
index 5831586c..85550f94 100644
--- a/api/app/controllers/__init__.py
+++ b/api/app/controllers/__init__.py
@@ -19,6 +19,8 @@ from . import (
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
+ mcp_market_controller,
+ mcp_market_config_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_episodic_controller,
@@ -60,6 +62,8 @@ manager_router.include_router(model_controller.router)
manager_router.include_router(file_controller.router)
manager_router.include_router(document_controller.router)
manager_router.include_router(knowledge_controller.router)
+manager_router.include_router(mcp_market_controller.router)
+manager_router.include_router(mcp_market_config_controller.router)
manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
diff --git a/api/app/controllers/auth_controller.py b/api/app/controllers/auth_controller.py
index a6960096..708cbaa2 100644
--- a/api/app/controllers/auth_controller.py
+++ b/api/app/controllers/auth_controller.py
@@ -61,6 +61,7 @@ async def login_for_access_token(
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
+ username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py
new file mode 100644
index 00000000..98012568
--- /dev/null
+++ b/api/app/controllers/mcp_market_config_controller.py
@@ -0,0 +1,336 @@
+import datetime
+import json
+from typing import Optional
+import uuid
+
+from fastapi import APIRouter, Depends, HTTPException, status, Query
+from fastapi.encoders import jsonable_encoder
+import requests
+from sqlalchemy import or_
+from sqlalchemy.orm import Session
+from modelscope.hub.errors import raise_for_http_status
+from modelscope.hub.mcp_api import MCPApi
+
+from app.core.logging_config import get_api_logger
+from app.core.response_utils import success, fail
+from app.db import get_db
+from app.dependencies import get_current_user
+from app.models import mcp_market_config_model
+from app.models.user_model import User
+from app.schemas import mcp_market_config_schema
+from app.schemas.response_schema import ApiResponse
+from app.services import mcp_market_config_service
+
+# Obtain a dedicated API logger
+api_logger = get_api_logger()
+
+router = APIRouter(
+ prefix="/mcp_market_configs",
+ tags=["mcp_market_configs"],
+ dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
+)
+
+
+@router.get("/mcp_servers", response_model=ApiResponse)
+async def get_mcp_servers(
+ mcp_market_config_id: uuid.UUID,
+ page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
+ pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
+ keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Query the mcp servers list in pages
+ - Support keyword search for name,author,owner
+ - Return paging metadata + mcp server list
+ """
+ api_logger.info(
+ f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
+
+ # 1. parameter validation
+ if page < 1 or pagesize < 1:
+ api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="The paging parameter must be greater than 0"
+ )
+
+ # 2. Query mcp market config information from the database
+ api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
+ mcp_market_config_id=mcp_market_config_id,
+ current_user=current_user)
+ if not db_mcp_market_config:
+ api_logger.warning(
+ f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or access is denied"
+ )
+
+ # 3. Execute paged query
+ api = MCPApi()
+ token = db_mcp_market_config.token
+ api.login(token)
+
+ body = {
+ 'filter': {},
+ 'page_number': page,
+ 'page_size': pagesize,
+ 'search': keywords
+ }
+
+ try:
+ cookies = api.get_cookies(token)
+ r = api.session.put(
+ url=api.mcp_base_url,
+ headers=api.builder_headers(api.headers),
+ json=body,
+ cookies=cookies)
+ raise_for_http_status(r)
+ except requests.exceptions.RequestException as e:
+ api_logger.error(f"mFailed to get MCP servers: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to get MCP servers: {str(e)}"
+ )
+
+ data = api._handle_response(r)
+ total = data.get('total_count', 0)
+ mcp_server_list = data.get('mcp_server_list', [])
+ # items = [{
+ # 'name': item.get('name', ''),
+ # 'id': item.get('id', ''),
+ # 'description': item.get('description', '')
+ # } for item in mcp_server_list]
+
+ # 4. Return structured response
+ result = {
+ "items": mcp_server_list,
+ "page": {
+ "page": page,
+ "pagesize": pagesize,
+ "total": total,
+ "has_next": True if page * pagesize < total else False
+ }
+ }
+ return success(data=result, msg="Query of mcp servers list successful")
+
+
+@router.get("/mcp_server", response_model=ApiResponse)
+async def get_mcp_server(
+ mcp_market_config_id: uuid.UUID,
+ server_id: str,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Get detailed information for a specific MCP Server
+ """
+ api_logger.info(
+ f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
+
+ # 1. Query mcp market config information from the database
+ api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
+ mcp_market_config_id=mcp_market_config_id,
+ current_user=current_user)
+ if not db_mcp_market_config:
+ api_logger.warning(
+ f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or access is denied"
+ )
+
+ # 2. Get detailed information for a specific MCP Server
+ api = MCPApi()
+ token = db_mcp_market_config.token
+ api.login(token)
+
+ result = api.get_mcp_server(server_id=server_id)
+ return success(data=result, msg="Query of mcp servers list successful")
+
+
+@router.post("/mcp_market_config", response_model=ApiResponse)
+async def create_mcp_market_config(
+ create_data: mcp_market_config_schema.McpMarketConfigCreate,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ create mcp market config
+ """
+ api_logger.info(
+ f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
+
+ try:
+ api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
+ # 1. Check if the mcp market name already exists
+ db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
+ if db_mcp_market_config_exist:
+ api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
+ )
+ db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
+ api_logger.info(
+ f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
+ return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
+ msg="The mcp market config has been successfully created")
+ except Exception as e:
+ api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
+ raise
+
+
+@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
+async def get_mcp_market_config(
+ mcp_market_config_id: uuid.UUID,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Retrieve mcp market config information based on mcp_market_config_id
+ """
+ api_logger.info(
+ f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
+
+ try:
+ # 1. Query mcp market config information from the database
+ api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
+ if not db_mcp_market_config:
+ api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or access is denied"
+ )
+
+ api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
+ return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
+ msg="Successfully obtained mcp market config information")
+ except HTTPException:
+ raise
+ except Exception as e:
+ api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
+ raise
+
+
+@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
+async def get_mcp_market_config_by_mcp_market_id(
+ mcp_market_id: uuid.UUID,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Retrieve mcp market config information based on mcp_market_id
+ """
+ api_logger.info(
+ f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
+
+ try:
+ # 1. Query mcp market config information from the database
+ api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
+ if not db_mcp_market_config:
+ api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or access is denied"
+ )
+
+ api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
+ return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
+ msg="Successfully obtained mcp market config information")
+ except HTTPException:
+ raise
+ except Exception as e:
+ api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
+ raise
+
+
+@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
+async def update_mcp_market_config(
+ mcp_market_config_id: uuid.UUID,
+ update_data: mcp_market_config_schema.McpMarketConfigUpdate,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ # 1. Check if the mcp market config exists
+ api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
+
+ if not db_mcp_market_config:
+ api_logger.warning(
+ f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or you do not have permission to access it"
+ )
+
+ # 2. Update fields (only update non-null fields)
+ api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
+ update_dict = update_data.dict(exclude_unset=True)
+ updated_fields = []
+ for field, value in update_dict.items():
+ if hasattr(db_mcp_market_config, field):
+ old_value = getattr(db_mcp_market_config, field)
+ if old_value != value:
+ # update value
+ setattr(db_mcp_market_config, field, value)
+ updated_fields.append(f"{field}: {old_value} -> {value}")
+
+ if updated_fields:
+ api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
+
+ # 3. Save to database
+ try:
+ db.commit()
+ db.refresh(db_mcp_market_config)
+ api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
+ except Exception as e:
+ db.rollback()
+ api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"The mcp market config update failed: {str(e)}"
+ )
+
+ # 4. Return the updated mcp market config
+ return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
+ msg="The mcp market config information updated successfully")
+
+
+@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
+async def delete_mcp_market_config(
+ mcp_market_config_id: uuid.UUID,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ delete mcp market config
+ """
+ api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
+
+ try:
+ # 1. Check whether the mcp market config exists
+ api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
+ db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
+
+ if not db_mcp_market_config:
+ api_logger.warning(
+ f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market config does not exist or you do not have permission to access it"
+ )
+
+ # 2. Deleting mcp market config
+ mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
+ api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
+ return success(msg="The mcp market config has been successfully deleted")
+ except Exception as e:
+ api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
+ raise
diff --git a/api/app/controllers/mcp_market_controller.py b/api/app/controllers/mcp_market_controller.py
new file mode 100644
index 00000000..61531a0f
--- /dev/null
+++ b/api/app/controllers/mcp_market_controller.py
@@ -0,0 +1,262 @@
+import datetime
+import json
+from typing import Optional
+import uuid
+
+from fastapi import APIRouter, Depends, HTTPException, status, Query
+from fastapi.encoders import jsonable_encoder
+from sqlalchemy import or_
+from sqlalchemy.orm import Session
+
+from app.core.logging_config import get_api_logger
+from app.core.response_utils import success, fail
+from app.db import get_db
+from app.dependencies import get_current_user
+from app.models import mcp_market_model
+from app.models.user_model import User
+from app.schemas import mcp_market_schema
+from app.schemas.response_schema import ApiResponse
+from app.services import mcp_market_service
+
+# Obtain a dedicated API logger
+api_logger = get_api_logger()
+
+router = APIRouter(
+ prefix="/mcp_markets",
+ tags=["mcp_markets"],
+ dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
+)
+
+
+@router.get("/mcp_markets", response_model=ApiResponse)
+async def get_mcp_markets(
+ page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
+ pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
+ orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
+ desc: Optional[bool] = Query(False, description="Is it descending order"),
+ keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Query the mcp markets list in pages
+ - Support keyword search for name,description
+ - Support dynamic sorting
+ - Return paging metadata + mcp_market list
+ """
+ api_logger.info(
+ f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
+
+ # 1. parameter validation
+ if page < 1 or pagesize < 1:
+ api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="The paging parameter must be greater than 0"
+ )
+
+ # 2. Construct query conditions
+ filters = []
+
+ # Keyword search (fuzzy matching of mcp market name,description)
+ if keywords:
+ api_logger.debug(f"Add keyword search criteria: {keywords}")
+ filters.append(
+ or_(
+ mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
+ mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
+ )
+ )
+ # 3. Execute paged query
+ try:
+ api_logger.debug("Start executing mcp market paging query")
+ total, items = mcp_market_service.get_mcp_markets_paginated(
+ db=db,
+ filters=filters,
+ page=page,
+ pagesize=pagesize,
+ orderby=orderby,
+ desc=desc,
+ current_user=current_user
+ )
+ api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
+ except Exception as e:
+ api_logger.error(f"mcp market query failed: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Query failed: {str(e)}"
+ )
+
+ # 4. Return structured response
+ result = {
+ "items": items,
+ "page": {
+ "page": page,
+ "pagesize": pagesize,
+ "total": total,
+ "has_next": True if page * pagesize < total else False
+ }
+ }
+ return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
+
+
+@router.post("/mcp_market", response_model=ApiResponse)
+async def create_mcp_market(
+ create_data: mcp_market_schema.McpMarketCreate,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ create mcp market
+ """
+ api_logger.info(
+ f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
+
+ try:
+ api_logger.debug(f"Start creating the mcp market: {create_data.name}")
+ # 1. Check if the mcp market name already exists
+ db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
+ if db_mcp_market_exist:
+ api_logger.warning(f"The mcp market name already exists: {create_data.name}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"The mcp market name already exists: {create_data.name}"
+ )
+ db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
+ api_logger.info(
+ f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
+ return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
+ msg="The mcp market has been successfully created")
+ except Exception as e:
+ api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
+ raise
+
+
+@router.get("/{mcp_market_id}", response_model=ApiResponse)
+async def get_mcp_market(
+ mcp_market_id: uuid.UUID,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ Retrieve mcp market information based on mcp_market_id
+ """
+ api_logger.info(
+ f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
+
+ try:
+ # 1. Query mcp market information from the database
+ api_logger.debug(f"Query mcp market: {mcp_market_id}")
+ db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
+ if not db_mcp_market:
+ api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market does not exist or access is denied"
+ )
+
+ api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
+ return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
+ msg="Successfully obtained mcp market information")
+ except HTTPException:
+ raise
+ except Exception as e:
+ api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
+ raise
+
+
+@router.put("/{mcp_market_id}", response_model=ApiResponse)
+async def update_mcp_market(
+ mcp_market_id: uuid.UUID,
+ update_data: mcp_market_schema.McpMarketUpdate,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ # 1. Check if the mcp market exists
+ api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
+ db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
+
+ if not db_mcp_market:
+ api_logger.warning(
+ f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market does not exist or you do not have permission to access it"
+ )
+
+ # 2. not updating the name (name already exists)
+ update_dict = update_data.dict(exclude_unset=True)
+ if "name" in update_dict:
+ name = update_dict["name"]
+ if name != db_mcp_market.name:
+ # Check if the mcp market name already exists
+ db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
+ if db_mcp_market_exist:
+ api_logger.warning(f"The mcp market name already exists: {name}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"The mcp market name already exists: {name}"
+ )
+ # 3. Update fields (only update non-null fields)
+ api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
+ updated_fields = []
+ for field, value in update_dict.items():
+ if hasattr(db_mcp_market, field):
+ old_value = getattr(db_mcp_market, field)
+ if old_value != value:
+ # update value
+ setattr(db_mcp_market, field, value)
+ updated_fields.append(f"{field}: {old_value} -> {value}")
+
+ if updated_fields:
+ api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
+
+ # 4. Save to database
+ try:
+ db.commit()
+ db.refresh(db_mcp_market)
+ api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
+ except Exception as e:
+ db.rollback()
+ api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"The mcp market update failed: {str(e)}"
+ )
+
+ # 5. Return the updated mcp market
+ return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
+ msg="The mcp market information updated successfully")
+
+
+@router.delete("/{mcp_market_id}", response_model=ApiResponse)
+async def delete_mcp_market(
+ mcp_market_id: uuid.UUID,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user)
+):
+ """
+ delete mcp market
+ """
+ api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
+
+ try:
+ # 1. Check whether the mcp market exists
+ api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
+ db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
+
+ if not db_mcp_market:
+ api_logger.warning(
+ f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="The mcp market does not exist or you do not have permission to access it"
+ )
+
+ # 2. Deleting mcp market
+ mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
+ api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
+ return success(msg="The mcp market has been successfully deleted")
+ except Exception as e:
+ api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
+ raise
diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py
index 83753744..bb1ba526 100644
--- a/api/app/controllers/model_controller.py
+++ b/api/app/controllers/model_controller.py
@@ -328,7 +328,7 @@ async def update_composite_model(
try:
if model_data.type is not None:
- raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
+ raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
@@ -368,6 +368,9 @@ def update_model(
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
+
+ if model_data.type is not None or model_data.provider is not None:
+ raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py
index 57495a7c..2806da1a 100644
--- a/api/app/controllers/user_controller.py
+++ b/api/app/controllers/user_controller.py
@@ -2,15 +2,23 @@ from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import uuid
+from app.core.error_codes import BizCode
+from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.models.user_model import User
from app.schemas import user_schema
-from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
+from app.schemas.user_schema import (
+ ChangePasswordRequest,
+ AdminChangePasswordRequest,
+ SendEmailCodeRequest,
+ VerifyEmailCodeRequest,
+ VerifyPasswordRequest)
from app.schemas.response_schema import ApiResponse
from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
+from app.core.security import verify_password
# 获取API专用日志器
api_logger = get_api_logger()
@@ -92,7 +100,7 @@ def get_current_user_info(
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
- if ws.workspace_id == current_user.current_workspace_id:
+ if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
result_schema.role = ws.role
break
@@ -120,6 +128,7 @@ def get_tenant_superusers(
return success(data=superusers_schema, msg="租户超管列表获取成功")
+
@router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id(
user_id: uuid.UUID,
@@ -180,4 +189,54 @@ async def admin_change_password(
return success(msg="密码修改成功")
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
- return success(data=generated_password, msg="密码重置成功")
\ No newline at end of file
+ return success(data=generated_password, msg="密码重置成功")
+
+
+@router.post("/verify_pwd", response_model=ApiResponse)
+def verify_pwd(
+ request: VerifyPasswordRequest,
+ current_user: User = Depends(get_current_user),
+):
+ """验证当前用户密码"""
+ api_logger.info(f"用户验证密码请求: {current_user.username}")
+
+ is_valid = verify_password(request.password, current_user.hashed_password)
+ api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
+ if not is_valid:
+ raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
+ return success(data={"valid": is_valid}, msg="验证完成")
+
+
+@router.post("/send-email-code", response_model=ApiResponse)
+async def send_email_code(
+ request: SendEmailCodeRequest,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ """发送邮箱验证码"""
+ api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
+
+ await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
+
+ api_logger.info(f"邮箱验证码已发送: {current_user.username}")
+ return success(msg="验证码已发送到您的邮箱,请查收")
+
+
+@router.put("/change-email", response_model=ApiResponse)
+async def change_email(
+ request: VerifyEmailCodeRequest,
+ db: Session = Depends(get_db),
+ current_user: User = Depends(get_current_user),
+):
+ """验证验证码并修改邮箱"""
+ api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
+
+ await user_service.verify_and_change_email(
+ db=db,
+ user_id=current_user.id,
+ new_email=request.new_email,
+ code=request.code
+ )
+
+ api_logger.info(f"用户邮箱修改成功: {current_user.username}")
+ return success(msg="邮箱修改成功")
diff --git a/api/app/core/__init__.py b/api/app/core/__init__.py
new file mode 100644
index 00000000..559af4a5
--- /dev/null
+++ b/api/app/core/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/9 16:24
diff --git a/api/app/core/config.py b/api/app/core/config.py
index b1354b9f..3a0c97b4 100644
--- a/api/app/core/config.py
+++ b/api/app/core/config.py
@@ -193,6 +193,12 @@ class Settings:
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
+ # SMTP Email Configuration
+ SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
+ SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
+ SMTP_USER: str = os.getenv("SMTP_USER", "")
+ SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
+
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
diff --git a/api/app/core/workflow/engine/__init__.py b/api/app/core/workflow/engine/__init__.py
new file mode 100644
index 00000000..bdd44b47
--- /dev/null
+++ b/api/app/core/workflow/engine/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/9 16:28
diff --git a/api/app/core/workflow/engine/event_stream_handler.py b/api/app/core/workflow/engine/event_stream_handler.py
new file mode 100644
index 00000000..5b7d8de2
--- /dev/null
+++ b/api/app/core/workflow/engine/event_stream_handler.py
@@ -0,0 +1,281 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/10 13:33
+import datetime
+
+from langchain_core.runnables import RunnableConfig
+from langgraph.graph.state import CompiledStateGraph
+
+from app.core.logging_config import get_logger
+from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
+from app.core.workflow.engine.variable_pool import VariablePool
+
+logger = get_logger(__name__)
+
+
+class EventStreamHandler:
+ def __init__(
+ self,
+ output_coordinator: StreamOutputCoordinator,
+ variable_pool: VariablePool,
+ execution_id: str,
+ ):
+ self.coordinator = output_coordinator
+ self.variable_pool = variable_pool
+ self.execution_id = execution_id
+
+ def update_stream_output_status(self, activate: dict, data: dict):
+ """
+ Update the stream output state of End nodes based on workflow state updates.
+
+ This method checks which nodes/scopes are activated and propagates
+ activation to End nodes accordingly.
+
+ Args:
+ activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
+ data (dict): Mapping of node_id -> node runtime data, including outputs.
+
+ Behavior:
+ For each node in `data`:
+ 1. If the node is activated (`activate[node_id]` is True),
+ retrieve its output status from `runtime_vars`.
+ 2. Call `_update_scope_activate` to propagate the activation
+ to all relevant End nodes and update `self.activate_end`.
+ """
+ for node_id in data.keys():
+ if activate.get(node_id):
+ node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
+ self.coordinator.update_scope_activation(node_id, status=node_output_status)
+
+ async def handle_updates_event(
+ self,
+ data: dict,
+ graph: CompiledStateGraph,
+ checkpoint_config: RunnableConfig
+ ):
+ """
+ Handle workflow state update events ("updates") and stream active End node outputs.
+
+ Steps:
+ 1. Retrieve the current graph state.
+ 2. Extract node activation information from the state.
+ 3. Update the activation status of all End nodes.
+ 4. While there is an active End node:
+ - Call _emit_active_chunks() to yield all currently active output segments.
+ - After all segments are processed, update activate_end if there are remaining End nodes.
+ 5. Log a debug message indicating state update received.
+
+ Args:
+ data (dict): The latest node state updates.
+ graph (CompiledStateGraph): The compiled LangGraph state machine.
+ checkpoint_config (RunnableConfig): Configuration for the current execution context.)
+
+ Yields:
+ dict: Streamed output event, each chunk in the format:
+ {"event": "message", "data": {"chunk": ...}}
+ """
+ state = graph.get_state(config=checkpoint_config).values
+ activate = state.get("activate", {})
+
+ self.update_stream_output_status(activate, data)
+ wait = False
+ while self.coordinator.activate_end and not wait:
+ async for msg_event in self.coordinator.emit_activate_chunk(self.variable_pool):
+ yield msg_event
+
+ if self.coordinator.activate_end:
+ wait = True
+ else:
+ self.update_stream_output_status(activate, data)
+
+ logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
+ f"- execution_id: {self.execution_id}")
+
+ async def handle_node_chunk_event(self, data: dict):
+ """
+ Handle streaming chunk events from individual nodes ("node_chunk").
+
+ This method processes output segments for the currently active End node.
+ If the segment depends on the provided node_id:
+ - If the node has finished execution (`done=True`), advance the cursor.
+ - If all segments are processed, deactivate the End node.
+ - Otherwise, yield the current chunk as a streaming message.
+
+ Args:
+ data (dict): Node chunk event data, expected keys:
+ - "node_id": ID of the node producing this chunk
+ - "chunk": Chunk of output text
+ - "done": Boolean indicating whether the node finished producing output
+
+ Yields:
+ dict: Streaming message event in the format:
+ {"event": "message", "data": {"chunk": ...}}
+ """
+ node_id = data.get("node_id")
+ if self.coordinator.activate_end:
+ end_info = self.coordinator.current_activate_end_info
+ if not end_info or end_info.cursor >= len(end_info.outputs):
+ return
+ current_output = end_info.outputs[end_info.cursor]
+ if current_output.is_variable and current_output.depends_on_scope(node_id):
+ if data.get("done"):
+ end_info.cursor += 1
+ if end_info.cursor >= len(end_info.outputs):
+ self.coordinator.pop_current_activate_end()
+ else:
+ yield {
+ "event": "message",
+ "data": {
+ "chunk": data.get("chunk")
+ }
+ }
+
+ @staticmethod
+ async def handle_node_error_event(data: dict):
+ """
+ Handle node error events ("node_error") during workflow execution.
+
+ This method streams an error event for a node that has failed. The event
+ contains the node ID, status, input data, elapsed time, and error message.
+
+ Args:
+ data (dict): Node error event data, expected keys:
+ - "node_id": ID of the node that failed
+ - "input_data": The input data that caused the error
+ - "elapsed_time": Execution time before the error occurred
+ - "error": Error message or exception string
+
+ Yields:
+ dict: Node error event in the format:
+ {
+ "event": "node_error",
+ "data": {
+ "node_id": str,
+ "status": "failed",
+ "input": ...,
+ "elapsed_time": float,
+ "output": None,
+ "error": str
+ }
+ }
+ """
+ node_id = data.get("node_id")
+ yield {
+ "event": "node_error",
+ "data": {
+ "node_id": node_id,
+ "status": "failed",
+ "input": data.get("input_data"),
+ "elapsed_time": data.get("elapsed_time"),
+ "output": None,
+ "error": data.get("error")
+ }
+ }
+
+ async def handle_debug_event(self, data: dict, input_data: dict):
+ """
+ Handle debug events ("debug") related to node execution status.
+
+ This method streams debug events for nodes, including when a node starts
+ execution ("node_start") and when it completes execution ("node_end").
+ It filters out nodes with names starting with "nop" as no-operation nodes.
+
+ Args:
+ data (dict): Debug event data, expected keys:
+ - "type": Event type ("task" for start, "task_result" for completion)
+ - "payload": Node-related information, including:
+ - "name": Node name / ID
+ - "input": Node input data (for "task" type)
+ - "result": Node execution result (for "task_result" type)
+ - "timestamp": ISO timestamp string of the event
+ input_data (dict): Original workflow input data (used to get conversation_id)
+
+ Yields:
+ dict: Node debug event in one of the following formats:
+ 1. Node start:
+ {
+ "event": "node_start",
+ "data": {
+ "node_id": str,
+ "conversation_id": str,
+ "execution_id": str,
+ "timestamp": int (ms)
+ }
+ }
+ 2. Node end:
+ {
+ "event": "node_end",
+ "data": {
+ "node_id": str,
+ "conversation_id": str,
+ "execution_id": str,
+ "timestamp": int (ms),
+ "input": dict,
+ "output": Any,
+ "elapsed_time": float
+ }
+ }
+ """
+ event_type = data.get("type")
+ payload = data.get("payload", {})
+ node_name = payload.get("name")
+ conversation_id = input_data.get("conversation_id")
+
+ # Skip no-operation nodes
+ if node_name and node_name.startswith("nop"):
+ return
+
+ if event_type == "task":
+ # Node starts execution
+ inputv = payload.get("input", {})
+ if not inputv.get("activate", {}).get(node_name):
+ return
+
+ logger.info(
+ f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
+
+ yield {
+ "event": "node_start",
+ "data": {
+ "node_id": node_name,
+ "conversation_id": conversation_id,
+ "execution_id": self.execution_id,
+ "timestamp": int(datetime.datetime.fromisoformat(
+ data.get("timestamp")
+ ).timestamp() * 1000),
+ }
+ }
+ elif event_type == "task_result":
+ # Node execution completed
+ result = payload.get("result", {})
+ if not result.get("activate", {}).get(node_name):
+ return
+
+ logger.info(
+ f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
+
+ yield {
+ "event": "node_end",
+ "data": {
+ "node_id": node_name,
+ "conversation_id": conversation_id,
+ "execution_id": self.execution_id,
+ "timestamp": int(datetime.datetime.fromisoformat(
+ data.get("timestamp")
+ ).timestamp() * 1000),
+ "input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
+ "output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
+ "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
+ "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
+ }
+ }
+
+ @staticmethod
+ async def handle_cycle_item_event(data: dict):
+ yield {
+ "event": "cycle_item",
+ "data": data.get("data")
+ }
+
+
diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py
similarity index 78%
rename from api/app/core/workflow/graph_builder.py
rename to api/app/core/workflow/engine/graph_builder.py
index 8620bb9a..7b5c059c 100644
--- a/api/app/core/workflow/graph_builder.py
+++ b/api/app/core/workflow/engine/graph_builder.py
@@ -1,177 +1,28 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/10 13:33
import logging
import re
import uuid
from collections import defaultdict
from functools import lru_cache
-from typing import Any
+from typing import Any, Iterable
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.types import Send
-from pydantic import BaseModel, Field
-from app.core.workflow.expression_evaluator import evaluate_condition
-from app.core.workflow.nodes import WorkflowState, NodeFactory
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.stream_output_coordinator import OutputContent, StreamOutputConfig
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
-from app.core.workflow.variable_pool import VariablePool
+from app.core.workflow.utils.expression_evaluator import evaluate_condition
logger = logging.getLogger(__name__)
-SCOPE_PATTERN = re.compile(
- r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
-)
-
-
-class OutputContent(BaseModel):
- """
- Represents a single output segment of an End node.
-
- An output segment can be either:
- - literal text (static string)
- - a variable placeholder (e.g. {{ node.field }})
-
- Each segment has its own activation state, which is especially
- important in stream mode.
- """
-
- literal: str = Field(
- ...,
- description="Raw output content. Can be literal text or a variable placeholder."
- )
-
- activate: bool = Field(
- ...,
- description=(
- "Whether this output segment is currently active.\n"
- "- True: allowed to be emitted/output\n"
- "- False: blocked until activated by branch control"
- )
- )
-
- is_variable: bool = Field(
- ...,
- description=(
- "Whether this segment represents a variable placeholder.\n"
- "True -> variable (e.g. {{ node.field }})\n"
- "False -> literal text"
- )
- )
-
- _SCOPE: str | None = None
-
- def get_scope(self) -> str:
- self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
- return self._SCOPE
-
- def depends_on_scope(self, scope: str) -> bool:
- """
- Check if this segment depends on a given scope.
-
- Args:
- scope (str): Node ID or special variable prefix (e.g., "sys").
-
- Returns:
- bool: True if this segment references the given scope.
- """
- if self._SCOPE:
- return self._SCOPE == scope
- return self.get_scope() == scope
-
-
-class StreamOutputConfig(BaseModel):
- """
- Streaming output configuration for an End node.
-
- This configuration describes how the End node output behaves in streaming mode,
- including:
- - whether output emission is globally activated
- - which upstream branch/control nodes gate the activation
- - how each parsed output segment is streamed and activated
- """
-
- activate: bool = Field(
- ...,
- description=(
- "Global activation flag for the End node output.\n"
- "When False, output segments should not be emitted even if available.\n"
- "This flag typically becomes True once required control branch conditions "
- "are satisfied."
- )
- )
-
- control_nodes: dict[str, list[str]] = Field(
- ...,
- description=(
- "Control branch conditions for this End node output.\n"
- "Mapping of `branch_node_id -> expected_branch_label`.\n"
- "The End node output becomes globally active when a controlling branch node "
- "reports a matching completion status."
- )
- )
-
- outputs: list[OutputContent] = Field(
- ...,
- description=(
- "Ordered list of output segments parsed from the output template.\n"
- "Each segment represents either a literal text block or a variable placeholder "
- "that may be activated independently."
- )
- )
-
- cursor: int = Field(
- ...,
- description=(
- "Streaming cursor index.\n"
- "Indicates the next output segment index to be emitted.\n"
- "Segments with index < cursor are considered already streamed."
- )
- )
-
- def update_activate(self, scope: str, status=None):
- """
- Update streaming activation state based on an upstream node or special variable.
-
- Args:
- scope (str):
- Identifier of the completed upstream entity.
- - If a control branch node, it should match a key in `control_nodes`.
- - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
- status (optional):
- Completion status of the control branch node.
- Required when `scope` refers to a control node.
-
- Behavior:
- 1. Control branch nodes:
- - If `scope` matches a key in `control_nodes` and `status` matches the expected
- branch label, the End node output becomes globally active (`activate = True`).
-
- 2. Variable output segments:
- - For each segment that is a variable (`is_variable=True`):
- - If the segment literal references `scope`, mark the segment as active.
- - This applies both to regular node variables (e.g., "node_id.field")
- and special system variables (e.g., "sys.xxx").
-
- Notes:
- - This method does not emit output or advance the streaming cursor.
- - It only updates activation flags based on upstream events or special variables.
- """
-
- # Case 1: resolve control branch dependency
- if scope in self.control_nodes.keys():
- if status is None:
- raise RuntimeError("[Stream Output] Control node activation status not provided")
- if status in self.control_nodes[scope]:
- self.activate = True
-
- # Case 2: activate variable segments related to this node
- for i in range(len(self.outputs)):
- if (
- self.outputs[i].is_variable
- and self.outputs[i].depends_on_scope(scope)
- ):
- self.outputs[i].activate = True
-
class GraphBuilder:
def __init__(
@@ -230,7 +81,7 @@ class GraphBuilder:
raise RuntimeError(f"Node not found: Id={node_id}")
@staticmethod
- def _merge_control_nodes(control_nodes: list[tuple[str, str]]) -> dict[str, list]:
+ def _merge_control_nodes(control_nodes: Iterable[tuple[str, str]]) -> dict[str, list]:
result = defaultdict(list)
for node in control_nodes:
result[node[0]].append(node[1])
diff --git a/api/app/core/workflow/engine/result_builder.py b/api/app/core/workflow/engine/result_builder.py
new file mode 100644
index 00000000..31bccf57
--- /dev/null
+++ b/api/app/core/workflow/engine/result_builder.py
@@ -0,0 +1,104 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/10 13:33
+from app.core.workflow.engine.variable_pool import VariablePool
+
+
+class WorkflowResultBuilder:
+ def build_final_output(
+ self,
+ result: dict,
+ variable_pool: VariablePool,
+ elapsed_time: float,
+ final_output: str,
+ ):
+ """Construct the final standardized output of the workflow execution.
+
+ This method aggregates node outputs, token usage, conversation and system
+ variables, messages, and other metadata into a consistent dictionary
+ structure suitable for returning from workflow execution.
+
+ Args:
+ result (dict): The runtime state returned by the workflow graph execution.
+ Expected keys include:
+ - "node_outputs" (dict): Outputs of executed nodes.
+ - "messages" (list): Conversation messages exchanged during execution.
+ - "error" (str, optional): Error message if any node failed.
+ variable_pool (VariablePool): Variable Pool
+ elapsed_time (float): Total execution time in seconds.
+ final_output (Any): The aggregated or final output content of the workflow
+ (e.g., combined messages from all End nodes).
+
+ Returns:
+ dict: A dictionary containing the final workflow execution result with keys:
+ - "status": Execution status ("completed")
+ - "output": Aggregated final output content
+ - "variables": Namespace dictionary with:
+ - "conv": Conversation variables
+ - "sys": System variables
+ - "node_outputs": Outputs from all executed nodes
+ - "messages": Conversation messages exchanged
+ - "conversation_id": ID of the current conversation
+ - "elapsed_time": Total execution time in seconds
+ - "token_usage": Aggregated token usage across nodes (if available)
+ - "error": Error message if any occurred during execution
+ """
+ node_outputs = result.get("node_outputs", {})
+ token_usage = self.aggregate_token_usage(node_outputs)
+ conversation_id = variable_pool.get_value("sys.conversation_id")
+
+ return {
+ "status": "completed",
+ "output": final_output,
+ "variables": {
+ "conv": variable_pool.get_all_conversation_vars(),
+ "sys": variable_pool.get_all_system_vars()
+ },
+ "node_outputs": node_outputs,
+ "messages": result.get("messages", []),
+ "conversation_id": conversation_id,
+ "elapsed_time": elapsed_time,
+ "token_usage": token_usage,
+ "error": result.get("error"),
+ }
+
+ @staticmethod
+ def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
+ """
+ Aggregate token usage statistics across all nodes.
+
+ Args:
+ node_outputs (dict): A dictionary of all node outputs.
+
+ Returns:
+ dict | None: Aggregated token usage in the format:
+ {
+ "prompt_tokens": int,
+ "completion_tokens": int,
+ "total_tokens": int
+ }
+ Returns None if no token usage information is available.
+ """
+ total_prompt_tokens = 0
+ total_completion_tokens = 0
+ total_tokens = 0
+ has_token_info = False
+
+ for node_output in node_outputs.values():
+ if isinstance(node_output, dict):
+ token_usage = node_output.get("token_usage")
+ if token_usage and isinstance(token_usage, dict):
+ has_token_info = True
+ total_prompt_tokens += token_usage.get("prompt_tokens", 0)
+ total_completion_tokens += token_usage.get("completion_tokens", 0)
+ total_tokens += token_usage.get("total_tokens", 0)
+
+ if not has_token_info:
+ return None
+
+ return {
+ "prompt_tokens": total_prompt_tokens,
+ "completion_tokens": total_completion_tokens,
+ "total_tokens": total_tokens
+ }
diff --git a/api/app/core/workflow/engine/runtime_schema.py b/api/app/core/workflow/engine/runtime_schema.py
new file mode 100644
index 00000000..e4bf65af
--- /dev/null
+++ b/api/app/core/workflow/engine/runtime_schema.py
@@ -0,0 +1,29 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/10 13:33
+import uuid
+
+from langchain_core.runnables import RunnableConfig
+from pydantic import BaseModel
+
+
+class ExecutionContext(BaseModel):
+ execution_id: str
+ workspace_id: str
+ user_id: str
+ checkpoint_config: RunnableConfig
+
+ @classmethod
+ def create(cls, execution_id: str, workspace_id: str, user_id: str):
+ return cls(
+ execution_id=execution_id,
+ workspace_id=workspace_id,
+ user_id=user_id,
+ checkpoint_config=RunnableConfig(
+ configurable={
+ "thread_id": uuid.uuid4(),
+ }
+ )
+ )
+
diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py
new file mode 100644
index 00000000..0a4a1463
--- /dev/null
+++ b/api/app/core/workflow/engine/state_manager.py
@@ -0,0 +1,99 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/10 13:33
+from typing import Annotated, Any
+
+from app.core.workflow.engine.runtime_schema import ExecutionContext
+from app.core.workflow.nodes.enums import NodeType
+
+
+def merge_activate_state(x, y):
+ return {
+ k: x.get(k, False) or y.get(k, False)
+ for k in set(x) | set(y)
+ }
+
+
+def merge_looping_state(x, y):
+ return y if y > x else x
+
+
+class WorkflowState(dict):
+ """Workflow state
+
+ The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
+ """
+ __required_keys__ = frozenset({
+ "messages",
+ "cycle_nodes",
+ "looping",
+ "node_outputs",
+ "execution_id",
+ "workspace_id",
+ "user_id",
+ "activate",
+ })
+ __optional_keys__ = frozenset({
+ "error",
+ "error_node",
+ })
+
+ # List of messages (append mode)
+ messages: Annotated[list[dict[str, str]], lambda x, y: y]
+
+ # Set of loop node IDs, used for assigning values in loop nodes
+ cycle_nodes: list
+ looping: Annotated[int, merge_looping_state]
+
+ # Node outputs (stores execution results of each node for variable references)
+ # Uses a custom merge function to combine new node outputs into the existing dictionary
+ node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
+
+ # Execution context
+ execution_id: str
+ workspace_id: str
+ user_id: str
+
+ # Error information (for error edges)
+ error: str | None
+ error_node: str | None
+
+ # node activate status
+ activate: Annotated[dict[str, bool], merge_activate_state]
+
+
+class WorkflowStateManager:
+ def create_initial_state(
+ self,
+ workflow_config: dict,
+ input_data: dict,
+ execution_context: ExecutionContext,
+ start_node_id: str
+ ) -> WorkflowState:
+ conversation_messages = input_data.get("conv_messages", [])
+
+ return WorkflowState(
+ messages=conversation_messages,
+ node_outputs={},
+ execution_id=execution_context.execution_id,
+ workspace_id=execution_context.workspace_id,
+ user_id=execution_context.user_id,
+ error=None,
+ error_node=None,
+ cycle_nodes=self._identify_cycle_nodes(workflow_config),
+ looping=0,
+ activate={
+ start_node_id: True
+ }
+ )
+
+ @staticmethod
+ def _identify_cycle_nodes(
+ workflow_config: dict
+ ):
+ return [
+ node.get("id")
+ for node in workflow_config.get("nodes")
+ if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
+ ]
diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py
new file mode 100644
index 00000000..5155a76f
--- /dev/null
+++ b/api/app/core/workflow/engine/stream_output_coordinator.py
@@ -0,0 +1,327 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/9 15:11
+import re
+from typing import AsyncGenerator
+
+from pydantic import BaseModel, Field
+
+from app.core.logging_config import get_logger
+from app.core.workflow.engine.variable_pool import VariablePool
+
+logger = get_logger(__name__)
+
+SCOPE_PATTERN = re.compile(
+ r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
+)
+
+
+class OutputContent(BaseModel):
+ """
+ Represents a single output segment of an End node.
+
+ An output segment can be either:
+ - literal text (static string)
+ - a variable placeholder (e.g. {{ node.field }})
+
+ Each segment has its own activation state, which is especially
+ important in stream mode.
+ """
+
+ literal: str = Field(
+ ...,
+ description="Raw output content. Can be literal text or a variable placeholder."
+ )
+
+ activate: bool = Field(
+ ...,
+ description=(
+ "Whether this output segment is currently active.\n"
+ "- True: allowed to be emitted/output\n"
+ "- False: blocked until activated by branch control"
+ )
+ )
+
+ is_variable: bool = Field(
+ ...,
+ description=(
+ "Whether this segment represents a variable placeholder.\n"
+ "True -> variable (e.g. {{ node.field }})\n"
+ "False -> literal text"
+ )
+ )
+
+ _SCOPE: str | None = None
+
+ def get_scope(self) -> str:
+ self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
+ return self._SCOPE
+
+ def depends_on_scope(self, scope: str) -> bool:
+ """
+ Check if this segment depends on a given scope.
+
+ Args:
+ scope (str): Node ID or special variable prefix (e.g., "sys").
+
+ Returns:
+ bool: True if this segment references the given scope.
+ """
+ if self._SCOPE:
+ return self._SCOPE == scope
+ return self.get_scope() == scope
+
+
+class StreamOutputConfig(BaseModel):
+ """
+ Streaming output configuration for an End node.
+
+ This configuration describes how the End node output behaves in streaming mode,
+ including:
+ - whether output emission is globally activated
+ - which upstream branch/control nodes gate the activation
+ - how each parsed output segment is streamed and activated
+ """
+
+ activate: bool = Field(
+ ...,
+ description=(
+ "Global activation flag for the End node output.\n"
+ "When False, output segments should not be emitted even if available.\n"
+ "This flag typically becomes True once required control branch conditions "
+ "are satisfied."
+ )
+ )
+
+ control_nodes: dict[str, list[str]] = Field(
+ ...,
+ description=(
+ "Control branch conditions for this End node output.\n"
+ "Mapping of `branch_node_id -> expected_branch_label`.\n"
+ "The End node output becomes globally active when a controlling branch node "
+ "reports a matching completion status."
+ )
+ )
+
+ outputs: list[OutputContent] = Field(
+ ...,
+ description=(
+ "Ordered list of output segments parsed from the output template.\n"
+ "Each segment represents either a literal text block or a variable placeholder "
+ "that may be activated independently."
+ )
+ )
+
+ cursor: int = Field(
+ ...,
+ description=(
+ "Streaming cursor index.\n"
+ "Indicates the next output segment index to be emitted.\n"
+ "Segments with index < cursor are considered already streamed."
+ )
+ )
+
+ def update_activate(self, scope: str, status=None):
+ """
+ Update streaming activation state based on an upstream node or special variable.
+
+ Args:
+ scope (str):
+ Identifier of the completed upstream entity.
+ - If a control branch node, it should match a key in `control_nodes`.
+ - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
+ status (optional):
+ Completion status of the control branch node.
+ Required when `scope` refers to a control node.
+
+ Behavior:
+ 1. Control branch nodes:
+ - If `scope` matches a key in `control_nodes` and `status` matches the expected
+ branch label, the End node output becomes globally active (`activate = True`).
+
+ 2. Variable output segments:
+ - For each segment that is a variable (`is_variable=True`):
+ - If the segment literal references `scope`, mark the segment as active.
+ - This applies both to regular node variables (e.g., "node_id.field")
+ and special system variables (e.g., "sys.xxx").
+
+ Notes:
+ - This method does not emit output or advance the streaming cursor.
+ - It only updates activation flags based on upstream events or special variables.
+ """
+
+ # Case 1: resolve control branch dependency
+ if scope in self.control_nodes.keys():
+ if status is None:
+ raise RuntimeError("[Stream Output] Control node activation status not provided")
+ if status in self.control_nodes[scope]:
+ self.activate = True
+
+ # Case 2: activate variable segments related to this node
+ for i in range(len(self.outputs)):
+ if (
+ self.outputs[i].is_variable
+ and self.outputs[i].depends_on_scope(scope)
+ ):
+ self.outputs[i].activate = True
+
+
+class StreamOutputCoordinator:
+ def __init__(self):
+ self.end_outputs: dict[str, StreamOutputConfig] = {}
+ self.activate_end: str | None = None
+
+ def initialize_end_outputs(
+ self,
+ end_node_map: dict[str, StreamOutputConfig]
+ ):
+ self.end_outputs = end_node_map
+
+ @property
+ def current_activate_end_info(self):
+ return self.end_outputs.get(self.activate_end)
+
+ def pop_current_activate_end(self):
+ self.end_outputs.pop(self.activate_end)
+ self.activate_end = None
+
+ def update_scope_activation(
+ self,
+ scope: str,
+ status: str | None = None
+ ):
+ """
+ Update the activation state of all End nodes based on a completed scope (node or variable).
+
+ Iterates over all End nodes in `self.end_outputs` and calls
+ `update_activate` on each, which may:
+ - Activate variable segments that depend on the completed node/scope.
+ - Activate the entire End node output if any control conditions are met.
+
+ If any End node becomes active and `self.activate_end` is not yet set,
+ this node will be marked as the currently active End node.
+
+ Args:
+ scope (str): The node ID or scope that has completed execution.
+ status (str | None): Optional status of the node (used for branch/control nodes).
+ """
+ for node in self.end_outputs.keys():
+ self.end_outputs[node].update_activate(scope, status)
+ if self.end_outputs[node].activate and self.activate_end is None:
+ self.activate_end = node
+
+ async def emit_activate_chunk(
+ self,
+ variable_pool: VariablePool,
+ force: bool = False
+ ) -> AsyncGenerator[dict[str, str | dict], None]:
+ """
+ Process and yield all currently active output segments for the currently active End node.
+
+ This method handles stream-mode output for an End node by iterating through its output segments
+ (`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
+ `force=True`, which allows all segments to be processed regardless of their activation state.
+
+ Behavior:
+ 1. Iterates from the current `cursor` position to the end of the outputs list.
+ 2. For each segment:
+ - If the segment is literal text (`is_variable=False`), append it directly.
+ - If the segment is a variable (`is_variable=True`), evaluate it using
+ `evaluate_expression` with the given `node_outputs` and `variables`,
+ then transform the result with `_trans_output_string`.
+ 3. Yield a stream event of type "message" containing the processed chunk.
+ 4. Move the `cursor` forward after processing each segment.
+ 5. When all segments have been processed, remove this End node from `end_outputs`
+ and reset `activate_end` to None.
+
+ Args:
+ variable_pool (VariablePool): Pool of variables for evaluating segment values.
+ force (bool, default=False): If True, process segments even if `activate=False`.
+
+ Yields:
+ dict: A stream event of type "message" containing the processed chunk.
+
+ Notes:
+ - Segments that fail evaluation (ValueError) are skipped with a warning logged.
+ - This method only processes the currently active End node (`self.activate_end`).
+ - Use `force=True` for final emission regardless of activation state.
+ """
+ end_info = self.end_outputs[self.activate_end]
+
+ while end_info.cursor < len(end_info.outputs):
+ final_chunk = ''
+ current_segment = end_info.outputs[end_info.cursor]
+
+ if not current_segment.activate and not force:
+ # Stop processing until this segment becomes active
+ break
+
+ # Literal segment
+ if not current_segment.is_variable:
+ final_chunk += current_segment.literal
+ else:
+ # Variable segment: evaluate and transform
+ try:
+ chunk = variable_pool.get_literal(current_segment.literal)
+ final_chunk += chunk
+ except Exception as e:
+ # Log failed evaluation but continue streaming
+ logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
+
+ if final_chunk:
+ logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
+ yield {
+ "event": "message",
+ "data": {
+ "chunk": final_chunk
+ }
+ }
+
+ # Advance cursor after processing
+ end_info.cursor += 1
+
+ if end_info.cursor >= len(end_info.outputs):
+ self.end_outputs.pop(self.activate_end)
+ self.activate_end = None
+
+ async def flush_remaining_chunk(
+ self,
+ variable_pool: VariablePool
+ ) -> AsyncGenerator[dict[str, str | dict], None]:
+ """
+ Flush and yield all remaining output segments from active End nodes.
+
+ This method ensures that any remaining chunks of output, which may not have
+ been emitted during normal streaming due to activation conditions, are fully
+ processed. It is typically called at the end of a workflow to guarantee
+ that all output is delivered.
+
+ Behavior:
+ 1. Filter `end_outputs` to only keep End nodes that are still active.
+ 2. While there is an active End node (`self.activate_end`):
+ - Call `_emit_active_chunks(force=True)` to emit all segments regardless
+ of their activation state.
+ - If the current End node finishes, move to the next active End node
+ if any remain.
+
+ Yields:
+ dict: Streamed output events in the format:
+ {"event": "message", "data": {"chunk": ...}}
+ """
+ # Keep only active End nodes
+ self.end_outputs = {
+ node_id: node_info
+ for node_id, node_info in self.end_outputs.items()
+ if node_info.activate
+ }
+
+ if self.end_outputs or self.activate_end:
+ while self.activate_end:
+ # Force emit all remaining chunks of the active End node
+ async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
+ yield msg_event
+
+ # Move to next active End node if current one is done
+ if not self.activate_end and self.end_outputs:
+ self.activate_end = list(self.end_outputs.keys())[0]
diff --git a/api/app/core/workflow/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py
similarity index 79%
rename from api/app/core/workflow/variable_pool.py
rename to api/app/core/workflow/engine/variable_pool.py
index ae56bcb4..22be08c8 100644
--- a/api/app/core/workflow/variable_pool.py
+++ b/api/app/core/workflow/engine/variable_pool.py
@@ -1,14 +1,7 @@
-"""
-变量池 (Variable Pool)
-
-工作流执行的数据中心,管理所有变量的存储和访问。
-
-变量类型:
-1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
-2. 节点输出 (node_id.*) - 节点执行结果
-3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
-"""
-
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2025/12/15 19:50
import logging
import re
from asyncio import Lock
@@ -18,7 +11,8 @@ from typing import Any, Generic
from pydantic import BaseModel
-from app.core.workflow.variable.base_variable import VariableType
+from app.core.workflow.engine.runtime_schema import ExecutionContext
+from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.variable_objects import T, create_variable_instance
logger = logging.getLogger(__name__)
@@ -359,3 +353,77 @@ class VariablePool:
f" runtime_vars={len(runtime_vars)}\n"
f")"
)
+
+
+class VariablePoolInitializer:
+ def __init__(self, workflow_config: dict):
+ self.workflow_config = workflow_config
+
+ async def initialize(
+ self,
+ variable_pool: VariablePool,
+ input_data: dict,
+ execution_context: ExecutionContext
+ ) -> None:
+ await self._init_conversation_vars(variable_pool, input_data)
+ await self._init_system_vars(variable_pool, input_data, execution_context)
+
+ async def _init_conversation_vars(
+ self,
+ variable_pool: VariablePool,
+ input_data: dict
+ ):
+ init_conv_vars: list[dict] = self.workflow_config.get("variables") or []
+ runtime_conv_vars: dict[str, Any] = input_data.get("conv", {})
+
+ for var_def in init_conv_vars:
+ var_name = var_def.get("name")
+ var_default = runtime_conv_vars.get(var_name, var_def.get("default"))
+ var_type = var_def.get("type")
+ if var_name:
+ if var_default:
+ var_value = var_default
+ else:
+ var_value = DEFAULT_VALUE(var_type)
+ await variable_pool.new(
+ namespace="conv",
+ key=var_name,
+ value=var_value,
+ var_type=var_type,
+ mut=True
+ )
+
+ @staticmethod
+ async def _init_system_vars(
+ variable_pool: VariablePool,
+ input_data: dict,
+ context: ExecutionContext
+ ):
+ user_message = input_data.get("message") or ""
+ user_files = input_data.get("files") or []
+ conversations = input_data.get("conv_messages", [])
+ conversation_index = len(conversations) // 2
+
+ input_variables = input_data.get("variables") or {}
+ sys_vars = {
+ "message": (user_message, VariableType.STRING),
+ "conversation_index": (conversation_index, VariableType.NUMBER),
+ "conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
+ "execution_id": (context.execution_id, VariableType.STRING),
+ "workspace_id": (context.workspace_id, VariableType.STRING),
+ "user_id": (context.user_id, VariableType.STRING),
+ "input_variables": (input_variables, VariableType.OBJECT),
+ "files": (user_files, VariableType.ARRAY_FILE)
+ }
+ for key, var_def in sys_vars.items():
+ value = var_def[0]
+ var_type = var_def[1]
+ await variable_pool.new(
+ namespace='sys',
+ key=key,
+ value=value,
+ var_type=var_type,
+ mut=False
+ )
+
+
diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py
index bebb67fc..2b554a60 100644
--- a/api/app/core/workflow/executor.py
+++ b/api/app/core/workflow/executor.py
@@ -1,21 +1,20 @@
-"""
-工作流执行器
-
-基于 LangGraph 的工作流执行引擎。
-"""
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/9 13:51
import datetime
import logging
-import uuid
from typing import Any
-from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
-from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
-from app.core.workflow.nodes import WorkflowState
-from app.core.workflow.nodes.enums import NodeType
-from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
-from app.core.workflow.variable_pool import VariablePool
+from app.core.workflow.engine.event_stream_handler import EventStreamHandler
+from app.core.workflow.engine.graph_builder import GraphBuilder
+from app.core.workflow.engine.result_builder import WorkflowResultBuilder
+from app.core.workflow.engine.runtime_schema import ExecutionContext
+from app.core.workflow.engine.state_manager import WorkflowStateManager
+from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
+from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
logger = logging.getLogger(__name__)
@@ -30,9 +29,7 @@ class WorkflowExecutor:
def __init__(
self,
workflow_config: dict[str, Any],
- execution_id: str,
- workspace_id: str,
- user_id: str,
+ execution_context: ExecutionContext,
):
"""Initialize Workflow Executor.
@@ -41,13 +38,10 @@ class WorkflowExecutor:
Args:
workflow_config (dict): The workflow configuration dictionary.
- execution_id (str): Unique identifier for this workflow execution.
- workspace_id (str): Workspace or project ID.
- user_id (str): User ID executing the workflow.
+ execution_context (ExecutionContext): The workflow execution context
+ include execution_id, workspace_id, user_id, checkpoint_config
Attributes:
- self.nodes (list): List of node definitions from workflow_config.
- self.edges (list): List of edge definitions from workflow_config.
self.execution_config (dict): Optional execution parameters from workflow_config.
self.start_node_id (str | None): ID of the Start node, set after graph build.
self.end_outputs (dict[str, StreamOutputConfig]): End node output configs.
@@ -57,555 +51,18 @@ class WorkflowExecutor:
self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing.
"""
self.workflow_config = workflow_config
- self.execution_id = execution_id
- self.workspace_id = workspace_id
- self.user_id = user_id
- self.nodes = workflow_config.get("nodes", [])
- self.edges = workflow_config.get("edges", [])
+ self.execution_context = execution_context
self.execution_config = workflow_config.get("execution_config", {})
- self.start_node_id = None
- self.end_outputs: dict[str, StreamOutputConfig] = {}
- self.activate_end: str | None = None
+ self.start_node_id: str | None = None
self.variable_pool: VariablePool | None = None
-
self.graph: CompiledStateGraph | None = None
- self.checkpoint_config = RunnableConfig(
- configurable={
- "thread_id": uuid.uuid4(),
- }
- )
- async def __init_variable_pool(self, input_data: dict[str, Any]):
- """Initialize the variable pool with system, conversation, and input variables.
-
- This method populates the VariablePool instance with:
- - Conversation-level variables (`conv` namespace) from workflow config or provided values.
- - System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables.
-
- Args:
- input_data (dict): Input data for workflow execution, may contain:
- - "message": user message (str)
- - "file": list of user-uploaded files
- - "conv": existing conversation variables (dict)
- - "variables": custom variables for the Start node (dict)
- - "conversation_id": conversation identifier
- """
- user_message = input_data.get("message") or ""
- user_files = input_data.get("files") or []
-
- config_variables_list = self.workflow_config.get("variables") or []
- conv_vars = input_data.get("conv", {})
-
- # Initialize conversation variables (conv namespace)
- for var_def in config_variables_list:
- var_name = var_def.get("name")
- var_default = conv_vars.get(var_name, var_def.get("default"))
- var_type = var_def.get("type")
- if var_name:
- if var_default:
- var_value = var_default
- else:
- var_value = DEFAULT_VALUE(var_type)
- await self.variable_pool.new(
- namespace="conv",
- key=var_name,
- value=var_value,
- var_type=var_type,
- mut=True
- )
-
- # Initialize system variables (sys namespace)
- input_variables = input_data.get("variables") or {}
- sys_vars = {
- "message": (user_message, VariableType.STRING),
- "conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
- "execution_id": (self.execution_id, VariableType.STRING),
- "workspace_id": (self.workspace_id, VariableType.STRING),
- "user_id": (self.user_id, VariableType.STRING),
- "input_variables": (input_variables, VariableType.OBJECT),
- "files": (user_files, VariableType.ARRAY_FILE)
- }
- for key, var_def in sys_vars.items():
- value = var_def[0]
- var_type = var_def[1]
- await self.variable_pool.new(
- namespace='sys',
- key=key,
- value=value,
- var_type=var_type,
- mut=False
- )
-
- def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
- """Generate the initial workflow state for execution.
-
- This method prepares the runtime state dictionary with system variables,
- conversation variables, node outputs, loop tracking, and activation flags.
-
- Args:
- input_data (dict): The input payload for workflow execution.
- Expected keys:
- - "conv_messages" (list, optional): Historical conversation messages
- to include in the workflow state.
-
- Returns:
- WorkflowState: A dictionary representing the initialized workflow state
- with the following keys:
- - "messages": List of conversation messages
- - "node_outputs": Empty dict to store outputs of executed nodes
- - "execution_id": Current workflow execution ID
- - "workspace_id": Current workspace ID
- - "user_id": ID of the user triggering execution
- - "error": None initially, will store error message if a node fails
- - "error_node": None initially, will store ID of node that caused error
- - "cycle_nodes": List of node IDs that are of type LOOP or ITERATION
- - "looping": Integer flag indicating loop execution state (0 = not looping)
- - "activate": Dict mapping node IDs to activation status; initially
- only the start node is active
- """
- conversation_messages = input_data.get("conv_messages") or []
-
- return {
- "messages": conversation_messages,
- "node_outputs": {},
- "execution_id": self.execution_id,
- "workspace_id": self.workspace_id,
- "user_id": self.user_id,
- "error": None,
- "error_node": None,
- "cycle_nodes": [
- node.get("id")
- for node in self.workflow_config.get("nodes")
- if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
- ], # loop, iteration node id
- "looping": 0, # loop runing flag, only use in loop node,not use in main loop
- "activate": {
- self.start_node_id: True
- }
- }
-
- def _build_final_output(self, result, elapsed_time, final_output):
- """Construct the final standardized output of the workflow execution.
-
- This method aggregates node outputs, token usage, conversation and system
- variables, messages, and other metadata into a consistent dictionary
- structure suitable for returning from workflow execution.
-
- Args:
- result (dict): The runtime state returned by the workflow graph execution.
- Expected keys include:
- - "node_outputs" (dict): Outputs of executed nodes.
- - "messages" (list): Conversation messages exchanged during execution.
- - "error" (str, optional): Error message if any node failed.
- elapsed_time (float): Total execution time in seconds.
- final_output (Any): The aggregated or final output content of the workflow
- (e.g., combined messages from all End nodes).
-
- Returns:
- dict: A dictionary containing the final workflow execution result with keys:
- - "status": Execution status ("completed")
- - "output": Aggregated final output content
- - "variables": Namespace dictionary with:
- - "conv": Conversation variables
- - "sys": System variables
- - "node_outputs": Outputs from all executed nodes
- - "messages": Conversation messages exchanged
- - "conversation_id": ID of the current conversation
- - "elapsed_time": Total execution time in seconds
- - "token_usage": Aggregated token usage across nodes (if available)
- - "error": Error message if any occurred during execution
- """
- node_outputs = result.get("node_outputs", {})
- token_usage = self._aggregate_token_usage(node_outputs)
- conversation_id = self.variable_pool.get_value("sys.conversation_id")
-
- return {
- "status": "completed",
- "output": final_output,
- "variables": {
- "conv": self.variable_pool.get_all_conversation_vars(),
- "sys": self.variable_pool.get_all_system_vars()
- },
- "node_outputs": node_outputs,
- "messages": result.get("messages", []),
- "conversation_id": conversation_id,
- "elapsed_time": elapsed_time,
- "token_usage": token_usage,
- "error": result.get("error"),
- }
-
- def _update_scope_activate(self, scope, status=None):
- """
- Update the activation state of all End nodes based on a completed scope (node or variable).
-
- Iterates over all End nodes in `self.end_outputs` and calls
- `update_activate` on each, which may:
- - Activate variable segments that depend on the completed node/scope.
- - Activate the entire End node output if any control conditions are met.
-
- If any End node becomes active and `self.activate_end` is not yet set,
- this node will be marked as the currently active End node.
-
- Args:
- scope (str): The node ID or scope that has completed execution.
- status (str | None): Optional status of the node (used for branch/control nodes).
- """
- for node in self.end_outputs.keys():
- self.end_outputs[node].update_activate(scope, status)
- if self.end_outputs[node].activate and self.activate_end is None:
- self.activate_end = node
-
- def _update_stream_output_status(self, activate, data):
- """
- Update the stream output state of End nodes based on workflow state updates.
-
- This method checks which nodes/scopes are activated and propagates
- activation to End nodes accordingly.
-
- Args:
- activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
- data (dict): Mapping of node_id -> node runtime data, including outputs.
-
- Behavior:
- For each node in `data`:
- 1. If the node is activated (`activate[node_id]` is True),
- retrieve its output status from `runtime_vars`.
- 2. Call `_update_scope_activate` to propagate the activation
- to all relevant End nodes and update `self.activate_end`.
- """
- for node_id in data.keys():
- if activate.get(node_id):
- node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
- self._update_scope_activate(node_id, status=node_output_status)
-
- async def _emit_active_chunks(
- self,
- force=False
- ):
- """
- Process and yield all currently active output segments for the currently active End node.
-
- This method handles stream-mode output for an End node by iterating through its output segments
- (`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
- `force=True`, which allows all segments to be processed regardless of their activation state.
-
- Behavior:
- 1. Iterates from the current `cursor` position to the end of the outputs list.
- 2. For each segment:
- - If the segment is literal text (`is_variable=False`), append it directly.
- - If the segment is a variable (`is_variable=True`), evaluate it using
- `evaluate_expression` with the given `node_outputs` and `variables`,
- then transform the result with `_trans_output_string`.
- 3. Yield a stream event of type "message" containing the processed chunk.
- 4. Move the `cursor` forward after processing each segment.
- 5. When all segments have been processed, remove this End node from `end_outputs`
- and reset `activate_end` to None.
-
- Args:
- force (bool, default=False): If True, process segments even if `activate=False`.
-
- Yields:
- dict: A stream event of type "message" containing the processed chunk.
-
- Notes:
- - Segments that fail evaluation (ValueError) are skipped with a warning logged.
- - This method only processes the currently active End node (`self.activate_end`).
- - Use `force=True` for final emission regardless of activation state.
- """
-
- end_info = self.end_outputs[self.activate_end]
-
- while end_info.cursor < len(end_info.outputs):
- final_chunk = ''
- current_segment = end_info.outputs[end_info.cursor]
-
- if not current_segment.activate and not force:
- # Stop processing until this segment becomes active
- break
-
- # Literal segment
- if not current_segment.is_variable:
- final_chunk += current_segment.literal
- else:
- # Variable segment: evaluate and transform
- try:
- chunk = self.variable_pool.get_literal(current_segment.literal)
- final_chunk += chunk
- except KeyError:
- # Log failed evaluation but continue streaming
- logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
-
- if final_chunk:
- logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
- yield {
- "event": "message",
- "data": {
- "chunk": final_chunk
- }
- }
-
- # Advance cursor after processing
- end_info.cursor += 1
-
- # Remove End node from active tracking if all segments have been processed
- if end_info.cursor >= len(end_info.outputs):
- self.end_outputs.pop(self.activate_end)
- self.activate_end = None
-
- async def _handle_updates_event(self, data):
- """
- Handle workflow state update events ("updates") and stream active End node outputs.
-
- Steps:
- 1. Retrieve the current graph state.
- 2. Extract node activation information from the state.
- 3. Update the activation status of all End nodes.
- 4. While there is an active End node:
- - Call _emit_active_chunks() to yield all currently active output segments.
- - After all segments are processed, update activate_end if there are remaining End nodes.
- 5. Log a debug message indicating state update received.
-
- Args:
- data (dict): The latest node state updates.
-
- Yields:
- dict: Streamed output event, each chunk in the format:
- {"event": "message", "data": {"chunk": ...}}
- """
- # Get the latest workflow state
- state = self.graph.get_state(config=self.checkpoint_config).values
- activate = state.get("activate", {})
-
- # Update End node activation based on the new state
- self._update_stream_output_status(activate, data)
- wait = False
- while self.activate_end and not wait:
- async for msg_event in self._emit_active_chunks():
- yield msg_event
-
- if self.activate_end:
- wait = True
- else:
- self._update_stream_output_status(activate, data)
-
- logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
- f"- execution_id: {self.execution_id}")
-
- async def _handle_node_chunk_event(self, data):
- """
- Handle streaming chunk events from individual nodes ("node_chunk").
-
- This method processes output segments for the currently active End node.
- If the segment depends on the provided node_id:
- - If the node has finished execution (`done=True`), advance the cursor.
- - If all segments are processed, deactivate the End node.
- - Otherwise, yield the current chunk as a streaming message.
-
- Args:
- data (dict): Node chunk event data, expected keys:
- - "node_id": ID of the node producing this chunk
- - "chunk": Chunk of output text
- - "done": Boolean indicating whether the node finished producing output
-
- Yields:
- dict: Streaming message event in the format:
- {"event": "message", "data": {"chunk": ...}}
- """
- node_id = data.get("node_id")
- if self.activate_end:
- end_info = self.end_outputs.get(self.activate_end)
- if not end_info or end_info.cursor >= len(end_info.outputs):
- return
- current_output = end_info.outputs[end_info.cursor]
- if current_output.is_variable and current_output.depends_on_scope(node_id):
- if data.get("done"):
- end_info.cursor += 1
- if end_info.cursor >= len(end_info.outputs):
- self.end_outputs.pop(self.activate_end)
- self.activate_end = None
- else:
- yield {
- "event": "message",
- "data": {
- "chunk": data.get("chunk")
- }
- }
-
- async def _handle_node_error_event(self, data):
- """
- Handle node error events ("node_error") during workflow execution.
-
- This method streams an error event for a node that has failed. The event
- contains the node ID, status, input data, elapsed time, and error message.
-
- Args:
- data (dict): Node error event data, expected keys:
- - "node_id": ID of the node that failed
- - "input_data": The input data that caused the error
- - "elapsed_time": Execution time before the error occurred
- - "error": Error message or exception string
-
- Yields:
- dict: Node error event in the format:
- {
- "event": "node_error",
- "data": {
- "node_id": str,
- "status": "failed",
- "input": ...,
- "elapsed_time": float,
- "output": None,
- "error": str
- }
- }
- """
- node_id = data.get("node_id")
- yield {
- "event": "node_error",
- "data": {
- "node_id": node_id,
- "status": "failed",
- "input": data.get("input_data"),
- "elapsed_time": data.get("elapsed_time"),
- "output": None,
- "error": data.get("error")
- }
- }
-
- async def _handle_debug_event(self, data, input_data):
- """
- Handle debug events ("debug") related to node execution status.
-
- This method streams debug events for nodes, including when a node starts
- execution ("node_start") and when it completes execution ("node_end").
- It filters out nodes with names starting with "nop" as no-operation nodes.
-
- Args:
- data (dict): Debug event data, expected keys:
- - "type": Event type ("task" for start, "task_result" for completion)
- - "payload": Node-related information, including:
- - "name": Node name / ID
- - "input": Node input data (for "task" type)
- - "result": Node execution result (for "task_result" type)
- - "timestamp": ISO timestamp string of the event
- input_data (dict): Original workflow input data (used to get conversation_id)
-
- Yields:
- dict: Node debug event in one of the following formats:
- 1. Node start:
- {
- "event": "node_start",
- "data": {
- "node_id": str,
- "conversation_id": str,
- "execution_id": str,
- "timestamp": int (ms)
- }
- }
- 2. Node end:
- {
- "event": "node_end",
- "data": {
- "node_id": str,
- "conversation_id": str,
- "execution_id": str,
- "timestamp": int (ms),
- "input": dict,
- "output": Any,
- "elapsed_time": float
- }
- }
- """
- event_type = data.get("type")
- payload = data.get("payload", {})
- node_name = payload.get("name")
-
- # Skip no-operation nodes
- if node_name and node_name.startswith("nop"):
- return
-
- if event_type == "task":
- # Node starts execution
- inputv = payload.get("input", {})
- if not inputv.get("activate", {}).get(node_name):
- return
- conversation_id = input_data.get("conversation_id")
- logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
-
- yield {
- "event": "node_start",
- "data": {
- "node_id": node_name,
- "conversation_id": conversation_id,
- "execution_id": self.execution_id,
- "timestamp": int(datetime.datetime.fromisoformat(
- data.get("timestamp")
- ).timestamp() * 1000),
- }
- }
- elif event_type == "task_result":
- # Node execution completed
- result = payload.get("result", {})
- if not result.get("activate", {}).get(node_name):
- return
-
- conversation_id = input_data.get("conversation_id")
- logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
-
- yield {
- "event": "node_end",
- "data": {
- "node_id": node_name,
- "conversation_id": conversation_id,
- "execution_id": self.execution_id,
- "timestamp": int(datetime.datetime.fromisoformat(
- data.get("timestamp")
- ).timestamp() * 1000),
- "input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
- "output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
- "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
- "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
- }
- }
-
- async def _flush_remaining_chunk(self):
- """
- Flush and yield all remaining output segments from active End nodes.
-
- This method ensures that any remaining chunks of output, which may not have
- been emitted during normal streaming due to activation conditions, are fully
- processed. It is typically called at the end of a workflow to guarantee
- that all output is delivered.
-
- Behavior:
- 1. Filter `end_outputs` to only keep End nodes that are still active.
- 2. While there is an active End node (`self.activate_end`):
- - Call `_emit_active_chunks(force=True)` to emit all segments regardless
- of their activation state.
- - If the current End node finishes, move to the next active End node
- if any remain.
-
- Yields:
- dict: Streamed output events in the format:
- {"event": "message", "data": {"chunk": ...}}
- """
- # Keep only active End nodes
- self.end_outputs = {
- node_id: node_info
- for node_id, node_info in self.end_outputs.items()
- if node_info.activate
- }
-
- if self.end_outputs or self.activate_end:
- while self.activate_end:
- # Force emit all remaining chunks of the active End node
- async for msg_event in self._emit_active_chunks(force=True):
- yield msg_event
-
- # Move to next active End node if current one is done
- if not self.activate_end and self.end_outputs:
- self.activate_end = list(self.end_outputs.keys())[0]
+ self.variable_initializer = VariablePoolInitializer(workflow_config)
+ self.state_manager = WorkflowStateManager()
+ self.result_builder = WorkflowResultBuilder()
+ self.stream_coordinator = StreamOutputCoordinator()
+ self.event_handler: EventStreamHandler | None = None
def build_graph(self, stream=False) -> CompiledStateGraph:
"""
@@ -624,16 +81,22 @@ class WorkflowExecutor:
Returns:
CompiledStateGraph: The compiled and ready-to-run state graph.
"""
- logger.info(f"Starting workflow graph build: execution_id={self.execution_id}")
+ logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
builder = GraphBuilder(
self.workflow_config,
stream=stream,
)
self.start_node_id = builder.start_node_id
- self.end_outputs = builder.end_node_map
self.variable_pool = builder.variable_pool
self.graph = builder.build()
- logger.info(f"Workflow graph build completed: execution_id={self.execution_id}")
+
+ self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
+ self.event_handler = EventStreamHandler(
+ output_coordinator=self.stream_coordinator,
+ variable_pool=self.variable_pool,
+ execution_id=self.execution_context.execution_id
+ )
+ logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
return self.graph
@@ -665,7 +128,7 @@ class WorkflowExecutor:
- token_usage: aggregated token usage if available
- error: error message if any
"""
- logger.info(f"Starting workflow execution: execution_id={self.execution_id}")
+ logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
start_time = datetime.datetime.now()
@@ -673,16 +136,25 @@ class WorkflowExecutor:
graph = self.build_graph()
# Initialize the variable pool with input data
- await self.__init_variable_pool(input_data)
- initial_state = self._prepare_initial_state(input_data)
+ await self.variable_initializer.initialize(
+ variable_pool=self.variable_pool,
+ input_data=input_data,
+ execution_context=self.execution_context
+ )
+ initial_state = self.state_manager.create_initial_state(
+ workflow_config=self.workflow_config,
+ input_data=input_data,
+ execution_context=self.execution_context,
+ start_node_id=self.start_node_id
+ )
# Execute the workflow
try:
- result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
+ result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
full_content = ''
- for end_id in self.end_outputs.keys():
+ for end_id in self.stream_coordinator.end_outputs.keys():
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
# Append messages for user and assistant
@@ -703,15 +175,16 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
- f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
+ f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
- return self._build_final_output(result, elapsed_time, full_content)
+ return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
except Exception as e:
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
- logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
+ logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
+ exc_info=True)
return {
"status": "failed",
"error": str(e),
@@ -744,15 +217,15 @@ class WorkflowExecutor:
"data": {...}
}
"""
- logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}")
+ logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_context.execution_id}")
start_time = datetime.datetime.now()
yield {
"event": "workflow_start",
"data": {
- "execution_id": self.execution_id,
- "workspace_id": self.workspace_id,
+ "execution_id": self.execution_context.execution_id,
+ "workspace_id": self.execution_context.workspace_id,
"conversation_id": input_data.get("conversation_id"),
"timestamp": int(start_time.timestamp() * 1000)
}
@@ -762,18 +235,27 @@ class WorkflowExecutor:
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
- await self.__init_variable_pool(input_data)
- initial_state = self._prepare_initial_state(input_data)
+ await self.variable_initializer.initialize(
+ variable_pool=self.variable_pool,
+ input_data=input_data,
+ execution_context=self.execution_context
+ )
+ initial_state = self.state_manager.create_initial_state(
+ workflow_config=self.workflow_config,
+ input_data=input_data,
+ execution_context=self.execution_context,
+ start_node_id=self.start_node_id
+ )
try:
full_content = ''
- self._update_scope_activate("sys")
+ self.stream_coordinator.update_scope_activation("sys")
# Execute the workflow with streaming
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
- config=self.checkpoint_config
+ config=self.execution_context.checkpoint_config
):
# event should be a tuple: (mode, data)
# But let's handle both cases
@@ -782,38 +264,46 @@ class WorkflowExecutor:
else:
# Unexpected format, log and skip
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
- f"- execution_id: {self.execution_id}")
+ f"- execution_id: {self.execution_context.execution_id}")
continue
if mode == "custom":
# Handle custom streaming events (chunks from nodes via stream writer)
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type == "node_chunk":
- async for msg_event in self._handle_node_chunk_event(data):
+ async for msg_event in self.event_handler.handle_node_chunk_event(data):
full_content += msg_event["data"]["chunk"]
yield msg_event
elif event_type == "node_error":
- async for error_event in self._handle_node_error_event(data):
+ async for error_event in self.event_handler.handle_node_error_event(data):
yield error_event
+ elif event_type == "cycle_item":
+ async for cycle_event in self.event_handler.handle_cycle_item_event(data):
+ yield cycle_event
+
elif mode == "debug":
- async for debug_event in self._handle_debug_event(data, input_data):
+ async for debug_event in self.event_handler.handle_debug_event(data, input_data):
yield debug_event
elif mode == "updates":
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
- f"- execution_id: {self.execution_id}")
- async for msg_event in self._handle_updates_event(data):
+ f"- execution_id: {self.execution_context.execution_id}")
+ async for msg_event in self.event_handler.handle_updates_event(
+ data,
+ self.graph,
+ self.execution_context.checkpoint_config
+ ):
full_content += msg_event["data"]['chunk']
yield msg_event
# Flush any remaining chunks
- async for msg_event in self._flush_remaining_chunk():
+ async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
full_content += msg_event["data"]['chunk']
yield msg_event
- result = graph.get_state(self.checkpoint_config).values
+ result = graph.get_state(self.execution_context.checkpoint_config).values
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
@@ -832,24 +322,25 @@ class WorkflowExecutor:
)
logger.info(
f"Workflow execution completed (streaming), "
- f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
+ f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
)
yield {
"event": "workflow_end",
- "data": self._build_final_output(result, elapsed_time, full_content)
+ "data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
}
except Exception as e:
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
- logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
+ logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
+ exc_info=True)
yield {
"event": "workflow_end",
"data": {
- "execution_id": self.execution_id,
+ "execution_id": self.execution_context.execution_id,
"status": "failed",
"error": str(e),
"elapsed_time": elapsed_time,
@@ -857,46 +348,6 @@ class WorkflowExecutor:
}
}
- @staticmethod
- def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
- """
- Aggregate token usage statistics across all nodes.
-
- Args:
- node_outputs (dict): A dictionary of all node outputs.
-
- Returns:
- dict | None: Aggregated token usage in the format:
- {
- "prompt_tokens": int,
- "completion_tokens": int,
- "total_tokens": int
- }
- Returns None if no token usage information is available.
- """
- total_prompt_tokens = 0
- total_completion_tokens = 0
- total_tokens = 0
- has_token_info = False
-
- for node_output in node_outputs.values():
- if isinstance(node_output, dict):
- token_usage = node_output.get("token_usage")
- if token_usage and isinstance(token_usage, dict):
- has_token_info = True
- total_prompt_tokens += token_usage.get("prompt_tokens", 0)
- total_completion_tokens += token_usage.get("completion_tokens", 0)
- total_tokens += token_usage.get("total_tokens", 0)
-
- if not has_token_info:
- return None
-
- return {
- "prompt_tokens": total_prompt_tokens,
- "completion_tokens": total_completion_tokens,
- "total_tokens": total_tokens
- }
-
async def execute_workflow(
workflow_config: dict[str, Any],
@@ -918,12 +369,15 @@ async def execute_workflow(
Returns:
dict: Workflow execution result.
"""
- executor = WorkflowExecutor(
- workflow_config=workflow_config,
+ execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
+ executor = WorkflowExecutor(
+ workflow_config=workflow_config,
+ execution_context=execution_context
+ )
return await executor.execute(input_data)
@@ -947,11 +401,14 @@ async def execute_workflow_stream(
Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
"""
- executor = WorkflowExecutor(
- workflow_config=workflow_config,
+ execution_context = ExecutionContext.create(
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
)
+ executor = WorkflowExecutor(
+ workflow_config=workflow_config,
+ execution_context=execution_context
+ )
async for event in executor.execute_stream(input_data):
yield event
diff --git a/api/app/core/workflow/nodes/__init__.py b/api/app/core/workflow/nodes/__init__.py
index 885dfbc9..7c24d079 100644
--- a/api/app/core/workflow/nodes/__init__.py
+++ b/api/app/core/workflow/nodes/__init__.py
@@ -6,7 +6,8 @@
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.nodes.base_node import BaseNode
+from app.core.workflow.nodes.code import CodeNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.http_request import HttpRequestNode
from app.core.workflow.nodes.if_else import IfElseNode
@@ -14,16 +15,14 @@ from app.core.workflow.nodes.jinja_render import JinjaRenderNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
-from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
+from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
-from app.core.workflow.nodes.code import CodeNode
__all__ = [
"BaseNode",
- "WorkflowState",
"LLMNode",
"AgentNode",
"IfElseNode",
diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py
index 0818749c..98d8bb75 100644
--- a/api/app/core/workflow/nodes/agent/node.py
+++ b/api/app/core/workflow/nodes/agent/node.py
@@ -7,14 +7,16 @@ Agent 节点实现
import logging
from typing import Any
+
from langchain_core.messages import AIMessage
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
-from app.services.draft_run_service import DraftRunService
-from app.models import AppRelease
from app.db import get_db
+from app.models import AppRelease
+from app.services.draft_run_service import DraftRunService
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py
index e1bb6e9d..be51f81d 100644
--- a/api/app/core/workflow/nodes/assigner/node.py
+++ b/api/app/core/workflow/nodes/assigner/node.py
@@ -2,12 +2,13 @@ import logging
import re
from typing import Any
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py
index 107567e1..a01ffbe3 100644
--- a/api/app/core/workflow/nodes/base_node.py
+++ b/api/app/core/workflow/nodes/base_node.py
@@ -5,57 +5,17 @@ from functools import cached_property
from typing import Any, AsyncGenerator
from langgraph.config import get_stream_writer
-from typing_extensions import TypedDict, Annotated
from app.core.config import settings
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.services.multimodal_service import PROVIDER_STRATEGIES
logger = logging.getLogger(__name__)
-def merge_activate_state(x, y):
- return {
- k: x.get(k, False) or y.get(k, False)
- for k in set(x) | set(y)
- }
-
-
-def merge_looping_state(x, y):
- return y if y > x else x
-
-
-class WorkflowState(TypedDict):
- """Workflow state
-
- The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
- """
- # List of messages (append mode)
- messages: Annotated[list[dict[str, str]], lambda x, y: y]
-
- # Set of loop node IDs, used for assigning values in loop nodes
- cycle_nodes: list
- looping: Annotated[int, merge_looping_state]
-
- # Node outputs (stores execution results of each node for variable references)
- # Uses a custom merge function to combine new node outputs into the existing dictionary
- node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
-
- # Execution context
- execution_id: str
- workspace_id: str
- user_id: str
-
- # Error information (for error edges)
- error: str | None
- error_node: str | None
-
- # node activate status
- activate: Annotated[dict[str, bool], merge_activate_state]
-
-
class BaseNode(ABC):
"""Base class for workflow nodes.
@@ -584,7 +544,7 @@ class BaseNode(ABC):
Returns:
The rendered string with all variables substituted.
"""
- from app.core.workflow.template_renderer import render_template
+ from app.core.workflow.utils.template_renderer import render_template
return render_template(
template=template,
@@ -611,7 +571,7 @@ class BaseNode(ABC):
Returns:
The boolean result of evaluating the expression.
"""
- from app.core.workflow.expression_evaluator import evaluate_condition
+ from app.core.workflow.utils.expression_evaluator import evaluate_condition
return evaluate_condition(
expression=expression,
diff --git a/api/app/core/workflow/nodes/breaker/node.py b/api/app/core/workflow/nodes/breaker/node.py
index 8b772d6a..34162c1d 100644
--- a/api/app/core/workflow/nodes/breaker/node.py
+++ b/api/app/core/workflow/nodes/breaker/node.py
@@ -1,9 +1,10 @@
import logging
from typing import Any
-from app.core.workflow.nodes import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes import BaseNode
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py
index f6176edf..9303302d 100644
--- a/api/app/core/workflow/nodes/code/node.py
+++ b/api/app/core/workflow/nodes/code/node.py
@@ -6,13 +6,14 @@ import urllib.parse
from string import Template
from textwrap import dedent
from typing import Any
-import urllib.parse
+
import httpx
-from app.core.workflow.nodes import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes import BaseNode
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py
index 762da847..e4026f2d 100644
--- a/api/app/core/workflow/nodes/cycle_graph/iteration.py
+++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py
@@ -1,14 +1,18 @@
import asyncio
import logging
import re
+import uuid
from typing import Any
+from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
+from langgraph.config import get_stream_writer
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
+from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -25,6 +29,7 @@ class IterationRuntime:
def __init__(
self,
start_id: str,
+ stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
@@ -42,6 +47,7 @@ class IterationRuntime:
state: Current workflow state at the point of iteration.
"""
self.start_id = start_id
+ self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
@@ -49,6 +55,12 @@ class IterationRuntime:
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
+ self.event_write = get_stream_writer()
+ self.checkpoint = RunnableConfig(
+ configurable={
+ "thread_id": uuid.uuid4()
+ }
+ )
self.output_value = None
self.result: list = []
@@ -91,7 +103,46 @@ class IterationRuntime:
item: The input element for this iteration.
idx: The index of this iteration.
"""
- result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
+ if self.stream:
+ async for event in self.graph.astream(
+ await self._init_iteration_state(item, idx),
+ stream_mode=["debug"],
+ config=self.checkpoint
+ ):
+ if isinstance(event, tuple) and len(event) == 2:
+ mode, data = event
+ else:
+ continue
+ if mode == "debug":
+ event_type = data.get("type")
+ payload = data.get("payload", {})
+ node_name = payload.get("name")
+
+ if node_name and node_name.startswith("nop"):
+ continue
+ if event_type == "task_result":
+ result = payload.get("result", {})
+ if not result.get("activate", {}).get(node_name):
+ continue
+ node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
+ cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
+ self.event_write({
+ "type": "cycle_item",
+ "data": {
+ "cycle_id": self.node_id,
+ "cycle_idx": idx,
+ "node_id": node_name,
+ "input": result.get("node_outputs", {}).get(node_name, {}).get("input")
+ if not cycle_variable else cycle_variable,
+ "output": result.get("node_outputs", {}).get(node_name, {}).get("output")
+ if not cycle_variable else cycle_variable,
+ "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
+ "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
+ }
+ })
+ result = self.graph.get_state(config=self.checkpoint).values
+ else:
+ result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
@@ -152,16 +203,9 @@ class IterationRuntime:
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
- result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
- child_state.append(result)
- output = self.child_variable_pool.get_value(self.output_value)
+ result = await self.run_task(item, idx)
self.merge_conv_vars()
- if isinstance(output, list) and self.typed_config.flatten:
- self.result.extend(output)
- else:
- self.result.append(output)
- if result["looping"] == 2:
- self.looping = False
+ child_state.append(result)
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return {
diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py
index 7204a642..cebadfdc 100644
--- a/api/app/core/workflow/nodes/cycle_graph/loop.py
+++ b/api/app/core/workflow/nodes/cycle_graph/loop.py
@@ -1,14 +1,17 @@
import logging
+import uuid
from typing import Any
+from langchain_core.runnables import RunnableConfig
+from langgraph.config import get_stream_writer
from langgraph.graph.state import CompiledStateGraph
-from app.core.workflow.expression_evaluator import evaluate_expression
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
-from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
+from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
-from app.core.workflow.variable_pool import VariablePool
+from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -27,6 +30,7 @@ class LoopRuntime:
def __init__(
self,
start_id: str,
+ stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
@@ -46,6 +50,7 @@ class LoopRuntime:
child_variable_pool: A VariablePool instance for managing child node outputs.
"""
self.start_id = start_id
+ self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
@@ -53,6 +58,13 @@ class LoopRuntime:
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
+ self.event_write = get_stream_writer()
+
+ self.checkpoint = RunnableConfig(
+ configurable={
+ "thread_id": uuid.uuid4()
+ }
+ )
async def _init_loop_state(self):
"""
@@ -142,10 +154,12 @@ class LoopRuntime:
case _:
raise ValueError(f"Invalid condition: {operator}")
- def merge_conv_vars(self):
+ def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
)
+ loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
+ loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool:
"""
@@ -175,6 +189,50 @@ class LoopRuntime:
else:
return any(conditions)
+ async def _run(self, loopstate, idx):
+ if self.stream:
+ async for event in self.graph.astream(
+ loopstate,
+ stream_mode=["debug"],
+ config=self.checkpoint
+ ):
+ if isinstance(event, tuple) and len(event) == 2:
+ mode, data = event
+ else:
+ continue
+ if mode == "debug":
+ event_type = data.get("type")
+ payload = data.get("payload", {})
+ node_name = payload.get("name")
+
+ if node_name and node_name.startswith("nop"):
+ continue
+ if event_type == "task_result":
+ result = payload.get("result", {})
+ node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
+ if not result.get("activate", {}).get(node_name):
+ continue
+ cycle_variable = None
+ if node_type == NodeType.CYCLE_START:
+ cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
+ self.event_write({
+ "type": "cycle_item",
+ "data": {
+ "cycle_id": self.node_id,
+ "cycle_idx": idx,
+ "node_id": node_name,
+ "input": result.get("node_outputs", {}).get(node_name, {}).get("input")
+ if not cycle_variable else cycle_variable,
+ "output": result.get("node_outputs", {}).get(node_name, {}).get("output")
+ if not cycle_variable else cycle_variable,
+ "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
+ "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
+ }
+ })
+ return self.graph.get_state(config=self.checkpoint).values
+ else:
+ return await self.graph.ainvoke(loopstate)
+
async def run(self):
"""
Execute the loop node until termination conditions are met.
@@ -190,15 +248,17 @@ class LoopRuntime:
loopstate = await self._init_loop_state()
loop_time = self.typed_config.max_loop
child_state = []
+ idx = 0
while not self.evaluate_conditional() and self.looping and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
- result = await self.graph.ainvoke(loopstate)
+ result = await self._run(loopstate, idx)
child_state.append(result)
- self.merge_conv_vars()
+ self.merge_conv_vars(loopstate)
if result["looping"] == 2:
self.looping = False
loop_time -= 1
+ idx += 1
logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py
index 6908cb73..f2912e2c 100644
--- a/api/app/core/workflow/nodes/cycle_graph/node.py
+++ b/api/app/core/workflow/nodes/cycle_graph/node.py
@@ -4,14 +4,14 @@ from typing import Any
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -136,7 +136,7 @@ class CycleGraphNode(BaseNode):
2. Construct a StateGraph using GraphBuilder in subgraph mode
3. Compile the graph for runtime execution
"""
- from app.core.workflow.graph_builder import GraphBuilder
+ from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool()
builder = GraphBuilder(
@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
start_id=self.start_node_id,
+ stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
start_id=self.start_node_id,
+ stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
child_variable_pool=self.child_variable_pool
).run()
raise RuntimeError("Unknown cycle node type")
+
+ async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
+ if self.node_type == NodeType.LOOP:
+ yield {
+ "__final__": True,
+ "result": await LoopRuntime(
+ start_id=self.start_node_id,
+ stream=True,
+ graph=self.graph,
+ node_id=self.node_id,
+ config=self.config,
+ state=state,
+ variable_pool=variable_pool,
+ child_variable_pool=self.child_variable_pool,
+ ).run()
+ }
+ return
+ if self.node_type == NodeType.ITERATION:
+ yield {
+ "__final__": True,
+ "result": await IterationRuntime(
+ start_id=self.start_node_id,
+ stream=True,
+ graph=self.graph,
+ node_id=self.node_id,
+ config=self.config,
+ state=state,
+ variable_pool=variable_pool,
+ child_variable_pool=self.child_variable_pool
+ ).run()
+ }
+ return
+ raise RuntimeError("Unknown cycle node type")
diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py
index a13a8153..2799316a 100644
--- a/api/app/core/workflow/nodes/end/node.py
+++ b/api/app/core/workflow/nodes/end/node.py
@@ -6,9 +6,10 @@ End 节点实现
import logging
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py
index 64fdfcb9..cdb34b57 100644
--- a/api/app/core/workflow/nodes/http_request/node.py
+++ b/api/app/core/workflow/nodes/http_request/node.py
@@ -7,11 +7,12 @@ import httpx
# import filetypes # TODO: File support (Feature)
from httpx import AsyncClient, Response, Timeout
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__file__)
diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py
index 3e5ea22a..894898f0 100644
--- a/api/app/core/workflow/nodes/if_else/config.py
+++ b/api/app/core/workflow/nodes/if_else/config.py
@@ -60,7 +60,7 @@ class IfElseNodeConfig(BaseNodeConfig):
@field_validator("cases")
@classmethod
- def validate_case_number(cls, v, info):
+ def validate_case_number(cls, v):
if len(v) < 1:
raise ValueError("At least one cases are required")
return v
diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py
index 3c6d0e36..29f7085b 100644
--- a/api/app/core/workflow/nodes/if_else/node.py
+++ b/api/app/core/workflow/nodes/if_else/node.py
@@ -2,12 +2,13 @@ import logging
import re
from typing import Any
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py
index 240b003b..e13709d4 100644
--- a/api/app/core/workflow/nodes/jinja_render/node.py
+++ b/api/app/core/workflow/nodes/jinja_render/node.py
@@ -1,12 +1,12 @@
import logging
from typing import Any
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
-from app.core.workflow.template_renderer import TemplateRenderer
+from app.core.workflow.utils.template_renderer import TemplateRenderer
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py
index 1e146721..17f55319 100644
--- a/api/app/core/workflow/nodes/knowledge/node.py
+++ b/api/app/core/workflow/nodes/knowledge/node.py
@@ -6,10 +6,11 @@ from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
diff --git a/api/app/core/workflow/nodes/llm/config.py b/api/app/core/workflow/nodes/llm/config.py
index 1229450f..771262c1 100644
--- a/api/app/core/workflow/nodes/llm/config.py
+++ b/api/app/core/workflow/nodes/llm/config.py
@@ -1,6 +1,7 @@
"""LLM 节点配置"""
from typing import Any
+import uuid
from pydantic import BaseModel, Field, field_validator
@@ -56,7 +57,7 @@ class LLMNodeConfig(BaseNodeConfig):
2. 消息模式:使用 messages 字段(推荐)
"""
- model_id: str = Field(
+ model_id: uuid.UUID = Field(
...,
description="模型配置 ID"
)
@@ -148,7 +149,7 @@ class LLMNodeConfig(BaseNodeConfig):
@field_validator("messages", "prompt")
@classmethod
- def validate_input_mode(cls, v, info):
+ def validate_input_mode(cls, v):
"""验证输入模式:prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py
index 761a2e22..fdd5df58 100644
--- a/api/app/core/workflow/nodes/llm/node.py
+++ b/api/app/core/workflow/nodes/llm/node.py
@@ -13,10 +13,11 @@ from langchain_core.messages import AIMessage
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
@@ -268,7 +269,7 @@ class LLMNode(BaseNode):
llm = await self._prepare_llm(state, variable_pool, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
- logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
+ # logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 累积完整响应
full_response = ""
diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py
index 654ea0c6..1d42e82e 100644
--- a/api/app/core/workflow/nodes/memory/node.py
+++ b/api/app/core/workflow/nodes/memory/node.py
@@ -1,10 +1,10 @@
from typing import Any
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task
diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py
index 251d6a79..be33d35a 100644
--- a/api/app/core/workflow/nodes/operators.py
+++ b/api/app/core/workflow/nodes/operators.py
@@ -3,9 +3,9 @@ import re
from abc import ABC
from typing import Union, Type, NoReturn, Any
-from app.core.workflow.variable.base_variable import VariableType
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import ValueInputType
-from app.core.workflow.variable_pool import VariablePool
+from app.core.workflow.variable.base_variable import VariableType
class TypeTransformer:
diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py
index 9dd91cad..4811c118 100644
--- a/api/app/core/workflow/nodes/parameter_extractor/node.py
+++ b/api/app/core/workflow/nodes/parameter_extractor/node.py
@@ -1,19 +1,18 @@
-import os
import logging
-
-import json_repair
+import os
from typing import Any
+import json_repair
from jinja2 import Template
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearLLM, RedBearModelConfig
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py
index 5b041a6a..e2fd97ae 100644
--- a/api/app/core/workflow/nodes/question_classifier/node.py
+++ b/api/app/core/workflow/nodes/question_classifier/node.py
@@ -1,13 +1,14 @@
import logging
from typing import Any
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
-from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
-from app.core.models import RedBearLLM, RedBearModelConfig
-from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
+from app.core.exceptions import BusinessException
+from app.core.models import RedBearLLM, RedBearModelConfig
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
+from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
from app.db import get_db_read
from app.models import ModelType
from app.services.model_service import ModelConfigService
diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py
index db66bc65..a9618f7b 100644
--- a/api/app/core/workflow/nodes/start/node.py
+++ b/api/app/core/workflow/nodes/start/node.py
@@ -7,10 +7,11 @@ Start 节点实现
import logging
from typing import Any
-from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.start.config import StartNodeConfig
-from app.core.workflow.variable_pool import VariablePool
+from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py
index adc55d87..096f498f 100644
--- a/api/app/core/workflow/nodes/tool/node.py
+++ b/api/app/core/workflow/nodes/tool/node.py
@@ -4,16 +4,17 @@ import re
import uuid
from typing import Any
-from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
+from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.tool.config import ToolNodeConfig
from app.core.workflow.variable.base_variable import VariableType
-from app.core.workflow.variable_pool import VariablePool
-from app.services.tool_service import ToolService
from app.db import get_db_read
+from app.services.tool_service import ToolService
logger = logging.getLogger(__name__)
-TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
+TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode):
diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py
index 56ab4cfb..de82f8ff 100644
--- a/api/app/core/workflow/nodes/variable_aggregator/node.py
+++ b/api/app/core/workflow/nodes/variable_aggregator/node.py
@@ -2,11 +2,11 @@ import logging
import re
from typing import Any
-from app.core.workflow.nodes import WorkflowState
+from app.core.workflow.engine.state_manager import WorkflowState
+from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
-from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
diff --git a/api/app/core/workflow/utils/__init__.py b/api/app/core/workflow/utils/__init__.py
new file mode 100644
index 00000000..559af4a5
--- /dev/null
+++ b/api/app/core/workflow/utils/__init__.py
@@ -0,0 +1,4 @@
+# -*- coding: UTF-8 -*-
+# Author: Eternity
+# @Email: 1533512157@qq.com
+# @Time : 2026/2/9 16:24
diff --git a/api/app/core/workflow/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py
similarity index 100%
rename from api/app/core/workflow/expression_evaluator.py
rename to api/app/core/workflow/utils/expression_evaluator.py
diff --git a/api/app/core/workflow/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py
similarity index 99%
rename from api/app/core/workflow/template_renderer.py
rename to api/app/core/workflow/utils/template_renderer.py
index 9e2a28e8..236e0840 100644
--- a/api/app/core/workflow/template_renderer.py
+++ b/api/app/core/workflow/utils/template_renderer.py
@@ -5,7 +5,6 @@
"""
import logging
-from collections import defaultdict
from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py
index c846a1c4..47256b75 100644
--- a/api/app/core/workflow/validator.py
+++ b/api/app/core/workflow/validator.py
@@ -187,7 +187,7 @@ class WorkflowValidator:
)
# 8. 验证变量名
- from app.core.workflow.expression_evaluator import ExpressionEvaluator
+ from app.core.workflow.utils.expression_evaluator import ExpressionEvaluator
var_errors = ExpressionEvaluator.validate_variable_names(variables)
errors.extend(var_errors)
diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py
index daf03841..b1b723e9 100644
--- a/api/app/models/__init__.py
+++ b/api/app/models/__init__.py
@@ -9,6 +9,8 @@ from .generic_file_model import GenericFile
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy
from .memory_short_model import ShortTermMemory, LongTermMemory
from .knowledgeshare_model import KnowledgeShare
+from .mcp_market_model import McpMarket
+from .mcp_market_config_model import McpMarketConfig
from .app_model import App
from .agent_app_config_model import AgentConfig
from .app_release_model import AppRelease
@@ -50,6 +52,8 @@ __all__ = [
"ModelType",
"ModelApiKey",
"KnowledgeShare",
+ "McpMarket",
+ "McpMarketConfig",
"App",
"AgentConfig",
"AppRelease",
diff --git a/api/app/models/file_metadata_model.py b/api/app/models/file_metadata_model.py
index baf9bd97..28e87367 100644
--- a/api/app/models/file_metadata_model.py
+++ b/api/app/models/file_metadata_model.py
@@ -35,7 +35,7 @@ class FileMetadata(Base):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Tenant ID")
- workspace_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Workspace ID")
+ workspace_id = Column(UUID(as_uuid=True), nullable=True, index=True, comment="Workspace ID")
file_key = Column(String(512), nullable=False, unique=True, index=True, comment="Storage file key")
file_name = Column(String(255), nullable=False, comment="Original file name")
file_ext = Column(String(32), nullable=False, comment="File extension")
diff --git a/api/app/models/mcp_market_config_model.py b/api/app/models/mcp_market_config_model.py
new file mode 100644
index 00000000..a7051a91
--- /dev/null
+++ b/api/app/models/mcp_market_config_model.py
@@ -0,0 +1,16 @@
+import datetime
+import uuid
+from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
+from sqlalchemy.dialects.postgresql import UUID
+from app.db import Base
+
+class McpMarketConfig(Base):
+ __tablename__ = "mcp_market_configs"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
+ mcp_market_id = Column(UUID(as_uuid=True), nullable=False, comment="mcp_markets.id")
+ token = Column(String, nullable=True, comment="mcp market token")
+ status = Column(Integer, default=0, comment="connect status(0: Not connected, 1: connected)")
+ tenant_id = Column(UUID(as_uuid=True), nullable=False, comment="tenant.id")
+ created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
+ created_at = Column(DateTime, default=datetime.datetime.now)
\ No newline at end of file
diff --git a/api/app/models/mcp_market_model.py b/api/app/models/mcp_market_model.py
new file mode 100644
index 00000000..95c9cec4
--- /dev/null
+++ b/api/app/models/mcp_market_model.py
@@ -0,0 +1,18 @@
+import datetime
+import uuid
+from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
+from sqlalchemy.dialects.postgresql import UUID
+from app.db import Base
+
+class McpMarket(Base):
+ __tablename__ = "mcp_markets"
+
+ id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
+ name = Column(String, index=True, nullable=False, comment="mcp market name")
+ description = Column(String, index=True, nullable=True, comment="mcp market description")
+ logo_url = Column(String, index=True, nullable=True, comment="logo url")
+ mcp_count = Column(Integer, default=1, comment="mcp count")
+ url = Column(String, index=True, nullable=False, comment="mcp market url")
+ category = Column(String, index=True, nullable=False, comment="category")
+ created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
+ created_at = Column(DateTime, default=datetime.datetime.now)
\ No newline at end of file
diff --git a/api/app/repositories/mcp_market_config_repository.py b/api/app/repositories/mcp_market_config_repository.py
new file mode 100644
index 00000000..ec31becf
--- /dev/null
+++ b/api/app/repositories/mcp_market_config_repository.py
@@ -0,0 +1,72 @@
+import uuid
+import datetime
+from sqlalchemy.orm import Session
+from app.models.mcp_market_config_model import McpMarketConfig
+from app.schemas import mcp_market_config_schema
+from app.core.logging_config import get_db_logger
+
+# Obtain a dedicated logger for the database
+db_logger = get_db_logger()
+
+
+def create_mcp_market_config(db: Session, mcp_market_config: mcp_market_config_schema.McpMarketConfigCreate) -> McpMarketConfig:
+ db_logger.debug(f"Create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id}")
+
+ try:
+ db_mcp_market_config = McpMarketConfig(**mcp_market_config.model_dump())
+ db.add(db_mcp_market_config)
+ db.commit()
+ db_logger.info(f"McpMarketConfig record created successfully: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id})")
+ return db_mcp_market_config
+ except Exception as e:
+ db_logger.error(f"Failed to create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id} - {str(e)}")
+ db.rollback()
+ raise
+
+
+def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID) -> McpMarketConfig | None:
+ db_logger.debug(f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}")
+
+ try:
+ db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).first()
+ if db_mcp_market_config:
+ db_logger.debug(f"McpMarketConfig query successful: (ID: {mcp_market_config_id})")
+ else:
+ db_logger.debug(f"McpMarketConfig does not exist: mcp_market_config_id={mcp_market_config_id}")
+ return db_mcp_market_config
+ except Exception as e:
+ db_logger.error(f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}")
+ raise
+
+
+def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, tenant_id: uuid.UUID) -> McpMarketConfig | None:
+ db_logger.debug(f"Query mcp market config based on mcp_market_id: {mcp_market_id}")
+
+ try:
+ db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.mcp_market_id == mcp_market_id, McpMarketConfig.tenant_id == tenant_id).first()
+ if db_mcp_market_config:
+ db_logger.debug(f"McpMarketConfig query successful: (mcp_market_id: {mcp_market_id})")
+ else:
+ db_logger.debug(f"McpMarketConfig does not exist: mcp_market_id={mcp_market_id}")
+ return db_mcp_market_config
+ except Exception as e:
+ db_logger.error(f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}")
+ raise
+
+
+def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID):
+ db_logger.debug(f"Delete McpMarketConfig record: mcp_market_config_id={mcp_market_config_id}")
+
+ try:
+ # First, query the mcp market config information for logging purposes
+ result = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).delete()
+ db.commit()
+
+ if result > 0:
+ db_logger.info(f"McpMarketConfig record deleted successfully: (ID: {mcp_market_config_id})")
+ else:
+ db_logger.warning(f"The mcp market config record does not exist, and cannot be deleted: id={mcp_market_config_id}")
+ except Exception as e:
+ db_logger.error(f"Failed to delete mcp market config record: id={mcp_market_config_id} - {str(e)}")
+ db.rollback()
+ raise
diff --git a/api/app/repositories/mcp_market_repository.py b/api/app/repositories/mcp_market_repository.py
new file mode 100644
index 00000000..d5089815
--- /dev/null
+++ b/api/app/repositories/mcp_market_repository.py
@@ -0,0 +1,124 @@
+import uuid
+import datetime
+from sqlalchemy.orm import Session
+from app.models.mcp_market_model import McpMarket
+from app.schemas import mcp_market_schema
+from app.core.logging_config import get_db_logger
+
+# Obtain a dedicated logger for the database
+db_logger = get_db_logger()
+
+
+def get_mcp_markets_paginated(
+ db: Session,
+ filters: list,
+ page: int,
+ pagesize: int,
+ orderby: str = None,
+ desc: bool = False
+) -> tuple[int, list]:
+ """
+ Paged query mcp market (with filtering and sorting)
+ """
+ db_logger.debug(
+ f"Query mcp market in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
+
+ try:
+ query = db.query(McpMarket)
+
+ # Apply filter conditions
+ for filter_cond in filters:
+ query = query.filter(filter_cond)
+
+ # Calculate the total count (for pagination)
+ total = query.count()
+ db_logger.debug(f"Total number of mcp_market queries: {total}")
+
+ # sort
+ if orderby:
+ order_attr = getattr(McpMarket, orderby, None)
+ if order_attr is not None:
+ if desc:
+ query = query.order_by(order_attr.desc())
+ else:
+ query = query.order_by(order_attr.asc())
+ db_logger.debug(f"sort: {orderby}, desc={desc}")
+
+ # pagination
+ items = query.offset((page - 1) * pagesize).limit(pagesize).all()
+ db_logger.info(
+ f"The mcp market paging query has been successful: total={total}, Number of current page={len(items)}")
+
+ return total, [mcp_market_schema.McpMarket.model_validate(item) for item in items]
+ except Exception as e:
+ db_logger.error(f"Querying mcp_market pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
+ raise
+
+
+def create_mcp_market(db: Session, mcp_market: mcp_market_schema.McpMarketCreate) -> McpMarket:
+ db_logger.debug(f"Create a mcp market record: name={mcp_market.name}")
+
+ try:
+ db_mcp_market = McpMarket(**mcp_market.model_dump())
+ db.add(db_mcp_market)
+ db.commit()
+ db_logger.info(f"McpMarket record created successfully: {mcp_market.name} (ID: {db_mcp_market.id})")
+ return db_mcp_market
+ except Exception as e:
+ db_logger.error(f"Failed to create a mcp market record: title={mcp_market.name} - {str(e)}")
+ db.rollback()
+ raise
+
+
+def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID) -> McpMarket | None:
+ db_logger.debug(f"Query mcp market based on ID: mcp_market_id={mcp_market_id}")
+
+ try:
+ db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
+ if db_mcp_market:
+ db_logger.debug(f"McpMarket query successful: {db_mcp_market.name} (ID: {mcp_market_id})")
+ else:
+ db_logger.debug(f"McpMarket does not exist: mcp_market_id={mcp_market_id}")
+ return db_mcp_market
+ except Exception as e:
+ db_logger.error(f"Failed to query the mcp market based on the ID: mcp_market_id={mcp_market_id} - {str(e)}")
+ raise
+
+
+def get_mcp_market_by_name(db: Session, name: str) -> McpMarket | None:
+ db_logger.debug(f"Query mcp market based on name: name={name}")
+
+ try:
+ db_mcp_market = db.query(McpMarket).filter(McpMarket.name == name).first()
+ if db_mcp_market:
+ db_logger.debug(f"mcp market query successful: {name} (ID: {db_mcp_market.id})")
+ else:
+ db_logger.debug(f"mcp market does not exist: name={name}")
+ return db_mcp_market
+ except Exception as e:
+ db_logger.error(f"Failed to query the mcp market based on the name: {name} - {str(e)}")
+ raise
+
+
+def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID):
+ db_logger.debug(f"Delete McpMarket record: mcp_market_id={mcp_market_id}")
+
+ try:
+ # First, query the mcp market information for logging purposes
+ db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
+ if db_mcp_market:
+ name = db_mcp_market.name
+ else:
+ name = "unknown"
+
+ result = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).delete()
+ db.commit()
+
+ if result > 0:
+ db_logger.info(f"McpMarket record deleted successfully: {name} (ID: {mcp_market_id})")
+ else:
+ db_logger.warning(f"The mcp market record does not exist, and cannot be deleted: mcp_market_id={mcp_market_id}")
+ except Exception as e:
+ db_logger.error(f"Failed to delete mcp market record: mcp_market_id={mcp_market_id} - {str(e)}")
+ db.rollback()
+ raise
diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py
index f323b30c..2c513e82 100644
--- a/api/app/repositories/model_repository.py
+++ b/api/app/repositories/model_repository.py
@@ -48,13 +48,17 @@ class ModelConfigRepository:
raise
@staticmethod
- def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
- """根据名称获取模型配置"""
- db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}")
+ def get_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
+ """根据名称和供应商获取模型配置"""
+ db_logger.debug(f"根据名称查询模型配置: name={name}, provider={provider}, tenant_id={tenant_id}")
try:
query = db.query(ModelConfig).filter(ModelConfig.name == name)
+ # 添加供应商过滤
+ if provider:
+ query = query.filter(ModelConfig.provider == provider)
+
# 添加租户过滤
if tenant_id:
query = query.filter(
@@ -69,7 +73,7 @@ class ModelConfigRepository:
db_logger.debug(f"模型配置查询成功: {model.name}")
return model
except Exception as e:
- db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
+ db_logger.error(f"根据名称查询模型配置失败: name={name}, provider={provider} - {str(e)}")
raise
@staticmethod
diff --git a/api/app/repositories/workspace_repository.py b/api/app/repositories/workspace_repository.py
index 70ed7521..87b0e20f 100644
--- a/api/app/repositories/workspace_repository.py
+++ b/api/app/repositories/workspace_repository.py
@@ -115,6 +115,7 @@ class WorkspaceRepository:
self.db.query(Workspace)
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
.filter(WorkspaceMember.user_id == user_id)
+ .filter(WorkspaceMember.is_active.is_(True))
.filter(Workspace.is_active.is_(True))
.order_by(Workspace.updated_at.desc())
.all()
diff --git a/api/app/schemas/__init__.py b/api/app/schemas/__init__.py
index 299251f4..96c42ce7 100644
--- a/api/app/schemas/__init__.py
+++ b/api/app/schemas/__init__.py
@@ -8,6 +8,8 @@ from .file_schema import File, FileCreate, FileUpdate
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
+from .mcp_market_schema import McpMarketCreate, McpMarketUpdate, McpMarket
+from .mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate, McpMarketConfig
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
from .app_schema import (
AppChatRequest,
@@ -78,6 +80,12 @@ __all__ = [
"ChunkRetrieve",
"KnowledgeShare",
"KnowledgeShareCreate",
+ "McpMarketCreate",
+ "McpMarketUpdate",
+ "McpMarket",
+ "McpMarketConfigCreate",
+ "McpMarketConfigUpdate",
+ "McpMarketConfig",
"CreateOrderRequest",
"OrderResponse",
"ExternalOrderResponse",
diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py
index 792a32ac..8cf81b92 100644
--- a/api/app/schemas/app_schema.py
+++ b/api/app/schemas/app_schema.py
@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
- files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
+ files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
class DraftRunResponse(BaseModel):
diff --git a/api/app/schemas/mcp_market_config_schema.py b/api/app/schemas/mcp_market_config_schema.py
new file mode 100644
index 00000000..c33239cf
--- /dev/null
+++ b/api/app/schemas/mcp_market_config_schema.py
@@ -0,0 +1,31 @@
+from pydantic import BaseModel, Field, field_serializer, ConfigDict
+import datetime
+import uuid
+
+
+class McpMarketConfigBase(BaseModel):
+ mcp_market_id: uuid.UUID
+ token: str | None = None
+ status: int | None = None
+ tenant_id: uuid.UUID | None = None
+ created_by: uuid.UUID | None = None
+
+
+class McpMarketConfigCreate(McpMarketConfigBase):
+ pass
+
+
+class McpMarketConfigUpdate(BaseModel):
+ token: str | None = None
+ status: int | None = None
+
+
+class McpMarketConfig(McpMarketConfigBase):
+ id: uuid.UUID
+ created_at: datetime.datetime
+
+ model_config = ConfigDict(from_attributes=True)
+
+ @field_serializer("created_at", when_used="json")
+ def _serialize_created_at(self, dt: datetime.datetime):
+ return int(dt.timestamp() * 1000) if dt else None
diff --git a/api/app/schemas/mcp_market_schema.py b/api/app/schemas/mcp_market_schema.py
new file mode 100644
index 00000000..54d3b35e
--- /dev/null
+++ b/api/app/schemas/mcp_market_schema.py
@@ -0,0 +1,37 @@
+from pydantic import BaseModel, Field, field_serializer, ConfigDict
+import datetime
+import uuid
+
+
+class McpMarketBase(BaseModel):
+ name: str
+ description: str | None = None
+ logo_url: str | None = None
+ mcp_count: int
+ url: str
+ category: str
+ created_by: uuid.UUID | None = None
+
+
+class McpMarketCreate(McpMarketBase):
+ pass
+
+
+class McpMarketUpdate(BaseModel):
+ name: str | None = Field(None)
+ description: str | None = Field(None)
+ logo_url: str | None = Field(None)
+ mcp_count: int | None = Field(None)
+ url: str | None = Field(None)
+ category: str | None = Field(None)
+
+
+class McpMarket(McpMarketBase):
+ id: uuid.UUID
+ created_at: datetime.datetime
+
+ model_config = ConfigDict(from_attributes=True)
+
+ @field_serializer("created_at", when_used="json")
+ def _serialize_created_at(self, dt: datetime.datetime):
+ return int(dt.timestamp() * 1000) if dt else None
diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py
index a2d3650a..0c0bbeed 100644
--- a/api/app/schemas/model_schema.py
+++ b/api/app/schemas/model_schema.py
@@ -25,9 +25,9 @@ class ModelConfigBase(BaseModel):
class ApiKeyCreateNested(BaseModel):
"""用于在创建模型时内嵌创建API Key的Schema"""
- model_name: str = Field(..., description="模型实际名称", max_length=255)
+ model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
description: Optional[str] = Field(None, description="备注")
- provider: ModelProvider = Field(..., description="API Key提供商")
+ provider: Optional[str] = Field(None, description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
@@ -57,6 +57,8 @@ class ModelConfigUpdate(BaseModel):
"""更新模型配置Schema"""
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
type: Optional[ModelType] = Field(None, description="模型类型")
+ provider: Optional[str] = Field(None, description="供应商")
+ logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
description: Optional[str] = Field(None, description="模型描述")
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
is_active: Optional[bool] = Field(None, description="是否激活")
diff --git a/api/app/schemas/token_schema.py b/api/app/schemas/token_schema.py
index 310e98a0..3bbea35e 100644
--- a/api/app/schemas/token_schema.py
+++ b/api/app/schemas/token_schema.py
@@ -27,4 +27,5 @@ class TokenRequest(BaseModel):
email: EmailStr
password: str
invite: Optional[str] = None
+ username: Optional[str] = None
diff --git a/api/app/schemas/user_schema.py b/api/app/schemas/user_schema.py
index 60f52aaf..7b9e201d 100644
--- a/api/app/schemas/user_schema.py
+++ b/api/app/schemas/user_schema.py
@@ -36,6 +36,28 @@ class AdminChangePasswordRequest(BaseModel):
new_password: Optional[str] = Field(None, min_length=6, description="新密码,至少6位。如果不提供则自动生成随机密码")
+class ChangeEmailRequest(BaseModel):
+ """修改邮箱请求"""
+ password: str = Field(..., description="当前密码")
+ new_email: EmailStr = Field(..., description="新邮箱地址")
+
+
+class SendEmailCodeRequest(BaseModel):
+ """发送邮箱验证码请求"""
+ email: EmailStr = Field(..., description="邮箱地址")
+
+
+class VerifyEmailCodeRequest(BaseModel):
+ """验证邮箱验证码并修改邮箱请求"""
+ new_email: EmailStr = Field(..., description="新邮箱地址")
+ code: str = Field(..., min_length=6, max_length=6, description="验证码")
+
+
+class VerifyPasswordRequest(BaseModel):
+ """验证密码请求"""
+ password: str = Field(..., description="密码")
+
+
class ChangePasswordResponse(BaseModel):
"""修改密码响应"""
message: str
diff --git a/api/app/services/auth_service.py b/api/app/services/auth_service.py
index 877d8d5c..03e1ebc0 100644
--- a/api/app/services/auth_service.py
+++ b/api/app/services/auth_service.py
@@ -129,7 +129,8 @@ def register_user_with_invite(
email: str,
password: str,
invite_token: str,
- workspace_id: str
+ workspace_id: str,
+ username: Optional[str] = None,
) -> User:
"""
使用邀请码注册新用户并加入工作空间
@@ -139,6 +140,7 @@ def register_user_with_invite(
:param password: 用户密码
:param invite_token: 邀请令牌
:param workspace_id: 工作空间ID
+ :param username: 用户名
:return: 创建的用户对象
"""
from app.schemas.user_schema import UserCreate
@@ -154,7 +156,7 @@ def register_user_with_invite(
user_create = UserCreate(
email=email,
password=password,
- username=email.split('@')[0]
+ username=email.split('@')[0] if not username else username
)
user = user_service.create_user(db=db, user=user_create)
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
diff --git a/api/app/services/email_service.py b/api/app/services/email_service.py
new file mode 100644
index 00000000..d7b255dc
--- /dev/null
+++ b/api/app/services/email_service.py
@@ -0,0 +1,88 @@
+import smtplib
+import re
+import asyncio
+from email.mime.text import MIMEText
+from email.mime.multipart import MIMEMultipart
+from email.header import Header
+from email.utils import formataddr
+from concurrent.futures import ThreadPoolExecutor
+
+from app.core.config import settings
+from app.core.error_codes import BizCode
+from app.core.exceptions import BusinessException
+from app.core.logging_config import get_business_logger
+
+business_logger = get_business_logger()
+
+
+def _send_email_sync(to_email: str, subject: str, html_content: str, text_content: str = None):
+ """同步发送邮件"""
+ smtp_server = settings.SMTP_SERVER
+ smtp_port = settings.SMTP_PORT
+ smtp_user = settings.SMTP_USER
+ smtp_password = settings.SMTP_PASSWORD
+
+ if not smtp_server or not smtp_user or not smtp_password:
+ raise BusinessException("邮件服务未配置", code=BizCode.SERVICE_UNAVAILABLE)
+
+ msg = MIMEMultipart('alternative')
+ msg['Subject'] = Header(subject, "utf-8")
+ from_name = "MemoryBear系统"
+ msg['From'] = formataddr((Header(from_name, 'utf-8').encode(), smtp_user))
+ msg['To'] = Header(to_email, "utf-8")
+
+ if not text_content:
+ text_content = html_content.replace('
', '\n').replace('
', '\n').replace('
', '\n') + text_content = re.sub(r'<.*?>', '', text_content) + text_part = MIMEText(text_content, 'plain', 'utf-8') + msg.attach(text_part) + + html_part = MIMEText(html_content, 'html', 'utf-8') + msg.attach(html_part) + + if smtp_port == 465: + with smtplib.SMTP_SSL(smtp_server, smtp_port, timeout=10) as server: + server.login(smtp_user, smtp_password) + server.send_message(msg) + else: + with smtplib.SMTP(smtp_server, smtp_port, timeout=10) as server: + server.starttls() + server.login(smtp_user, smtp_password) + server.send_message(msg) + + +async def send_email(to_email: str, subject: str, html_content: str, text_content: str = None): + """异步发送邮件""" + to_email = to_email.strip() + if not to_email or not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', to_email): + err_msg = f"收件人邮箱格式无效: {to_email}" + business_logger.error(err_msg) + raise BusinessException(err_msg, code=BizCode.INVALID_PARAMETER) + + try: + loop = asyncio.get_event_loop() + with ThreadPoolExecutor() as executor: + await loop.run_in_executor( + executor, + _send_email_sync, + to_email, + subject, + html_content, + text_content + ) + business_logger.info(f"邮件发送成功: {to_email}") + except smtplib.SMTPAuthenticationError: + err_msg = "SMTP认证失败,请检查SMTP账号/密码是否正确" + business_logger.error(f"邮件发送失败: {to_email} - {err_msg}") + raise BusinessException(err_msg, code=BizCode.UNAUTHORIZED) + except smtplib.SMTPConnectError: + err_msg = "SMTP服务器连接失败,请检查服务器地址/端口是否正确" + business_logger.error(f"邮件发送失败: {to_email} - {err_msg}") + raise BusinessException(err_msg, code=BizCode.SERVICE_UNAVAILABLE) + except TimeoutError: + err_msg = "邮件发送超时,请检查SMTP服务器配置" + business_logger.error(f"邮件发送失败: {to_email} - {err_msg}") + raise BusinessException(err_msg, code=BizCode.BAD_REQUEST) + except Exception as e: + business_logger.error(f"邮件发送失败: {to_email} - {str(e)}") + raise BusinessException(f"邮件发送失败: {str(e)}", code=BizCode.SERVICE_UNAVAILABLE) diff --git a/api/app/services/file_storage_service.py b/api/app/services/file_storage_service.py index 672e1cff..bb9f1894 100644 --- a/api/app/services/file_storage_service.py +++ b/api/app/services/file_storage_service.py @@ -26,7 +26,7 @@ logger = get_business_logger() def generate_file_key( tenant_id: uuid.UUID, - workspace_id: uuid.UUID, + workspace_id: uuid.UUID | None, file_id: uuid.UUID, file_ext: str, ) -> str: @@ -56,8 +56,9 @@ def generate_file_key( # Ensure file_ext starts with a dot if file_ext and not file_ext.startswith('.'): file_ext = f'.{file_ext}' - - return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}" + if workspace_id: + return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}" + return f"{tenant_id}/{file_id}{file_ext}" class FileStorageService: @@ -96,7 +97,7 @@ class FileStorageService: async def upload_file( self, tenant_id: uuid.UUID, - workspace_id: uuid.UUID, + workspace_id: uuid.UUID | None, file_id: uuid.UUID, file_ext: str, content: bytes, diff --git a/api/app/services/mcp_market_config_service.py b/api/app/services/mcp_market_config_service.py new file mode 100644 index 00000000..86485902 --- /dev/null +++ b/api/app/services/mcp_market_config_service.py @@ -0,0 +1,83 @@ +import uuid +from sqlalchemy.orm import Session +from app.models.user_model import User +from app.models.mcp_market_config_model import McpMarketConfig +from app.schemas.mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate +from app.repositories import mcp_market_config_repository +from app.core.logging_config import get_business_logger + +# Obtain a dedicated logger for business logic +business_logger = get_business_logger() + + +def create_mcp_market_config( + db: Session, mcp_market_config: McpMarketConfigCreate, current_user: User +) -> McpMarketConfig: + business_logger.info(f"Create a mcp market config base: {mcp_market_config.mcp_market_id}, creator: {current_user.username}") + + try: + mcp_market_config.tenant_id = current_user.tenant_id + mcp_market_config.created_by = current_user.id + business_logger.debug(f"Start creating the mcp market config on mcp_market_id: {mcp_market_config.mcp_market_id}") + db_mcp_market_config = mcp_market_config_repository.create_mcp_market_config( + db=db, mcp_market_config=mcp_market_config + ) + business_logger.info( + f"The mcp market config has been successfully created: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id}), creator: {current_user.username}") + return db_mcp_market_config + except Exception as e: + business_logger.error(f"Failed to create a mcp marke config: {mcp_market_config.mcp_market_id} - {str(e)}") + raise + + +def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> McpMarketConfig | None: + business_logger.debug( + f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}") + + try: + mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id) + if mcpMarketConfig: + business_logger.info(f"mcp market config query successful: (ID: {mcp_market_config_id})") + else: + business_logger.warning(f"mcp market config does not exist: mcp_market_config_id={mcp_market_config_id}") + return mcpMarketConfig + except Exception as e: + business_logger.error( + f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}") + raise + + +def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarketConfig | None: + business_logger.debug( + f"Query mcp market config based on mcp_market_id: {mcp_market_id}, username: {current_user.username}") + + try: + mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_mcp_market_id(db=db, mcp_market_id=mcp_market_id, tenant_id=current_user.tenant_id) + if mcpMarketConfig: + business_logger.info(f"mcp market config query successful: (mcp_market_id: {mcp_market_id})") + else: + business_logger.warning(f"mcp market config does not exist: mcp_market_id={mcp_market_id}") + return mcpMarketConfig + except Exception as e: + business_logger.error( + f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}") + raise + + +def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> None: + business_logger.info(f"Delete mcp market config: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}") + + try: + # First, query the mcp market config information for logging purposes + mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id) + if mcpMarketConfig: + business_logger.debug(f"Execute mcp market config deletion: (ID: {mcp_market_config_id})") + else: + business_logger.warning(f"The mcp market config to be deleted does not exist: mcp_market_config_id={mcp_market_config_id}") + + mcp_market_config_repository.delete_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id) + business_logger.info( + f"mcp market config record deleted successfully: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}") + except Exception as e: + business_logger.error(f"Failed to delete mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}") + raise diff --git a/api/app/services/mcp_market_service.py b/api/app/services/mcp_market_service.py new file mode 100644 index 00000000..6d9d26fc --- /dev/null +++ b/api/app/services/mcp_market_service.py @@ -0,0 +1,109 @@ +import uuid +from sqlalchemy.orm import Session +from app.models.user_model import User +from app.models.mcp_market_model import McpMarket +from app.schemas.mcp_market_schema import McpMarketCreate, McpMarketUpdate +from app.repositories import mcp_market_repository +from app.core.logging_config import get_business_logger + +# Obtain a dedicated logger for business logic +business_logger = get_business_logger() + + +def get_mcp_markets_paginated( + db: Session, + current_user: User, + filters: list, + page: int, + pagesize: int, + orderby: str = None, + desc: bool = False +) -> tuple[int, list]: + business_logger.debug( + f"Query mcp market in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}") + + try: + total, items = mcp_market_repository.get_mcp_markets_paginated( + db=db, + filters=filters, + page=page, + pagesize=pagesize, + orderby=orderby, + desc=desc + ) + business_logger.info( + f"The mcp market paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}") + return total, items + except Exception as e: + business_logger.error(f"Querying mcp market pagination failed: username={current_user.username} - {str(e)}") + raise + + +def create_mcp_market( + db: Session, mcp_market: McpMarketCreate, current_user: User +) -> McpMarket: + business_logger.info(f"Create a mcp market base: {mcp_market.name}, creator: {current_user.username}") + + try: + mcp_market.created_by = current_user.id + business_logger.debug(f"Start creating the mcp market: {mcp_market.name}") + db_mcp_market = mcp_market_repository.create_mcp_market( + db=db, mcp_market=mcp_market + ) + business_logger.info( + f"The mcp market has been successfully created: {mcp_market.name} (ID: {db_mcp_market.id}), creator: {current_user.username}") + return db_mcp_market + except Exception as e: + business_logger.error(f"Failed to create a mcp market: {mcp_market.name} - {str(e)}") + raise + + +def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarket | None: + business_logger.debug( + f"Query mcp market based on ID: mcp_market_id={mcp_market_id}, username: {current_user.username}") + + try: + mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id) + if mcpMarket: + business_logger.info(f"mcp market query successful: {mcpMarket.name} (ID: {mcp_market_id})") + else: + business_logger.warning(f"mcp market does not exist: mcp_market_id={mcp_market_id}") + return mcpMarket + except Exception as e: + business_logger.error( + f"Failed to query the mcp market based on the ID: {mcp_market_id} - {str(e)}") + raise + + +def get_mcp_market_by_name(db: Session, name: str, current_user: User) -> McpMarket | None: + business_logger.debug(f"Query mcp market based on name: name={name}, username: {current_user.username}") + + try: + db_mcp_market = mcp_market_repository.get_mcp_market_by_name(db=db, name=name) + if db_mcp_market: + business_logger.info(f"mcp market query successful: {name} (ID: {db_mcp_market.id})") + else: + business_logger.warning(f"mcp market does not exist: name={name}") + return db_mcp_market + except Exception as e: + business_logger.error(f"Failed to query the mcp market based on the name: name={name} - {str(e)}") + raise + + +def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> None: + business_logger.info(f"Delete mcp market: mcp_market_id={mcp_market_id}, operator: {current_user.username}") + + try: + # First, query the mcp market information for logging purposes + mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id) + if mcpMarket: + business_logger.debug(f"Execute mcp market deletion: {mcpMarket.name} (ID: {mcp_market_id})") + else: + business_logger.warning(f"The mcp market to be deleted does not exist: mcp_market_id={mcp_market_id}") + + mcp_market_repository.delete_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id) + business_logger.info( + f"mcp market record deleted successfully: mcp_market_id={mcp_market_id}, operator: {current_user.username}") + except Exception as e: + business_logger.error(f"Failed to delete mcp market: mcp_market_id={mcp_market_id} - {str(e)}") + raise diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index d382b1b1..aa8cfbac 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -6,7 +6,7 @@ import math import time import asyncio -from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy +from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository from app.schemas import model_schema from app.schemas.model_schema import ( @@ -69,9 +69,9 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" - model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id) + model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) return model @@ -244,7 +244,7 @@ class ModelConfigService: async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建模型配置""" # 检查名称是否已存在(同租户内) - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证配置 @@ -253,8 +253,8 @@ class ModelConfigService: for api_key_data in api_key_data_list: validation_result = await ModelConfigService.validate_model_config( db=db, - model_name=api_key_data.model_name, - provider=api_key_data.provider, + model_name=model_data.name, + provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_data.type, # 传递模型类型 @@ -277,6 +277,8 @@ class ModelConfigService: if api_key_datas: for api_key_data in api_key_datas: + api_key_data.model_name = model_data.name + api_key_data.provider = model_data.provider api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() @@ -295,7 +297,7 @@ class ModelConfigService: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -306,7 +308,7 @@ class ModelConfigService: @staticmethod async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) # 验证所有 API Key 存在且类型匹配 @@ -341,7 +343,7 @@ class ModelConfigService: "type": model_data.type, "logo": model_data.logo, "description": model_data.description, - "provider": "composite", + "provider": ModelProvider.COMPOSITE, "config": model_data.config, "is_active": model_data.is_active, "is_public": model_data.is_public, @@ -369,6 +371,10 @@ class ModelConfigService: existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + + if model_data.name and model_data.name != existing_model.name: + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) @@ -471,11 +477,14 @@ class ModelApiKeyService: # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - # 检查是否存在API Key(包括软删除) - existing_key = db.query(ModelApiKey).filter( + # 检查是否存在API Key(包括软删除),需要考虑tenant_id + existing_key = db.query(ModelApiKey).join( + ModelApiKey.model_configs + ).filter( ModelApiKey.api_key == data.api_key, ModelApiKey.provider == data.provider, - ModelApiKey.model_name == model_name + ModelApiKey.model_name == model_name, + ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: @@ -542,11 +551,14 @@ class ModelApiKeyService: if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - # 检查API Key是否已存在(包括软删除) - existing_key = db.query(ModelApiKey).filter( + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id + existing_key = db.query(ModelApiKey).join( + ModelApiKey.model_configs + ).filter( ModelApiKey.api_key == api_key_data.api_key, ModelApiKey.provider == api_key_data.provider, - ModelApiKey.model_name == api_key_data.model_name + ModelApiKey.model_name == api_key_data.model_name, + ModelConfig.tenant_id == model_config.tenant_id ).first() if existing_key: diff --git a/api/app/services/user_service.py b/api/app/services/user_service.py index d97e2fb2..22dabed7 100644 --- a/api/app/services/user_service.py +++ b/api/app/services/user_service.py @@ -1,13 +1,18 @@ import datetime +import json import secrets import string + +from pydantic import EmailStr from sqlalchemy.orm import Session import uuid +from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete from app.models.user_model import User from app.repositories import user_repository from app.schemas.user_schema import UserCreate from app.schemas.tenant_schema import TenantCreate +from app.services.email_service import send_email from app.services.tenant_service import TenantService from app.services.session_service import SessionService from app.core.security import get_password_hash, verify_password @@ -563,3 +568,175 @@ def generate_random_password(length: int = 12) -> str: secrets.SystemRandom().shuffle(password) return ''.join(password) + + +def generate_email_code() -> str: + """生成6位数字验证码""" + return ''.join([str(secrets.randbelow(10)) for _ in range(6)]) + + +async def send_email_code_method(db: Session, email: EmailStr, user_id: uuid.UUID): + """发送邮箱验证码""" + business_logger.info(f"发送邮箱验证码: email={email}") + + # 检查发送间隔 + rate_limit_key = f"email_code_rate:{user_id}" + last_send = await aio_redis_get(rate_limit_key) + + if last_send: + raise BusinessException("请稍后再试,验证码发送间隔为1分钟", code=BizCode.RATE_LIMITED) + + # 检查新邮箱是否已被使用 + existing_user = user_repository.get_user_by_email(db=db, email=email) + if existing_user and existing_user.id != user_id: + raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME) + + if existing_user and existing_user.id == user_id: + raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.DUPLICATE_NAME) + + # 生成验证码 + code = generate_email_code() + + # 存储到 Redis,5分钟过期 + cache_key = f"email_code:{user_id}:{email}" + await aio_redis_set(cache_key, json.dumps(code), expire=300) + + # 发送邮件 + await send_email( + email, + "邮箱验证码", + f'您的验证码是:{code}
验证码在5分钟内有效。
' + ) + + # 设置发送间隔限制,60秒 + await aio_redis_set(rate_limit_key, "1", expire=60) + + business_logger.info(f"邮箱验证码已发送: {email}") + + +async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, code: str) -> User: + """验证验证码并修改邮箱""" + business_logger.info(f"验证并修改邮箱: user_id={user_id}, new_email={new_email}") + + db_user = user_repository.get_user_by_id(db=db, user_id=user_id) + if not db_user: + raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) + + # 验证验证码 + cache_key = f"email_code:{user_id}:{new_email}" + cached_code = await aio_redis_get(cache_key) + + if not cached_code: + raise BusinessException("验证码已过期", code=BizCode.VALIDATION_FAILED) + + if json.loads(cached_code) != code: + raise BusinessException("验证码错误", code=BizCode.VALIDATION_FAILED) + + # 修改邮箱 + db_user.email = new_email + db.commit() + db.refresh(db_user) + + # 删除验证码 + await aio_redis_delete(cache_key) + + # 使所有旧 tokens 失效 + # await SessionService.invalidate_all_user_tokens(str(user_id)) + + business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}") + return db_user + + +# def generate_email_token(user_id: str, old_email: str, new_email: str) -> str: +# """生成邮箱修改token""" +# payload = { +# "user_id": user_id, +# "old_email": old_email, +# "new_email": new_email, +# "exp": datetime.datetime.now(datetime.timezone.utc) + timedelta(hours=24) +# } +# return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM) +# +# +# def verify_email_token(token: str) -> dict: +# """验证邮箱修改token""" +# try: +# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) +# return payload +# except jwt.ExpiredSignatureError: +# raise BusinessException("链接已过期", code=BizCode.VALIDATION_FAILED) +# except jwt.InvalidTokenError: +# raise BusinessException("无效的链接", code=BizCode.VALIDATION_FAILED) +# +# +# async def request_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, current_user: User): +# """请求修改邮箱,发送验证邮件""" +# business_logger.info(f"用户请求修改邮箱: user_id={user_id}, new_email={new_email}") +# +# if current_user.id != user_id: +# raise PermissionDeniedException("只能修改自己的邮箱") +# +# db_user = user_repository.get_user_by_id(db=db, user_id=user_id) +# if not db_user: +# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) +# +# if db_user.email == new_email: +# raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.VALIDATION_FAILED) +# +# existing_user = user_repository.get_user_by_email(db=db, email=new_email) +# if existing_user and existing_user.id != user_id: +# raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME) +# +# token = generate_email_token(str(user_id), db_user.email, new_email) +# +# # 发送确认邮件到旧邮箱 +# old_email_link = f"{settings.BASE_URL}/api/users/email/confirm-email-change?token={token}" +# await send_email( +# db_user.email, +# "确认修改邮箱", +# f'请点击以下链接确认修改邮箱:
确认修改' +# ) +# +# business_logger.info(f"邮箱修改确认邮件已发送到旧邮箱: {db_user.email}") +# +# +# async def confirm_email_change(db: Session, token: str): +# """确认修改邮箱(旧邮箱确认)""" +# payload = verify_email_token(token) +# user_id = uuid.UUID(payload["user_id"]) +# new_email = payload["new_email"] +# +# db_user = user_repository.get_user_by_id(db=db, user_id=user_id) +# if not db_user: +# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) +# +# # 发送激活邮件到新邮箱 +# activate_link = f"{settings.BASE_URL}/api/users/email/activate-new-email?token={token}" +# await send_email( +# new_email, +# "激活新邮箱", +# f'请点击以下链接激活新邮箱:
激活邮箱' +# ) +# +# business_logger.info(f"新邮箱激活邮件已发送: {new_email}") +# +# +# async def activate_new_email(db: Session, token: str) -> User: +# """激活新邮箱""" +# payload = verify_email_token(token) +# user_id = uuid.UUID(payload["user_id"]) +# new_email = payload["new_email"] +# +# db_user = user_repository.get_user_by_id(db=db, user_id=user_id) +# if not db_user: +# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND) +# +# db_user.email = new_email +# db.commit() +# db.refresh(db_user) +# +# # 使所有旧 tokens 失效 +# await SessionService.invalidate_all_user_tokens(str(user_id)) +# +# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}") +# return db_user diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index fb88f804..d06a05d7 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -588,7 +588,7 @@ class WorkflowService: "message_length": len(payload.get("output", "")) } } - case "node_start" | "node_end" | "node_error": + case "node_start" | "node_end" | "node_error" | "cycle_item": return None case _: return event diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 9ee98fa0..6f102695 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -70,10 +70,10 @@ def delete_workspace_member( _check_workspace_admin_permission(db, workspace_id, user) workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id) if not workspace_member: - raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND) + raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND) if workspace_member.workspace_id != workspace_id: - raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) + raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND) try: workspace_member.is_active = False diff --git a/api/app/tasks.py b/api/app/tasks.py index 5202bf60..d408a0da 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,4 +1,5 @@ import asyncio +from concurrent.futures import ThreadPoolExecutor import json import os import re @@ -13,7 +14,6 @@ from typing import Any, Dict, List, Optional import redis import requests -import trio # Import a unified Celery instance from app.celery_app import celery_app @@ -66,6 +66,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): """ Document parsing, vectorization, and storage """ + # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) + import trio + import importlib + importlib.reload(trio) db = next(get_db()) # Manually call the generator db_document = None db_knowledge = None @@ -292,6 +296,10 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): """ build knowledge graph """ + # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) + import trio + import importlib + importlib.reload(trio) db = next(get_db()) # Manually call the generator db_documents = None db_knowledge = None @@ -362,7 +370,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n") return result - try: + def sync_task(): trio.run( lambda: _run( row=task, @@ -377,8 +385,15 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): with_community=with_community, ) ) + try: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(sync_task) + future.result() # Blocks until the task completes except Exception as e: print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n") + finally: + if db: + db.close() print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)") result = f"build knowledge graph '{db_knowledge.name}' processed successfully." @@ -389,7 +404,8 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): result = f"build knowledge grap '{db_knowledge.name}' failed." return result finally: - db.close() + if db: + db.close() @celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb") diff --git a/api/app/version_info.json b/api/app/version_info.json index 991369d7..7d82eabc 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,68 @@ { + "v0.2.5": { + "introduction": { + "codeName": "行云", + "releaseDate": "2026-2-26", + "upgradePosition": "🐻 精炼根基,优化核心用户体验与系统稳定性", + "coreUpgrades": [ + "1. 用户体验与国际化 🎨