diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index 29656608..6c0440b8 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -361,7 +361,8 @@ async def draft_run( workspace_id=workspace_id, user=current_user ) - if storage_type is None: storage_type = 'neo4j' + if storage_type is None: + storage_type = 'neo4j' user_rag_memory_id = '' if workspace_id: @@ -370,7 +371,8 @@ async def draft_run( name="USER_RAG_MERORY", workspace_id=workspace_id ) - if knowledge: user_rag_memory_id = str(knowledge.id) + if knowledge: + user_rag_memory_id = str(knowledge.id) # 提前验证和准备(在流式响应开始前完成) diff --git a/api/app/controllers/order_controller.py b/api/app/controllers/order_controller.py new file mode 100644 index 00000000..d9534a8e --- /dev/null +++ b/api/app/controllers/order_controller.py @@ -0,0 +1,97 @@ +from fastapi import APIRouter, Depends, status +from sqlalchemy.orm import Session +import os + +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.order_schema import CreateOrderRequest +from app.schemas.response_schema import ApiResponse +from app.services.order_service import get_order_service +from app.core.logging_config import get_api_logger +from app.core.response_utils import success, error + +# Get API logger +api_logger = get_api_logger() + +router = APIRouter( + prefix="/order", + tags=["Order"], +) + + +@router.post("", response_model=ApiResponse) +async def create_order( + order_data: CreateOrderRequest, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + + try: + api_logger.info(f"User {current_user.id} creating order for product {order_data.product_id}") + + # Get external API configuration from environment + external_api_url = os.getenv("EXTERNAL_ORDER_API_URL") + api_key = os.getenv("EXTERNAL_ORDER_API_KEY") + + # Get order service instance + order_service = get_order_service( + external_api_url=external_api_url, + api_key=api_key + ) + + # Forward request to external API + result = await order_service.create_order( + order_data=order_data, + user_id=str(current_user.id) + ) + + api_logger.info(f"Order created successfully: {result.get('order_id')}") + + return success(data=result, msg="Order created successfully") + + except Exception as e: + api_logger.error(f"Failed to create order: {str(e)}", exc_info=True) + return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@router.get("/{order_id}", response_model=ApiResponse) +async def get_order( + order_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """Get order details from external API + + Args: + order_id: Order ID + db: Database session + current_user: Current authenticated user + + Returns: + API response with order details + """ + try: + api_logger.info(f"User {current_user.id} fetching order {order_id}") + + # Get external API configuration + external_api_url = os.getenv("EXTERNAL_ORDER_API_URL") + api_key = os.getenv("EXTERNAL_ORDER_API_KEY") + + # Get order service instance + order_service = get_order_service( + external_api_url=external_api_url, + api_key=api_key + ) + + # Fetch order from external API + result = await order_service.get_order(order_id) + + api_logger.info(f"Order {order_id} fetched successfully") + + return success(data=result, msg="Order fetched successfully") + + except Exception as e: + api_logger.error(f"Failed to fetch order {order_id}: {str(e)}", exc_info=True) + return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR) + diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 1731405c..d621caf9 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -2,14 +2,30 @@ import uuid from fastapi import APIRouter, Depends, Request, Body from sqlalchemy.orm import Session +from typing import Optional, Annotated +from starlette.responses import StreamingResponse + +from app.core.api_key_auth import require_api_key from app.db import get_db from app.core.response_utils import success from app.core.logging_config import get_business_logger -from app.core.api_key_auth import require_api_key -from app.schemas.api_key_schema import ApiKeyAuth +from app.dependencies import get_app_or_workspace +from app.repositories import knowledge_repository +from app.schemas import AppChatRequest, conversation_schema +from app.models.app_model import App +from app.models.app_model import AppType +from app.repositories.end_user_repository import EndUserRepository +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.services import workspace_service +from app.services.app_chat_service import AppChatService, get_app_chat_service +from app.services.app_service import AppService +from app.services.conversation_service import ConversationService, get_conversation_service +from app.services.workflow_service import WorkflowService, get_workflow_service +from app.utils.app_config_utils import dict_to_multi_agent_config,dict_to_agent_config,dict_to_workflow_config -router = APIRouter(prefix="/apps", tags=["V1 - App API"]) +router = APIRouter(prefix="/app", tags=["V1 - App API"]) logger = get_business_logger() @@ -19,28 +35,232 @@ async def list_apps(): return success(data=[], msg="App API - Coming Soon") # /v1/apps/{resource_id}/chat -@router.post("/{resource_id}/chat") -@require_api_key(scopes=["app"]) -async def chat_with_agent_demo( - resource_id: uuid.UUID, - request: Request, - api_key_auth: ApiKeyAuth = None, + + +# async def chat( +# request: Request, +# api_key_auth: ApiKeyAuth = None, +# db: Session = Depends(get_db), +# message: str = Body(..., description="聊天消息内容"), +# ): +# """ +# Agent 聊天接口demo + +# scopes: 所需的权限范围列表["app", "rag", "memory"] + +# Args: +# resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id +# message: 请求参数 +# request: 声明请求 +# api_key_auth: 包含验证后的API Key 信息 +# db: db_session +# """ +# logger.info(f"API Key Auth: {api_key_auth}") +# logger.info(f"Resource ID: {resource_id}") +# logger.info(f"Message: {message}") +# return success(data={"received": True}, msg="消息已接收") + + +def _checkAppConfig(app: App): + if app.type == AppType.AGENT: + if not app.current_release.config: + raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) + elif app.type == AppType.MULTI_AGENT: + if not app.current_release.config: + raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) + elif app.type == AppType.WORKFLOW: + if not app.current_release.config: + raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING) + else: + raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING) + +@router.post("/chat") +# @require_api_key(scopes=["app"]) +async def chat( + payload: AppChatRequest, + app: App = Depends(get_app_or_workspace), db: Session = Depends(get_db), - message: str = Body(..., description="聊天消息内容"), + conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None, + app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None, + ): - """ - Agent 聊天接口demo + other_id = payload.user_id + workspace_id = app.workspace_id + end_user_repo = EndUserRepository(db) + new_end_user = end_user_repo.get_or_create_end_user( + app_id=app.id, + other_id=other_id, + original_user_id=other_id # Save original user_id to other_id + ) + end_user_id = str(new_end_user.id) - scopes: 所需的权限范围列表["app", "rag", "memory"] + # 提前验证和准备(在流式响应开始前完成) + storage_type = workspace_service.get_workspace_storage_type_without_auth( + db=db, + workspace_id=workspace_id + ) + if storage_type is None: + storage_type = 'neo4j' + user_rag_memory_id = '' + if storage_type == 'rag': + if workspace_id: + knowledge = knowledge_repository.get_knowledge_by_name( + db=db, + name="USER_RAG_MERORY", + workspace_id=workspace_id + ) + if knowledge: + user_rag_memory_id = str(knowledge.id) + else: + logger.warning( + f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") + storage_type = 'neo4j' + else: + logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") + storage_type = 'neo4j' + app_type = app.type + # check app config + _checkAppConfig(app) - Args: - resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id - message: 请求参数 - request: 声明请求 - api_key_auth: 包含验证后的API Key 信息 - db: db_session - """ - logger.info(f"API Key Auth: {api_key_auth}") - logger.info(f"Resource ID: {resource_id}") - logger.info(f"Message: {message}") - return success(data={"received": True}, msg="消息已接收") + # 获取或创建会话(提前验证) + conversation = conversation_service.create_or_get_conversation( + app_id=app.id, + workspace_id=workspace_id, + user_id=end_user_id, + is_draft=False + ) + + if app_type == AppType.AGENT: + agent_config = dict_to_agent_config(app.current_release.config) + # 流式返回 + if payload.stream: + async def event_generator(): + async for event in app_chat_service.agnet_chat_stream( + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id= end_user_id, # 转换为字符串 + variables=payload.variables, + web_search=payload.web_search, + config=app.current_release.config, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ): + yield event + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 非流式返回 + result = await app_chat_service.agnet_chat( + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config= agent_config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ) + return success(data=conversation_schema.ChatResponse(**result)) + elif app_type == AppType.MULTI_AGENT: + # 多 Agent 流式返回 + config = dict_to_multi_agent_config(app.current_release.config) + if payload.stream: + async def event_generator(): + async for event in app_chat_service.multi_agent_chat_stream( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ): + yield event + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 多 Agent 非流式返回 + result = await app_chat_service.multi_agent_chat( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ) + + return success(data=conversation_schema.ChatResponse(**result)) + elif app_type == AppType.WORKFLOW: + # 多 Agent 流式返回 + config = dict_to_workflow_config(app.current_release.config) + if payload.stream: + async def event_generator(): + async for event in app_chat_service.workflow_chat_stream( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ): + yield event + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no" + } + ) + + # 非流式返回 + result = await app_chat_service.workflow_chat( + + message=payload.message, + conversation_id=conversation.id, # 使用已创建的会话 ID + user_id=end_user_id, # 转换为字符串 + variables=payload.variables, + config=config, + web_search=payload.web_search, + memory=payload.memory, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ) + + return success(data=conversation_schema.ChatResponse(**result)) + else: + from app.core.exceptions import BusinessException + from app.core.error_codes import BizCode + raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) + pass diff --git a/api/app/controllers/user_controller.py b/api/app/controllers/user_controller.py index b4d1c123..57495a7c 100644 --- a/api/app/controllers/user_controller.py +++ b/api/app/controllers/user_controller.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session import uuid diff --git a/api/app/dependencies.py b/api/app/dependencies.py index 9e0cd88c..10684788 100644 --- a/api/app/dependencies.py +++ b/api/app/dependencies.py @@ -1,12 +1,13 @@ import uuid from functools import wraps -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from sqlalchemy.orm import Session from jose import jwt, JWTError from app.db import get_db, SessionLocal +from app.models import App from app.schemas import token_schema from app.core.config import settings from app.core.security import get_token_id @@ -27,6 +28,51 @@ security_logger = get_security_logger() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +class APIKeyExtractor: + """ + Custom dependency to extract API Key from request headers + + Supports two formats: + 1. Authorization: Bearer + 2. X-API-Key: + """ + + async def __call__(self, request: Request) -> str: + """Extract API Key from request headers + + Args: + request: FastAPI Request object + + Returns: + API Key string + + Raises: + HTTPException: If API Key is not found + """ + # Try Authorization header first + auth_header = request.headers.get("Authorization") + if auth_header and " " in auth_header: + auth_scheme, auth_token = auth_header.split(" ", 1) + if auth_scheme.lower() == "bearer": + return auth_token + + # Try X-API-Key header + api_key = request.headers.get("X-API-Key") + if api_key: + return api_key + + # No API Key found + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API Key not found in request headers", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +api_key_extractor = APIKeyExtractor() + + + async def get_current_user( token: str = Depends(oauth2_scheme), db: Session = Depends(get_db) @@ -304,7 +350,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False): - db: Session = Depends(get_db) - user 或 current_user: User = Depends(get_current_user) - workspace_id: uuid.UUID (query/path 参数)或 payload: AppCreate(body,含 workspace_id) - + 支持同步和异步函数。 """ import asyncio @@ -360,7 +406,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False): def get_uow() -> IUnitOfWork: """ 获取工作单元实例 - + Returns: IUnitOfWork: 工作单元实例 """ @@ -373,7 +419,7 @@ def cur_workspace_access_guard(): 要求端点函数签名包含: - db: Session = Depends(get_db) - current_user: User = Depends(get_current_user) - + 支持同步和异步函数。 """ import asyncio @@ -423,10 +469,10 @@ async def get_share_user_id( ) -> ShareTokenData: """ 从分享访问 token 中获取用户 ID 和 share_token - + 这个函数用于公开分享的接口,验证访问 token 并返回用户信息 不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效 - + Returns: ShareTokenData: 包含 user_id 和 share_token """ @@ -469,4 +515,75 @@ async def get_share_user_id( raise credentials_exception +async def get_app_or_workspace( + api_key: str = Depends(api_key_extractor), + db: Session = Depends(get_db) +) -> App | Workspace: + """ + Get App or Workspace from API Key + + Supports two API Key formats: + 1. Authorization: Bearer + 2. X-API-Key: + + Args: + api_key: API Key extracted from request headers + db: Database session + + Returns: + App or Workspace object based on API Key + + Raises: + HTTPException: If API Key is invalid or not found + """ + from app.services.api_key_service import ApiKeyAuthService + from app.repositories.app_repository import get_apps_by_id + from app.repositories.workspace_repository import get_workspace_by_id + + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate API Key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + auth_logger.debug(f"Validating API Key: {api_key[:10]}...") + + # Validate API Key + api_key_obj = ApiKeyAuthService.validate_api_key(db, api_key) + if not api_key_obj: + auth_logger.warning(f"Invalid or expired API Key: {api_key[:10]}...") + raise credentials_exception + + auth_logger.debug(f"API Key validated successfully, type: {api_key_obj.type}") + + # Return App or Workspace based on API Key type + if (api_key_obj.type == "agent" or api_key.type == "multi_agent") and api_key_obj.resource_id: + # App API Key + app = get_apps_by_id(db, api_key_obj.resource_id) + if not app: + auth_logger.warning(f"App not found for API Key: {api_key_obj.resource_id}") + raise credentials_exception + auth_logger.info(f"App access granted: {app.id}") + return app + + elif api_key_obj.type == "service": + # Workspace API Key + workspace = get_workspace_by_id(db, api_key_obj.workspace_id) + if not workspace: + auth_logger.warning(f"Workspace not found for API Key: {api_key_obj.workspace_id}") + raise credentials_exception + auth_logger.info(f"Workspace access granted: {workspace.id}") + return workspace + + else: + auth_logger.warning(f"Unsupported API Key type: {api_key_obj.type}") + raise credentials_exception + + except HTTPException: + raise + except Exception as e: + auth_logger.error(f"Error validating API Key: {str(e)}", exc_info=True) + raise credentials_exception + diff --git a/api/app/main.py b/api/app/main.py index d5efeb35..87bfecf8 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -2,38 +2,10 @@ import os import subprocess from contextlib import asynccontextmanager -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, APIRouter +from fastapi import HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse -from app.core.response_utils import fail -from app.core.logging_config import LoggingConfig, get_logger -from app.core.exceptions import BusinessException -from app.core.error_codes import BizCode, HTTP_MAPPING -from app.controllers import ( - model_controller, - task_controller, - test_controller, - user_controller, - auth_controller, - workspace_controller, - setup_controller, - file_controller, - document_controller, - knowledge_controller, - chunk_controller, - knowledgeshare_controller, - app_controller, - upload_controller, - memory_agent_controller, - memory_storage_controller, - memory_dashboard_controller, - multi_agent_controller, -) - -from fastapi import FastAPI, APIRouter - -app = FastAPI(title="Data Config API", version="1.0.0") -router = APIRouter(prefix="/memory", tags=["Memory"]) # 管理端 API (JWT 认证) from app.controllers import manager_router diff --git a/api/app/schemas/__init__.py b/api/app/schemas/__init__.py index 208adc68..5eb36dd6 100644 --- a/api/app/schemas/__init__.py +++ b/api/app/schemas/__init__.py @@ -8,7 +8,9 @@ from .file_schema import File, FileCreate, FileUpdate from .tenant_schema import Tenant, TenantCreate, TenantUpdate from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate +from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse from .app_schema import ( + AppChatRequest, DraftRunRequest, DraftRunResponse, DraftRunStreamChunk, @@ -73,6 +75,10 @@ __all__ = [ "ChunkRetrieve", "KnowledgeShare", "KnowledgeShareCreate", + "CreateOrderRequest", + "OrderResponse", + "ExternalOrderResponse", + "AppChatRequest", "DraftRunRequest", "DraftRunResponse", "DraftRunStreamChunk", diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 52c5ae81..de0a4c53 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -334,6 +334,13 @@ class AppShare(BaseModel): # ---------- Draft Run Schemas ---------- +class AppChatRequest(BaseModel): + message: str = Field(..., description="用户消息") + conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)") + user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") + variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") + stream: bool = Field(default=False, description="是否流式返回") + class DraftRunRequest(BaseModel): """试运行请求""" message: str = Field(..., description="用户消息") diff --git a/api/app/schemas/order_schema.py b/api/app/schemas/order_schema.py new file mode 100644 index 00000000..77653fe6 --- /dev/null +++ b/api/app/schemas/order_schema.py @@ -0,0 +1,63 @@ +""" +Order Schema + +Defines request and response models for order operations. +""" + +from pydantic import BaseModel, Field +from typing import Any, Optional + + +class CreateOrderRequest(BaseModel): + """Create order request model""" + + product_id: str = Field(..., description="Product ID") + quantity: int = Field(..., gt=0, description="Order quantity") + customer_name: Optional[str] = Field(None, description="Customer name") + customer_email: Optional[str] = Field(None, description="Customer email") + notes: Optional[str] = Field(None, description="Order notes") + + class Config: + json_schema_extra = { + "example": { + "product_id": "PROD-001", + "quantity": 2, + "customer_name": "John Doe", + "customer_email": "john@example.com", + "notes": "Please deliver before 5pm" + } + } + + +class OrderResponse(BaseModel): + """Order response model""" + + order_id: str = Field(..., description="Order ID") + status: str = Field(..., description="Order status") + product_id: str = Field(..., description="Product ID") + quantity: int = Field(..., description="Order quantity") + total_amount: Optional[float] = Field(None, description="Total amount") + created_at: Optional[str] = Field(None, description="Creation timestamp") + message: Optional[str] = Field(None, description="Response message") + + class Config: + json_schema_extra = { + "example": { + "order_id": "ORD-20231224-001", + "status": "pending", + "product_id": "PROD-001", + "quantity": 2, + "total_amount": 199.99, + "created_at": "2023-12-24T10:30:00Z", + "message": "Order created successfully" + } + } + + +class ExternalOrderResponse(BaseModel): + """External API response model (flexible structure)""" + + success: bool = Field(default=True, description="Request success status") + data: Optional[Any] = Field(None, description="Response data") + error: Optional[str] = Field(None, description="Error message") + code: Optional[int] = Field(None, description="Response code") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py new file mode 100644 index 00000000..4f3ecf5f --- /dev/null +++ b/api/app/services/app_chat_service.py @@ -0,0 +1,485 @@ +"""基于分享链接的聊天服务""" +import asyncio +import json +import time +import uuid +from typing import Optional, Dict, Any, AsyncGenerator, Annotated + +from fastapi import Depends +from sqlalchemy.orm import Session + +from app.core.agent.langchain_agent import LangChainAgent +from app.core.logging_config import get_business_logger +from app.db import get_db +from app.models import MultiAgentConfig, AgentConfig +from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole +from app.services.conversation_service import ConversationService +from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool +from app.services.draft_run_service import create_web_search_tool +from app.services.model_service import ModelApiKeyService +from app.services.multi_agent_orchestrator import MultiAgentOrchestrator + +logger = get_business_logger() + + +class AppChatService: + """基于分享链接的聊天服务""" + + def __init__(self, db: Session): + self.db = db + self.conversation_service = ConversationService(db) + + async def agnet_chat( + self, + message: str, + conversation_id: uuid.UUID, + config: AgentConfig, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """聊天(非流式)""" + + start_time = time.time() + config_id = None + + if variables is None: + variables = {} + + # 获取模型配置ID + model_config_id = config.default_model_config_id + api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id) + # 处理系统提示词(支持变量替换) + system_prompt = config.get("system_prompt", "") + if variables: + system_prompt_rendered = render_prompt_message( + system_prompt, + PromptMessageRole.USER, + variables + ) + system_prompt = system_prompt_rendered.get_text_content() or system_prompt + + # 准备工具列表 + tools = [] + + # 添加知识库检索工具 + knowledge_retrieval = config.get("knowledge_retrieval") + if knowledge_retrieval: + knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) + kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] + if kb_ids: + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) + tools.append(kb_tool) + + # 添加长期记忆工具 + memory_flag = False + if memory == True: + memory_config = config.get("memory", {}) + if memory_config.get("enabled") and user_id: + memory_flag = True + memory_tool = create_long_term_memory_tool(memory_config, user_id) + tools.append(memory_tool) + + web_tools = config.get("tools") + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 获取模型参数 + model_parameters = config.get("model_parameters", {}) + + # 创建 LangChain Agent + agent = LangChainAgent( + model_name=api_key_obj.model_name, + api_key=api_key_obj.api_key, + provider=api_key_obj.provider, + api_base=api_key_obj.api_base, + temperature=model_parameters.get("temperature", 0.7), + max_tokens=model_parameters.get("max_tokens", 2000), + system_prompt=system_prompt, + tools=tools, + + ) + + # 加载历史消息 + history = [] + memory_config = {"enabled": True, 'max_history': 10} + if memory_config.get("enabled"): + messages = self.conversation_service.get_messages( + conversation_id=conversation_id, + limit=memory_config.get("max_history", 10) + ) + history = [ + {"role": msg.role, "content": msg.content} + for msg in messages + ] + + # 调用 Agent + result = await agent.chat( + message=message, + history=history, + context=None, + end_user_id=user_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + config_id=config_id, + memory_flag=memory_flag + ) + + # 保存消息 + self.conversation_service.save_conversation_messages( + conversation_id=conversation_id, + user_message=message, + assistant_message=result["content"] + ) + + elapsed_time = time.time() - start_time + + return { + "conversation_id": conversation_id, + "message": result["content"], + "usage": result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }), + "elapsed_time": elapsed_time + } + + async def agnet_chat_stream( + self, + message: str, + conversation_id: uuid.UUID, + config: AgentConfig, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """聊天(流式)""" + + try: + start_time = time.time() + config_id = None + + if variables is None: + variables = {} + + # 获取模型配置ID + model_config_id = config.default_model_config_id + api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id) + # 处理系统提示词(支持变量替换) + system_prompt = config.get("system_prompt", "") + if variables: + system_prompt_rendered = render_prompt_message( + system_prompt, + PromptMessageRole.USER, + variables + ) + system_prompt = system_prompt_rendered.get_text_content() or system_prompt + + # 准备工具列表 + tools = [] + + # 添加知识库检索工具 + knowledge_retrieval = config.get("knowledge_retrieval") + if knowledge_retrieval: + knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) + kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] + if kb_ids: + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) + tools.append(kb_tool) + + # 添加长期记忆工具 + memory_flag = False + if memory: + memory_config = config.get("memory", {}) + if memory_config.get("enabled") and user_id: + memory_flag = True + memory_tool = create_long_term_memory_tool(memory_config, user_id) + tools.append(memory_tool) + + web_tools = config.get("tools") + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + + # 获取模型参数 + model_parameters = config.get("model_parameters", {}) + + # 创建 LangChain Agent + agent = LangChainAgent( + model_name=api_key_obj.model_name, + api_key=api_key_obj.api_key, + provider=api_key_obj.provider, + api_base=api_key_obj.api_base, + temperature=model_parameters.get("temperature", 0.7), + max_tokens=model_parameters.get("max_tokens", 2000), + system_prompt=system_prompt, + tools=tools, + streaming=True + ) + + # 加载历史消息 + history = [] + memory_config = {"enabled": True, 'max_history': 10} + if memory_config.get("enabled"): + messages = self.conversation_service.get_messages( + conversation_id=conversation_id, + limit=memory_config.get("max_history", 10) + ) + history = [ + {"role": msg.role, "content": msg.content} + for msg in messages + ] + + # 发送开始事件 + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + + # 流式调用 Agent + full_content = "" + async for chunk in agent.chat_stream( + message=message, + history=history, + context=None, + end_user_id=user_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + config_id=config_id, + memory_flag=memory_flag + ): + full_content += chunk + # 发送消息块事件 + yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" + + elapsed_time = time.time() - start_time + + # 保存消息 + self.conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=message + ) + + self.conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=full_content, + meta_data={ + "model": api_key_obj.model_name, + "usage": {} + } + ) + + # 发送结束事件 + end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} + yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" + + logger.info( + "流式聊天完成", + extra={ + "conversation_id": str(conversation_id), + "elapsed_time": elapsed_time, + "message_length": len(full_content) + } + ) + + except (GeneratorExit, asyncio.CancelledError): + # 生成器被关闭或任务被取消,正常退出 + logger.debug("流式聊天被中断") + raise + except Exception as e: + logger.error(f"流式聊天失败: {str(e)}", exc_info=True) + # 发送错误事件 + yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + + async def multi_agent_chat( + self, + message: str, + conversation_id: uuid.UUID, + config: MultiAgentConfig, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + ) -> Dict[str, Any]: + """多 Agent 聊天(非流式)""" + + start_time = time.time() + actual_config_id = None + config_id = actual_config_id + + if variables is None: + variables = {} + + # 2. 创建编排器 + orchestrator = MultiAgentOrchestrator(self.db, config) + + # 3. 执行任务 + result = await orchestrator.execute( + message=message, + conversation_id=conversation_id, + user_id=user_id, + variables=variables, + use_llm_routing=True, # 默认启用 LLM 路由 + web_search=web_search, # 网络搜索参数 + memory=memory # 记忆功能参数 + ) + + elapsed_time = time.time() - start_time + + # 保存消息 + self.conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=message + ) + + self.conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=result.get("message", ""), + meta_data={ + "mode": result.get("mode"), + "elapsed_time": result.get("elapsed_time"), + "sub_results": result.get("sub_results") + } + ) + + return { + "conversation_id": conversation_id, + "message": result.get("message", ""), + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }, + "elapsed_time": elapsed_time + } + + async def multi_agent_chat_stream( + self, + message: str, + conversation_id: uuid.UUID, + config: MultiAgentConfig, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + ) -> AsyncGenerator[str, None]: + """多 Agent 聊天(流式)""" + + start_time = time.time() + actual_config_id = None + config_id = actual_config_id + + if variables is None: + variables = {} + + try: + + # 发送开始事件 + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + + full_content = "" + + # 2. 创建编排器 + orchestrator = MultiAgentOrchestrator(self.db, config) + + # 3. 流式执行任务 + async for event in orchestrator.execute_stream( + message=message, + conversation_id=conversation_id, + user_id=user_id, + variables=variables, + use_llm_routing=True, + web_search=web_search, # 网络搜索参数 + memory=memory, # 记忆功能参数 + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id + ): + yield event + # 尝试提取内容(用于保存) + if "data:" in event: + try: + data_line = event.split("data: ", 1)[1].strip() + data = json.loads(data_line) + if "content" in data: + full_content += data["content"] + except: + pass + + elapsed_time = time.time() - start_time + + # 保存消息 + self.conversation_service.add_message( + conversation_id=conversation_id, + role="user", + content=message + ) + + self.conversation_service.add_message( + conversation_id=conversation_id, + role="assistant", + content=full_content, + meta_data={ + "elapsed_time": elapsed_time + } + ) + + logger.info( + "多 Agent 流式聊天完成", + extra={ + "conversation_id": str(conversation_id), + "elapsed_time": elapsed_time, + "message_length": len(full_content) + } + ) + + + except (GeneratorExit, asyncio.CancelledError): + # 生成器被关闭或任务被取消,正常退出 + logger.debug("多 Agent 流式聊天被中断") + raise + except Exception as e: + logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True) + # 发送错误事件 + yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + + +# ==================== 依赖注入函数 ==================== + +def get_app_chat_service( + db: Annotated[Session, Depends(get_db)] +) -> ChatService: + """获取工作流服务(依赖注入)""" + return ChatService(db) diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 63826726..f618ea17 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -1,9 +1,12 @@ """会话服务""" import uuid -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Annotated + +from fastapi import Depends from sqlalchemy.orm import Session from sqlalchemy import select, desc +from app.db import get_db from app.models import Conversation, Message from app.core.exceptions import ResourceNotFoundException, BusinessException from app.core.error_codes import BizCode @@ -14,10 +17,10 @@ logger = get_business_logger() class ConversationService: """会话服务""" - + def __init__(self, db: Session): self.db = db - + def create_conversation( self, app_id: uuid.UUID, @@ -36,11 +39,11 @@ class ConversationService: is_draft=is_draft, config_snapshot=config_snapshot ) - + self.db.add(conversation) self.db.commit() self.db.refresh(conversation) - + logger.info( "创建会话成功", extra={ @@ -50,9 +53,9 @@ class ConversationService: "is_draft": is_draft } ) - + return conversation - + def get_conversation( self, conversation_id: uuid.UUID, @@ -60,17 +63,17 @@ class ConversationService: ) -> Conversation: """获取会话""" stmt = select(Conversation).where(Conversation.id == conversation_id) - + if workspace_id: stmt = stmt.where(Conversation.workspace_id == workspace_id) - + conversation = self.db.scalars(stmt).first() - + if not conversation: raise ResourceNotFoundException("会话", str(conversation_id)) - + return conversation - + def list_conversations( self, app_id: uuid.UUID, @@ -86,25 +89,25 @@ class ConversationService: Conversation.workspace_id == workspace_id, Conversation.is_active == True ) - + if user_id: stmt = stmt.where(Conversation.user_id == user_id) - + if is_draft is not None: stmt = stmt.where(Conversation.is_draft == is_draft) - + # 总数 count_stmt = stmt.with_only_columns(Conversation.id) total = len(self.db.execute(count_stmt).all()) - + # 分页 stmt = stmt.order_by(desc(Conversation.updated_at)) stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) - + conversations = list(self.db.scalars(stmt).all()) - + return conversations, total - + def add_message( self, conversation_id: uuid.UUID, @@ -119,22 +122,22 @@ class ConversationService: content=content, meta_data=meta_data ) - + self.db.add(message) - + # 更新会话的消息计数和更新时间 conversation = self.get_conversation(conversation_id) conversation.message_count += 1 - + # 如果是第一条用户消息,可以用它作为标题 if conversation.message_count == 1 and role == "user": conversation.title = content[:50] + ("..." if len(content) > 50 else "") - + self.db.commit() self.db.refresh(message) - + return message - + def get_messages( self, conversation_id: uuid.UUID, @@ -144,30 +147,30 @@ class ConversationService: stmt = select(Message).where( Message.conversation_id == conversation_id ).order_by(Message.created_at) - + if limit: stmt = stmt.limit(limit) - + messages = list(self.db.scalars(stmt).all()) - + return messages - + def get_conversation_history( self, conversation_id: uuid.UUID, max_history: Optional[int] = None ) -> List[dict]: """获取会话历史消息 - + Args: conversation_id: 会话ID max_history: 最大历史消息数量 - + Returns: List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...] """ messages = self.get_messages(conversation_id, limit=max_history) - + # 转换为字典格式 history = [ { @@ -176,9 +179,9 @@ class ConversationService: } for msg in messages ] - + return history - + def save_conversation_messages( self, conversation_id: uuid.UUID, @@ -192,14 +195,14 @@ class ConversationService: role="user", content=user_message ) - + # 添加助手消息 self.add_message( conversation_id=conversation_id, role="assistant", content=assistant_message ) - + logger.debug( "保存会话消息成功", extra={ @@ -208,7 +211,7 @@ class ConversationService: "assistant_message_length": len(assistant_message) } ) - + def delete_conversation( self, conversation_id: uuid.UUID, @@ -217,9 +220,9 @@ class ConversationService: """删除会话(软删除)""" conversation = self.get_conversation(conversation_id, workspace_id) conversation.is_active = False - + self.db.commit() - + logger.info( "删除会话成功", extra={ @@ -227,3 +230,53 @@ class ConversationService: "workspace_id": str(workspace_id) } ) + + def create_or_get_conversation( + self, + app_id: uuid.UUID, + workspace_id: uuid.UUID, + is_draft: bool = False, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + ) -> Conversation: + """创建或获取会话""" + + # 如果提供了 conversation_id,尝试获取现有会话 + if conversation_id: + try: + conversation = self.get_conversation( + conversation_id=conversation_id, + workspace_id=workspace_id + ) + + # 验证会话是否属于该应用 + if conversation.app_id != app_id: + raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) + return conversation + except ResourceNotFoundException: + logger.warning( + "会话不存在,将创建新会话", + extra={"conversation_id": str(conversation_id)} + ) + + # 创建新会话(使用发布版本的配置) + conversation = self.create_conversation( + app_id=app_id, + workspace_id=workspace_id, + user_id=user_id, + is_draft=is_draft + ) + + logger.info( + "为分享链接创建新会话" + ) + + return conversation + +# ==================== 依赖注入函数 ==================== + +def get_conversation_service( + db: Annotated[Session, Depends(get_db)] +) -> ConversationService: + """获取工作流服务(依赖注入)""" + return ConversationService(db) diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index 1d2822c0..e94a889b 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -36,7 +36,7 @@ class ModelConfigService: """获取模型配置列表""" models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id) pages = math.ceil(total / query.pagesize) if total > 0 else 0 - + return PageData( page=PageMeta( page=query.page, @@ -72,7 +72,7 @@ class ModelConfigService: test_message: str = "Hello" ) -> Dict[str, Any]: """验证模型配置是否有效 - + Args: db: 数据库会话 model_name: 模型名称 @@ -81,7 +81,7 @@ class ModelConfigService: api_base: API基础URL model_type: 模型类型 (llm/chat/embedding/rerank) test_message: 测试消息 - + Returns: Dict: 验证结果 """ @@ -89,10 +89,10 @@ class ModelConfigService: from app.core.models.base import RedBearModelConfig from app.core.models.embedding import RedBearEmbeddings import traceback - + try: start_time = time.time() - + model_config = RedBearModelConfig( model_name=model_name, provider=provider, @@ -101,16 +101,16 @@ class ModelConfigService: temperature=0.7, max_tokens=100 ) - + # 根据模型类型选择不同的验证方式 model_type_lower = model_type.lower() - + if model_type_lower in ["llm", "chat"]: # LLM/Chat 模型验证 - 统一使用字符串输入 llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT) response = await llm.ainvoke(test_message) elapsed_time = time.time() - start_time - + content = response.content if hasattr(response, 'content') else str(response) usage = None if hasattr(response, 'usage_metadata'): @@ -119,7 +119,7 @@ class ModelConfigService: "output_tokens": getattr(response.usage_metadata, 'output_tokens', 0), "total_tokens": getattr(response.usage_metadata, 'total_tokens', 0) } - + return { "valid": True, "message": f"{model_type.upper()} 模型配置验证成功", @@ -128,14 +128,14 @@ class ModelConfigService: "usage": usage, "error": None } - + elif model_type_lower == "embedding": # Embedding 模型验证(在线程中运行同步方法) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) elapsed_time = time.time() - start_time - + return { "valid": True, "message": "Embedding 模型配置验证成功", @@ -148,7 +148,7 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "rerank": # Rerank 模型验证(在线程中运行同步方法) rerank = RedBearRerank(model_config) @@ -156,7 +156,7 @@ class ModelConfigService: documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"] results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3) elapsed_time = time.time() - start_time - + return { "valid": True, "message": "Rerank 模型配置验证成功", @@ -169,7 +169,7 @@ class ModelConfigService: }, "error": None } - + else: return { "valid": False, @@ -179,7 +179,7 @@ class ModelConfigService: "usage": None, "error": f"不支持的模型类型: {model_type}" } - + except Exception as e: # 提取详细的错误信息 error_message = str(e) @@ -203,12 +203,12 @@ class ModelConfigService: error_message = f"无效请求: {error_message}" elif "model_copy" in error_message: error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)" - + # 记录详细错误日志 logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}") logger.error(f"错误详情: {error_message}") logger.debug(f"完整堆栈: {traceback.format_exc()}") - + return { "valid": False, "message": f"{model_type.upper()} 模型配置验证失败", @@ -249,7 +249,7 @@ class ModelConfigService: model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"}) # 添加租户ID model_config_data["tenant_id"] = tenant_id - + model = ModelConfigRepository.create(db, model_config_data) db.flush() # 获取生成的 ID @@ -259,7 +259,7 @@ class ModelConfigService: **api_key_data.dict() ) ModelApiKeyRepository.create(db, api_key_create_schema) - + db.commit() db.refresh(model) return model @@ -270,11 +270,11 @@ class ModelConfigService: existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - + if model_data.name and model_data.name != existing_model.name: if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) db.commit() db.refresh(model) @@ -285,7 +285,7 @@ class ModelConfigService: """删除模型配置""" if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id): raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - + success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id) db.commit() return success @@ -316,20 +316,20 @@ class ModelApiKeyService: return api_key @staticmethod - def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]: + def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]: """根据模型配置ID获取API Key列表""" if not ModelConfigRepository.get_by_id(db, model_config_id): raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - + return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active) - @staticmethod + @staticmethod async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey: """创建API Key""" model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id) if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name, @@ -345,7 +345,7 @@ class ModelApiKeyService: f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) - + api_key = ModelApiKeyRepository.create(db, api_key_data) db.commit() db.refresh(api_key) @@ -357,12 +357,12 @@ class ModelApiKeyService: existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not existing_api_key: raise BusinessException("API Key不存在", BizCode.NOT_FOUND) - + # 获取关联的模型配置以获取模型类型 model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id) if not model_config: raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND) - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name, @@ -378,7 +378,7 @@ class ModelApiKeyService: f"模型配置验证失败: {validation_result['error']}", BizCode.INVALID_PARAMETER ) - + api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data) db.commit() db.refresh(api_key) @@ -389,7 +389,7 @@ class ModelApiKeyService: """删除API Key""" if not ModelApiKeyRepository.get_by_id(db, api_key_id): raise BusinessException("API Key不存在", BizCode.NOT_FOUND) - + success = ModelApiKeyRepository.delete(db, api_key_id) db.commit() return success @@ -409,3 +409,11 @@ class ModelApiKeyService: if success: db.commit() return success + + @staticmethod + def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey: + + api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id) + if api_kes and len(api_kes) > 0: + return api_kes[0] + raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) diff --git a/api/app/services/order_service.py b/api/app/services/order_service.py new file mode 100644 index 00000000..c9649f3a --- /dev/null +++ b/api/app/services/order_service.py @@ -0,0 +1,205 @@ +""" +Order Service + +Handles order operations including forwarding requests to external APIs. +""" + +import logging +import httpx +from typing import Dict, Any, Optional +from app.schemas.order_schema import CreateOrderRequest + +logger = logging.getLogger(__name__) + + +class OrderService: + """Order service for handling order operations""" + + def __init__(self, external_api_url: Optional[str] = None, api_key: Optional[str] = None): + """Initialize order service + + Args: + external_api_url: External API base URL + api_key: API key for authentication + """ + # Default external API URL (replace with actual URL) + self.external_api_url = external_api_url or "https://api.example.com/v1" + self.api_key = api_key + self.timeout = 30.0 # 30 seconds timeout + + async def create_order( + self, + order_data: CreateOrderRequest, + user_id: str + ) -> Dict[str, Any]: + """Create order by forwarding request to external API + + Args: + order_data: Order creation data + user_id: Current user ID + + Returns: + Order response data + + Raises: + httpx.HTTPError: If external API request fails + Exception: For other errors + """ + try: + # Prepare request payload + payload = { + "product_id": order_data.product_id, + "quantity": order_data.quantity, + "customer_name": order_data.customer_name, + "customer_email": order_data.customer_email, + "notes": order_data.notes, + "user_id": user_id # Include user ID for tracking + } + + # Prepare headers + headers = { + "Content-Type": "application/json", + "User-Agent": "MemoryBear-OrderService/1.0" + } + + # Add API key if configured + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + logger.info(f"Forwarding order creation request to external API: {self.external_api_url}/orders") + logger.debug(f"Request payload: {payload}") + + # Make async HTTP request to external API + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.external_api_url}/orders", + json=payload, + headers=headers + ) + + # Log response status + logger.info(f"External API response status: {response.status_code}") + + # Raise exception for 4xx/5xx status codes + response.raise_for_status() + + # Parse response + response_data = response.json() + logger.debug(f"External API response data: {response_data}") + + # Transform external API response to internal format + return self._transform_external_response(response_data) + + except httpx.HTTPStatusError as e: + logger.error(f"External API returned error status: {e.response.status_code}") + logger.error(f"Error response: {e.response.text}") + + # Try to parse error response + try: + error_data = e.response.json() + error_message = error_data.get("message") or error_data.get("error") or "External API error" + except Exception: + error_message = f"External API error: {e.response.status_code}" + + raise Exception(f"Failed to create order: {error_message}") + + except httpx.TimeoutException: + logger.error(f"External API request timeout after {self.timeout}s") + raise Exception("Order creation timeout - external service not responding") + + except httpx.RequestError as e: + logger.error(f"External API request failed: {str(e)}") + raise Exception(f"Failed to connect to external order service: {str(e)}") + + except Exception as e: + logger.error(f"Unexpected error during order creation: {str(e)}", exc_info=True) + raise Exception(f"Order creation failed: {str(e)}") + + def _transform_external_response(self, external_data: Dict[str, Any]) -> Dict[str, Any]: + """Transform external API response to internal format + + Args: + external_data: Response data from external API + + Returns: + Transformed response data + """ + # Handle different response formats from external API + # Adjust this based on actual external API response structure + + if "data" in external_data: + # Format 1: {"success": true, "data": {...}} + data = external_data["data"] + elif "order" in external_data: + # Format 2: {"order": {...}} + data = external_data["order"] + else: + # Format 3: Direct response + data = external_data + + # Extract fields with fallbacks + return { + "order_id": data.get("order_id") or data.get("id") or "UNKNOWN", + "status": data.get("status") or "pending", + "product_id": data.get("product_id") or "", + "quantity": data.get("quantity") or 0, + "total_amount": data.get("total_amount") or data.get("amount"), + "created_at": data.get("created_at") or data.get("timestamp"), + "message": external_data.get("message") or "Order created successfully" + } + + async def get_order(self, order_id: str) -> Dict[str, Any]: + """Get order details from external API + + Args: + order_id: Order ID + + Returns: + Order details + """ + try: + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + logger.info(f"Fetching order {order_id} from external API") + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.get( + f"{self.external_api_url}/orders/{order_id}", + headers=headers + ) + response.raise_for_status() + return response.json() + + except Exception as e: + logger.error(f"Failed to fetch order {order_id}: {str(e)}") + raise Exception(f"Failed to fetch order: {str(e)}") + + +# Singleton instance +_order_service_instance: Optional[OrderService] = None + + +def get_order_service( + external_api_url: Optional[str] = None, + api_key: Optional[str] = None +) -> OrderService: + """Get order service instance + + Args: + external_api_url: External API URL (optional, uses default if not provided) + api_key: API key (optional) + + Returns: + OrderService instance + """ + global _order_service_instance + + if _order_service_instance is None: + _order_service_instance = OrderService( + external_api_url=external_api_url, + api_key=api_key + ) + + return _order_service_instance diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 04ee647c..5e95517d 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -1,35 +1,32 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid -import secrets -import hashlib import datetime -from fastapi import HTTPException, status +import hashlib +import secrets +import uuid +from os import getenv +from typing import List, Optional + +from sqlalchemy.orm import Session + +from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, PermissionDeniedException -from app.models.tenant_model import Tenants +from app.core.logging_config import get_business_logger from app.models.user_model import User -from app.models.app_model import App -from app.models.end_user_model import EndUser -from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember +from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember +from app.repositories import workspace_repository +from app.repositories.workspace_invite_repository import WorkspaceInviteRepository from app.schemas.workspace_schema import ( - WorkspaceCreate, - WorkspaceUpdate, - WorkspaceInviteCreate, + WorkspaceCreate, + WorkspaceUpdate, + WorkspaceInviteCreate, WorkspaceInviteResponse, InviteValidateResponse, InviteAcceptRequest, WorkspaceMemberUpdate ) -from app.repositories import workspace_repository -from app.repositories.workspace_invite_repository import WorkspaceInviteRepository -from app.core.logging_config import get_business_logger -from app.core.config import settings -from app.services import user_service -from os import getenv + # 获取业务逻辑专用日志器 business_logger = get_business_logger() -import os # from dotenv import load_dotenv load_dotenv() def switch_workspace( @@ -39,10 +36,10 @@ def switch_workspace( ): """切换工作空间""" business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}") - + # 检查用户是否为成员或超级管理员 _check_workspace_member_permission(db, workspace_id, user) - + # 更新当前用户的工作空间上下文 try: user.current_workspace_id = workspace_id @@ -63,22 +60,22 @@ def delete_workspace_member( ): """删除工作空间成员""" business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}") - _check_workspace_admin_permission(db, workspace_id, user) + _check_workspace_admin_permission(db, workspace_id, user) workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id) if not workspace_member: raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - + if workspace_member.workspace_id != workspace_id: raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - - try: + + try: workspace_member.is_active = False workspace_member.user.current_workspace_id = None - db.commit() + db.commit() business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}") except Exception as e: db.rollback() - business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") + business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}") raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR) @@ -94,7 +91,7 @@ def _create_workspace_only( db: Session, workspace: WorkspaceCreate, owner: User ) -> Workspace: business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}") - + try: # Create the workspace without adding any members business_logger.debug(f"创建工作空间: {workspace.name}") @@ -126,7 +123,7 @@ def create_workspace( business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}") db.commit() db.refresh(db_workspace) - + # 如果 storage_type 是 "rag",自动创建知识库 if workspace.storage_type == "rag": business_logger.info( @@ -138,7 +135,7 @@ def create_workspace( from app.schemas.knowledge_schema import KnowledgeCreate from app.models.knowledge_model import KnowledgeType, PermissionType from app.repositories import knowledge_repository - + # 创建知识库数据 knowledge_data = KnowledgeCreate( workspace_id=db_workspace.id, @@ -162,10 +159,10 @@ def create_workspace( "html4excel": False } ) - + # 直接使用 repository 创建知识库,避免 service 层的额外逻辑 db_knowledge = knowledge_repository.create_knowledge( - db=db, + db=db, knowledge=knowledge_data ) db.commit() @@ -179,12 +176,12 @@ def create_workspace( ) db.rollback() raise BusinessException( - f"工作空间创建成功,但知识库创建失败: {str(kb_error)}", + f"工作空间创建成功,但知识库创建失败: {str(kb_error)}", BizCode.INTERNAL_ERROR ) - + return db_workspace - + except Exception as e: business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}") db.rollback() @@ -195,7 +192,7 @@ def update_workspace( db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User ) -> Workspace: business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}") - + db_workspace = _check_workspace_admin_permission(db,workspace_id,user) try: # 更新工作空间 @@ -219,8 +216,8 @@ def get_workspace_members( db: Session, workspace_id: uuid.UUID, user: User ) -> List[WorkspaceMember]: """获取某工作空间的成员列表(关系序列化由模型关系支持)""" - business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}") - + business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}") + # 查找工作空间 business_logger.debug(f"查找工作空间: {workspace_id}") workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) @@ -237,10 +234,10 @@ def get_workspace_members( db=db, user_id=user.id, workspace_id=workspace_id ) workspace_memberships = {workspace_id} if member else set() - + subject = Subject.from_user(user, workspace_memberships=workspace_memberships) resource = Resource.from_workspace(workspace) - + try: permission_service.require_permission( subject, @@ -265,7 +262,7 @@ def get_workspace_members( def _generate_invite_token() -> tuple[str, str]: """生成邀请令牌和其哈希值 - + Returns: tuple: (原始令牌, 令牌哈希) """ @@ -285,21 +282,21 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use message="Workspace not found", code=BizCode.WORKSPACE_NOT_FOUND ) - + # 使用统一权限服务检查访问权限 from app.core.permissions import permission_service, Subject, Resource, Action - + # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( db=db, user_id=user.id, workspace_id=workspace_id ) - + # 任何成员都有访问权限 workspace_memberships = {workspace_id} if member else set() - + subject = Subject.from_user(user, workspace_memberships=workspace_memberships) resource = Resource.from_workspace(db_workspace) - + try: permission_service.require_permission( subject, @@ -323,21 +320,21 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user message="Workspace not found", code=BizCode.WORKSPACE_NOT_FOUND ) - + # 使用统一权限服务检查管理权限 from app.core.permissions import permission_service, Subject, Resource, Action - + # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( db=db, user_id=user.id, workspace_id=workspace_id ) - + # 只有 manager 才有管理权限 workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set() - + subject = Subject.from_user(user, workspace_memberships=workspace_memberships) resource = Resource.from_workspace(db_workspace) - + try: permission_service.require_permission( subject, @@ -353,14 +350,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user def create_workspace_invite( - db: Session, - workspace_id: uuid.UUID, - invite_data: WorkspaceInviteCreate, + db: Session, + workspace_id: uuid.UUID, + invite_data: WorkspaceInviteCreate, user: User ) -> WorkspaceInviteResponse: """创建工作空间邀请""" business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}") - + try: # 检查权限 _check_workspace_admin_permission(db, workspace_id, user) @@ -368,7 +365,7 @@ def create_workspace_invite( # 检查被邀请用户是否已经在工作空间中 from app.repositories import user_repository invited_user = user_repository.get_user_by_email(db, invite_data.email) - + if invited_user: # 用户存在,检查是否已经是工作空间成员 existing_member = workspace_repository.get_member_in_workspace( @@ -379,14 +376,14 @@ def create_workspace_invite( if existing_member: business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员") raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS) - + # 检查是否已有待处理的邀请 invite_repo = WorkspaceInviteRepository(db) existing_invite = invite_repo.get_pending_invite_by_email_and_workspace( - email=invite_data.email, + email=invite_data.email, workspace_id=workspace_id ) - + invite_token = None if existing_invite: business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请") @@ -409,17 +406,17 @@ def create_workspace_invite( ) db.commit() db.refresh(db_invite) - invite_token = token - + invite_token = token + invite_obj = existing_invite or db_invite business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}") - + # 构造响应 response = WorkspaceInviteResponse.model_validate(invite_obj) response.invite_token = invite_token return response - - + + except Exception as e: db.rollback() business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}") @@ -427,8 +424,8 @@ def create_workspace_invite( def get_workspace_invites( - db: Session, - workspace_id: uuid.UUID, + db: Session, + workspace_id: uuid.UUID, user: User, status: Optional[InviteStatus] = None, limit: int = 50, @@ -436,15 +433,15 @@ def get_workspace_invites( ) -> List[WorkspaceInviteResponse]: """获取工作空间邀请列表""" business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}") - + # 检查工作空间是否存在 workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id) if not workspace: raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND) - + # 检查权限 _check_workspace_admin_permission(db, workspace_id, user) - + # 获取邀请列表 invite_repo = WorkspaceInviteRepository(db) invites = invite_repo.get_workspace_invites( @@ -453,35 +450,35 @@ def get_workspace_invites( limit=limit, offset=offset ) - + return [WorkspaceInviteResponse.model_validate(invite) for invite in invites] def validate_invite_token(db: Session, token: str) -> InviteValidateResponse: """验证邀请令牌""" business_logger.info("验证邀请令牌") - + # 生成令牌哈希 token_hash = hashlib.sha256(token.encode()).hexdigest() - + # 查找邀请 invite_repo = WorkspaceInviteRepository(db) invite = invite_repo.get_invite_by_token_hash(token_hash) - + if not invite: business_logger.warning("邀请令牌无效") raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND) - + # 检查邀请状态和过期时间 now = datetime.datetime.now() is_expired = invite.expires_at < now or invite.status != InviteStatus.pending is_valid = not is_expired - + # 获取工作空间信息 workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) - + business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}") - + return InviteValidateResponse( workspace_name=workspace.name, workspace_id=invite.workspace_id, @@ -493,32 +490,32 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse: def accept_workspace_invite( - db: Session, - accept_request: InviteAcceptRequest, + db: Session, + accept_request: InviteAcceptRequest, user: User ) -> dict: """接受工作空间邀请""" business_logger.info(f"接受工作空间邀请: 用户 {user.username}") - + try: from app.core.config import settings - + # 生成令牌哈希 token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest() - + # 查找邀请 invite_repo = WorkspaceInviteRepository(db) invite = invite_repo.get_invite_by_token_hash(token_hash) - + if not invite: business_logger.warning("邀请令牌无效") raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND) - + # 检查邀请状态 if invite.status != InviteStatus.pending: business_logger.warning(f"邀请已被处理: status={invite.status}") raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID) - + # 检查过期时间 now = datetime.datetime.now() if invite.expires_at < now: @@ -526,31 +523,31 @@ def accept_workspace_invite( # 标记为过期 invite_repo.update_invite_status(invite.id, InviteStatus.expired) raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED) - + # 检查邮箱是否匹配 if invite.email != user.email: business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}") raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN) - + # 如果启用单工作空间模式,检查用户是否已有工作空间 if settings.ENABLE_SINGLE_WORKSPACE: user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id) if user_workspaces: business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}") raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN) - + # 检查用户是否已经是工作空间成员 existing_member = workspace_repository.get_member_in_workspace( - db=db, - user_id=user.id, + db=db, + user_id=user.id, workspace_id=invite.workspace_id ) - + if existing_member: business_logger.info("用户已是工作空间成员,更新邀请状态") invite_repo.update_invite_status( - invite.id, - InviteStatus.accepted, + invite.id, + InviteStatus.accepted, accepted_at=now ) db.commit() @@ -559,10 +556,10 @@ def accept_workspace_invite( "message": "You are already a member of this workspace", "workspace": workspace } - + # 将角色映射到工作空间角色(现在直接使用相同的角色) workspace_role = invite.role - + # 添加用户到工作空间 workspace_repository.add_member_to_workspace( db=db, @@ -570,27 +567,27 @@ def accept_workspace_invite( workspace_id=invite.workspace_id, role=workspace_role ) - + # 标记邀请为已接受 invite_repo.update_invite_status( - invite.id, - InviteStatus.accepted, + invite.id, + InviteStatus.accepted, accepted_at=now ) - + db.commit() - + # 获取工作空间信息 workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id) - + business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}") - + return { "message": "Successfully joined the workspace", "workspace": workspace, "role": workspace_role } - + except Exception as e: db.rollback() business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}") @@ -598,34 +595,34 @@ def accept_workspace_invite( def revoke_workspace_invite( - db: Session, - workspace_id: uuid.UUID, - invite_id: uuid.UUID, + db: Session, + workspace_id: uuid.UUID, + invite_id: uuid.UUID, user: User ) -> dict: """撤销工作空间邀请""" business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}") - + try: # 检查权限 _check_workspace_admin_permission(db, workspace_id, user) - + # 撤销邀请 invite_repo = WorkspaceInviteRepository(db) invite = invite_repo.revoke_invite(invite_id) - + if not invite: business_logger.warning(f"邀请不存在: invite_id={invite_id}") raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND) - + if invite.workspace_id != workspace_id: business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}") raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST) - + db.commit() business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}") return {"message": "邀请撤销成功"} - + except Exception as e: db.rollback() business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}") @@ -640,48 +637,48 @@ def update_workspace_member_roles( ) -> List[WorkspaceMember]: """更新工作空间成员角色""" business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}") - + # 检查管理员权限 _check_workspace_admin_permission(db, workspace_id, user) - + # 获取所有当前成员 all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id) member_map = {m.id: m for m in all_members} - + # 验证和业务规则检查 update_ids = set() for upd in updates: # 检查成员是否存在 if upd.id not in member_map: raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - + member = member_map[upd.id] - + # 检查成员是否属于该工作空间 if member.workspace_id != workspace_id: raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND) - + # 不能修改自己的角色 if member.user_id == user.id: raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST) - + update_ids.add(upd.id) - + # 检查是否至少保留一个 manager current_managers = [m for m in all_members if m.role == WorkspaceRole.manager] managers_after_update = [ - m for m in all_members + m for m in all_members if m.id not in update_ids and m.role == WorkspaceRole.manager ] - + # 添加更新后会成为 manager 的成员 for upd in updates: if upd.role == WorkspaceRole.manager: managers_after_update.append(member_map[upd.id]) - + if len(managers_after_update) == 0: raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST) - + # 执行更新 try: for upd in updates: @@ -691,15 +688,15 @@ def update_workspace_member_roles( role=upd.role, ) business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}") - + db.commit() - + # 重新获取更新后的成员列表 updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id) business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}") - + return updated_members - + except Exception as e: db.rollback() business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}") @@ -789,7 +786,7 @@ def get_workspace_models_configs( # 查询工作空间模型配置 configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id) - + if configs is None: business_logger.error(f"工作空间不存在: workspace_id={workspace_id}") raise BusinessException( @@ -801,4 +798,5 @@ def get_workspace_models_configs( f"成功获取工作空间 {workspace_id} 的模型配置: " f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}" ) - return configs \ No newline at end of file + return configs +