[add] app chat v1
This commit is contained in:
@@ -361,7 +361,8 @@ async def draft_run(
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
@@ -370,7 +371,8 @@ async def draft_run(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
|
||||
97
api/app/controllers/order_controller.py
Normal file
97
api/app/controllers/order_controller.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy.orm import Session
|
||||
import os
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.order_schema import CreateOrderRequest
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.order_service import get_order_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, error
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/order",
|
||||
tags=["Order"],
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
async def create_order(
|
||||
order_data: CreateOrderRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
|
||||
try:
|
||||
api_logger.info(f"User {current_user.id} creating order for product {order_data.product_id}")
|
||||
|
||||
# Get external API configuration from environment
|
||||
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
|
||||
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
|
||||
|
||||
# Get order service instance
|
||||
order_service = get_order_service(
|
||||
external_api_url=external_api_url,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Forward request to external API
|
||||
result = await order_service.create_order(
|
||||
order_data=order_data,
|
||||
user_id=str(current_user.id)
|
||||
)
|
||||
|
||||
api_logger.info(f"Order created successfully: {result.get('order_id')}")
|
||||
|
||||
return success(data=result, msg="Order created successfully")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to create order: {str(e)}", exc_info=True)
|
||||
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
|
||||
@router.get("/{order_id}", response_model=ApiResponse)
|
||||
async def get_order(
|
||||
order_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get order details from external API
|
||||
|
||||
Args:
|
||||
order_id: Order ID
|
||||
db: Database session
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
API response with order details
|
||||
"""
|
||||
try:
|
||||
api_logger.info(f"User {current_user.id} fetching order {order_id}")
|
||||
|
||||
# Get external API configuration
|
||||
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
|
||||
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
|
||||
|
||||
# Get order service instance
|
||||
order_service = get_order_service(
|
||||
external_api_url=external_api_url,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
# Fetch order from external API
|
||||
result = await order_service.get_order(order_id)
|
||||
|
||||
api_logger.info(f"Order {order_id} fetched successfully")
|
||||
|
||||
return success(data=result, msg="Order fetched successfully")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch order {order_id}: {str(e)}", exc_info=True)
|
||||
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.dependencies import get_app_or_workspace
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas import AppChatRequest, conversation_schema
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config,dict_to_agent_config,dict_to_workflow_config
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -19,28 +35,232 @@ async def list_apps():
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
# /v1/apps/{resource_id}/chat
|
||||
@router.post("/{resource_id}/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat_with_agent_demo(
|
||||
resource_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
|
||||
|
||||
# async def chat(
|
||||
# request: Request,
|
||||
# api_key_auth: ApiKeyAuth = None,
|
||||
# db: Session = Depends(get_db),
|
||||
# message: str = Body(..., description="聊天消息内容"),
|
||||
# ):
|
||||
# """
|
||||
# Agent 聊天接口demo
|
||||
|
||||
# scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
|
||||
# Args:
|
||||
# resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
# message: 请求参数
|
||||
# request: 声明请求
|
||||
# api_key_auth: 包含验证后的API Key 信息
|
||||
# db: db_session
|
||||
# """
|
||||
# logger.info(f"API Key Auth: {api_key_auth}")
|
||||
# logger.info(f"Resource ID: {resource_id}")
|
||||
# logger.info(f"Message: {message}")
|
||||
# return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
|
||||
def _checkAppConfig(app: App):
|
||||
if app.type == AppType.AGENT:
|
||||
if not app.current_release.config:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
if not app.current_release.config:
|
||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.WORKFLOW:
|
||||
if not app.current_release.config:
|
||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@router.post("/chat")
|
||||
# @require_api_key(scopes=["app"])
|
||||
async def chat(
|
||||
payload: AppChatRequest,
|
||||
app: App = Depends(get_app_or_workspace),
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
|
||||
):
|
||||
"""
|
||||
Agent 聊天接口demo
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
app_type = app.type
|
||||
# check app config
|
||||
_checkAppConfig(app)
|
||||
|
||||
Args:
|
||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||
message: 请求参数
|
||||
request: 声明请求
|
||||
api_key_auth: 包含验证后的API Key 信息
|
||||
db: db_session
|
||||
"""
|
||||
logger.info(f"API Key Auth: {api_key_auth}")
|
||||
logger.info(f"Resource ID: {resource_id}")
|
||||
logger.info(f"Message: {message}")
|
||||
return success(data={"received": True}, msg="消息已接收")
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=end_user_id,
|
||||
is_draft=False
|
||||
)
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
agent_config = dict_to_agent_config(app.current_release.config)
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=app.current_release.config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= agent_config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = dict_to_multi_agent_config(app.current_release.config)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = dict_to_workflow_config(app.current_release.config)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from jose import jwt, JWTError
|
||||
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.models import App
|
||||
from app.schemas import token_schema
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_token_id
|
||||
@@ -27,6 +28,51 @@ security_logger = get_security_logger()
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
class APIKeyExtractor:
|
||||
"""
|
||||
Custom dependency to extract API Key from request headers
|
||||
|
||||
Supports two formats:
|
||||
1. Authorization: Bearer <api_key>
|
||||
2. X-API-Key: <api_key>
|
||||
"""
|
||||
|
||||
async def __call__(self, request: Request) -> str:
|
||||
"""Extract API Key from request headers
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
API Key string
|
||||
|
||||
Raises:
|
||||
HTTPException: If API Key is not found
|
||||
"""
|
||||
# Try Authorization header first
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and " " in auth_header:
|
||||
auth_scheme, auth_token = auth_header.split(" ", 1)
|
||||
if auth_scheme.lower() == "bearer":
|
||||
return auth_token
|
||||
|
||||
# Try X-API-Key header
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
if api_key:
|
||||
return api_key
|
||||
|
||||
# No API Key found
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API Key not found in request headers",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
api_key_extractor = APIKeyExtractor()
|
||||
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
@@ -304,7 +350,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
- db: Session = Depends(get_db)
|
||||
- user 或 current_user: User = Depends(get_current_user)
|
||||
- workspace_id: uuid.UUID (query/path 参数)或 payload: AppCreate(body,含 workspace_id)
|
||||
|
||||
|
||||
支持同步和异步函数。
|
||||
"""
|
||||
import asyncio
|
||||
@@ -360,7 +406,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
def get_uow() -> IUnitOfWork:
|
||||
"""
|
||||
获取工作单元实例
|
||||
|
||||
|
||||
Returns:
|
||||
IUnitOfWork: 工作单元实例
|
||||
"""
|
||||
@@ -373,7 +419,7 @@ def cur_workspace_access_guard():
|
||||
要求端点函数签名包含:
|
||||
- db: Session = Depends(get_db)
|
||||
- current_user: User = Depends(get_current_user)
|
||||
|
||||
|
||||
支持同步和异步函数。
|
||||
"""
|
||||
import asyncio
|
||||
@@ -423,10 +469,10 @@ async def get_share_user_id(
|
||||
) -> ShareTokenData:
|
||||
"""
|
||||
从分享访问 token 中获取用户 ID 和 share_token
|
||||
|
||||
|
||||
这个函数用于公开分享的接口,验证访问 token 并返回用户信息
|
||||
不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效
|
||||
|
||||
|
||||
Returns:
|
||||
ShareTokenData: 包含 user_id 和 share_token
|
||||
"""
|
||||
@@ -469,4 +515,75 @@ async def get_share_user_id(
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
async def get_app_or_workspace(
|
||||
api_key: str = Depends(api_key_extractor),
|
||||
db: Session = Depends(get_db)
|
||||
) -> App | Workspace:
|
||||
"""
|
||||
Get App or Workspace from API Key
|
||||
|
||||
Supports two API Key formats:
|
||||
1. Authorization: Bearer <api_key>
|
||||
2. X-API-Key: <api_key>
|
||||
|
||||
Args:
|
||||
api_key: API Key extracted from request headers
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
App or Workspace object based on API Key
|
||||
|
||||
Raises:
|
||||
HTTPException: If API Key is invalid or not found
|
||||
"""
|
||||
from app.services.api_key_service import ApiKeyAuthService
|
||||
from app.repositories.app_repository import get_apps_by_id
|
||||
from app.repositories.workspace_repository import get_workspace_by_id
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate API Key",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
auth_logger.debug(f"Validating API Key: {api_key[:10]}...")
|
||||
|
||||
# Validate API Key
|
||||
api_key_obj = ApiKeyAuthService.validate_api_key(db, api_key)
|
||||
if not api_key_obj:
|
||||
auth_logger.warning(f"Invalid or expired API Key: {api_key[:10]}...")
|
||||
raise credentials_exception
|
||||
|
||||
auth_logger.debug(f"API Key validated successfully, type: {api_key_obj.type}")
|
||||
|
||||
# Return App or Workspace based on API Key type
|
||||
if (api_key_obj.type == "agent" or api_key.type == "multi_agent") and api_key_obj.resource_id:
|
||||
# App API Key
|
||||
app = get_apps_by_id(db, api_key_obj.resource_id)
|
||||
if not app:
|
||||
auth_logger.warning(f"App not found for API Key: {api_key_obj.resource_id}")
|
||||
raise credentials_exception
|
||||
auth_logger.info(f"App access granted: {app.id}")
|
||||
return app
|
||||
|
||||
elif api_key_obj.type == "service":
|
||||
# Workspace API Key
|
||||
workspace = get_workspace_by_id(db, api_key_obj.workspace_id)
|
||||
if not workspace:
|
||||
auth_logger.warning(f"Workspace not found for API Key: {api_key_obj.workspace_id}")
|
||||
raise credentials_exception
|
||||
auth_logger.info(f"Workspace access granted: {workspace.id}")
|
||||
return workspace
|
||||
|
||||
else:
|
||||
auth_logger.warning(f"Unsupported API Key type: {api_key_obj.type}")
|
||||
raise credentials_exception
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_logger.error(f"Error validating API Key: {str(e)}", exc_info=True)
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
|
||||
@@ -2,38 +2,10 @@ import os
|
||||
import subprocess
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, APIRouter
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.core.response_utils import fail
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode, HTTP_MAPPING
|
||||
from app.controllers import (
|
||||
model_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
user_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
memory_agent_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
multi_agent_controller,
|
||||
)
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
|
||||
app = FastAPI(title="Data Config API", version="1.0.0")
|
||||
router = APIRouter(prefix="/memory", tags=["Memory"])
|
||||
|
||||
# 管理端 API (JWT 认证)
|
||||
from app.controllers import manager_router
|
||||
|
||||
@@ -8,7 +8,9 @@ from .file_schema import File, FileCreate, FileUpdate
|
||||
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
||||
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
||||
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
||||
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
|
||||
from .app_schema import (
|
||||
AppChatRequest,
|
||||
DraftRunRequest,
|
||||
DraftRunResponse,
|
||||
DraftRunStreamChunk,
|
||||
@@ -73,6 +75,10 @@ __all__ = [
|
||||
"ChunkRetrieve",
|
||||
"KnowledgeShare",
|
||||
"KnowledgeShareCreate",
|
||||
"CreateOrderRequest",
|
||||
"OrderResponse",
|
||||
"ExternalOrderResponse",
|
||||
"AppChatRequest",
|
||||
"DraftRunRequest",
|
||||
"DraftRunResponse",
|
||||
"DraftRunStreamChunk",
|
||||
|
||||
@@ -334,6 +334,13 @@ class AppShare(BaseModel):
|
||||
|
||||
# ---------- Draft Run Schemas ----------
|
||||
|
||||
class AppChatRequest(BaseModel):
|
||||
message: str = Field(..., description="用户消息")
|
||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
"""试运行请求"""
|
||||
message: str = Field(..., description="用户消息")
|
||||
|
||||
63
api/app/schemas/order_schema.py
Normal file
63
api/app/schemas/order_schema.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Order Schema
|
||||
|
||||
Defines request and response models for order operations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class CreateOrderRequest(BaseModel):
|
||||
"""Create order request model"""
|
||||
|
||||
product_id: str = Field(..., description="Product ID")
|
||||
quantity: int = Field(..., gt=0, description="Order quantity")
|
||||
customer_name: Optional[str] = Field(None, description="Customer name")
|
||||
customer_email: Optional[str] = Field(None, description="Customer email")
|
||||
notes: Optional[str] = Field(None, description="Order notes")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"product_id": "PROD-001",
|
||||
"quantity": 2,
|
||||
"customer_name": "John Doe",
|
||||
"customer_email": "john@example.com",
|
||||
"notes": "Please deliver before 5pm"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class OrderResponse(BaseModel):
|
||||
"""Order response model"""
|
||||
|
||||
order_id: str = Field(..., description="Order ID")
|
||||
status: str = Field(..., description="Order status")
|
||||
product_id: str = Field(..., description="Product ID")
|
||||
quantity: int = Field(..., description="Order quantity")
|
||||
total_amount: Optional[float] = Field(None, description="Total amount")
|
||||
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||
message: Optional[str] = Field(None, description="Response message")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"order_id": "ORD-20231224-001",
|
||||
"status": "pending",
|
||||
"product_id": "PROD-001",
|
||||
"quantity": 2,
|
||||
"total_amount": 199.99,
|
||||
"created_at": "2023-12-24T10:30:00Z",
|
||||
"message": "Order created successfully"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExternalOrderResponse(BaseModel):
|
||||
"""External API response model (flexible structure)"""
|
||||
|
||||
success: bool = Field(default=True, description="Request success status")
|
||||
data: Optional[Any] = Field(None, description="Response data")
|
||||
error: Optional[str] = Field(None, description="Error message")
|
||||
code: Optional[int] = Field(None, description="Response code")
|
||||
485
api/app/services/app_chat_service.py
Normal file
485
api/app/services/app_chat_service.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AppChatService:
|
||||
"""基于分享链接的聊天服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
|
||||
async def agnet_chat(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "")
|
||||
if variables:
|
||||
system_prompt_rendered = render_prompt_message(
|
||||
system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory == True:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 调用 Agent
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
assistant_message=result["content"]
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"message": result["content"],
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}),
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def agnet_chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "")
|
||||
if variables:
|
||||
system_prompt_rendered = render_prompt_message(
|
||||
system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": {}
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
"流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("流式聊天被中断")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: MultiAgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""多 Agent 聊天(非流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
# 3. 执行任务
|
||||
result = await orchestrator.execute(
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
use_llm_routing=True, # 默认启用 LLM 路由
|
||||
web_search=web_search, # 网络搜索参数
|
||||
memory=memory # 记忆功能参数
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=result.get("message", ""),
|
||||
meta_data={
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"sub_results": result.get("sub_results")
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"message": result.get("message", ""),
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
},
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: MultiAgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
try:
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
full_content = ""
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
use_llm_routing=True,
|
||||
web_search=web_search, # 网络搜索参数
|
||||
memory=memory, # 记忆功能参数
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
# 尝试提取内容(用于保存)
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "content" in data:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"多 Agent 流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("多 Agent 流式聊天被中断")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_app_chat_service(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> ChatService:
|
||||
"""获取工作流服务(依赖注入)"""
|
||||
return ChatService(db)
|
||||
@@ -1,9 +1,12 @@
|
||||
"""会话服务"""
|
||||
import uuid
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import Optional, List, Tuple, Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
from app.db import get_db
|
||||
from app.models import Conversation, Message
|
||||
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -14,10 +17,10 @@ logger = get_business_logger()
|
||||
|
||||
class ConversationService:
|
||||
"""会话服务"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -36,11 +39,11 @@ class ConversationService:
|
||||
is_draft=is_draft,
|
||||
config_snapshot=config_snapshot
|
||||
)
|
||||
|
||||
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
|
||||
|
||||
logger.info(
|
||||
"创建会话成功",
|
||||
extra={
|
||||
@@ -50,9 +53,9 @@ class ConversationService:
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def get_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -60,17 +63,17 @@ class ConversationService:
|
||||
) -> Conversation:
|
||||
"""获取会话"""
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
|
||||
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
|
||||
|
||||
if not conversation:
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -86,25 +89,25 @@ class ConversationService:
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active == True
|
||||
)
|
||||
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
|
||||
# 总数
|
||||
count_stmt = stmt.with_only_columns(Conversation.id)
|
||||
total = len(self.db.execute(count_stmt).all())
|
||||
|
||||
|
||||
# 分页
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
|
||||
|
||||
return conversations, total
|
||||
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -119,22 +122,22 @@ class ConversationService:
|
||||
content=content,
|
||||
meta_data=meta_data
|
||||
)
|
||||
|
||||
|
||||
self.db.add(message)
|
||||
|
||||
|
||||
# 更新会话的消息计数和更新时间
|
||||
conversation = self.get_conversation(conversation_id)
|
||||
conversation.message_count += 1
|
||||
|
||||
|
||||
# 如果是第一条用户消息,可以用它作为标题
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
|
||||
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -144,30 +147,30 @@ class ConversationService:
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""获取会话历史消息
|
||||
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
|
||||
|
||||
Returns:
|
||||
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
|
||||
"""
|
||||
messages = self.get_messages(conversation_id, limit=max_history)
|
||||
|
||||
|
||||
# 转换为字典格式
|
||||
history = [
|
||||
{
|
||||
@@ -176,9 +179,9 @@ class ConversationService:
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def save_conversation_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -192,14 +195,14 @@ class ConversationService:
|
||||
role="user",
|
||||
content=user_message
|
||||
)
|
||||
|
||||
|
||||
# 添加助手消息
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=assistant_message
|
||||
)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"保存会话消息成功",
|
||||
extra={
|
||||
@@ -208,7 +211,7 @@ class ConversationService:
|
||||
"assistant_message_length": len(assistant_message)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -217,9 +220,9 @@ class ConversationService:
|
||||
"""删除会话(软删除)"""
|
||||
conversation = self.get_conversation(conversation_id, workspace_id)
|
||||
conversation.is_active = False
|
||||
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
||||
logger.info(
|
||||
"删除会话成功",
|
||||
extra={
|
||||
@@ -227,3 +230,53 @@ class ConversationService:
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
def create_or_get_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
is_draft: bool = False,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Conversation:
|
||||
"""创建或获取会话"""
|
||||
|
||||
# 如果提供了 conversation_id,尝试获取现有会话
|
||||
if conversation_id:
|
||||
try:
|
||||
conversation = self.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
"会话不存在,将创建新会话",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
# 创建新会话(使用发布版本的配置)
|
||||
conversation = self.create_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"为分享链接创建新会话"
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_conversation_service(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> ConversationService:
|
||||
"""获取工作流服务(依赖注入)"""
|
||||
return ConversationService(db)
|
||||
|
||||
@@ -36,7 +36,7 @@ class ModelConfigService:
|
||||
"""获取模型配置列表"""
|
||||
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
|
||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||
|
||||
|
||||
return PageData(
|
||||
page=PageMeta(
|
||||
page=query.page,
|
||||
@@ -72,7 +72,7 @@ class ModelConfigService:
|
||||
test_message: str = "Hello"
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称
|
||||
@@ -81,7 +81,7 @@ class ModelConfigService:
|
||||
api_base: API基础URL
|
||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||
test_message: 测试消息
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
"""
|
||||
@@ -89,10 +89,10 @@ class ModelConfigService:
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
import traceback
|
||||
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
@@ -101,16 +101,16 @@ class ModelConfigService:
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
|
||||
# 根据模型类型选择不同的验证方式
|
||||
model_type_lower = model_type.lower()
|
||||
|
||||
|
||||
if model_type_lower in ["llm", "chat"]:
|
||||
# LLM/Chat 模型验证 - 统一使用字符串输入
|
||||
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
|
||||
response = await llm.ainvoke(test_message)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
usage = None
|
||||
if hasattr(response, 'usage_metadata'):
|
||||
@@ -119,7 +119,7 @@ class ModelConfigService:
|
||||
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
|
||||
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
|
||||
}
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": f"{model_type.upper()} 模型配置验证成功",
|
||||
@@ -128,14 +128,14 @@ class ModelConfigService:
|
||||
"usage": usage,
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "embedding":
|
||||
# Embedding 模型验证(在线程中运行同步方法)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Embedding 模型配置验证成功",
|
||||
@@ -148,7 +148,7 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "rerank":
|
||||
# Rerank 模型验证(在线程中运行同步方法)
|
||||
rerank = RedBearRerank(model_config)
|
||||
@@ -156,7 +156,7 @@ class ModelConfigService:
|
||||
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
|
||||
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Rerank 模型配置验证成功",
|
||||
@@ -169,7 +169,7 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
@@ -179,7 +179,7 @@ class ModelConfigService:
|
||||
"usage": None,
|
||||
"error": f"不支持的模型类型: {model_type}"
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
@@ -203,12 +203,12 @@ class ModelConfigService:
|
||||
error_message = f"无效请求: {error_message}"
|
||||
elif "model_copy" in error_message:
|
||||
error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)"
|
||||
|
||||
|
||||
# 记录详细错误日志
|
||||
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
|
||||
logger.error(f"错误详情: {error_message}")
|
||||
logger.debug(f"完整堆栈: {traceback.format_exc()}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": False,
|
||||
"message": f"{model_type.upper()} 模型配置验证失败",
|
||||
@@ -249,7 +249,7 @@ class ModelConfigService:
|
||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||
# 添加租户ID
|
||||
model_config_data["tenant_id"] = tenant_id
|
||||
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush() # 获取生成的 ID
|
||||
|
||||
@@ -259,7 +259,7 @@ class ModelConfigService:
|
||||
**api_key_data.dict()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
@@ -270,11 +270,11 @@ class ModelConfigService:
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
@@ -285,7 +285,7 @@ class ModelConfigService:
|
||||
"""删除模型配置"""
|
||||
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
|
||||
db.commit()
|
||||
return success
|
||||
@@ -316,20 +316,20 @@ class ModelApiKeyService:
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
|
||||
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]:
|
||||
"""根据模型配置ID获取API Key列表"""
|
||||
if not ModelConfigRepository.get_by_id(db, model_config_id):
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
"""创建API Key"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
@@ -345,7 +345,7 @@ class ModelApiKeyService:
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
@@ -357,12 +357,12 @@ class ModelApiKeyService:
|
||||
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not existing_api_key:
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
@@ -378,7 +378,7 @@ class ModelApiKeyService:
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
@@ -389,7 +389,7 @@ class ModelApiKeyService:
|
||||
"""删除API Key"""
|
||||
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
success = ModelApiKeyRepository.delete(db, api_key_id)
|
||||
db.commit()
|
||||
return success
|
||||
@@ -409,3 +409,11 @@ class ModelApiKeyService:
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
|
||||
|
||||
api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id)
|
||||
if api_kes and len(api_kes) > 0:
|
||||
return api_kes[0]
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
205
api/app/services/order_service.py
Normal file
205
api/app/services/order_service.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Order Service
|
||||
|
||||
Handles order operations including forwarding requests to external APIs.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional
|
||||
from app.schemas.order_schema import CreateOrderRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrderService:
|
||||
"""Order service for handling order operations"""
|
||||
|
||||
def __init__(self, external_api_url: Optional[str] = None, api_key: Optional[str] = None):
|
||||
"""Initialize order service
|
||||
|
||||
Args:
|
||||
external_api_url: External API base URL
|
||||
api_key: API key for authentication
|
||||
"""
|
||||
# Default external API URL (replace with actual URL)
|
||||
self.external_api_url = external_api_url or "https://api.example.com/v1"
|
||||
self.api_key = api_key
|
||||
self.timeout = 30.0 # 30 seconds timeout
|
||||
|
||||
async def create_order(
|
||||
self,
|
||||
order_data: CreateOrderRequest,
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create order by forwarding request to external API
|
||||
|
||||
Args:
|
||||
order_data: Order creation data
|
||||
user_id: Current user ID
|
||||
|
||||
Returns:
|
||||
Order response data
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If external API request fails
|
||||
Exception: For other errors
|
||||
"""
|
||||
try:
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"product_id": order_data.product_id,
|
||||
"quantity": order_data.quantity,
|
||||
"customer_name": order_data.customer_name,
|
||||
"customer_email": order_data.customer_email,
|
||||
"notes": order_data.notes,
|
||||
"user_id": user_id # Include user ID for tracking
|
||||
}
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MemoryBear-OrderService/1.0"
|
||||
}
|
||||
|
||||
# Add API key if configured
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
logger.info(f"Forwarding order creation request to external API: {self.external_api_url}/orders")
|
||||
logger.debug(f"Request payload: {payload}")
|
||||
|
||||
# Make async HTTP request to external API
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.external_api_url}/orders",
|
||||
json=payload,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Log response status
|
||||
logger.info(f"External API response status: {response.status_code}")
|
||||
|
||||
# Raise exception for 4xx/5xx status codes
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse response
|
||||
response_data = response.json()
|
||||
logger.debug(f"External API response data: {response_data}")
|
||||
|
||||
# Transform external API response to internal format
|
||||
return self._transform_external_response(response_data)
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"External API returned error status: {e.response.status_code}")
|
||||
logger.error(f"Error response: {e.response.text}")
|
||||
|
||||
# Try to parse error response
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_message = error_data.get("message") or error_data.get("error") or "External API error"
|
||||
except Exception:
|
||||
error_message = f"External API error: {e.response.status_code}"
|
||||
|
||||
raise Exception(f"Failed to create order: {error_message}")
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"External API request timeout after {self.timeout}s")
|
||||
raise Exception("Order creation timeout - external service not responding")
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"External API request failed: {str(e)}")
|
||||
raise Exception(f"Failed to connect to external order service: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during order creation: {str(e)}", exc_info=True)
|
||||
raise Exception(f"Order creation failed: {str(e)}")
|
||||
|
||||
def _transform_external_response(self, external_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Transform external API response to internal format
|
||||
|
||||
Args:
|
||||
external_data: Response data from external API
|
||||
|
||||
Returns:
|
||||
Transformed response data
|
||||
"""
|
||||
# Handle different response formats from external API
|
||||
# Adjust this based on actual external API response structure
|
||||
|
||||
if "data" in external_data:
|
||||
# Format 1: {"success": true, "data": {...}}
|
||||
data = external_data["data"]
|
||||
elif "order" in external_data:
|
||||
# Format 2: {"order": {...}}
|
||||
data = external_data["order"]
|
||||
else:
|
||||
# Format 3: Direct response
|
||||
data = external_data
|
||||
|
||||
# Extract fields with fallbacks
|
||||
return {
|
||||
"order_id": data.get("order_id") or data.get("id") or "UNKNOWN",
|
||||
"status": data.get("status") or "pending",
|
||||
"product_id": data.get("product_id") or "",
|
||||
"quantity": data.get("quantity") or 0,
|
||||
"total_amount": data.get("total_amount") or data.get("amount"),
|
||||
"created_at": data.get("created_at") or data.get("timestamp"),
|
||||
"message": external_data.get("message") or "Order created successfully"
|
||||
}
|
||||
|
||||
async def get_order(self, order_id: str) -> Dict[str, Any]:
|
||||
"""Get order details from external API
|
||||
|
||||
Args:
|
||||
order_id: Order ID
|
||||
|
||||
Returns:
|
||||
Order details
|
||||
"""
|
||||
try:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
logger.info(f"Fetching order {order_id} from external API")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self.external_api_url}/orders/{order_id}",
|
||||
headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch order {order_id}: {str(e)}")
|
||||
raise Exception(f"Failed to fetch order: {str(e)}")
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_order_service_instance: Optional[OrderService] = None
|
||||
|
||||
|
||||
def get_order_service(
|
||||
external_api_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> OrderService:
|
||||
"""Get order service instance
|
||||
|
||||
Args:
|
||||
external_api_url: External API URL (optional, uses default if not provided)
|
||||
api_key: API key (optional)
|
||||
|
||||
Returns:
|
||||
OrderService instance
|
||||
"""
|
||||
global _order_service_instance
|
||||
|
||||
if _order_service_instance is None:
|
||||
_order_service_instance = OrderService(
|
||||
external_api_url=external_api_url,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return _order_service_instance
|
||||
@@ -1,35 +1,32 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import secrets
|
||||
import hashlib
|
||||
import datetime
|
||||
from fastapi import HTTPException, status
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
|
||||
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.schemas.workspace_schema import (
|
||||
WorkspaceCreate,
|
||||
WorkspaceUpdate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceCreate,
|
||||
WorkspaceUpdate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceInviteResponse,
|
||||
InviteValidateResponse,
|
||||
InviteAcceptRequest,
|
||||
WorkspaceMemberUpdate
|
||||
)
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.services import user_service
|
||||
from os import getenv
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
import os #
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
def switch_workspace(
|
||||
@@ -39,10 +36,10 @@ def switch_workspace(
|
||||
):
|
||||
"""切换工作空间"""
|
||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||
|
||||
|
||||
# 检查用户是否为成员或超级管理员
|
||||
_check_workspace_member_permission(db, workspace_id, user)
|
||||
|
||||
|
||||
# 更新当前用户的工作空间上下文
|
||||
try:
|
||||
user.current_workspace_id = workspace_id
|
||||
@@ -63,22 +60,22 @@ def delete_workspace_member(
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
@@ -94,7 +91,7 @@ def _create_workspace_only(
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
) -> Workspace:
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||
|
||||
|
||||
try:
|
||||
# Create the workspace without adding any members
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||
@@ -126,7 +123,7 @@ def create_workspace(
|
||||
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
|
||||
db.commit()
|
||||
db.refresh(db_workspace)
|
||||
|
||||
|
||||
# 如果 storage_type 是 "rag",自动创建知识库
|
||||
if workspace.storage_type == "rag":
|
||||
business_logger.info(
|
||||
@@ -138,7 +135,7 @@ def create_workspace(
|
||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||
from app.repositories import knowledge_repository
|
||||
|
||||
|
||||
# 创建知识库数据
|
||||
knowledge_data = KnowledgeCreate(
|
||||
workspace_id=db_workspace.id,
|
||||
@@ -162,10 +159,10 @@ def create_workspace(
|
||||
"html4excel": False
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
|
||||
db_knowledge = knowledge_repository.create_knowledge(
|
||||
db=db,
|
||||
db=db,
|
||||
knowledge=knowledge_data
|
||||
)
|
||||
db.commit()
|
||||
@@ -179,12 +176,12 @@ def create_workspace(
|
||||
)
|
||||
db.rollback()
|
||||
raise BusinessException(
|
||||
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
|
||||
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
|
||||
BizCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
|
||||
return db_workspace
|
||||
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
|
||||
db.rollback()
|
||||
@@ -195,7 +192,7 @@ def update_workspace(
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
) -> Workspace:
|
||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
|
||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
||||
try:
|
||||
# 更新工作空间
|
||||
@@ -219,8 +216,8 @@ def get_workspace_members(
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
) -> List[WorkspaceMember]:
|
||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
# 查找工作空间
|
||||
business_logger.debug(f"查找工作空间: {workspace_id}")
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
@@ -237,10 +234,10 @@ def get_workspace_members(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
workspace_memberships = {workspace_id} if member else set()
|
||||
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(workspace)
|
||||
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
@@ -265,7 +262,7 @@ def get_workspace_members(
|
||||
|
||||
def _generate_invite_token() -> tuple[str, str]:
|
||||
"""生成邀请令牌和其哈希值
|
||||
|
||||
|
||||
Returns:
|
||||
tuple: (原始令牌, 令牌哈希)
|
||||
"""
|
||||
@@ -285,21 +282,21 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
|
||||
message="Workspace not found",
|
||||
code=BizCode.WORKSPACE_NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# 使用统一权限服务检查访问权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
# 任何成员都有访问权限
|
||||
workspace_memberships = {workspace_id} if member else set()
|
||||
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(db_workspace)
|
||||
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
@@ -323,21 +320,21 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
message="Workspace not found",
|
||||
code=BizCode.WORKSPACE_NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# 使用统一权限服务检查管理权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
# 只有 manager 才有管理权限
|
||||
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
|
||||
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(db_workspace)
|
||||
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
@@ -353,14 +350,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
|
||||
|
||||
def create_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
) -> WorkspaceInviteResponse:
|
||||
"""创建工作空间邀请"""
|
||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
@@ -368,7 +365,7 @@ def create_workspace_invite(
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
|
||||
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
@@ -379,14 +376,14 @@ def create_workspace_invite(
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
|
||||
# 检查是否已有待处理的邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
|
||||
email=invite_data.email,
|
||||
email=invite_data.email,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
invite_token = None
|
||||
if existing_invite:
|
||||
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
|
||||
@@ -409,17 +406,17 @@ def create_workspace_invite(
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(db_invite)
|
||||
invite_token = token
|
||||
|
||||
invite_token = token
|
||||
|
||||
invite_obj = existing_invite or db_invite
|
||||
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
|
||||
|
||||
|
||||
# 构造响应
|
||||
response = WorkspaceInviteResponse.model_validate(invite_obj)
|
||||
response.invite_token = invite_token
|
||||
return response
|
||||
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
@@ -427,8 +424,8 @@ def create_workspace_invite(
|
||||
|
||||
|
||||
def get_workspace_invites(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
@@ -436,15 +433,15 @@ def get_workspace_invites(
|
||||
) -> List[WorkspaceInviteResponse]:
|
||||
"""获取工作空间邀请列表"""
|
||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
|
||||
# 检查工作空间是否存在
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
|
||||
# 获取邀请列表
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invites = invite_repo.get_workspace_invites(
|
||||
@@ -453,35 +450,35 @@ def get_workspace_invites(
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
|
||||
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
|
||||
|
||||
|
||||
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
"""验证邀请令牌"""
|
||||
business_logger.info("验证邀请令牌")
|
||||
|
||||
|
||||
# 生成令牌哈希
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
# 查找邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
|
||||
if not invite:
|
||||
business_logger.warning("邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
|
||||
# 检查邀请状态和过期时间
|
||||
now = datetime.datetime.now()
|
||||
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
|
||||
is_valid = not is_expired
|
||||
|
||||
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
|
||||
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
|
||||
|
||||
|
||||
return InviteValidateResponse(
|
||||
workspace_name=workspace.name,
|
||||
workspace_id=invite.workspace_id,
|
||||
@@ -493,32 +490,32 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
|
||||
|
||||
def accept_workspace_invite(
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""接受工作空间邀请"""
|
||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
# 生成令牌哈希
|
||||
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
|
||||
|
||||
|
||||
# 查找邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
|
||||
if not invite:
|
||||
business_logger.warning("邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
|
||||
# 检查邀请状态
|
||||
if invite.status != InviteStatus.pending:
|
||||
business_logger.warning(f"邀请已被处理: status={invite.status}")
|
||||
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
|
||||
|
||||
|
||||
# 检查过期时间
|
||||
now = datetime.datetime.now()
|
||||
if invite.expires_at < now:
|
||||
@@ -526,31 +523,31 @@ def accept_workspace_invite(
|
||||
# 标记为过期
|
||||
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
|
||||
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
|
||||
|
||||
|
||||
# 检查邮箱是否匹配
|
||||
if invite.email != user.email:
|
||||
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
|
||||
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
|
||||
|
||||
|
||||
# 如果启用单工作空间模式,检查用户是否已有工作空间
|
||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||
if user_workspaces:
|
||||
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
|
||||
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
|
||||
|
||||
|
||||
# 检查用户是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
workspace_id=invite.workspace_id
|
||||
)
|
||||
|
||||
|
||||
if existing_member:
|
||||
business_logger.info("用户已是工作空间成员,更新邀请状态")
|
||||
invite_repo.update_invite_status(
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
accepted_at=now
|
||||
)
|
||||
db.commit()
|
||||
@@ -559,10 +556,10 @@ def accept_workspace_invite(
|
||||
"message": "You are already a member of this workspace",
|
||||
"workspace": workspace
|
||||
}
|
||||
|
||||
|
||||
# 将角色映射到工作空间角色(现在直接使用相同的角色)
|
||||
workspace_role = invite.role
|
||||
|
||||
|
||||
# 添加用户到工作空间
|
||||
workspace_repository.add_member_to_workspace(
|
||||
db=db,
|
||||
@@ -570,27 +567,27 @@ def accept_workspace_invite(
|
||||
workspace_id=invite.workspace_id,
|
||||
role=workspace_role
|
||||
)
|
||||
|
||||
|
||||
# 标记邀请为已接受
|
||||
invite_repo.update_invite_status(
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
accepted_at=now
|
||||
)
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
|
||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
|
||||
|
||||
return {
|
||||
"message": "Successfully joined the workspace",
|
||||
"workspace": workspace,
|
||||
"role": workspace_role
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
|
||||
@@ -598,34 +595,34 @@ def accept_workspace_invite(
|
||||
|
||||
|
||||
def revoke_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""撤销工作空间邀请"""
|
||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
|
||||
# 撤销邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.revoke_invite(invite_id)
|
||||
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
|
||||
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
|
||||
if invite.workspace_id != workspace_id:
|
||||
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
|
||||
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
|
||||
|
||||
|
||||
db.commit()
|
||||
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
|
||||
return {"message": "邀请撤销成功"}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
|
||||
@@ -640,48 +637,48 @@ def update_workspace_member_roles(
|
||||
) -> List[WorkspaceMember]:
|
||||
"""更新工作空间成员角色"""
|
||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
|
||||
|
||||
# 检查管理员权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
|
||||
# 获取所有当前成员
|
||||
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||
member_map = {m.id: m for m in all_members}
|
||||
|
||||
|
||||
# 验证和业务规则检查
|
||||
update_ids = set()
|
||||
for upd in updates:
|
||||
# 检查成员是否存在
|
||||
if upd.id not in member_map:
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
|
||||
member = member_map[upd.id]
|
||||
|
||||
|
||||
# 检查成员是否属于该工作空间
|
||||
if member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
|
||||
# 不能修改自己的角色
|
||||
if member.user_id == user.id:
|
||||
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
|
||||
|
||||
|
||||
update_ids.add(upd.id)
|
||||
|
||||
|
||||
# 检查是否至少保留一个 manager
|
||||
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
|
||||
managers_after_update = [
|
||||
m for m in all_members
|
||||
m for m in all_members
|
||||
if m.id not in update_ids and m.role == WorkspaceRole.manager
|
||||
]
|
||||
|
||||
|
||||
# 添加更新后会成为 manager 的成员
|
||||
for upd in updates:
|
||||
if upd.role == WorkspaceRole.manager:
|
||||
managers_after_update.append(member_map[upd.id])
|
||||
|
||||
|
||||
if len(managers_after_update) == 0:
|
||||
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
|
||||
|
||||
|
||||
# 执行更新
|
||||
try:
|
||||
for upd in updates:
|
||||
@@ -691,15 +688,15 @@ def update_workspace_member_roles(
|
||||
role=upd.role,
|
||||
)
|
||||
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
|
||||
|
||||
|
||||
db.commit()
|
||||
|
||||
|
||||
# 重新获取更新后的成员列表
|
||||
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
|
||||
|
||||
|
||||
return updated_members
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
|
||||
@@ -789,7 +786,7 @@ def get_workspace_models_configs(
|
||||
|
||||
# 查询工作空间模型配置
|
||||
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
|
||||
|
||||
|
||||
if configs is None:
|
||||
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
raise BusinessException(
|
||||
@@ -801,4 +798,5 @@ def get_workspace_models_configs(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return configs
|
||||
return configs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user