Merge pull request #421 from SuanmoSuanyangTechnology/release/v0.2.5
Release/v0.2.5
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -37,5 +37,4 @@ tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_python/target
|
||||
sandbox/lib/seccomp_nodejs/target
|
||||
sandbox/lib/seccomp_redbear/target
|
||||
|
||||
@@ -19,6 +19,8 @@ from . import (
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
mcp_market_controller,
|
||||
mcp_market_config_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
@@ -60,6 +62,8 @@ manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(mcp_market_controller.router)
|
||||
manager_router.include_router(mcp_market_config_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
|
||||
@@ -61,6 +61,7 @@ async def login_for_access_token(
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
|
||||
336
api/app/controllers/mcp_market_config_controller.py
Normal file
336
api/app/controllers/mcp_market_config_controller.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import requests
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
from modelscope.hub.errors import raise_for_http_status
|
||||
from modelscope.hub.mcp_api import MCPApi
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_market_configs",
|
||||
tags=["mcp_market_configs"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
||||
async def get_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': page,
|
||||
'page_size': pagesize,
|
||||
'search': keywords
|
||||
}
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": mcp_server_list,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
server_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information for a specific MCP Server
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
||||
async def create_mcp_market_config(
|
||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market config
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_config_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config_by_mcp_market_id(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market config exists
|
||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market_config, field):
|
||||
old_value = getattr(db_mcp_market_config, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market_config, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market config
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market config exists
|
||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
||||
return success(msg="The mcp market config has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
262
api/app/controllers/mcp_market_controller.py
Normal file
262
api/app/controllers/mcp_market_controller.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_markets",
|
||||
tags=["mcp_markets"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
||||
async def get_mcp_markets(
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp markets list in pages
|
||||
- Support keyword search for name,description
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + mcp_market list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = []
|
||||
|
||||
# Keyword search (fuzzy matching of mcp market name,description)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing mcp market paging query")
|
||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market", response_model=ApiResponse)
|
||||
async def create_mcp_market(
|
||||
create_data: mcp_market_schema.McpMarketCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {create_data.name}"
|
||||
)
|
||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market information from the database
|
||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="Successfully obtained mcp market information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
update_data: mcp_market_schema.McpMarketUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market exists
|
||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. not updating the name (name already exists)
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_mcp_market.name:
|
||||
# Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {name}"
|
||||
)
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market, field):
|
||||
old_value = getattr(db_mcp_market, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market exists
|
||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market
|
||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
||||
return success(msg="The mcp market has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -328,7 +328,7 @@ async def update_composite_model(
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
@@ -368,6 +368,9 @@ def update_model(
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
|
||||
@@ -2,15 +2,23 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.user_schema import (
|
||||
ChangePasswordRequest,
|
||||
AdminChangePasswordRequest,
|
||||
SendEmailCodeRequest,
|
||||
VerifyEmailCodeRequest,
|
||||
VerifyPasswordRequest)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -92,7 +100,7 @@ def get_current_user_info(
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
@@ -120,6 +128,7 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
@@ -180,4 +189,54 @@ async def admin_change_password(
|
||||
return success(msg="密码修改成功")
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
|
||||
await user_service.verify_and_change_email(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
new_email=request.new_email,
|
||||
code=request.code
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
|
||||
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
@@ -193,6 +193,12 @@ class Settings:
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
|
||||
4
api/app/core/workflow/engine/__init__.py
Normal file
4
api/app/core/workflow/engine/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:28
|
||||
281
api/app/core/workflow/engine/event_stream_handler.py
Normal file
281
api/app/core/workflow/engine/event_stream_handler.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EventStreamHandler:
|
||||
def __init__(
|
||||
self,
|
||||
output_coordinator: StreamOutputCoordinator,
|
||||
variable_pool: VariablePool,
|
||||
execution_id: str,
|
||||
):
|
||||
self.coordinator = output_coordinator
|
||||
self.variable_pool = variable_pool
|
||||
self.execution_id = execution_id
|
||||
|
||||
def update_stream_output_status(self, activate: dict, data: dict):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self.coordinator.update_scope_activation(node_id, status=node_output_status)
|
||||
|
||||
async def handle_updates_event(
|
||||
self,
|
||||
data: dict,
|
||||
graph: CompiledStateGraph,
|
||||
checkpoint_config: RunnableConfig
|
||||
):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
graph (CompiledStateGraph): The compiled LangGraph state machine.
|
||||
checkpoint_config (RunnableConfig): Configuration for the current execution context.)
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
state = graph.get_state(config=checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
self.update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.coordinator.activate_end and not wait:
|
||||
async for msg_event in self.coordinator.emit_activate_chunk(self.variable_pool):
|
||||
yield msg_event
|
||||
|
||||
if self.coordinator.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self.update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def handle_node_chunk_event(self, data: dict):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.coordinator.activate_end:
|
||||
end_info = self.coordinator.current_activate_end_info
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.coordinator.pop_current_activate_end()
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_node_error_event(data: dict):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def handle_debug_event(self, data: dict, input_data: dict):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_cycle_item_event(data: dict):
|
||||
yield {
|
||||
"event": "cycle_item",
|
||||
"data": data.get("data")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,177 +1,28 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import Any, Iterable
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.stream_output_coordinator import OutputContent, StreamOutputConfig
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
@@ -230,7 +81,7 @@ class GraphBuilder:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
@staticmethod
|
||||
def _merge_control_nodes(control_nodes: list[tuple[str, str]]) -> dict[str, list]:
|
||||
def _merge_control_nodes(control_nodes: Iterable[tuple[str, str]]) -> dict[str, list]:
|
||||
result = defaultdict(list)
|
||||
for node in control_nodes:
|
||||
result[node[0]].append(node[1])
|
||||
104
api/app/core/workflow/engine/result_builder.py
Normal file
104
api/app/core/workflow/engine/result_builder.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
|
||||
class WorkflowResultBuilder:
|
||||
def build_final_output(
|
||||
self,
|
||||
result: dict,
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
variable_pool (VariablePool): Variable Pool
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self.aggregate_token_usage(node_outputs)
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
"sys": variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import uuid
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
99
api/app/core/workflow/engine/state_manager.py
Normal file
99
api/app/core/workflow/engine/state_manager.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(dict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
__required_keys__ = frozenset({
|
||||
"messages",
|
||||
"cycle_nodes",
|
||||
"looping",
|
||||
"node_outputs",
|
||||
"execution_id",
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
"error_node",
|
||||
})
|
||||
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
self,
|
||||
workflow_config: dict,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext,
|
||||
start_node_id: str
|
||||
) -> WorkflowState:
|
||||
conversation_messages = input_data.get("conv_messages", [])
|
||||
|
||||
return WorkflowState(
|
||||
messages=conversation_messages,
|
||||
node_outputs={},
|
||||
execution_id=execution_context.execution_id,
|
||||
workspace_id=execution_context.workspace_id,
|
||||
user_id=execution_context.user_id,
|
||||
error=None,
|
||||
error_node=None,
|
||||
cycle_nodes=self._identify_cycle_nodes(workflow_config),
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _identify_cycle_nodes(
|
||||
workflow_config: dict
|
||||
):
|
||||
return [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
327
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
327
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
def initialize_end_outputs(
|
||||
self,
|
||||
end_node_map: dict[str, StreamOutputConfig]
|
||||
):
|
||||
self.end_outputs = end_node_map
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
return self.end_outputs.get(self.activate_end)
|
||||
|
||||
def pop_current_activate_end(self):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
def update_scope_activation(
|
||||
self,
|
||||
scope: str,
|
||||
status: str | None = None
|
||||
):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
force: bool = False
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
variable_pool (VariablePool): Pool of variables for evaluating segment values.
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except Exception as e:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def flush_remaining_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
@@ -1,14 +1,7 @@
|
||||
"""
|
||||
变量池 (Variable Pool)
|
||||
|
||||
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||
|
||||
变量类型:
|
||||
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||
2. 节点输出 (node_id.*) - 节点执行结果
|
||||
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||
"""
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2025/12/15 19:50
|
||||
import logging
|
||||
import re
|
||||
from asyncio import Lock
|
||||
@@ -18,7 +11,8 @@ from typing import Any, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -359,3 +353,77 @@ class VariablePool:
|
||||
f" runtime_vars={len(runtime_vars)}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class VariablePoolInitializer:
|
||||
def __init__(self, workflow_config: dict):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext
|
||||
) -> None:
|
||||
await self._init_conversation_vars(variable_pool, input_data)
|
||||
await self._init_system_vars(variable_pool, input_data, execution_context)
|
||||
|
||||
async def _init_conversation_vars(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict
|
||||
):
|
||||
init_conv_vars: list[dict] = self.workflow_config.get("variables") or []
|
||||
runtime_conv_vars: dict[str, Any] = input_data.get("conv", {})
|
||||
|
||||
for var_def in init_conv_vars:
|
||||
var_name = var_def.get("name")
|
||||
var_default = runtime_conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
context: ExecutionContext
|
||||
):
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
conversations = input_data.get("conv_messages", [])
|
||||
conversation_index = len(conversations) // 2
|
||||
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_index": (conversation_index, VariableType.NUMBER),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (context.execution_id, VariableType.STRING),
|
||||
"workspace_id": (context.workspace_id, VariableType.STRING),
|
||||
"user_id": (context.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
"""
|
||||
工作流执行器
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 13:51
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.event_stream_handler import EventStreamHandler
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.result_builder import WorkflowResultBuilder
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,9 +29,7 @@ class WorkflowExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
):
|
||||
"""Initialize Workflow Executor.
|
||||
|
||||
@@ -41,13 +38,10 @@ class WorkflowExecutor:
|
||||
|
||||
Args:
|
||||
workflow_config (dict): The workflow configuration dictionary.
|
||||
execution_id (str): Unique identifier for this workflow execution.
|
||||
workspace_id (str): Workspace or project ID.
|
||||
user_id (str): User ID executing the workflow.
|
||||
execution_context (ExecutionContext): The workflow execution context
|
||||
include execution_id, workspace_id, user_id, checkpoint_config
|
||||
|
||||
Attributes:
|
||||
self.nodes (list): List of node definitions from workflow_config.
|
||||
self.edges (list): List of edge definitions from workflow_config.
|
||||
self.execution_config (dict): Optional execution parameters from workflow_config.
|
||||
self.start_node_id (str | None): ID of the Start node, set after graph build.
|
||||
self.end_outputs (dict[str, StreamOutputConfig]): End node output configs.
|
||||
@@ -57,555 +51,18 @@ class WorkflowExecutor:
|
||||
self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing.
|
||||
"""
|
||||
self.workflow_config = workflow_config
|
||||
self.execution_id = execution_id
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_context = execution_context
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.start_node_id: str | None = None
|
||||
self.variable_pool: VariablePool | None = None
|
||||
|
||||
self.graph: CompiledStateGraph | None = None
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
|
||||
async def __init_variable_pool(self, input_data: dict[str, Any]):
|
||||
"""Initialize the variable pool with system, conversation, and input variables.
|
||||
|
||||
This method populates the VariablePool instance with:
|
||||
- Conversation-level variables (`conv` namespace) from workflow config or provided values.
|
||||
- System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data for workflow execution, may contain:
|
||||
- "message": user message (str)
|
||||
- "file": list of user-uploaded files
|
||||
- "conv": existing conversation variables (dict)
|
||||
- "variables": custom variables for the Start node (dict)
|
||||
- "conversation_id": conversation identifier
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conv_vars = input_data.get("conv", {})
|
||||
|
||||
# Initialize conversation variables (conv namespace)
|
||||
for var_def in config_variables_list:
|
||||
var_name = var_def.get("name")
|
||||
var_default = conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await self.variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
# Initialize system variables (sys namespace)
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (self.execution_id, VariableType.STRING),
|
||||
"workspace_id": (self.workspace_id, VariableType.STRING),
|
||||
"user_id": (self.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await self.variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""Generate the initial workflow state for execution.
|
||||
|
||||
This method prepares the runtime state dictionary with system variables,
|
||||
conversation variables, node outputs, loop tracking, and activation flags.
|
||||
|
||||
Args:
|
||||
input_data (dict): The input payload for workflow execution.
|
||||
Expected keys:
|
||||
- "conv_messages" (list, optional): Historical conversation messages
|
||||
to include in the workflow state.
|
||||
|
||||
Returns:
|
||||
WorkflowState: A dictionary representing the initialized workflow state
|
||||
with the following keys:
|
||||
- "messages": List of conversation messages
|
||||
- "node_outputs": Empty dict to store outputs of executed nodes
|
||||
- "execution_id": Current workflow execution ID
|
||||
- "workspace_id": Current workspace ID
|
||||
- "user_id": ID of the user triggering execution
|
||||
- "error": None initially, will store error message if a node fails
|
||||
- "error_node": None initially, will store ID of node that caused error
|
||||
- "cycle_nodes": List of node IDs that are of type LOOP or ITERATION
|
||||
- "looping": Integer flag indicating loop execution state (0 = not looping)
|
||||
- "activate": Dict mapping node IDs to activation status; initially
|
||||
only the start node is active
|
||||
"""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
return {
|
||||
"messages": conversation_messages,
|
||||
"node_outputs": {},
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {
|
||||
self.start_node_id: True
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = self.variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": self.variable_pool.get_all_conversation_vars(),
|
||||
"sys": self.variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_scope_activate(self, scope, status=None):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
def _update_stream_output_status(self, activate, data):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self._update_scope_activate(node_id, status=node_output_status)
|
||||
|
||||
async def _emit_active_chunks(
|
||||
self,
|
||||
force=False
|
||||
):
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = self.variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except KeyError:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
# Remove End node from active tracking if all segments have been processed
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def _handle_updates_event(self, data):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Get the latest workflow state
|
||||
state = self.graph.get_state(config=self.checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
# Update End node activation based on the new state
|
||||
self._update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.activate_end and not wait:
|
||||
async for msg_event in self._emit_active_chunks():
|
||||
yield msg_event
|
||||
|
||||
if self.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self._update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def _handle_node_chunk_event(self, data):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_node_error_event(self, data):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_debug_event(self, data, input_data):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
async def _flush_remaining_chunk(self):
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self._emit_active_chunks(force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
self.variable_initializer = VariablePoolInitializer(workflow_config)
|
||||
self.state_manager = WorkflowStateManager()
|
||||
self.result_builder = WorkflowResultBuilder()
|
||||
self.stream_coordinator = StreamOutputCoordinator()
|
||||
self.event_handler: EventStreamHandler | None = None
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""
|
||||
@@ -624,16 +81,22 @@ class WorkflowExecutor:
|
||||
Returns:
|
||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||
"""
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_id}")
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
output_coordinator=self.stream_coordinator,
|
||||
variable_pool=self.variable_pool,
|
||||
execution_id=self.execution_context.execution_id
|
||||
)
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
return self.graph
|
||||
|
||||
@@ -665,7 +128,7 @@ class WorkflowExecutor:
|
||||
- token_usage: aggregated token usage if available
|
||||
- error: error message if any
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
@@ -673,16 +136,25 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
full_content = ''
|
||||
for end_id in self.end_outputs.keys():
|
||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
@@ -703,15 +175,16 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
@@ -744,15 +217,15 @@ class WorkflowExecutor:
|
||||
"data": {...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
yield {
|
||||
"event": "workflow_start",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"workspace_id": self.execution_context.workspace_id,
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
@@ -762,18 +235,27 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
full_content = ''
|
||||
self._update_scope_activate("sys")
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
# Execute the workflow with streaming
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
config=self.checkpoint_config
|
||||
config=self.execution_context.checkpoint_config
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
@@ -782,38 +264,46 @@ class WorkflowExecutor:
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self._handle_node_chunk_event(data):
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
async for error_event in self._handle_node_error_event(data):
|
||||
async for error_event in self.event_handler.handle_node_error_event(data):
|
||||
yield error_event
|
||||
|
||||
elif event_type == "cycle_item":
|
||||
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
|
||||
yield cycle_event
|
||||
|
||||
elif mode == "debug":
|
||||
async for debug_event in self._handle_debug_event(data, input_data):
|
||||
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
||||
yield debug_event
|
||||
|
||||
elif mode == "updates":
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
async for msg_event in self._handle_updates_event(data):
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
async for msg_event in self.event_handler.handle_updates_event(
|
||||
data,
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self._flush_remaining_chunk():
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
@@ -832,24 +322,25 @@ class WorkflowExecutor:
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
|
||||
)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -857,46 +348,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
@@ -918,12 +369,15 @@ async def execute_workflow(
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
return await executor.execute(input_data)
|
||||
|
||||
|
||||
@@ -947,11 +401,14 @@ async def execute_workflow_stream(
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
@@ -14,16 +15,14 @@ from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"IfElseNode",
|
||||
|
||||
@@ -7,14 +7,16 @@ Agent 节点实现
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,57 +5,17 @@ from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""Base class for workflow nodes.
|
||||
|
||||
@@ -584,7 +544,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The rendered string with all variables substituted.
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
from app.core.workflow.utils.template_renderer import render_template
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
@@ -611,7 +571,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The boolean result of evaluating the expression.
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,13 +6,14 @@ import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +29,7 @@ class IterationRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -42,6 +47,7 @@ class IterationRuntime:
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -49,6 +55,12 @@ class IterationRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
@@ -91,7 +103,46 @@ class IterationRuntime:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
await self._init_iteration_state(item, idx),
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
result = self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
@@ -152,16 +203,9 @@ class IterationRuntime:
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
result = await self.run_task(item, idx)
|
||||
self.merge_conv_vars()
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
child_state.append(result)
|
||||
idx += 1
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +30,7 @@ class LoopRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -46,6 +50,7 @@ class LoopRuntime:
|
||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -53,6 +58,13 @@ class LoopRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
async def _init_loop_state(self):
|
||||
"""
|
||||
@@ -142,10 +154,12 @@ class LoopRuntime:
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def merge_conv_vars(self):
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
"""
|
||||
@@ -175,6 +189,50 @@ class LoopRuntime:
|
||||
else:
|
||||
return any(conditions)
|
||||
|
||||
async def _run(self, loopstate, idx):
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
loopstate,
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
cycle_variable = None
|
||||
if node_type == NodeType.CYCLE_START:
|
||||
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
return self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
return await self.graph.ainvoke(loopstate)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop node until termination conditions are met.
|
||||
@@ -190,15 +248,17 @@ class LoopRuntime:
|
||||
loopstate = await self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
idx = 0
|
||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
result = await self._run(loopstate, idx)
|
||||
child_state.append(result)
|
||||
|
||||
self.merge_conv_vars()
|
||||
self.merge_conv_vars(loopstate)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
idx += 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Any
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -136,7 +136,7 @@ class CycleGraphNode(BaseNode):
|
||||
2. Construct a StateGraph using GraphBuilder in subgraph mode
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool,
|
||||
).run()
|
||||
}
|
||||
return
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
}
|
||||
return
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -6,9 +6,10 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,11 +7,12 @@ import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("cases")
|
||||
@classmethod
|
||||
def validate_case_number(cls, v, info):
|
||||
def validate_case_number(cls, v):
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one cases are required")
|
||||
return v
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.utils.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -56,7 +57,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
model_id: str = Field(
|
||||
model_id: uuid.UUID = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
@@ -148,7 +149,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
def validate_input_mode(cls, v):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
@@ -13,10 +13,11 @@ from langchain_core.messages import AIMessage
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -268,7 +269,7 @@ class LLMNode(BaseNode):
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -3,9 +3,9 @@ import re
|
||||
from abc import ABC
|
||||
from typing import Union, Type, NoReturn, Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import ValueInputType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class TypeTransformer:
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import json_repair
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -7,10 +7,11 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
|
||||
@@ -2,11 +2,11 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
4
api/app/core/workflow/utils/__init__.py
Normal file
4
api/app/core/workflow/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -187,7 +187,7 @@ class WorkflowValidator:
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.utils.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from .generic_file_model import GenericFile
|
||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy
|
||||
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||
from .knowledgeshare_model import KnowledgeShare
|
||||
from .mcp_market_model import McpMarket
|
||||
from .mcp_market_config_model import McpMarketConfig
|
||||
from .app_model import App
|
||||
from .agent_app_config_model import AgentConfig
|
||||
from .app_release_model import AppRelease
|
||||
@@ -50,6 +52,8 @@ __all__ = [
|
||||
"ModelType",
|
||||
"ModelApiKey",
|
||||
"KnowledgeShare",
|
||||
"McpMarket",
|
||||
"McpMarketConfig",
|
||||
"App",
|
||||
"AgentConfig",
|
||||
"AppRelease",
|
||||
|
||||
@@ -35,7 +35,7 @@ class FileMetadata(Base):
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Tenant ID")
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Workspace ID")
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, index=True, comment="Workspace ID")
|
||||
file_key = Column(String(512), nullable=False, unique=True, index=True, comment="Storage file key")
|
||||
file_name = Column(String(255), nullable=False, comment="Original file name")
|
||||
file_ext = Column(String(32), nullable=False, comment="File extension")
|
||||
|
||||
16
api/app/models/mcp_market_config_model.py
Normal file
16
api/app/models/mcp_market_config_model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
class McpMarketConfig(Base):
|
||||
__tablename__ = "mcp_market_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
mcp_market_id = Column(UUID(as_uuid=True), nullable=False, comment="mcp_markets.id")
|
||||
token = Column(String, nullable=True, comment="mcp market token")
|
||||
status = Column(Integer, default=0, comment="connect status(0: Not connected, 1: connected)")
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, comment="tenant.id")
|
||||
created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
18
api/app/models/mcp_market_model.py
Normal file
18
api/app/models/mcp_market_model.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
class McpMarket(Base):
|
||||
__tablename__ = "mcp_markets"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
name = Column(String, index=True, nullable=False, comment="mcp market name")
|
||||
description = Column(String, index=True, nullable=True, comment="mcp market description")
|
||||
logo_url = Column(String, index=True, nullable=True, comment="logo url")
|
||||
mcp_count = Column(Integer, default=1, comment="mcp count")
|
||||
url = Column(String, index=True, nullable=False, comment="mcp market url")
|
||||
category = Column(String, index=True, nullable=False, comment="category")
|
||||
created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
72
api/app/repositories/mcp_market_config_repository.py
Normal file
72
api/app/repositories/mcp_market_config_repository.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.mcp_market_config_model import McpMarketConfig
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def create_mcp_market_config(db: Session, mcp_market_config: mcp_market_config_schema.McpMarketConfigCreate) -> McpMarketConfig:
|
||||
db_logger.debug(f"Create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = McpMarketConfig(**mcp_market_config.model_dump())
|
||||
db.add(db_mcp_market_config)
|
||||
db.commit()
|
||||
db_logger.info(f"McpMarketConfig record created successfully: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id})")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID) -> McpMarketConfig | None:
|
||||
db_logger.debug(f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).first()
|
||||
if db_mcp_market_config:
|
||||
db_logger.debug(f"McpMarketConfig query successful: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarketConfig does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, tenant_id: uuid.UUID) -> McpMarketConfig | None:
|
||||
db_logger.debug(f"Query mcp market config based on mcp_market_id: {mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.mcp_market_id == mcp_market_id, McpMarketConfig.tenant_id == tenant_id).first()
|
||||
if db_mcp_market_config:
|
||||
db_logger.debug(f"McpMarketConfig query successful: (mcp_market_id: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarketConfig does not exist: mcp_market_id={mcp_market_id}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete McpMarketConfig record: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market config information for logging purposes
|
||||
result = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"McpMarketConfig record deleted successfully: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
db_logger.warning(f"The mcp market config record does not exist, and cannot be deleted: id={mcp_market_config_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete mcp market config record: id={mcp_market_config_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
124
api/app/repositories/mcp_market_repository.py
Normal file
124
api/app/repositories/mcp_market_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.mcp_market_model import McpMarket
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_mcp_markets_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query mcp market (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Query mcp market in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(McpMarket)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of mcp_market queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(McpMarket, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(
|
||||
f"The mcp market paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [mcp_market_schema.McpMarket.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying mcp_market pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_market(db: Session, mcp_market: mcp_market_schema.McpMarketCreate) -> McpMarket:
|
||||
db_logger.debug(f"Create a mcp market record: name={mcp_market.name}")
|
||||
|
||||
try:
|
||||
db_mcp_market = McpMarket(**mcp_market.model_dump())
|
||||
db.add(db_mcp_market)
|
||||
db.commit()
|
||||
db_logger.info(f"McpMarket record created successfully: {mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a mcp market record: title={mcp_market.name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID) -> McpMarket | None:
|
||||
db_logger.debug(f"Query mcp market based on ID: mcp_market_id={mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
|
||||
if db_mcp_market:
|
||||
db_logger.debug(f"McpMarket query successful: {db_mcp_market.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarket does not exist: mcp_market_id={mcp_market_id}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market based on the ID: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_name(db: Session, name: str) -> McpMarket | None:
|
||||
db_logger.debug(f"Query mcp market based on name: name={name}")
|
||||
|
||||
try:
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.name == name).first()
|
||||
if db_mcp_market:
|
||||
db_logger.debug(f"mcp market query successful: {name} (ID: {db_mcp_market.id})")
|
||||
else:
|
||||
db_logger.debug(f"mcp market does not exist: name={name}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market based on the name: {name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete McpMarket record: mcp_market_id={mcp_market_id}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market information for logging purposes
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
|
||||
if db_mcp_market:
|
||||
name = db_mcp_market.name
|
||||
else:
|
||||
name = "unknown"
|
||||
|
||||
result = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"McpMarket record deleted successfully: {name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.warning(f"The mcp market record does not exist, and cannot be deleted: mcp_market_id={mcp_market_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete mcp market record: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
@@ -48,13 +48,17 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据名称获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}")
|
||||
def get_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据名称和供应商获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}, provider={provider}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelConfig).filter(ModelConfig.name == name)
|
||||
|
||||
# 添加供应商过滤
|
||||
if provider:
|
||||
query = query.filter(ModelConfig.provider == provider)
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
@@ -69,7 +73,7 @@ class ModelConfigRepository:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name}")
|
||||
return model
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
|
||||
db_logger.error(f"根据名称查询模型配置失败: name={name}, provider={provider} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -115,6 +115,7 @@ class WorkspaceRepository:
|
||||
self.db.query(Workspace)
|
||||
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
|
||||
.filter(WorkspaceMember.user_id == user_id)
|
||||
.filter(WorkspaceMember.is_active.is_(True))
|
||||
.filter(Workspace.is_active.is_(True))
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
|
||||
@@ -8,6 +8,8 @@ from .file_schema import File, FileCreate, FileUpdate
|
||||
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
||||
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
||||
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
||||
from .mcp_market_schema import McpMarketCreate, McpMarketUpdate, McpMarket
|
||||
from .mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate, McpMarketConfig
|
||||
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
|
||||
from .app_schema import (
|
||||
AppChatRequest,
|
||||
@@ -78,6 +80,12 @@ __all__ = [
|
||||
"ChunkRetrieve",
|
||||
"KnowledgeShare",
|
||||
"KnowledgeShareCreate",
|
||||
"McpMarketCreate",
|
||||
"McpMarketUpdate",
|
||||
"McpMarket",
|
||||
"McpMarketConfigCreate",
|
||||
"McpMarketConfigUpdate",
|
||||
"McpMarketConfig",
|
||||
"CreateOrderRequest",
|
||||
"OrderResponse",
|
||||
"ExternalOrderResponse",
|
||||
|
||||
@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
|
||||
31
api/app/schemas/mcp_market_config_schema.py
Normal file
31
api/app/schemas/mcp_market_config_schema.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class McpMarketConfigBase(BaseModel):
|
||||
mcp_market_id: uuid.UUID
|
||||
token: str | None = None
|
||||
status: int | None = None
|
||||
tenant_id: uuid.UUID | None = None
|
||||
created_by: uuid.UUID | None = None
|
||||
|
||||
|
||||
class McpMarketConfigCreate(McpMarketConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class McpMarketConfigUpdate(BaseModel):
|
||||
token: str | None = None
|
||||
status: int | None = None
|
||||
|
||||
|
||||
class McpMarketConfig(McpMarketConfigBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
37
api/app/schemas/mcp_market_schema.py
Normal file
37
api/app/schemas/mcp_market_schema.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class McpMarketBase(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
logo_url: str | None = None
|
||||
mcp_count: int
|
||||
url: str
|
||||
category: str
|
||||
created_by: uuid.UUID | None = None
|
||||
|
||||
|
||||
class McpMarketCreate(McpMarketBase):
|
||||
pass
|
||||
|
||||
|
||||
class McpMarketUpdate(BaseModel):
|
||||
name: str | None = Field(None)
|
||||
description: str | None = Field(None)
|
||||
logo_url: str | None = Field(None)
|
||||
mcp_count: int | None = Field(None)
|
||||
url: str | None = Field(None)
|
||||
category: str | None = Field(None)
|
||||
|
||||
|
||||
class McpMarket(McpMarketBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -25,9 +25,9 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
provider: Optional[str] = Field(None, description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
@@ -57,6 +57,8 @@ class ModelConfigUpdate(BaseModel):
|
||||
"""更新模型配置Schema"""
|
||||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||
provider: Optional[str] = Field(None, description="供应商")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
|
||||
@@ -27,4 +27,5 @@ class TokenRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
@@ -36,6 +36,28 @@ class AdminChangePasswordRequest(BaseModel):
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,至少6位。如果不提供则自动生成随机密码")
|
||||
|
||||
|
||||
class ChangeEmailRequest(BaseModel):
|
||||
"""修改邮箱请求"""
|
||||
password: str = Field(..., description="当前密码")
|
||||
new_email: EmailStr = Field(..., description="新邮箱地址")
|
||||
|
||||
|
||||
class SendEmailCodeRequest(BaseModel):
|
||||
"""发送邮箱验证码请求"""
|
||||
email: EmailStr = Field(..., description="邮箱地址")
|
||||
|
||||
|
||||
class VerifyEmailCodeRequest(BaseModel):
|
||||
"""验证邮箱验证码并修改邮箱请求"""
|
||||
new_email: EmailStr = Field(..., description="新邮箱地址")
|
||||
code: str = Field(..., min_length=6, max_length=6, description="验证码")
|
||||
|
||||
|
||||
class VerifyPasswordRequest(BaseModel):
|
||||
"""验证密码请求"""
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class ChangePasswordResponse(BaseModel):
|
||||
"""修改密码响应"""
|
||||
message: str
|
||||
|
||||
@@ -129,7 +129,8 @@ def register_user_with_invite(
|
||||
email: str,
|
||||
password: str,
|
||||
invite_token: str,
|
||||
workspace_id: str
|
||||
workspace_id: str,
|
||||
username: Optional[str] = None,
|
||||
) -> User:
|
||||
"""
|
||||
使用邀请码注册新用户并加入工作空间
|
||||
@@ -139,6 +140,7 @@ def register_user_with_invite(
|
||||
:param password: 用户密码
|
||||
:param invite_token: 邀请令牌
|
||||
:param workspace_id: 工作空间ID
|
||||
:param username: 用户名
|
||||
:return: 创建的用户对象
|
||||
"""
|
||||
from app.schemas.user_schema import UserCreate
|
||||
@@ -154,7 +156,7 @@ def register_user_with_invite(
|
||||
user_create = UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
username=email.split('@')[0]
|
||||
username=email.split('@')[0] if not username else username
|
||||
)
|
||||
user = user_service.create_user(db=db, user=user_create)
|
||||
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
|
||||
|
||||
88
api/app/services/email_service.py
Normal file
88
api/app/services/email_service.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import smtplib
|
||||
import re
|
||||
import asyncio
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.header import Header
|
||||
from email.utils import formataddr
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _send_email_sync(to_email: str, subject: str, html_content: str, text_content: str = None):
|
||||
"""同步发送邮件"""
|
||||
smtp_server = settings.SMTP_SERVER
|
||||
smtp_port = settings.SMTP_PORT
|
||||
smtp_user = settings.SMTP_USER
|
||||
smtp_password = settings.SMTP_PASSWORD
|
||||
|
||||
if not smtp_server or not smtp_user or not smtp_password:
|
||||
raise BusinessException("邮件服务未配置", code=BizCode.SERVICE_UNAVAILABLE)
|
||||
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = Header(subject, "utf-8")
|
||||
from_name = "MemoryBear系统"
|
||||
msg['From'] = formataddr((Header(from_name, 'utf-8').encode(), smtp_user))
|
||||
msg['To'] = Header(to_email, "utf-8")
|
||||
|
||||
if not text_content:
|
||||
text_content = html_content.replace('<br>', '\n').replace('<p>', '\n').replace('</p>', '\n')
|
||||
text_content = re.sub(r'<.*?>', '', text_content)
|
||||
text_part = MIMEText(text_content, 'plain', 'utf-8')
|
||||
msg.attach(text_part)
|
||||
|
||||
html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
msg.attach(html_part)
|
||||
|
||||
if smtp_port == 465:
|
||||
with smtplib.SMTP_SSL(smtp_server, smtp_port, timeout=10) as server:
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
else:
|
||||
with smtplib.SMTP(smtp_server, smtp_port, timeout=10) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
async def send_email(to_email: str, subject: str, html_content: str, text_content: str = None):
|
||||
"""异步发送邮件"""
|
||||
to_email = to_email.strip()
|
||||
if not to_email or not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', to_email):
|
||||
err_msg = f"收件人邮箱格式无效: {to_email}"
|
||||
business_logger.error(err_msg)
|
||||
raise BusinessException(err_msg, code=BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
await loop.run_in_executor(
|
||||
executor,
|
||||
_send_email_sync,
|
||||
to_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content
|
||||
)
|
||||
business_logger.info(f"邮件发送成功: {to_email}")
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
err_msg = "SMTP认证失败,请检查SMTP账号/密码是否正确"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.UNAUTHORIZED)
|
||||
except smtplib.SMTPConnectError:
|
||||
err_msg = "SMTP服务器连接失败,请检查服务器地址/端口是否正确"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.SERVICE_UNAVAILABLE)
|
||||
except TimeoutError:
|
||||
err_msg = "邮件发送超时,请检查SMTP服务器配置"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.BAD_REQUEST)
|
||||
except Exception as e:
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {str(e)}")
|
||||
raise BusinessException(f"邮件发送失败: {str(e)}", code=BizCode.SERVICE_UNAVAILABLE)
|
||||
@@ -26,7 +26,7 @@ logger = get_business_logger()
|
||||
|
||||
def generate_file_key(
|
||||
tenant_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
) -> str:
|
||||
@@ -56,8 +56,9 @@ def generate_file_key(
|
||||
# Ensure file_ext starts with a dot
|
||||
if file_ext and not file_ext.startswith('.'):
|
||||
file_ext = f'.{file_ext}'
|
||||
|
||||
return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}"
|
||||
if workspace_id:
|
||||
return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}"
|
||||
return f"{tenant_id}/{file_id}{file_ext}"
|
||||
|
||||
|
||||
class FileStorageService:
|
||||
@@ -96,7 +97,7 @@ class FileStorageService:
|
||||
async def upload_file(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
content: bytes,
|
||||
|
||||
83
api/app/services/mcp_market_config_service.py
Normal file
83
api/app/services/mcp_market_config_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.mcp_market_config_model import McpMarketConfig
|
||||
from app.schemas.mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate
|
||||
from app.repositories import mcp_market_config_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def create_mcp_market_config(
|
||||
db: Session, mcp_market_config: McpMarketConfigCreate, current_user: User
|
||||
) -> McpMarketConfig:
|
||||
business_logger.info(f"Create a mcp market config base: {mcp_market_config.mcp_market_id}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcp_market_config.tenant_id = current_user.tenant_id
|
||||
mcp_market_config.created_by = current_user.id
|
||||
business_logger.debug(f"Start creating the mcp market config on mcp_market_id: {mcp_market_config.mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_repository.create_mcp_market_config(
|
||||
db=db, mcp_market_config=mcp_market_config
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market config has been successfully created: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id}), creator: {current_user.username}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a mcp marke config: {mcp_market_config.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> McpMarketConfig | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.info(f"mcp market config query successful: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market config does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
return mcpMarketConfig
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarketConfig | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market config based on mcp_market_id: {mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_mcp_market_id(db=db, mcp_market_id=mcp_market_id, tenant_id=current_user.tenant_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.info(f"mcp market config query successful: (mcp_market_id: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market config does not exist: mcp_market_id={mcp_market_id}")
|
||||
return mcpMarketConfig
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete mcp market config: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market config information for logging purposes
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.debug(f"Execute mcp market config deletion: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
business_logger.warning(f"The mcp market config to be deleted does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
mcp_market_config_repository.delete_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
business_logger.info(
|
||||
f"mcp market config record deleted successfully: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
109
api/app/services/mcp_market_service.py
Normal file
109
api/app/services/mcp_market_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.mcp_market_model import McpMarket
|
||||
from app.schemas.mcp_market_schema import McpMarketCreate, McpMarketUpdate
|
||||
from app.repositories import mcp_market_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_mcp_markets_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(
|
||||
f"Query mcp market in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = mcp_market_repository.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying mcp market pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_market(
|
||||
db: Session, mcp_market: McpMarketCreate, current_user: User
|
||||
) -> McpMarket:
|
||||
business_logger.info(f"Create a mcp market base: {mcp_market.name}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcp_market.created_by = current_user.id
|
||||
business_logger.debug(f"Start creating the mcp market: {mcp_market.name}")
|
||||
db_mcp_market = mcp_market_repository.create_mcp_market(
|
||||
db=db, mcp_market=mcp_market
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market has been successfully created: {mcp_market.name} (ID: {db_mcp_market.id}), creator: {current_user.username}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a mcp market: {mcp_market.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarket | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market based on ID: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
if mcpMarket:
|
||||
business_logger.info(f"mcp market query successful: {mcpMarket.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market does not exist: mcp_market_id={mcp_market_id}")
|
||||
return mcpMarket
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market based on the ID: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_name(db: Session, name: str, current_user: User) -> McpMarket | None:
|
||||
business_logger.debug(f"Query mcp market based on name: name={name}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
db_mcp_market = mcp_market_repository.get_mcp_market_by_name(db=db, name=name)
|
||||
if db_mcp_market:
|
||||
business_logger.info(f"mcp market query successful: {name} (ID: {db_mcp_market.id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market does not exist: name={name}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the mcp market based on the name: name={name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete mcp market: mcp_market_id={mcp_market_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market information for logging purposes
|
||||
mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
if mcpMarket:
|
||||
business_logger.debug(f"Execute mcp market deletion: {mcpMarket.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"The mcp market to be deleted does not exist: mcp_market_id={mcp_market_id}")
|
||||
|
||||
mcp_market_repository.delete_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
business_logger.info(
|
||||
f"mcp market record deleted successfully: mcp_market_id={mcp_market_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
@@ -69,9 +69,9 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
@@ -244,7 +244,7 @@ class ModelConfigService:
|
||||
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建模型配置"""
|
||||
# 检查名称是否已存在(同租户内)
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证配置
|
||||
@@ -253,8 +253,8 @@ class ModelConfigService:
|
||||
for api_key_data in api_key_data_list:
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
model_name=model_data.name,
|
||||
provider=model_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
@@ -277,6 +277,8 @@ class ModelConfigService:
|
||||
|
||||
if api_key_datas:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_data.model_name = model_data.name
|
||||
api_key_data.provider = model_data.provider
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
@@ -295,7 +297,7 @@ class ModelConfigService:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -306,7 +308,7 @@ class ModelConfigService:
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
@@ -341,7 +343,7 @@ class ModelConfigService:
|
||||
"type": model_data.type,
|
||||
"logo": model_data.logo,
|
||||
"description": model_data.description,
|
||||
"provider": "composite",
|
||||
"provider": ModelProvider.COMPOSITE,
|
||||
"config": model_data.config,
|
||||
"is_active": model_data.is_active,
|
||||
"is_public": model_data.is_public,
|
||||
@@ -369,6 +371,10 @@ class ModelConfigService:
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
@@ -471,11 +477,14 @@ class ModelApiKeyService:
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
# 检查是否存在API Key(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
).filter(
|
||||
ModelApiKey.api_key == data.api_key,
|
||||
ModelApiKey.provider == data.provider,
|
||||
ModelApiKey.model_name == model_name
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
@@ -542,11 +551,14 @@ class ModelApiKeyService:
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 检查API Key是否已存在(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
).filter(
|
||||
ModelApiKey.api_key == api_key_data.api_key,
|
||||
ModelApiKey.provider == api_key_data.provider,
|
||||
ModelApiKey.model_name == api_key_data.model_name
|
||||
ModelApiKey.model_name == api_key_data.model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import datetime
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete
|
||||
from app.models.user_model import User
|
||||
from app.repositories import user_repository
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.tenant_schema import TenantCreate
|
||||
from app.services.email_service import send_email
|
||||
from app.services.tenant_service import TenantService
|
||||
from app.services.session_service import SessionService
|
||||
from app.core.security import get_password_hash, verify_password
|
||||
@@ -563,3 +568,175 @@ def generate_random_password(length: int = 12) -> str:
|
||||
secrets.SystemRandom().shuffle(password)
|
||||
|
||||
return ''.join(password)
|
||||
|
||||
|
||||
def generate_email_code() -> str:
|
||||
"""生成6位数字验证码"""
|
||||
return ''.join([str(secrets.randbelow(10)) for _ in range(6)])
|
||||
|
||||
|
||||
async def send_email_code_method(db: Session, email: EmailStr, user_id: uuid.UUID):
|
||||
"""发送邮箱验证码"""
|
||||
business_logger.info(f"发送邮箱验证码: email={email}")
|
||||
|
||||
# 检查发送间隔
|
||||
rate_limit_key = f"email_code_rate:{user_id}"
|
||||
last_send = await aio_redis_get(rate_limit_key)
|
||||
|
||||
if last_send:
|
||||
raise BusinessException("请稍后再试,验证码发送间隔为1分钟", code=BizCode.RATE_LIMITED)
|
||||
|
||||
# 检查新邮箱是否已被使用
|
||||
existing_user = user_repository.get_user_by_email(db=db, email=email)
|
||||
if existing_user and existing_user.id != user_id:
|
||||
raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
if existing_user and existing_user.id == user_id:
|
||||
raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 生成验证码
|
||||
code = generate_email_code()
|
||||
|
||||
# 存储到 Redis,5分钟过期
|
||||
cache_key = f"email_code:{user_id}:{email}"
|
||||
await aio_redis_set(cache_key, json.dumps(code), expire=300)
|
||||
|
||||
# 发送邮件
|
||||
await send_email(
|
||||
email,
|
||||
"邮箱验证码",
|
||||
f'<p>您的验证码是:<strong>{code}</strong></p><p>验证码在5分钟内有效。</p>'
|
||||
)
|
||||
|
||||
# 设置发送间隔限制,60秒
|
||||
await aio_redis_set(rate_limit_key, "1", expire=60)
|
||||
|
||||
business_logger.info(f"邮箱验证码已发送: {email}")
|
||||
|
||||
|
||||
async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, code: str) -> User:
|
||||
"""验证验证码并修改邮箱"""
|
||||
business_logger.info(f"验证并修改邮箱: user_id={user_id}, new_email={new_email}")
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证验证码
|
||||
cache_key = f"email_code:{user_id}:{new_email}"
|
||||
cached_code = await aio_redis_get(cache_key)
|
||||
|
||||
if not cached_code:
|
||||
raise BusinessException("验证码已过期", code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
if json.loads(cached_code) != code:
|
||||
raise BusinessException("验证码错误", code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 修改邮箱
|
||||
db_user.email = new_email
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
# 删除验证码
|
||||
await aio_redis_delete(cache_key)
|
||||
|
||||
# 使所有旧 tokens 失效
|
||||
# await SessionService.invalidate_all_user_tokens(str(user_id))
|
||||
|
||||
business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
return db_user
|
||||
|
||||
|
||||
# def generate_email_token(user_id: str, old_email: str, new_email: str) -> str:
|
||||
# """生成邮箱修改token"""
|
||||
# payload = {
|
||||
# "user_id": user_id,
|
||||
# "old_email": old_email,
|
||||
# "new_email": new_email,
|
||||
# "exp": datetime.datetime.now(datetime.timezone.utc) + timedelta(hours=24)
|
||||
# }
|
||||
# return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
#
|
||||
#
|
||||
# def verify_email_token(token: str) -> dict:
|
||||
# """验证邮箱修改token"""
|
||||
# try:
|
||||
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
# return payload
|
||||
# except jwt.ExpiredSignatureError:
|
||||
# raise BusinessException("链接已过期", code=BizCode.VALIDATION_FAILED)
|
||||
# except jwt.InvalidTokenError:
|
||||
# raise BusinessException("无效的链接", code=BizCode.VALIDATION_FAILED)
|
||||
#
|
||||
#
|
||||
# async def request_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, current_user: User):
|
||||
# """请求修改邮箱,发送验证邮件"""
|
||||
# business_logger.info(f"用户请求修改邮箱: user_id={user_id}, new_email={new_email}")
|
||||
#
|
||||
# if current_user.id != user_id:
|
||||
# raise PermissionDeniedException("只能修改自己的邮箱")
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# if db_user.email == new_email:
|
||||
# raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.VALIDATION_FAILED)
|
||||
#
|
||||
# existing_user = user_repository.get_user_by_email(db=db, email=new_email)
|
||||
# if existing_user and existing_user.id != user_id:
|
||||
# raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME)
|
||||
#
|
||||
# token = generate_email_token(str(user_id), db_user.email, new_email)
|
||||
#
|
||||
# # 发送确认邮件到旧邮箱
|
||||
# old_email_link = f"{settings.BASE_URL}/api/users/email/confirm-email-change?token={token}"
|
||||
# await send_email(
|
||||
# db_user.email,
|
||||
# "确认修改邮箱",
|
||||
# f'<p>请点击以下链接确认修改邮箱:</p><a href="{old_email_link}">确认修改</a>'
|
||||
# )
|
||||
#
|
||||
# business_logger.info(f"邮箱修改确认邮件已发送到旧邮箱: {db_user.email}")
|
||||
#
|
||||
#
|
||||
# async def confirm_email_change(db: Session, token: str):
|
||||
# """确认修改邮箱(旧邮箱确认)"""
|
||||
# payload = verify_email_token(token)
|
||||
# user_id = uuid.UUID(payload["user_id"])
|
||||
# new_email = payload["new_email"]
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# # 发送激活邮件到新邮箱
|
||||
# activate_link = f"{settings.BASE_URL}/api/users/email/activate-new-email?token={token}"
|
||||
# await send_email(
|
||||
# new_email,
|
||||
# "激活新邮箱",
|
||||
# f'<p>请点击以下链接激活新邮箱:</p><a href="{activate_link}">激活邮箱</a>'
|
||||
# )
|
||||
#
|
||||
# business_logger.info(f"新邮箱激活邮件已发送: {new_email}")
|
||||
#
|
||||
#
|
||||
# async def activate_new_email(db: Session, token: str) -> User:
|
||||
# """激活新邮箱"""
|
||||
# payload = verify_email_token(token)
|
||||
# user_id = uuid.UUID(payload["user_id"])
|
||||
# new_email = payload["new_email"]
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# db_user.email = new_email
|
||||
# db.commit()
|
||||
# db.refresh(db_user)
|
||||
#
|
||||
# # 使所有旧 tokens 失效
|
||||
# await SessionService.invalidate_all_user_tokens(str(user_id))
|
||||
#
|
||||
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
# return db_user
|
||||
|
||||
@@ -588,7 +588,7 @@ class WorkflowService:
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
@@ -70,10 +70,10 @@ def delete_workspace_member(
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -13,7 +14,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
import requests
|
||||
import trio
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
@@ -66,6 +66,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
"""
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
db_knowledge = None
|
||||
@@ -292,6 +296,10 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
build knowledge graph
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_documents = None
|
||||
db_knowledge = None
|
||||
@@ -362,7 +370,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n")
|
||||
return result
|
||||
|
||||
try:
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
@@ -377,8 +385,15 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n")
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)")
|
||||
|
||||
result = f"build knowledge graph '{db_knowledge.name}' processed successfully."
|
||||
@@ -389,7 +404,8 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
result = f"build knowledge grap '{db_knowledge.name}' failed."
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||
|
||||
@@ -1,4 +1,68 @@
|
||||
{
|
||||
"v0.2.5": {
|
||||
"introduction": {
|
||||
"codeName": "行云",
|
||||
"releaseDate": "2026-2-26",
|
||||
"upgradePosition": "🐻 精炼根基,优化核心用户体验与系统稳定性",
|
||||
"coreUpgrades": [
|
||||
"1. 用户体验与国际化 🎨<br>* 语言参数修复:语言偏好现正确保留<br>* 邮箱修改支持:用户可直接在用户管理系统中修改邮箱地址",
|
||||
"2. 工作流可视化增强 💬<br>* 循环与迭代节点输出展示:实时显示执行进度和中间输出,便于调试复杂迭代过程<br>* 变量支持回车选择:支持回车键确认变量选择,简化工作流配置流程",
|
||||
"3. 优化模型管理 ⚙️<br>* 模型广场移除自定义模型,优化模型使用体验",
|
||||
"4. 稳健性与缺陷修复 🔧<br>* 知识图谱构建修复:解决知识图谱构建流程稳定性问题,确保更可靠的实体提取和关系映射",
|
||||
"<br>",
|
||||
"版本 0.2.5 通过解决国际化边界情况和改进工作流透明度,构建更具生产就绪性的平台。工作流可视化改进为更复杂的调试和监控能力奠定基础。未来将继续深化企业就绪性,扩展用户管理功能、优化知识图谱智能和增强工作流编排能力,在可观测性、性能优化和无缝集成模式方面持续改进。",
|
||||
"智慧致远 🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "Flowing Clouds",
|
||||
"releaseDate": "2026-2-26",
|
||||
"upgradePosition": "🐻 Refined foundations with enhanced user experience and system stability",
|
||||
"coreUpgrades": [
|
||||
"1. User Experience & Internationalization 🎨<br>* Language parameter fix: language preferences are now correctly retained<br>* Email Update Support: Users can now modify email addresses directly in user management system",
|
||||
"2. Workflow Visualization Enhancements 💬<br>* Loop & Iteration Node Output Display: Real-time display of execution progress and intermediate outputs for easier debugging<br>* Variable Selection with Enter Key: Enabled Enter key confirmation for streamlined variable assignment",
|
||||
"3. Optimized Model Management ⚙️<br>* Custom models have been removed from the Model marketplace to optimize the model usage experience",
|
||||
"4. Robustness & Bug Fixes 🔧<br>* Knowledge Graph Construction Fix: Addressed stability issues in knowledge graph pipeline for more reliable entity extraction and relationship mapping",
|
||||
"<br>",
|
||||
"Version 0.2.5 matures MemoryBear's operational foundations by addressing internationalization edge cases and improving workflow transparency. The workflow visualization improvements lay groundwork for sophisticated debugging and monitoring capabilities. Looking forward, we will deepen enterprise readiness by expanding user management features, refining knowledge graph intelligence, and enhancing workflow orchestration with continued improvements in observability, performance optimization, and seamless integration patterns.",
|
||||
"Intelligent Resilience 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.4": {
|
||||
"introduction": {
|
||||
"codeName": "智远",
|
||||
"releaseDate": "2026-2-11",
|
||||
"upgradePosition": "🐻 生产级稳健性升级版本,智慧致远,从容应对复杂场景",
|
||||
"coreUpgrades": [
|
||||
"1. Skills 技能框架 🛠️<br>* Skills 支持:引入全新的Skills技能系统,支持可扩展的能力模块,可在Agent和工作流中动态加载与编排",
|
||||
"2. 多模态与交互 💬<br>* 文件多模态支持:全面支持消息输入、LLM处理和输出渲染中的多模态文件处理,实现更丰富的媒体感知对话<br>* 语音交互:语音交互功能正在积极开发中,为免提对话体验奠定基础(开发中)",
|
||||
"3. 知识库集成 📚<br>* 飞书知识库:无缝对接飞书文档库,支持企业知识检索<br>* 语雀知识库:原生连接语雀文档平台,扩展对国内企业工具生态的覆盖<br>* Web站点知识库:通用Web站点抓取与索引,支持从公开网页内容构建知识库<br>* 视觉模型选择优化:知识库视觉模型配置现已支持LLM和Chat两种模型类型,移除了此前仅限Chat类型的限制",
|
||||
"4. 记忆智能 🧠<br>* 本体工程(二期):基于本体工程的高级记忆场景分类与萃取,实现结构化、领域感知的记忆组织,提升分类准确性<br>* 默认模型配置:情绪分析、反思和记忆萃取模块现默认使用空间级模型,确保开箱即用的一致性行为<br>* 智能模型回退:当已配置的情绪或反思模型为空或不可用时,系统自动回退至空间默认模型,避免静默失败<br>* 记忆模型回退兜底:当记忆中配置的模型为空或不可用时,系统优雅降级至空间默认模型",
|
||||
"5. 性能与扩展 ⚡<br>* 模型并发(model_api_keys):支持并发模型API Key管理,实现并行模型调用,提升高负载场景下的吞吐能力",
|
||||
"6. 稳健性与缺陷修复 🔧<br>* 记忆配置版本固定:修复用户记忆配置未跟随应用版本发布固定的问题,消除跨部署的行为不一致<br>* 空间默认记忆保护:空间级默认记忆配置现不可删除;用户级配置仍可删除<br>* Agent与工作流配置兜底:解决Agent和工作流节点中记忆配置可能为空、或已选择但未配置的边界情况——全面的回退处理现可防止运行时错误<br>* 隐形记忆字段重命名:将隐形记忆接口JSON响应中的user_id修正为end_user_id,与规范数据模型对齐<br>* 记忆配置ID迁移:将Agent和工作流记忆配置中的memory_content重命名为memory_config_id,保持API一致性<br>* Worker-Memory告警解决:解决worker-memory服务中的告警级别问题,提升运维监控清晰度<br>* 双语接口修复:修复记忆相关API接口的中英文不一致问题<br>* 新用户记忆配置自动回填:新创建的EndUser若memory_config_id为None,系统自动从最新Release获取memory_config_id并回填<br>* 存量用户记忆配置自动回填:已有EndUser若memory_config_id为None,系统同样从最新Release获取并回填,确保向后兼容,无需手动迁移",
|
||||
"<br>",
|
||||
"Memory Bear v0.2.4 向生产级稳健性迈进,Skills框架与多模态支持开启认知平台新篇章。",
|
||||
"记忆熊,智慧致远,从容应对真实世界的多样性。🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "ZhiYuan",
|
||||
"releaseDate": "2026-2-11",
|
||||
"upgradePosition": "🐻 Production-grade resilience release — Wisdom Reaching Far, gracefully handling complex scenarios",
|
||||
"coreUpgrades": [
|
||||
"1. Skills Framework 🛠️<br>* Skills Support: Introduced a new Skills system, enabling extensible capability modules that can be dynamically loaded and orchestrated within agents and workflows",
|
||||
"2. Multimodal & Interaction 💬<br>* File Multimodal Support: Full multimodal file handling across message input, LLM processing, and output rendering — supporting richer, media-aware conversations<br>* Voice Interaction: Voice-based interaction capabilities are under active development, laying the groundwork for hands-free conversational experiences (In Progress)",
|
||||
"3. Knowledge Base Integration 📚<br>* Feishu Knowledge Base: Seamless integration with Feishu (Lark) document repositories for enterprise knowledge retrieval<br>* Yuque Knowledge Base: Native connector for Yuque documentation platforms, expanding coverage of Chinese enterprise tooling<br>* Web Site Knowledge Base: General-purpose web site crawling and indexing for knowledge base construction from public web content<br>* Visual Model Selection: Knowledge base visual model configuration now supports both LLM and Chat model types, removing the previous restriction to Chat-only selection",
|
||||
"4. Memory Intelligence 🧠<br>* Ontology Engineering (Phase 2): Advanced memory scene classification and extraction powered by ontology engineering — enabling structured, domain-aware memory organization with improved categorization accuracy<br>* Default Model Configuration: Emotion analysis, reflection, and memory extraction modules now default to the space-level model, ensuring consistent behavior out of the box<br>* Intelligent Model Fallback: If configured emotion or reflection models are empty or unavailable, the system automatically falls back to the space default model — preventing silent failures<br>* Memory Config Fallback for Models: When any memory-configured model is empty or unavailable, the system gracefully degrades to the space default model",
|
||||
"5. Performance & Scalability ⚡<br>* Model Concurrency (model_api_keys): Support for concurrent model API key management, enabling parallel model invocations and improved throughput for high-load scenarios",
|
||||
"6. Robustness & Bug Fixes 🔧<br>* Memory Config Version Pinning: Fixed an issue where user memory configurations were not pinned to application release versions, causing inconsistent behavior across deployments<br>* Space Default Memory Protection: Space-level default memory configurations are now protected from deletion; user-level configurations remain deletable<br>* Agent & Workflow Config Fallback: Resolved edge cases in Agent and Workflow nodes where memory config could be empty or selected but unconfigured — comprehensive fallback handling now prevents runtime errors<br>* Implicit Memory Field Rename: Corrected user_id to end_user_id in JSON responses from implicit memory interfaces, aligning with the canonical data model<br>* Memory Config ID Migration: Renamed memory_content to memory_config_id in Agent and Workflow memory configurations for API consistency<br>* Worker-Memory Alerts: Resolved warning-level alerts in the worker-memory service, improving operational monitoring clarity<br>* Bilingual Interface Fixes: Fixed Chinese/English language inconsistencies across memory-related API interfaces<br>* EndUser Memory Config Auto-Backfill (New Users): When a newly created EndUser has memory_config_id as None, the system automatically fetches the latest release's memory_config_id and backfills it<br>* EndUser Memory Config Auto-Backfill (Existing Users): For existing EndUsers with memory_config_id as None, the system similarly retrieves and backfills from the latest release — ensuring backward compatibility without manual migration",
|
||||
"<br>",
|
||||
"Memory Bear v0.2.4 advances toward production-grade resilience, with the Skills framework and multimodal support opening a new chapter for the cognitive platform.",
|
||||
"MemoryBear — Wisdom Reaching Far, gracefully handling real-world variability. 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.3": {
|
||||
"introduction": {
|
||||
"codeName": "归墟",
|
||||
|
||||
@@ -64,6 +64,9 @@ LANGCHAIN_ENDPOINT=
|
||||
# Generate a new one with: openssl rand -hex 32
|
||||
SECRET_KEY=your-secret-key-here-generate-with-openssl-rand-hex-32
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION=
|
||||
|
||||
# JWT Token expiration settings
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
@@ -129,6 +132,12 @@ KB_image2text_id=
|
||||
config_id=
|
||||
reranker_id=
|
||||
|
||||
# Email Configuration
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USER=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# 本体类型融合配置 (记得写入env_example)
|
||||
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||
|
||||
66
api/migrations/versions/75e28690ae87_202602251230.py
Normal file
66
api/migrations/versions/75e28690ae87_202602251230.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""202602251230
|
||||
|
||||
Revision ID: 75e28690ae87
|
||||
Revises: bab823f7cc82
|
||||
Create Date: 2026-02-25 12:27:36.919237
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '75e28690ae87'
|
||||
down_revision: Union[str, None] = 'bab823f7cc82'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('mcp_market_configs',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('mcp_market_id', sa.UUID(), nullable=False, comment='mcp_markets.id'),
|
||||
sa.Column('token', sa.String(), nullable=True, comment='mcp market token'),
|
||||
sa.Column('status', sa.Integer(), nullable=True, comment='connect status(0: Not connected, 1: connected)'),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='tenant.id'),
|
||||
sa.Column('created_by', sa.UUID(), nullable=False, comment='users.id'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mcp_market_configs_id'), 'mcp_market_configs', ['id'], unique=False)
|
||||
op.create_table('mcp_markets',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False, comment='mcp market name'),
|
||||
sa.Column('description', sa.String(), nullable=True, comment='mcp market description'),
|
||||
sa.Column('logo_url', sa.String(), nullable=True, comment='logo url'),
|
||||
sa.Column('mcp_count', sa.Integer(), nullable=True, comment='mcp count'),
|
||||
sa.Column('url', sa.String(), nullable=False, comment='mcp market url'),
|
||||
sa.Column('category', sa.String(), nullable=False, comment='category'),
|
||||
sa.Column('created_by', sa.UUID(), nullable=False, comment='users.id'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mcp_markets_category'), 'mcp_markets', ['category'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_description'), 'mcp_markets', ['description'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_id'), 'mcp_markets', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_logo_url'), 'mcp_markets', ['logo_url'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_name'), 'mcp_markets', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_url'), 'mcp_markets', ['url'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_mcp_markets_url'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_name'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_logo_url'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_id'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_description'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_category'), table_name='mcp_markets')
|
||||
op.drop_table('mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_market_configs_id'), table_name='mcp_market_configs')
|
||||
op.drop_table('mcp_market_configs')
|
||||
# ### end Alembic commands ###
|
||||
36
api/migrations/versions/7672d8f0f939_202602271020.py
Normal file
36
api/migrations/versions/7672d8f0f939_202602271020.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""202602271020
|
||||
|
||||
Revision ID: 7672d8f0f939
|
||||
Revises: 75e28690ae87
|
||||
Create Date: 2026-02-27 10:21:46.951584
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7672d8f0f939'
|
||||
down_revision: Union[str, None] = '75e28690ae87'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('file_metadata', 'workspace_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True,
|
||||
existing_comment='Workspace ID')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('file_metadata', 'workspace_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False,
|
||||
existing_comment='Workspace ID')
|
||||
# ### end Alembic commands ###
|
||||
@@ -144,6 +144,7 @@ dependencies = [
|
||||
"rdflib>=7.0.0",
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
"modelscope>=1.34.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -137,3 +137,4 @@ boto3>=1.28.0
|
||||
aiofiles>=23.0.0
|
||||
lxml>=4.9.0
|
||||
httpx>=0.28.0
|
||||
modelscope>=1.34.0
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariableSelector
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool, VariableSelector
|
||||
|
||||
|
||||
# ==================== VariableSelector 测试 ====================
|
||||
|
||||
@@ -6,8 +6,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
TEST_WORKSPACE_ID = "test_workspace_id"
|
||||
TEST_USER_ID = "test_user_id"
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import StartNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from tests.workflow.nodes.base import (
|
||||
simple_state,
|
||||
simple_state,
|
||||
simple_vairable_pool,
|
||||
TEST_EXECUTION_ID,
|
||||
TEST_WORKSPACE_ID,
|
||||
|
||||
Submodule redbear-mem-benchmark updated: 0c4bcafbc1...4b0257bb4e
@@ -33,8 +33,6 @@ key = b64decode(key)
|
||||
|
||||
os.chdir(running_path)
|
||||
|
||||
# Preload code
|
||||
{{preload}}
|
||||
|
||||
# Apply security if library is available
|
||||
init_status = lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}})
|
||||
@@ -42,6 +40,8 @@ if init_status != 0:
|
||||
raise Exception(f"code executor err - {str(init_status)}")
|
||||
del lib
|
||||
|
||||
# Preload code
|
||||
{{preload}}
|
||||
# Decrypt and execute code
|
||||
code = b64decode("{{code}}")
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ function App() {
|
||||
const { checkJump } = useUser();
|
||||
useEffect(() => {
|
||||
const authToken = cookieUtils.get('authToken')
|
||||
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump')) {
|
||||
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump') && !window.location.hash.includes('#/invite-register')) {
|
||||
window.location.href = `/#/login`;
|
||||
} else {
|
||||
checkJump()
|
||||
|
||||
@@ -66,9 +66,9 @@ export const addModelPlaza = (model_base_id: string) => {
|
||||
}
|
||||
// Create custom model
|
||||
export const addCustomModel = (data: CustomModelForm) => {
|
||||
return request.post('/models/model_plaza', data)
|
||||
return request.post('/models', data)
|
||||
}
|
||||
// Update custom model
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
|
||||
return request.put(`/models/model_plaza/${model_base_id}`, data)
|
||||
return request.put(`/models/${model_base_id}`, data)
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:23
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 14:00:23
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-25 11:17:44
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type { CreateModalData } from '@/views/UserManagement/types'
|
||||
import type { CreateModalData, ChangeEmailModalForm } from '@/views/UserManagement/types'
|
||||
import { cookieUtils } from '@/utils/request'
|
||||
|
||||
// User info
|
||||
@@ -28,6 +28,10 @@ export const refreshToken = () => {
|
||||
export const changePassword = (data: { user_id: string; new_password: string }) => {
|
||||
return request.put('/users/admin/change-password', data)
|
||||
}
|
||||
// Verify password
|
||||
export const verifyPassword = (data: { password: string }) => {
|
||||
return request.post('/users/verify_pwd', data)
|
||||
}
|
||||
// Disable user
|
||||
export const deleteUser = (user_id: string) => {
|
||||
return request.delete(`/users/${user_id}`)
|
||||
@@ -44,4 +48,12 @@ export const addUser = (data: CreateModalData) => {
|
||||
export const logoutUrl = '/logout'
|
||||
export const logout = () => {
|
||||
return request.post(logoutUrl)
|
||||
}
|
||||
// Send email verification code
|
||||
export const sendEmailCode = (data: { email: string }) => {
|
||||
return request.post('/users/send-email-code', data)
|
||||
}
|
||||
// Verify code and change email
|
||||
export const changeEmail = (data: ChangeEmailModalForm) => {
|
||||
return request.put('/users/change-email', data)
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 15:03:25
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-02 15:47:31
|
||||
* @Last Modified time: 2026-02-25 11:14:25
|
||||
*/
|
||||
/**
|
||||
* Empty Component
|
||||
@@ -13,7 +13,7 @@
|
||||
* @component
|
||||
*/
|
||||
|
||||
import { type FC } from 'react';
|
||||
import { type FC, type ReactElement } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import emptyIcon from '@/assets/images/empty/empty.svg';
|
||||
@@ -24,7 +24,7 @@ interface EmptyProps {
|
||||
/** Icon size - single number or [width, height] array */
|
||||
size?: number | number[];
|
||||
/** Main title text */
|
||||
title?: string;
|
||||
title?: string | ReactElement;
|
||||
/** Whether to show subtitle */
|
||||
isNeedSubTitle?: boolean;
|
||||
/** Custom subtitle text */
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 15:09:47
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-02 15:51:54
|
||||
* @Last Modified time: 2026-02-25 11:40:47
|
||||
*/
|
||||
/**
|
||||
* UserInfoModal Component
|
||||
@@ -15,7 +15,7 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useImperativeHandle, useState, useRef } from 'react';
|
||||
import { Button } from 'antd';
|
||||
import { Button, Space } from 'antd';
|
||||
import { UnlockOutlined } from '@ant-design/icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -23,7 +23,9 @@ import { useUser } from '@/store/user';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import { formatDateTime } from '@/utils/format';
|
||||
import ResetPasswordModal from '@/views/UserManagement/components/ResetPasswordModal'
|
||||
import type { ResetPasswordModalRef } from '@/views/UserManagement/types'
|
||||
import type { ResetPasswordModalRef, VerifyPasswordModalRef, ChangeEmailModalRef } from '@/views/UserManagement/types'
|
||||
import VerifyPasswordModal from '@/views/UserManagement/components/VerifyPasswordModal'
|
||||
import ChangeEmailModal from '@/views/UserManagement/components/ChangeEmailModal'
|
||||
|
||||
/** Interface for UserInfoModal ref methods exposed to parent components */
|
||||
export interface UserInfoModalRef {
|
||||
@@ -37,8 +39,10 @@ export interface UserInfoModalRef {
|
||||
const UserInfoModal = forwardRef<UserInfoModalRef>((_props, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const resetPasswordModalRef = useRef<ResetPasswordModalRef>(null)
|
||||
const { user } = useUser();
|
||||
const { user, getUserInfo } = useUser();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const verifyPasswordModalRef = useRef<VerifyPasswordModalRef>(null)
|
||||
const changeEmailModalRef = useRef<ChangeEmailModalRef>(null)
|
||||
|
||||
/** Close the modal */
|
||||
const handleClose = () => {
|
||||
@@ -50,6 +54,17 @@ const UserInfoModal = forwardRef<UserInfoModalRef>((_props, ref) => {
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
/** Open password verification modal before editing email */
|
||||
const handleEditEmail = () => {
|
||||
verifyPasswordModalRef.current?.handleOpen()
|
||||
}
|
||||
|
||||
/** Update user information after email change */
|
||||
const updateUserInfo = () => {
|
||||
localStorage.removeItem('user')
|
||||
getUserInfo()
|
||||
}
|
||||
|
||||
/** Expose handleOpen and handleClose methods to parent component via ref */
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
@@ -74,7 +89,13 @@ const UserInfoModal = forwardRef<UserInfoModalRef>((_props, ref) => {
|
||||
{/* Email */}
|
||||
<div className="rb:flex rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3">
|
||||
<span className="rb:whitespace-nowrap">{t('user.email')}</span>
|
||||
<span className="rb:text-[#212332]">{user.email}</span>
|
||||
<Space size={8} className="rb:text-[#212332]">
|
||||
{user.email}
|
||||
<div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/editBorder.svg')] rb:hover:bg-[url('@/assets/images/editBg.svg')]"
|
||||
onClick={handleEditEmail}
|
||||
></div>
|
||||
</Space>
|
||||
</div>
|
||||
{/* Role */}
|
||||
<div className="rb:flex rb:justify-between rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:mb-3">
|
||||
@@ -106,6 +127,14 @@ const UserInfoModal = forwardRef<UserInfoModalRef>((_props, ref) => {
|
||||
ref={resetPasswordModalRef}
|
||||
source="changePassword"
|
||||
/>
|
||||
<VerifyPasswordModal
|
||||
ref={verifyPasswordModalRef}
|
||||
refresh={() => changeEmailModalRef.current?.handleOpen()}
|
||||
/>
|
||||
<ChangeEmailModal
|
||||
ref={changeEmailModalRef}
|
||||
refresh={updateUserInfo}
|
||||
/>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 15:12:42
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-04 14:06:28
|
||||
* @Last Modified time: 2026-02-28 17:28:41
|
||||
*/
|
||||
/**
|
||||
* BasicLayout Component
|
||||
* BasicAuthLayout Component
|
||||
*
|
||||
* A minimal layout wrapper that provides:
|
||||
* - User information initialization
|
||||
@@ -26,12 +26,12 @@ import { useUser } from '@/store/user';
|
||||
* Basic layout component for pages without navigation UI.
|
||||
* Fetches user info and storage type on mount, then renders child routes.
|
||||
*/
|
||||
const BasicLayout: FC = () => {
|
||||
const BasicAuthLayout: FC = () => {
|
||||
const { getUserInfo } = useUser();
|
||||
|
||||
// Fetch user information and storage type on component mount
|
||||
useEffect(() => {
|
||||
getUserInfo();
|
||||
getUserInfo(undefined, true); // Pass true to skip navigation jump
|
||||
}, [getUserInfo]);
|
||||
|
||||
return (
|
||||
@@ -42,4 +42,4 @@ const BasicLayout: FC = () => {
|
||||
)
|
||||
};
|
||||
|
||||
export default BasicLayout;
|
||||
export default BasicAuthLayout;
|
||||
@@ -20,6 +20,7 @@
|
||||
import { type FC, type Key, type ReactNode, useEffect } from 'react';
|
||||
import { type RadioGroupProps } from 'antd';
|
||||
import clsx from 'clsx'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
/** Radio card option interface */
|
||||
interface RadioCardOption {
|
||||
@@ -33,6 +34,8 @@ interface RadioCardOption {
|
||||
icon?: string;
|
||||
/** Whether the option is disabled */
|
||||
disabled?: boolean;
|
||||
/** Whether the option is recommended */
|
||||
recommend?: boolean;
|
||||
/** Additional properties */
|
||||
[key: string]: string | number | boolean | undefined | null | Key;
|
||||
}
|
||||
@@ -63,6 +66,7 @@ const RadioGroupCard: FC<RadioCardProps> = ({
|
||||
allowClear = true,
|
||||
block = false,
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
/** Listen to value changes and trigger side effects via onValueChange callback */
|
||||
useEffect(() => {
|
||||
if (onValueChange) {
|
||||
@@ -91,12 +95,13 @@ const RadioGroupCard: FC<RadioCardProps> = ({
|
||||
})}>
|
||||
{/* Render each option as a selectable card */}
|
||||
{options.map(option => (
|
||||
<div key={String(option.value)} className={clsx("rb:border rb:rounded-lg rb:w-full rb:p-[20px_12px] rb:text-center rb:cursor-pointer", {
|
||||
<div key={String(option.value)} className={clsx("rb:relative rb:border rb:rounded-lg rb:w-full rb:p-[20px_12px] rb:text-center rb:cursor-pointer", {
|
||||
'rb:bg-[rgba(21,94,239,0.06)] rb:border-[#155EEF]': option.value === value,
|
||||
'rb:border-[#EBEBEB] rb:bg-[#ffffff]': option.value !== value,
|
||||
'rb:opacity-[0.75]': option.disabled,
|
||||
'rb:flex rb:items-center rb:text-left rb:gap-4': block,
|
||||
})} onClick={() => handleChange(option)}>
|
||||
{option.recommend && <div className="rb:absolute rb:right-0 rb:top-0 rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:p-[4px_8px]">{t('common.recommend')}</div>}
|
||||
{/* Use custom render or default card layout */}
|
||||
{itemRender ? itemRender(option) : (
|
||||
<>
|
||||
|
||||
@@ -93,7 +93,7 @@ const UploadImages = forwardRef<UploadImagesRef, UploadImagesProps>(({
|
||||
onChange,
|
||||
disabled = false,
|
||||
fileSize,
|
||||
fileType = ['png', 'jpg', 'gif'],
|
||||
fileType = ['png', 'jpg', 'gif', 'svg'],
|
||||
isAutoUpload = true,
|
||||
maxCount = 1,
|
||||
className = 'rb:size-24! rb:leading-1!',
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
export const en = {
|
||||
translation: {
|
||||
welcome: 'Welcome to React Font CLI',
|
||||
title: 'Memory Bear.AI ',
|
||||
memoryBear: 'Memory Bear.AI',
|
||||
index:{
|
||||
@@ -248,6 +247,7 @@ export const en = {
|
||||
usernameOrAccount: 'Username / Login Account',
|
||||
displayName: 'Display Name',
|
||||
role: 'Role',
|
||||
password: 'Password',
|
||||
status: 'Status',
|
||||
createTime: 'Creation Time',
|
||||
lastLoginTime: 'Last Login Time',
|
||||
@@ -256,10 +256,12 @@ export const en = {
|
||||
resetPasswordSuccess: 'Password reset successful',
|
||||
resetPasswordFailed: 'Password reset failed',
|
||||
enabled: 'Enabled',
|
||||
enabledOpera: 'Activate',
|
||||
enabledConfirm: 'Are you sure to enable this user?',
|
||||
enabledConfirmSuccess: 'Enabled successfully',
|
||||
enabledConfirmFailed: 'Enabled failed',
|
||||
disabled: 'Disabled',
|
||||
disabledOpera: 'Deactivate',
|
||||
disabledConfirm: 'Are you sure to disable this user?',
|
||||
disabledConfirmSuccess: 'Disabled successfully',
|
||||
disabledConfirmFailed: 'Disabled failed',
|
||||
@@ -274,6 +276,28 @@ export const en = {
|
||||
createdAt: 'Creation Time',
|
||||
member: 'Member',
|
||||
passwordRule: 'password should have at least 6 characters',
|
||||
authVerify: 'Identity Verification',
|
||||
authVerifyDesc: 'For security reasons, please verify your login password first',
|
||||
verify: 'Verify',
|
||||
loginPassword: 'Login Password',
|
||||
loginPasswordPlaceholder: 'Please enter the login password for the current account',
|
||||
loginPasswordVerifyFailed: 'Incorrect password, please try again',
|
||||
bindNewEmail: 'Bind New Email',
|
||||
sureChange: 'Confirm Change',
|
||||
sendEmailCode: 'Send Verification Code',
|
||||
currentEmail: 'Current Email',
|
||||
newEmail: 'New Email Address',
|
||||
emailCode: 'Verification Code',
|
||||
emailCodePlaceholder: 'Please enter the verification code received by the new email',
|
||||
sureChangeEmail: 'Confirm to change the bound email to',
|
||||
sureChangeEmailDesc: '?',
|
||||
changeSuccess: 'Changed successfully',
|
||||
sendSuccess: 'Verification code has been sent, please check',
|
||||
newEmailSameAsOld: 'New email cannot be the same as current email',
|
||||
emailCodeLengthRule: 'Please enter a 6-digit verification code',
|
||||
emailFormatError: 'Incorrect email format',
|
||||
sendCodeTooFrequent: 'Please resend after {{seconds}}s',
|
||||
retrySend: 'Can resend after {{seconds}}s',
|
||||
},
|
||||
timezones: {
|
||||
'Asia/Shanghai': 'China Standard Time (UTC+8)',
|
||||
@@ -428,6 +452,7 @@ export const en = {
|
||||
nextStep: 'Next Step',
|
||||
prevStep: 'Previous Step',
|
||||
exportSuccess: 'Export successful',
|
||||
recommend: 'Recommend',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: 'search model…',
|
||||
@@ -578,6 +603,9 @@ export const en = {
|
||||
bedrock: "Bedrock"
|
||||
},
|
||||
knowledgeBase: {
|
||||
home: 'Home',
|
||||
selectSpace: 'Please select space',
|
||||
preview: 'Preview',
|
||||
pleaseUploadFileFirst: 'Please upload file first',
|
||||
shareSuccess: 'Share successfully',
|
||||
shareFailed: 'Share failed',
|
||||
@@ -1168,7 +1196,7 @@ export const en = {
|
||||
stateSharingStrategy: 'State Sharing Strategy',
|
||||
intermediateResultProcessing: 'Intermediate Result Processing',
|
||||
metadataTransfer: 'Metadata Transfer',
|
||||
|
||||
knowledgeConfig: 'Knowledge Base Configuration',
|
||||
temperature: 'Temperature',
|
||||
temperature_desc: 'Temperature parameters, control the randomness of output',
|
||||
max_tokens: 'Max Tokens',
|
||||
@@ -1312,6 +1340,13 @@ export const en = {
|
||||
analyTask: 'Analyze Task Intent',
|
||||
dynamicMatchSkill: 'Dynamic Match Skill',
|
||||
executeTask: 'Execute Task',
|
||||
|
||||
upload: 'Upload & Parse',
|
||||
complex: 'Compatibility Analysis',
|
||||
node: 'Node Mapping',
|
||||
configCheck: 'Configuration Validation',
|
||||
sureInfo: 'Information Confirmation',
|
||||
completed: 'Import Completed',
|
||||
},
|
||||
userMemory: {
|
||||
userMemory: 'User Memory',
|
||||
@@ -1988,6 +2023,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
query: 'Query Variable',
|
||||
knowledge_retrieval: 'Knowledge Base',
|
||||
recallConfig: 'Recall Test',
|
||||
addKnowledge: 'Add Knowledge Base'
|
||||
},
|
||||
'parameter-extractor': {
|
||||
model_id: 'Model',
|
||||
@@ -2148,6 +2184,14 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
input: 'Input',
|
||||
output: 'Output',
|
||||
error: 'Error Message',
|
||||
loopNum: ' loops',
|
||||
iterationNum: ' iterations',
|
||||
runtime: {
|
||||
loop: 'Loop',
|
||||
iteration: 'Iteration',
|
||||
input_cycle_vars: 'Initial Loop Variables',
|
||||
output_cycle_vars: 'Final Loop Variables',
|
||||
}
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: 'Emotion Engine Configuration',
|
||||
|
||||
@@ -2,7 +2,6 @@ export const zh = {
|
||||
translation: {
|
||||
title: '记忆熊',
|
||||
memoryBear: '记忆熊',
|
||||
welcome: '欢迎使用 React Font CLI',
|
||||
index:{
|
||||
viewGuide: '查看引导',
|
||||
watchVideo: '观看视频',
|
||||
@@ -92,6 +91,7 @@ export const zh = {
|
||||
memberManagement: '成员管理',
|
||||
memorySummary: '记忆摘要',
|
||||
memoryConversation: '记忆验证',
|
||||
helpCenter: '帮助中心',
|
||||
memorySummaryHandlers: '记忆摘要处理器',
|
||||
createMemorySummary: '创建记忆摘要',
|
||||
memoryManagement: '记忆管理',
|
||||
@@ -189,6 +189,7 @@ export const zh = {
|
||||
customText: '自定义文本',
|
||||
customContent: '自定义内容',
|
||||
createContentError: '创建自定义文件失败',
|
||||
createLinkError: '创建链接内容失败',
|
||||
manuallyInputText: '手动输入一段文本作为数据集',
|
||||
openKnowledgeBase: '打开知识库',
|
||||
searchPlaceholder: '搜索',
|
||||
@@ -243,6 +244,7 @@ export const zh = {
|
||||
processing: '处理中',
|
||||
name: '名称',
|
||||
processingMode: '处理模式',
|
||||
processMsg: '处理消息',
|
||||
dataSize: '数据量',
|
||||
createUpdateTime: '创建/更新时间',
|
||||
datasets: '通用知识库',
|
||||
@@ -440,8 +442,6 @@ export const zh = {
|
||||
agentDesc: '创建单个智能代理',
|
||||
multi_agent: '集群',
|
||||
multi_agentDesc: '创建Agent集群',
|
||||
cluster: '集群',
|
||||
clusterDesc: '创建Agent集群',
|
||||
workflow: '工作流',
|
||||
workflowDesc: '创建策略工作流',
|
||||
editApplication: '编辑应用信息',
|
||||
@@ -550,7 +550,6 @@ export const zh = {
|
||||
|
||||
versionList: '版本列表',
|
||||
versionListDesc: '所有发布记录和状态',
|
||||
fullAmount: '全量',
|
||||
current: '当前',
|
||||
rolledBack: '已回滚',
|
||||
history: '历史',
|
||||
@@ -574,6 +573,10 @@ export const zh = {
|
||||
clusterName: '集群名称',
|
||||
clusterDescription: '集群描述',
|
||||
clusterDescriptionPlaceholder: '这是一个专门处理核心业务的Agent集群,能够协作完成复杂的业务处理任务。',
|
||||
toolCalling: '工具调用',
|
||||
toolCallingDesc: '主控代理将子代理作为工具调用',
|
||||
toolCallingFeature: '集中控制,适合结构化工作流',
|
||||
handoffsFeature: '分散控制,适合复杂对话场景',
|
||||
recommend: '推荐',
|
||||
advanced: '高级',
|
||||
multiAgentArchitecture: '多代理架构模式',
|
||||
@@ -586,7 +589,7 @@ export const zh = {
|
||||
addSubAgent: '添加子代理',
|
||||
versionName: '版本名称',
|
||||
versionNameTip: '版本号格式:v[主版本号].[次版本号].[修订号](例如 v1.3.0)',
|
||||
agentName: '代理名称',
|
||||
agentName: 'Agent名称',
|
||||
roleType: '角色类型',
|
||||
coordinator: '协调者',
|
||||
analyzer: '分析者',
|
||||
@@ -595,6 +598,9 @@ export const zh = {
|
||||
updateSubAgent: '更新子代理',
|
||||
subAgentMaxLength: '子代理最多{{maxLength}}个',
|
||||
capabilities: '能力',
|
||||
subAgent: '子代理',
|
||||
maxChatCount: '最多添加4个模型',
|
||||
ReplyException: '回复异常',
|
||||
contextEngineering: '上下文工程',
|
||||
dialogueHistoryManagement: '对话历史管理',
|
||||
stateSharingStrategy: '状态共享策略',
|
||||
@@ -737,55 +743,6 @@ export const zh = {
|
||||
sureInfo: '信息确认',
|
||||
completed: '完成导入',
|
||||
},
|
||||
role: {
|
||||
roleManagement: '角色管理',
|
||||
roleId: '角色ID',
|
||||
roleName: '角色名称',
|
||||
roleCode: '角色编码',
|
||||
description: '角色描述',
|
||||
status: '状态',
|
||||
enabled: '已启用',
|
||||
disabled: '已停用',
|
||||
createTime: '创建时间',
|
||||
createRole: '新建角色',
|
||||
editRole: '编辑角色',
|
||||
roleTemplate: '角色模板',
|
||||
emptyTemplate: '空模板',
|
||||
adminTemplate: '管理员模板',
|
||||
userTemplate: '用户模板',
|
||||
confirmDelete: '确定要删除这个角色吗?',
|
||||
createSuccess: '角色创建成功',
|
||||
updateSuccess: '角色更新成功',
|
||||
deleteSuccess: '角色删除成功',
|
||||
createFailed: '角色创建失败',
|
||||
updateFailed: '角色更新失败',
|
||||
deleteFailed: '角色删除失败'
|
||||
},
|
||||
tenant: {
|
||||
tenantId: '租户ID',
|
||||
tenantName: '租户名称',
|
||||
contactPerson: '联系人',
|
||||
contactInfo: '联系方式',
|
||||
status: '状态',
|
||||
enabled: '启用',
|
||||
disabled: '禁用',
|
||||
expiryDate: '到期时间',
|
||||
createTenant: '新增租户',
|
||||
editTenant: '编辑租户',
|
||||
searchPlaceholder: '搜索租户ID、名称、联系人或联系方式',
|
||||
confirmDelete: '确定要删除该租户吗?',
|
||||
confirmBatchDelete: '确定要批量删除选中的租户吗?',
|
||||
fetchFailed: '获取租户数据失败',
|
||||
batchEnableSuccess: '批量启用成功',
|
||||
batchEnableFailed: '批量启用失败',
|
||||
batchDisableSuccess: '批量停用成功',
|
||||
batchDisableFailed: '批量停用失败',
|
||||
exportSuccess: '导出成功',
|
||||
batchDeleteSuccess: '批量删除成功',
|
||||
batchDeleteFailed: '批量删除失败',
|
||||
saveFailed: '保存失败',
|
||||
batchImport: '批量导入'
|
||||
},
|
||||
table: {
|
||||
totalRecords: '共 {{total}} 条记录'
|
||||
},
|
||||
@@ -916,18 +873,17 @@ export const zh = {
|
||||
subUsername: '或登录账号',
|
||||
usernameOrAccount: '用户名 / 登录账号',
|
||||
displayName: '显示名称',
|
||||
tenantName: '所属租户',
|
||||
role: '角色',
|
||||
password: '密码',
|
||||
initialPassword: '初始密码',
|
||||
expiryDate: '有效期',
|
||||
expiryDateDue: '有效期至',
|
||||
status: '状态',
|
||||
enabled: '已启用',
|
||||
enabledOpera: '启用',
|
||||
enabledConfirm: '确定要启用此用户吗?',
|
||||
enabledConfirmSuccess: '启用成功',
|
||||
enabledConfirmFailed: '启用失败',
|
||||
disabled: '已停用',
|
||||
disabledOpera: '停用',
|
||||
disabledConfirm: '确定要停用此用户吗?',
|
||||
disabledConfirmSuccess: '停用成功',
|
||||
disabledConfirmFailed: '停用失败',
|
||||
@@ -946,18 +902,29 @@ export const zh = {
|
||||
email: '邮箱',
|
||||
createdAt: '创建时间',
|
||||
member: '成员',
|
||||
batchImport: '批量导入',
|
||||
batchImportUser: '批量导入用户',
|
||||
downloadTemplate: '下载导入模板',
|
||||
templateDownloadSuccess: '模板下载成功',
|
||||
startImport: '开始导入',
|
||||
batchImportSuccess: '批量导入成功',
|
||||
importFailed: '导入失败,请检查文件格式',
|
||||
noFileSelected: '请选择要导入的文件',
|
||||
onlyXlsxOrCsv: '只能上传 .xlsx 或 .csv 格式的文件',
|
||||
reselect: '重新选择',
|
||||
noFileSelectedTip: '未选择任何文件',
|
||||
downloadTemplateTip: '请下载模板,填写用户信息后上传。'
|
||||
passwordRule: '密码至少需要6个字符',
|
||||
authVerify: '身份验证',
|
||||
authVerifyDesc: '出于安全考虑,请先验证您的登录密码',
|
||||
verify: '验证',
|
||||
loginPassword: '登录密码',
|
||||
loginPasswordPlaceholder: '请输入当前账号的登录密码',
|
||||
loginPasswordVerifyFailed: '密码错误,请重新输入',
|
||||
bindNewEmail: '绑定新邮箱',
|
||||
sureChange: '确认修改',
|
||||
sendEmailCode: '发送验证码',
|
||||
currentEmail: '当前邮箱',
|
||||
newEmail: '新邮箱地址',
|
||||
emailCode: '验证码',
|
||||
emailCodePlaceholder: '请输入新邮箱收到的验证码',
|
||||
sureChangeEmail: '确认将绑定邮箱修改为',
|
||||
sureChangeEmailDesc: '吗?',
|
||||
changeSuccess: '修改成功',
|
||||
sendSuccess: '验证码已发送,请查收',
|
||||
newEmailSameAsOld: '新邮箱不能与当前邮箱相同',
|
||||
emailCodeLengthRule: '请输入6位的验证码',
|
||||
emailFormatError: '邮箱格式不正确',
|
||||
sendCodeTooFrequent: '请在{{seconds}}s后重新发送',
|
||||
retrySend: '{{seconds}}s后可重发',
|
||||
},
|
||||
common: {
|
||||
search: '搜索',
|
||||
@@ -980,6 +947,7 @@ export const zh = {
|
||||
exportList: '导出列表',
|
||||
selectPlaceholder: '请选择{{title}}',
|
||||
inputPlaceholder: '请输入{{title}}',
|
||||
enterPlaceholder: '输入 {{title}}',
|
||||
saveSuccess: '保存成功',
|
||||
saveFailure: '保存失败',
|
||||
pleaseSelect: '请选择',
|
||||
@@ -1019,6 +987,7 @@ export const zh = {
|
||||
confirmChangeStatusDesc: '确定要更改【{{name}}】的状态吗?',
|
||||
operationSuccess: '操作成功',
|
||||
operateSuccess: '操作成功',
|
||||
deleted: '已删除',
|
||||
pleaseUpload: '请上传',
|
||||
returnToSpace: '返回空间',
|
||||
createSuccess: '创建成功',
|
||||
@@ -1050,24 +1019,7 @@ export const zh = {
|
||||
nextStep: '下一步',
|
||||
prevStep: '上一步',
|
||||
exportSuccess: '导出成功',
|
||||
},
|
||||
product: {
|
||||
applicationManagement: '应用管理',
|
||||
createApplication: '创建应用',
|
||||
applicationName: '应用名称',
|
||||
applicationIcon: '应用图标',
|
||||
applicationNameRequired: '请输入应用名称',
|
||||
associationStatus: '关联状态',
|
||||
associated: '已关联',
|
||||
notAssociated: '未关联',
|
||||
unassociate: '解除关联',
|
||||
unassociateSuccess: '解除关联成功',
|
||||
unassociateFailed: '解除关联失败',
|
||||
viewKey: '查看KEY',
|
||||
viewStats: '查看统计',
|
||||
disableSuccess: '停用成功',
|
||||
enableSuccess: '启用成功',
|
||||
operationFailed: '操作失败',
|
||||
recommend: '推荐',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: '搜索模型…',
|
||||
@@ -1745,6 +1697,7 @@ export const zh = {
|
||||
name: '姓名',
|
||||
nameSubTitle: '(可选,用于团队成员识别)',
|
||||
namePlaceholder: '请输入您的姓名',
|
||||
inviteLinkInvalid: '邀请链接无效',
|
||||
|
||||
passwordStrength: '密码强度',
|
||||
noSet: '未设置',
|
||||
@@ -1768,21 +1721,6 @@ export const zh = {
|
||||
pageEmpty: '哎呀!暂无搜索结果',
|
||||
pageEmptyDesc: '红熊歪着头等待您更换新的关键词,让我们一起探索吧。',
|
||||
},
|
||||
|
||||
home: {
|
||||
title: '首页',
|
||||
welcome: '欢迎使用我们的带单页路由的 React 应用!',
|
||||
counterCard: '计数器演示',
|
||||
aboutCard: '关于我们',
|
||||
workflowCard: '工作流编辑器',
|
||||
websocketDemoCard: 'WebSocket 演示',
|
||||
sseDemoCard: 'SSE演示'
|
||||
},
|
||||
notFound: {
|
||||
title: '页面未找到',
|
||||
description: '请求的页面不存在。',
|
||||
backToHome: '返回首页'
|
||||
},
|
||||
apiKey: {
|
||||
name: '项目名称',
|
||||
createApiKey: '创建API Key',
|
||||
@@ -2242,6 +2180,14 @@ export const zh = {
|
||||
input: '输入',
|
||||
output: '输出',
|
||||
error: '错误信息',
|
||||
loopNum: '个循环',
|
||||
iterationNum: '个迭代',
|
||||
runtime: {
|
||||
loop: '循环',
|
||||
iteration: '迭代',
|
||||
input_cycle_vars: '初始循环变量',
|
||||
output_cycle_vars: '最终循环变量',
|
||||
}
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: '情感引擎配置',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 16:33:54
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-04 18:30:10
|
||||
* @Last Modified time: 2026-02-28 17:21:20
|
||||
*/
|
||||
/**
|
||||
* User Store
|
||||
@@ -44,7 +44,7 @@ export interface UserState {
|
||||
/** Update login information */
|
||||
updateLoginInfo: (values: LoginInfo) => void;
|
||||
/** Get user information */
|
||||
getUserInfo: (flag?: boolean) => void;
|
||||
getUserInfo: (flag?: boolean, notNeedJump?: boolean) => void;
|
||||
/** Clear user information */
|
||||
clearUserInfo: () => void;
|
||||
/** Logout user */
|
||||
@@ -73,13 +73,13 @@ export const useUser = create<UserState>((set, get) => ({
|
||||
cookieUtils.set('refreshToken', values.refresh_token);
|
||||
set({ loginInfo: values });
|
||||
},
|
||||
getUserInfo: async (flag?: boolean) => {
|
||||
getUserInfo: async (flag?: boolean, notNeedJump?: boolean) => {
|
||||
if (!cookieUtils.get('authToken')) {
|
||||
return
|
||||
}
|
||||
const { checkJump } = get()
|
||||
const localUser = JSON.parse(localStorage.getItem('user') || '{}') as User;
|
||||
if (localUser.id) {
|
||||
if (localUser.id && !notNeedJump) {
|
||||
checkJump()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:29:21
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-10 18:46:40
|
||||
* @Last Modified time: 2026-02-25 18:11:49
|
||||
*/
|
||||
import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react';
|
||||
import clsx from 'clsx'
|
||||
@@ -168,7 +168,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
setLoading(true)
|
||||
getApplicationConfig(id as string).then(res => {
|
||||
const response = res as Config
|
||||
const { skills } = response
|
||||
const { skills, variables } = response
|
||||
let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : []
|
||||
let allTools = Array.isArray(response.tools) ? response.tools : []
|
||||
const memoryContent = response.memory?.memory_config_id
|
||||
@@ -187,6 +187,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
skill_ids: allSkills
|
||||
}
|
||||
})
|
||||
updateVariableList([...variables])
|
||||
setData({
|
||||
...response,
|
||||
tools: allTools
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-04 18:34:36
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-09 15:46:07
|
||||
* @Last Modified time: 2026-02-10 15:49:29
|
||||
*/
|
||||
import { useEffect, type FC } from 'react'
|
||||
import { useNavigate, useSearchParams } from 'react-router-dom'
|
||||
|
||||
import { cookieUtils } from '@/utils/request'
|
||||
import { useI18n } from '@/store/locale'
|
||||
import { clearAuthData } from '@/utils/auth'
|
||||
|
||||
/**
|
||||
* JumpPage Component
|
||||
@@ -30,6 +31,7 @@ const JumpPage: FC = () => {
|
||||
const { changeLanguage } = useI18n()
|
||||
|
||||
useEffect(() => {
|
||||
clearAuthData()
|
||||
// Convert URLSearchParams to a plain object for easier access
|
||||
const data = Object.fromEntries(searchParams)
|
||||
const { access_token, refresh_token, target, language } = data
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:50:10
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:50:10
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-27 10:20:51
|
||||
*/
|
||||
/**
|
||||
* Model List View
|
||||
@@ -10,27 +10,28 @@
|
||||
* Shows model tags and allows viewing model details
|
||||
*/
|
||||
|
||||
import { useRef, useState, useEffect, type FC } from 'react';
|
||||
import { useRef, useState, useEffect, forwardRef, useImperativeHandle } from 'react';
|
||||
import { Button, Flex, Row, Col } from 'antd'
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { ProviderModelItem, KeyConfigModalRef, ModelListDetailRef } from './types'
|
||||
import type { ProviderModelItem, KeyConfigModalRef, ModelListDetailRef, ModelListItem, BaseRef } from './types'
|
||||
import RbCard from '@/components/RbCard/Card'
|
||||
import { getModelNewList } from '@/api/models'
|
||||
import PageEmpty from '@/components/Empty/PageEmpty';
|
||||
import Tag from '@/components/Tag';
|
||||
import KeyConfigModal from './components/KeyConfigModal'
|
||||
import ModelListDetail from './components/ModelListDetail'
|
||||
import { getLogoUrl } from './utils'
|
||||
import { getListLogoUrl } from './utils'
|
||||
|
||||
/**
|
||||
* Model list component
|
||||
*/
|
||||
const ModelList: FC<{ query: any }> = ({ query }) => {
|
||||
const ModelList = forwardRef<BaseRef, { query: any; handleEdit: (vo?: ModelListItem) => void; }> (({ query, handleEdit }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const keyConfigModalRef = useRef<KeyConfigModalRef>(null)
|
||||
const modelListDetailRef = useRef<ModelListDetailRef>(null)
|
||||
const [list, setList] = useState<ProviderModelItem[]>([])
|
||||
|
||||
useEffect(() => {
|
||||
getList()
|
||||
}, [query])
|
||||
@@ -54,6 +55,11 @@ const ModelList: FC<{ query: any }> = ({ query }) => {
|
||||
keyConfigModalRef.current?.handleOpen(vo)
|
||||
}
|
||||
|
||||
/** Expose methods to parent component */
|
||||
useImperativeHandle(ref, () => ({
|
||||
getList,
|
||||
modelListDetailRefresh: () => modelListDetailRef.current?.handleRefresh()
|
||||
}));
|
||||
return (
|
||||
<>
|
||||
{list.length === 0
|
||||
@@ -64,7 +70,7 @@ const ModelList: FC<{ query: any }> = ({ query }) => {
|
||||
<RbCard
|
||||
key={item.provider}
|
||||
title={t(`modelNew.${item.provider}`)}
|
||||
avatarUrl={getLogoUrl(item.logo)}
|
||||
avatarUrl={getListLogoUrl(item.provider, item.logo)}
|
||||
avatar={
|
||||
<div className="rb:w-12 rb:h-12 rb:rounded-lg rb:mr-3.25 rb:bg-[#155eef] rb:flex rb:items-center rb:justify-center rb:text-[28px] rb:text-[#ffffff]">
|
||||
{item.provider[0].toUpperCase()}
|
||||
@@ -96,9 +102,10 @@ const ModelList: FC<{ query: any }> = ({ query }) => {
|
||||
<ModelListDetail
|
||||
ref={modelListDetailRef}
|
||||
refresh={getList}
|
||||
handleEdit={handleEdit}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
export default ModelList
|
||||
@@ -26,7 +26,7 @@ import { getLogoUrl } from './utils'
|
||||
/**
|
||||
* Model square component
|
||||
*/
|
||||
const ModelSquare = forwardRef <BaseRef, { query: any; handleEdit: (vo?: ModelPlazaItem) => void; }>(({ query, handleEdit }, ref) => {
|
||||
const ModelSquare = forwardRef <BaseRef, { query: any; }>(({ query }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const modelSquareDetailRef = useRef<ModelSquareDetailRef>(null)
|
||||
@@ -96,7 +96,6 @@ const ModelSquare = forwardRef <BaseRef, { query: any; handleEdit: (vo?: ModelPl
|
||||
<Flex justify="space-between">
|
||||
<Space size={8}><UsergroupAddOutlined /> {item.add_count}</Space>
|
||||
<Space>
|
||||
{!item.is_official && <Button type="primary" disabled={item.is_deprecated} onClick={() => handleEdit(item)}>{t('modelNew.edit')}</Button>}
|
||||
{item.is_added
|
||||
? <Button type="primary" disabled>{t('modelNew.added')}</Button>
|
||||
: <Button type="primary" ghost disabled={item.is_deprecated} onClick={() => handleAdd(item)}>{item.is_deprecated ? t('modelNew.deprecated') : `+ ${t('common.add')}`}</Button>
|
||||
@@ -114,7 +113,6 @@ const ModelSquare = forwardRef <BaseRef, { query: any; handleEdit: (vo?: ModelPl
|
||||
<ModelSquareDetail
|
||||
ref={modelSquareDetailRef}
|
||||
refresh={getList}
|
||||
handleEdit={handleEdit}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:49:28
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 16:49:28
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-28 17:24:05
|
||||
*/
|
||||
/**
|
||||
* Custom Model Modal
|
||||
@@ -11,10 +11,10 @@
|
||||
*/
|
||||
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, App, Select } from 'antd';
|
||||
import { Form, Input, App } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { CustomModelForm, ModelPlazaItem, CustomModelModalRef, CustomModelModalProps } from '../types';
|
||||
import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types';
|
||||
import RbModal from '@/components/RbModal'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import UploadImages from '@/components/Upload/UploadImages'
|
||||
@@ -30,28 +30,27 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [model, setModel] = useState<ModelPlazaItem>({} as ModelPlazaItem);
|
||||
const [model, setModel] = useState<ModelListItem>({} as ModelListItem);
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [form] = Form.useForm<CustomModelForm>();
|
||||
const [loading, setLoading] = useState(false)
|
||||
const formValues = Form.useWatch([], form)
|
||||
|
||||
/** Close modal and reset state */
|
||||
const handleClose = () => {
|
||||
setModel({} as ModelPlazaItem);
|
||||
setModel({} as ModelListItem);
|
||||
form.resetFields();
|
||||
setLoading(false)
|
||||
setVisible(false);
|
||||
};
|
||||
|
||||
/** Open modal with optional model data for editing */
|
||||
const handleOpen = (model?: ModelPlazaItem) => {
|
||||
const handleOpen = (model?: ModelListItem) => {
|
||||
if (model) {
|
||||
setIsEdit(true);
|
||||
setModel(model);
|
||||
form.setFieldsValue({
|
||||
...model,
|
||||
logo: model.logo ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
|
||||
logo: model.logo && model.logo.startsWith('http') ? { url: model.logo, uid: model.logo, status: 'done', name: 'logo' } : undefined
|
||||
});
|
||||
} else {
|
||||
setIsEdit(false);
|
||||
@@ -66,7 +65,7 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
const res = isEdit ? updateCustomModel(model.id, rest) : addCustomModel(data)
|
||||
|
||||
res.then(() => {
|
||||
refresh && refresh()
|
||||
refresh && refresh(isEdit)
|
||||
handleClose()
|
||||
message.success(isEdit ? t('common.updateSuccess') : t('common.createSuccess'))
|
||||
})
|
||||
@@ -79,12 +78,10 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
setLoading(true)
|
||||
const { logo, ...rest } = values;
|
||||
let formData: CustomModelForm = {
|
||||
...rest
|
||||
}
|
||||
formData.is_official = false;
|
||||
|
||||
if (typeof logo === 'object' && logo?.response?.data.file_id) {
|
||||
getFileLink(logo?.response?.data.file_id)
|
||||
@@ -111,8 +108,6 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
handleOpen,
|
||||
}));
|
||||
|
||||
console.log('formValues', formValues)
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={isEdit ? `${model.name} - ${t('modelNew.modelConfiguration')}` : t('modelNew.createCustomModel')}
|
||||
@@ -174,11 +169,22 @@ const CustomModelModal = forwardRef<CustomModelModalRef, CustomModelModalProps>(
|
||||
>
|
||||
<Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
|
||||
<Form.Item
|
||||
name="tags"
|
||||
label={t('modelNew.tags')}
|
||||
name={["api_keys", 0, "api_key"]}
|
||||
label={t('modelNew.api_key')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_key') }) }]}
|
||||
>
|
||||
<Select mode="tags" placeholder={t('common.pleaseEnter')} />
|
||||
<Input.Password placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name={["api_keys", 0, "api_base"]}
|
||||
label={t('modelNew.api_base')}
|
||||
rules={[{ required: true, message: t('common.inputPlaceholder', { title: t('modelNew.api_base') }) }]}
|
||||
>
|
||||
<Input placeholder="https://api.example.com/v1" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</RbModal>
|
||||
|
||||
@@ -30,12 +30,13 @@ import CustomSelect from '@/components/CustomSelect'
|
||||
interface ModelListDetailProps {
|
||||
/** Callback to refresh parent list */
|
||||
refresh?: () => void;
|
||||
handleEdit: (vo?: ModelListItem) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Model list detail drawer component
|
||||
*/
|
||||
const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({ refresh }, ref) => {
|
||||
const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({ refresh, handleEdit }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [data, setData] = useState<ProviderModelItem>({} as ProviderModelItem)
|
||||
@@ -95,7 +96,8 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
||||
|
||||
/** Expose methods to parent component */
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleOpen,
|
||||
handleRefresh,
|
||||
}));
|
||||
|
||||
/** Filter models by selected type */
|
||||
@@ -149,7 +151,10 @@ const ModelListDetail = forwardRef<ModelListDetailRef, ModelListDetailProps>(({
|
||||
</Tooltip>
|
||||
<div className="rb:absolute rb:bottom-4 rb:left-6 rb:right-6">
|
||||
<Row gutter={12}>
|
||||
<Col span={24}>
|
||||
<Col span={12}>
|
||||
<Button block onClick={() => handleEdit(item)}>{t('modelNew.modelConfiguration')}</Button>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Button type="primary" ghost block onClick={() => handleKeyConfig(item)}>{t('modelNew.keyConfig')}</Button>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
@@ -29,14 +29,12 @@ import { getLogoUrl } from '../utils'
|
||||
interface ModelSquareDetailProps {
|
||||
/** Callback to refresh parent list */
|
||||
refresh: () => void;
|
||||
/** Callback to edit model */
|
||||
handleEdit: (vo: ModelPlazaItem) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Model square detail drawer component
|
||||
*/
|
||||
const ModelSquareDetail = forwardRef<ModelSquareDetailRef, ModelSquareDetailProps>(({ refresh, handleEdit }, ref) => {
|
||||
const ModelSquareDetail = forwardRef<ModelSquareDetailRef, ModelSquareDetailProps>(({ refresh }, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp()
|
||||
const [model, setModel] = useState<ModelPlaza>({} as ModelPlaza)
|
||||
@@ -112,7 +110,6 @@ const ModelSquareDetail = forwardRef<ModelSquareDetailRef, ModelSquareDetailProp
|
||||
<Flex justify="space-between">
|
||||
<Space size={8}><UsergroupAddOutlined /> {item.add_count}</Space>
|
||||
<Space>
|
||||
{!item.is_official && <Button type="primary" disabled={item.is_deprecated} onClick={() => handleEdit(item)}>{t('modelNew.edit')}</Button>}
|
||||
{item.is_added
|
||||
? <Button type="primary" disabled>{t('modelNew.added')}</Button>
|
||||
: <Button type="primary" ghost disabled={item.is_deprecated} onClick={() => handleAdd(item)}>{item.is_deprecated ? t('modelNew.deprecated') : `+ ${t('common.add')}`}</Button>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user