[add] app chat v1
This commit is contained in:
@@ -361,7 +361,8 @@ async def draft_run(
|
|||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
if storage_type is None:
|
||||||
|
storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
user_rag_memory_id = ''
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
|
|
||||||
@@ -370,7 +371,8 @@ async def draft_run(
|
|||||||
name="USER_RAG_MERORY",
|
name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
|
||||||
|
|
||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
|
|||||||
97
api/app/controllers/order_controller.py
Normal file
97
api/app/controllers/order_controller.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
from fastapi import APIRouter, Depends, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.order_schema import CreateOrderRequest
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.services.order_service import get_order_service
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success, error
|
||||||
|
|
||||||
|
# Get API logger
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/order",
|
||||||
|
tags=["Order"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ApiResponse)
|
||||||
|
async def create_order(
|
||||||
|
order_data: CreateOrderRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_logger.info(f"User {current_user.id} creating order for product {order_data.product_id}")
|
||||||
|
|
||||||
|
# Get external API configuration from environment
|
||||||
|
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
|
||||||
|
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
|
||||||
|
|
||||||
|
# Get order service instance
|
||||||
|
order_service = get_order_service(
|
||||||
|
external_api_url=external_api_url,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward request to external API
|
||||||
|
result = await order_service.create_order(
|
||||||
|
order_data=order_data,
|
||||||
|
user_id=str(current_user.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Order created successfully: {result.get('order_id')}")
|
||||||
|
|
||||||
|
return success(data=result, msg="Order created successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to create order: {str(e)}", exc_info=True)
|
||||||
|
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{order_id}", response_model=ApiResponse)
|
||||||
|
async def get_order(
|
||||||
|
order_id: str,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get order details from external API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order ID
|
||||||
|
db: Database session
|
||||||
|
current_user: Current authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API response with order details
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
api_logger.info(f"User {current_user.id} fetching order {order_id}")
|
||||||
|
|
||||||
|
# Get external API configuration
|
||||||
|
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
|
||||||
|
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
|
||||||
|
|
||||||
|
# Get order service instance
|
||||||
|
order_service = get_order_service(
|
||||||
|
external_api_url=external_api_url,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch order from external API
|
||||||
|
result = await order_service.get_order(order_id)
|
||||||
|
|
||||||
|
api_logger.info(f"Order {order_id} fetched successfully")
|
||||||
|
|
||||||
|
return success(data=result, msg="Order fetched successfully")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to fetch order {order_id}: {str(e)}", exc_info=True)
|
||||||
|
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
@@ -2,14 +2,30 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, Request, Body
|
from fastapi import APIRouter, Depends, Request, Body
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.api_key_auth import require_api_key
|
from app.dependencies import get_app_or_workspace
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.repositories import knowledge_repository
|
||||||
|
from app.schemas import AppChatRequest, conversation_schema
|
||||||
|
from app.models.app_model import App
|
||||||
|
from app.models.app_model import AppType
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.services import workspace_service
|
||||||
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||||
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
from app.utils.app_config_utils import dict_to_multi_agent_config,dict_to_agent_config,dict_to_workflow_config
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -19,28 +35,232 @@ async def list_apps():
|
|||||||
return success(data=[], msg="App API - Coming Soon")
|
return success(data=[], msg="App API - Coming Soon")
|
||||||
|
|
||||||
# /v1/apps/{resource_id}/chat
|
# /v1/apps/{resource_id}/chat
|
||||||
@router.post("/{resource_id}/chat")
|
|
||||||
@require_api_key(scopes=["app"])
|
|
||||||
async def chat_with_agent_demo(
|
# async def chat(
|
||||||
resource_id: uuid.UUID,
|
# request: Request,
|
||||||
request: Request,
|
# api_key_auth: ApiKeyAuth = None,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
# db: Session = Depends(get_db),
|
||||||
|
# message: str = Body(..., description="聊天消息内容"),
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Agent 聊天接口demo
|
||||||
|
|
||||||
|
# scopes: 所需的权限范围列表["app", "rag", "memory"]
|
||||||
|
|
||||||
|
# Args:
|
||||||
|
# resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
||||||
|
# message: 请求参数
|
||||||
|
# request: 声明请求
|
||||||
|
# api_key_auth: 包含验证后的API Key 信息
|
||||||
|
# db: db_session
|
||||||
|
# """
|
||||||
|
# logger.info(f"API Key Auth: {api_key_auth}")
|
||||||
|
# logger.info(f"Resource ID: {resource_id}")
|
||||||
|
# logger.info(f"Message: {message}")
|
||||||
|
# return success(data={"received": True}, msg="消息已接收")
|
||||||
|
|
||||||
|
|
||||||
|
def _checkAppConfig(app: App):
|
||||||
|
if app.type == AppType.AGENT:
|
||||||
|
if not app.current_release.config:
|
||||||
|
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
elif app.type == AppType.MULTI_AGENT:
|
||||||
|
if not app.current_release.config:
|
||||||
|
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
elif app.type == AppType.WORKFLOW:
|
||||||
|
if not app.current_release.config:
|
||||||
|
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
else:
|
||||||
|
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
@router.post("/chat")
|
||||||
|
# @require_api_key(scopes=["app"])
|
||||||
|
async def chat(
|
||||||
|
payload: AppChatRequest,
|
||||||
|
app: App = Depends(get_app_or_workspace),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
message: str = Body(..., description="聊天消息内容"),
|
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||||
|
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
other_id = payload.user_id
|
||||||
Agent 聊天接口demo
|
workspace_id = app.workspace_id
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app.id,
|
||||||
|
other_id=other_id,
|
||||||
|
original_user_id=other_id # Save original user_id to other_id
|
||||||
|
)
|
||||||
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
scopes: 所需的权限范围列表["app", "rag", "memory"]
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
|
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if storage_type is None:
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
user_rag_memory_id = ''
|
||||||
|
if storage_type == 'rag':
|
||||||
|
if workspace_id:
|
||||||
|
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
|
db=db,
|
||||||
|
name="USER_RAG_MERORY",
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
if knowledge:
|
||||||
|
user_rag_memory_id = str(knowledge.id)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
else:
|
||||||
|
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
|
storage_type = 'neo4j'
|
||||||
|
app_type = app.type
|
||||||
|
# check app config
|
||||||
|
_checkAppConfig(app)
|
||||||
|
|
||||||
Args:
|
# 获取或创建会话(提前验证)
|
||||||
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
|
conversation = conversation_service.create_or_get_conversation(
|
||||||
message: 请求参数
|
app_id=app.id,
|
||||||
request: 声明请求
|
workspace_id=workspace_id,
|
||||||
api_key_auth: 包含验证后的API Key 信息
|
user_id=end_user_id,
|
||||||
db: db_session
|
is_draft=False
|
||||||
"""
|
)
|
||||||
logger.info(f"API Key Auth: {api_key_auth}")
|
|
||||||
logger.info(f"Resource ID: {resource_id}")
|
if app_type == AppType.AGENT:
|
||||||
logger.info(f"Message: {message}")
|
agent_config = dict_to_agent_config(app.current_release.config)
|
||||||
return success(data={"received": True}, msg="消息已接收")
|
# 流式返回
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id= end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
config=app.current_release.config,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
result = await app_chat_service.agnet_chat(
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config= agent_config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
|
# 多 Agent 流式返回
|
||||||
|
config = dict_to_multi_agent_config(app.current_release.config)
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in app_chat_service.multi_agent_chat_stream(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 多 Agent 非流式返回
|
||||||
|
result = await app_chat_service.multi_agent_chat(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
elif app_type == AppType.WORKFLOW:
|
||||||
|
# 多 Agent 流式返回
|
||||||
|
config = dict_to_workflow_config(app.current_release.config)
|
||||||
|
if payload.stream:
|
||||||
|
async def event_generator():
|
||||||
|
async for event in app_chat_service.workflow_chat_stream(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 非流式返回
|
||||||
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
|
message=payload.message,
|
||||||
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
user_id=end_user_id, # 转换为字符串
|
||||||
|
variables=payload.variables,
|
||||||
|
config=config,
|
||||||
|
web_search=payload.web_search,
|
||||||
|
memory=payload.memory,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
else:
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status, Request
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from jose import jwt, JWTError
|
from jose import jwt, JWTError
|
||||||
|
|
||||||
from app.db import get_db, SessionLocal
|
from app.db import get_db, SessionLocal
|
||||||
|
from app.models import App
|
||||||
from app.schemas import token_schema
|
from app.schemas import token_schema
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.security import get_token_id
|
from app.core.security import get_token_id
|
||||||
@@ -27,6 +28,51 @@ security_logger = get_security_logger()
|
|||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyExtractor:
|
||||||
|
"""
|
||||||
|
Custom dependency to extract API Key from request headers
|
||||||
|
|
||||||
|
Supports two formats:
|
||||||
|
1. Authorization: Bearer <api_key>
|
||||||
|
2. X-API-Key: <api_key>
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def __call__(self, request: Request) -> str:
|
||||||
|
"""Extract API Key from request headers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI Request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API Key string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API Key is not found
|
||||||
|
"""
|
||||||
|
# Try Authorization header first
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
if auth_header and " " in auth_header:
|
||||||
|
auth_scheme, auth_token = auth_header.split(" ", 1)
|
||||||
|
if auth_scheme.lower() == "bearer":
|
||||||
|
return auth_token
|
||||||
|
|
||||||
|
# Try X-API-Key header
|
||||||
|
api_key = request.headers.get("X-API-Key")
|
||||||
|
if api_key:
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
# No API Key found
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="API Key not found in request headers",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
api_key_extractor = APIKeyExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
@@ -304,7 +350,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
|||||||
- db: Session = Depends(get_db)
|
- db: Session = Depends(get_db)
|
||||||
- user 或 current_user: User = Depends(get_current_user)
|
- user 或 current_user: User = Depends(get_current_user)
|
||||||
- workspace_id: uuid.UUID (query/path 参数)或 payload: AppCreate(body,含 workspace_id)
|
- workspace_id: uuid.UUID (query/path 参数)或 payload: AppCreate(body,含 workspace_id)
|
||||||
|
|
||||||
支持同步和异步函数。
|
支持同步和异步函数。
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -360,7 +406,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
|||||||
def get_uow() -> IUnitOfWork:
|
def get_uow() -> IUnitOfWork:
|
||||||
"""
|
"""
|
||||||
获取工作单元实例
|
获取工作单元实例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IUnitOfWork: 工作单元实例
|
IUnitOfWork: 工作单元实例
|
||||||
"""
|
"""
|
||||||
@@ -373,7 +419,7 @@ def cur_workspace_access_guard():
|
|||||||
要求端点函数签名包含:
|
要求端点函数签名包含:
|
||||||
- db: Session = Depends(get_db)
|
- db: Session = Depends(get_db)
|
||||||
- current_user: User = Depends(get_current_user)
|
- current_user: User = Depends(get_current_user)
|
||||||
|
|
||||||
支持同步和异步函数。
|
支持同步和异步函数。
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -423,10 +469,10 @@ async def get_share_user_id(
|
|||||||
) -> ShareTokenData:
|
) -> ShareTokenData:
|
||||||
"""
|
"""
|
||||||
从分享访问 token 中获取用户 ID 和 share_token
|
从分享访问 token 中获取用户 ID 和 share_token
|
||||||
|
|
||||||
这个函数用于公开分享的接口,验证访问 token 并返回用户信息
|
这个函数用于公开分享的接口,验证访问 token 并返回用户信息
|
||||||
不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效
|
不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ShareTokenData: 包含 user_id 和 share_token
|
ShareTokenData: 包含 user_id 和 share_token
|
||||||
"""
|
"""
|
||||||
@@ -469,4 +515,75 @@ async def get_share_user_id(
|
|||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
||||||
|
|
||||||
|
async def get_app_or_workspace(
|
||||||
|
api_key: str = Depends(api_key_extractor),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> App | Workspace:
|
||||||
|
"""
|
||||||
|
Get App or Workspace from API Key
|
||||||
|
|
||||||
|
Supports two API Key formats:
|
||||||
|
1. Authorization: Bearer <api_key>
|
||||||
|
2. X-API-Key: <api_key>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API Key extracted from request headers
|
||||||
|
db: Database session
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
App or Workspace object based on API Key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If API Key is invalid or not found
|
||||||
|
"""
|
||||||
|
from app.services.api_key_service import ApiKeyAuthService
|
||||||
|
from app.repositories.app_repository import get_apps_by_id
|
||||||
|
from app.repositories.workspace_repository import get_workspace_by_id
|
||||||
|
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate API Key",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
auth_logger.debug(f"Validating API Key: {api_key[:10]}...")
|
||||||
|
|
||||||
|
# Validate API Key
|
||||||
|
api_key_obj = ApiKeyAuthService.validate_api_key(db, api_key)
|
||||||
|
if not api_key_obj:
|
||||||
|
auth_logger.warning(f"Invalid or expired API Key: {api_key[:10]}...")
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
auth_logger.debug(f"API Key validated successfully, type: {api_key_obj.type}")
|
||||||
|
|
||||||
|
# Return App or Workspace based on API Key type
|
||||||
|
if (api_key_obj.type == "agent" or api_key.type == "multi_agent") and api_key_obj.resource_id:
|
||||||
|
# App API Key
|
||||||
|
app = get_apps_by_id(db, api_key_obj.resource_id)
|
||||||
|
if not app:
|
||||||
|
auth_logger.warning(f"App not found for API Key: {api_key_obj.resource_id}")
|
||||||
|
raise credentials_exception
|
||||||
|
auth_logger.info(f"App access granted: {app.id}")
|
||||||
|
return app
|
||||||
|
|
||||||
|
elif api_key_obj.type == "service":
|
||||||
|
# Workspace API Key
|
||||||
|
workspace = get_workspace_by_id(db, api_key_obj.workspace_id)
|
||||||
|
if not workspace:
|
||||||
|
auth_logger.warning(f"Workspace not found for API Key: {api_key_obj.workspace_id}")
|
||||||
|
raise credentials_exception
|
||||||
|
auth_logger.info(f"Workspace access granted: {workspace.id}")
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
else:
|
||||||
|
auth_logger.warning(f"Unsupported API Key type: {api_key_obj.type}")
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
auth_logger.error(f"Error validating API Key: {str(e)}", exc_info=True)
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,38 +2,10 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, APIRouter
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from app.core.response_utils import fail
|
|
||||||
from app.core.logging_config import LoggingConfig, get_logger
|
|
||||||
from app.core.exceptions import BusinessException
|
|
||||||
from app.core.error_codes import BizCode, HTTP_MAPPING
|
|
||||||
from app.controllers import (
|
|
||||||
model_controller,
|
|
||||||
task_controller,
|
|
||||||
test_controller,
|
|
||||||
user_controller,
|
|
||||||
auth_controller,
|
|
||||||
workspace_controller,
|
|
||||||
setup_controller,
|
|
||||||
file_controller,
|
|
||||||
document_controller,
|
|
||||||
knowledge_controller,
|
|
||||||
chunk_controller,
|
|
||||||
knowledgeshare_controller,
|
|
||||||
app_controller,
|
|
||||||
upload_controller,
|
|
||||||
memory_agent_controller,
|
|
||||||
memory_storage_controller,
|
|
||||||
memory_dashboard_controller,
|
|
||||||
multi_agent_controller,
|
|
||||||
)
|
|
||||||
|
|
||||||
from fastapi import FastAPI, APIRouter
|
|
||||||
|
|
||||||
app = FastAPI(title="Data Config API", version="1.0.0")
|
|
||||||
router = APIRouter(prefix="/memory", tags=["Memory"])
|
|
||||||
|
|
||||||
# 管理端 API (JWT 认证)
|
# 管理端 API (JWT 认证)
|
||||||
from app.controllers import manager_router
|
from app.controllers import manager_router
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from .file_schema import File, FileCreate, FileUpdate
|
|||||||
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
||||||
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
||||||
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
||||||
|
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
|
||||||
from .app_schema import (
|
from .app_schema import (
|
||||||
|
AppChatRequest,
|
||||||
DraftRunRequest,
|
DraftRunRequest,
|
||||||
DraftRunResponse,
|
DraftRunResponse,
|
||||||
DraftRunStreamChunk,
|
DraftRunStreamChunk,
|
||||||
@@ -73,6 +75,10 @@ __all__ = [
|
|||||||
"ChunkRetrieve",
|
"ChunkRetrieve",
|
||||||
"KnowledgeShare",
|
"KnowledgeShare",
|
||||||
"KnowledgeShareCreate",
|
"KnowledgeShareCreate",
|
||||||
|
"CreateOrderRequest",
|
||||||
|
"OrderResponse",
|
||||||
|
"ExternalOrderResponse",
|
||||||
|
"AppChatRequest",
|
||||||
"DraftRunRequest",
|
"DraftRunRequest",
|
||||||
"DraftRunResponse",
|
"DraftRunResponse",
|
||||||
"DraftRunStreamChunk",
|
"DraftRunStreamChunk",
|
||||||
|
|||||||
@@ -334,6 +334,13 @@ class AppShare(BaseModel):
|
|||||||
|
|
||||||
# ---------- Draft Run Schemas ----------
|
# ---------- Draft Run Schemas ----------
|
||||||
|
|
||||||
|
class AppChatRequest(BaseModel):
|
||||||
|
message: str = Field(..., description="用户消息")
|
||||||
|
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||||
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
|
|
||||||
class DraftRunRequest(BaseModel):
|
class DraftRunRequest(BaseModel):
|
||||||
"""试运行请求"""
|
"""试运行请求"""
|
||||||
message: str = Field(..., description="用户消息")
|
message: str = Field(..., description="用户消息")
|
||||||
|
|||||||
63
api/app/schemas/order_schema.py
Normal file
63
api/app/schemas/order_schema.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Order Schema
|
||||||
|
|
||||||
|
Defines request and response models for order operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class CreateOrderRequest(BaseModel):
|
||||||
|
"""Create order request model"""
|
||||||
|
|
||||||
|
product_id: str = Field(..., description="Product ID")
|
||||||
|
quantity: int = Field(..., gt=0, description="Order quantity")
|
||||||
|
customer_name: Optional[str] = Field(None, description="Customer name")
|
||||||
|
customer_email: Optional[str] = Field(None, description="Customer email")
|
||||||
|
notes: Optional[str] = Field(None, description="Order notes")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"product_id": "PROD-001",
|
||||||
|
"quantity": 2,
|
||||||
|
"customer_name": "John Doe",
|
||||||
|
"customer_email": "john@example.com",
|
||||||
|
"notes": "Please deliver before 5pm"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OrderResponse(BaseModel):
|
||||||
|
"""Order response model"""
|
||||||
|
|
||||||
|
order_id: str = Field(..., description="Order ID")
|
||||||
|
status: str = Field(..., description="Order status")
|
||||||
|
product_id: str = Field(..., description="Product ID")
|
||||||
|
quantity: int = Field(..., description="Order quantity")
|
||||||
|
total_amount: Optional[float] = Field(None, description="Total amount")
|
||||||
|
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||||
|
message: Optional[str] = Field(None, description="Response message")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"order_id": "ORD-20231224-001",
|
||||||
|
"status": "pending",
|
||||||
|
"product_id": "PROD-001",
|
||||||
|
"quantity": 2,
|
||||||
|
"total_amount": 199.99,
|
||||||
|
"created_at": "2023-12-24T10:30:00Z",
|
||||||
|
"message": "Order created successfully"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalOrderResponse(BaseModel):
|
||||||
|
"""External API response model (flexible structure)"""
|
||||||
|
|
||||||
|
success: bool = Field(default=True, description="Request success status")
|
||||||
|
data: Optional[Any] = Field(None, description="Response data")
|
||||||
|
error: Optional[str] = Field(None, description="Error message")
|
||||||
|
code: Optional[int] = Field(None, description="Response code")
|
||||||
485
api/app/services/app_chat_service.py
Normal file
485
api/app/services/app_chat_service.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
"""基于分享链接的聊天服务"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.db import get_db
|
||||||
|
from app.models import MultiAgentConfig, AgentConfig
|
||||||
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
|
from app.services.conversation_service import ConversationService
|
||||||
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
|
from app.services.draft_run_service import create_web_search_tool
|
||||||
|
from app.services.model_service import ModelApiKeyService
|
||||||
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AppChatService:
|
||||||
|
"""基于分享链接的聊天服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.conversation_service = ConversationService(db)
|
||||||
|
|
||||||
|
async def agnet_chat(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
config: AgentConfig,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
|
web_search: bool = False,
|
||||||
|
memory: bool = True,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""聊天(非流式)"""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
config_id = None
|
||||||
|
|
||||||
|
if variables is None:
|
||||||
|
variables = {}
|
||||||
|
|
||||||
|
# 获取模型配置ID
|
||||||
|
model_config_id = config.default_model_config_id
|
||||||
|
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
|
||||||
|
# 处理系统提示词(支持变量替换)
|
||||||
|
system_prompt = config.get("system_prompt", "")
|
||||||
|
if variables:
|
||||||
|
system_prompt_rendered = render_prompt_message(
|
||||||
|
system_prompt,
|
||||||
|
PromptMessageRole.USER,
|
||||||
|
variables
|
||||||
|
)
|
||||||
|
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||||
|
|
||||||
|
# 准备工具列表
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
# 添加知识库检索工具
|
||||||
|
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||||
|
if knowledge_retrieval:
|
||||||
|
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||||
|
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||||
|
if kb_ids:
|
||||||
|
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||||
|
tools.append(kb_tool)
|
||||||
|
|
||||||
|
# 添加长期记忆工具
|
||||||
|
memory_flag = False
|
||||||
|
if memory == True:
|
||||||
|
memory_config = config.get("memory", {})
|
||||||
|
if memory_config.get("enabled") and user_id:
|
||||||
|
memory_flag = True
|
||||||
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
|
tools.append(memory_tool)
|
||||||
|
|
||||||
|
web_tools = config.get("tools")
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search == True:
|
||||||
|
if web_search_enable == True:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取模型参数
|
||||||
|
model_parameters = config.get("model_parameters", {})
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载历史消息
|
||||||
|
history = []
|
||||||
|
memory_config = {"enabled": True, 'max_history': 10}
|
||||||
|
if memory_config.get("enabled"):
|
||||||
|
messages = self.conversation_service.get_messages(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
limit=memory_config.get("max_history", 10)
|
||||||
|
)
|
||||||
|
history = [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
# 调用 Agent
|
||||||
|
result = await agent.chat(
|
||||||
|
message=message,
|
||||||
|
history=history,
|
||||||
|
context=None,
|
||||||
|
end_user_id=user_id,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
config_id=config_id,
|
||||||
|
memory_flag=memory_flag
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存消息
|
||||||
|
self.conversation_service.save_conversation_messages(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_message=message,
|
||||||
|
assistant_message=result["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
return {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"message": result["content"],
|
||||||
|
"usage": result.get("usage", {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
}),
|
||||||
|
"elapsed_time": elapsed_time
|
||||||
|
}
|
||||||
|
|
||||||
|
async def agnet_chat_stream(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
config: AgentConfig,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
|
web_search: bool = False,
|
||||||
|
memory: bool = True,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""聊天(流式)"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
config_id = None
|
||||||
|
|
||||||
|
if variables is None:
|
||||||
|
variables = {}
|
||||||
|
|
||||||
|
# 获取模型配置ID
|
||||||
|
model_config_id = config.default_model_config_id
|
||||||
|
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
|
||||||
|
# 处理系统提示词(支持变量替换)
|
||||||
|
system_prompt = config.get("system_prompt", "")
|
||||||
|
if variables:
|
||||||
|
system_prompt_rendered = render_prompt_message(
|
||||||
|
system_prompt,
|
||||||
|
PromptMessageRole.USER,
|
||||||
|
variables
|
||||||
|
)
|
||||||
|
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||||
|
|
||||||
|
# 准备工具列表
|
||||||
|
tools = []
|
||||||
|
|
||||||
|
# 添加知识库检索工具
|
||||||
|
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||||
|
if knowledge_retrieval:
|
||||||
|
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||||
|
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||||
|
if kb_ids:
|
||||||
|
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||||
|
tools.append(kb_tool)
|
||||||
|
|
||||||
|
# 添加长期记忆工具
|
||||||
|
memory_flag = False
|
||||||
|
if memory:
|
||||||
|
memory_config = config.get("memory", {})
|
||||||
|
if memory_config.get("enabled") and user_id:
|
||||||
|
memory_flag = True
|
||||||
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
|
tools.append(memory_tool)
|
||||||
|
|
||||||
|
web_tools = config.get("tools")
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search == True:
|
||||||
|
if web_search_enable == True:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取模型参数
|
||||||
|
model_parameters = config.get("model_parameters", {})
|
||||||
|
|
||||||
|
# 创建 LangChain Agent
|
||||||
|
agent = LangChainAgent(
|
||||||
|
model_name=api_key_obj.model_name,
|
||||||
|
api_key=api_key_obj.api_key,
|
||||||
|
provider=api_key_obj.provider,
|
||||||
|
api_base=api_key_obj.api_base,
|
||||||
|
temperature=model_parameters.get("temperature", 0.7),
|
||||||
|
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
tools=tools,
|
||||||
|
streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载历史消息
|
||||||
|
history = []
|
||||||
|
memory_config = {"enabled": True, 'max_history': 10}
|
||||||
|
if memory_config.get("enabled"):
|
||||||
|
messages = self.conversation_service.get_messages(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
limit=memory_config.get("max_history", 10)
|
||||||
|
)
|
||||||
|
history = [
|
||||||
|
{"role": msg.role, "content": msg.content}
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
# 发送开始事件
|
||||||
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
# 流式调用 Agent
|
||||||
|
full_content = ""
|
||||||
|
async for chunk in agent.chat_stream(
|
||||||
|
message=message,
|
||||||
|
history=history,
|
||||||
|
context=None,
|
||||||
|
end_user_id=user_id,
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
config_id=config_id,
|
||||||
|
memory_flag=memory_flag
|
||||||
|
):
|
||||||
|
full_content += chunk
|
||||||
|
# 发送消息块事件
|
||||||
|
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 保存消息
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=message
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=full_content,
|
||||||
|
meta_data={
|
||||||
|
"model": api_key_obj.model_name,
|
||||||
|
"usage": {}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送结束事件
|
||||||
|
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||||
|
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"流式聊天完成",
|
||||||
|
extra={
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"message_length": len(full_content)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except (GeneratorExit, asyncio.CancelledError):
|
||||||
|
# 生成器被关闭或任务被取消,正常退出
|
||||||
|
logger.debug("流式聊天被中断")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||||
|
# 发送错误事件
|
||||||
|
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
async def multi_agent_chat(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
config: MultiAgentConfig,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
|
web_search: bool = False,
|
||||||
|
memory: bool = True,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""多 Agent 聊天(非流式)"""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
actual_config_id = None
|
||||||
|
config_id = actual_config_id
|
||||||
|
|
||||||
|
if variables is None:
|
||||||
|
variables = {}
|
||||||
|
|
||||||
|
# 2. 创建编排器
|
||||||
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
# 3. 执行任务
|
||||||
|
result = await orchestrator.execute(
|
||||||
|
message=message,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_id=user_id,
|
||||||
|
variables=variables,
|
||||||
|
use_llm_routing=True, # 默认启用 LLM 路由
|
||||||
|
web_search=web_search, # 网络搜索参数
|
||||||
|
memory=memory # 记忆功能参数
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 保存消息
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=message
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result.get("message", ""),
|
||||||
|
meta_data={
|
||||||
|
"mode": result.get("mode"),
|
||||||
|
"elapsed_time": result.get("elapsed_time"),
|
||||||
|
"sub_results": result.get("sub_results")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"message": result.get("message", ""),
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0
|
||||||
|
},
|
||||||
|
"elapsed_time": elapsed_time
|
||||||
|
}
|
||||||
|
|
||||||
|
async def multi_agent_chat_stream(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
config: MultiAgentConfig,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
|
web_search: bool = False,
|
||||||
|
memory: bool = True,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
user_rag_memory_id: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""多 Agent 聊天(流式)"""
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
actual_config_id = None
|
||||||
|
config_id = actual_config_id
|
||||||
|
|
||||||
|
if variables is None:
|
||||||
|
variables = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
# 发送开始事件
|
||||||
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
full_content = ""
|
||||||
|
|
||||||
|
# 2. 创建编排器
|
||||||
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
# 3. 流式执行任务
|
||||||
|
async for event in orchestrator.execute_stream(
|
||||||
|
message=message,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
user_id=user_id,
|
||||||
|
variables=variables,
|
||||||
|
use_llm_routing=True,
|
||||||
|
web_search=web_search, # 网络搜索参数
|
||||||
|
memory=memory, # 记忆功能参数
|
||||||
|
storage_type=storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
# 尝试提取内容(用于保存)
|
||||||
|
if "data:" in event:
|
||||||
|
try:
|
||||||
|
data_line = event.split("data: ", 1)[1].strip()
|
||||||
|
data = json.loads(data_line)
|
||||||
|
if "content" in data:
|
||||||
|
full_content += data["content"]
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 保存消息
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=message
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=full_content,
|
||||||
|
meta_data={
|
||||||
|
"elapsed_time": elapsed_time
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"多 Agent 流式聊天完成",
|
||||||
|
extra={
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"message_length": len(full_content)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
except (GeneratorExit, asyncio.CancelledError):
|
||||||
|
# 生成器被关闭或任务被取消,正常退出
|
||||||
|
logger.debug("多 Agent 流式聊天被中断")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
|
||||||
|
# 发送错误事件
|
||||||
|
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
def get_app_chat_service(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> ChatService:
|
||||||
|
"""获取工作流服务(依赖注入)"""
|
||||||
|
return ChatService(db)
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
"""会话服务"""
|
"""会话服务"""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple, Annotated
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import select, desc
|
from sqlalchemy import select, desc
|
||||||
|
|
||||||
|
from app.db import get_db
|
||||||
from app.models import Conversation, Message
|
from app.models import Conversation, Message
|
||||||
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
@@ -14,10 +17,10 @@ logger = get_business_logger()
|
|||||||
|
|
||||||
class ConversationService:
|
class ConversationService:
|
||||||
"""会话服务"""
|
"""会话服务"""
|
||||||
|
|
||||||
def __init__(self, db: Session):
|
def __init__(self, db: Session):
|
||||||
self.db = db
|
self.db = db
|
||||||
|
|
||||||
def create_conversation(
|
def create_conversation(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -36,11 +39,11 @@ class ConversationService:
|
|||||||
is_draft=is_draft,
|
is_draft=is_draft,
|
||||||
config_snapshot=config_snapshot
|
config_snapshot=config_snapshot
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db.add(conversation)
|
self.db.add(conversation)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(conversation)
|
self.db.refresh(conversation)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"创建会话成功",
|
"创建会话成功",
|
||||||
extra={
|
extra={
|
||||||
@@ -50,9 +53,9 @@ class ConversationService:
|
|||||||
"is_draft": is_draft
|
"is_draft": is_draft
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
def get_conversation(
|
def get_conversation(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -60,17 +63,17 @@ class ConversationService:
|
|||||||
) -> Conversation:
|
) -> Conversation:
|
||||||
"""获取会话"""
|
"""获取会话"""
|
||||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
|
||||||
if workspace_id:
|
if workspace_id:
|
||||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||||
|
|
||||||
conversation = self.db.scalars(stmt).first()
|
conversation = self.db.scalars(stmt).first()
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
def list_conversations(
|
def list_conversations(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -86,25 +89,25 @@ class ConversationService:
|
|||||||
Conversation.workspace_id == workspace_id,
|
Conversation.workspace_id == workspace_id,
|
||||||
Conversation.is_active == True
|
Conversation.is_active == True
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
stmt = stmt.where(Conversation.user_id == user_id)
|
stmt = stmt.where(Conversation.user_id == user_id)
|
||||||
|
|
||||||
if is_draft is not None:
|
if is_draft is not None:
|
||||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||||
|
|
||||||
# 总数
|
# 总数
|
||||||
count_stmt = stmt.with_only_columns(Conversation.id)
|
count_stmt = stmt.with_only_columns(Conversation.id)
|
||||||
total = len(self.db.execute(count_stmt).all())
|
total = len(self.db.execute(count_stmt).all())
|
||||||
|
|
||||||
# 分页
|
# 分页
|
||||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||||
|
|
||||||
conversations = list(self.db.scalars(stmt).all())
|
conversations = list(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
return conversations, total
|
return conversations, total
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -119,22 +122,22 @@ class ConversationService:
|
|||||||
content=content,
|
content=content,
|
||||||
meta_data=meta_data
|
meta_data=meta_data
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db.add(message)
|
self.db.add(message)
|
||||||
|
|
||||||
# 更新会话的消息计数和更新时间
|
# 更新会话的消息计数和更新时间
|
||||||
conversation = self.get_conversation(conversation_id)
|
conversation = self.get_conversation(conversation_id)
|
||||||
conversation.message_count += 1
|
conversation.message_count += 1
|
||||||
|
|
||||||
# 如果是第一条用户消息,可以用它作为标题
|
# 如果是第一条用户消息,可以用它作为标题
|
||||||
if conversation.message_count == 1 and role == "user":
|
if conversation.message_count == 1 and role == "user":
|
||||||
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
|
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(message)
|
self.db.refresh(message)
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def get_messages(
|
def get_messages(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -144,30 +147,30 @@ class ConversationService:
|
|||||||
stmt = select(Message).where(
|
stmt = select(Message).where(
|
||||||
Message.conversation_id == conversation_id
|
Message.conversation_id == conversation_id
|
||||||
).order_by(Message.created_at)
|
).order_by(Message.created_at)
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
stmt = stmt.limit(limit)
|
stmt = stmt.limit(limit)
|
||||||
|
|
||||||
messages = list(self.db.scalars(stmt).all())
|
messages = list(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_conversation_history(
|
def get_conversation_history(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
max_history: Optional[int] = None
|
max_history: Optional[int] = None
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""获取会话历史消息
|
"""获取会话历史消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conversation_id: 会话ID
|
conversation_id: 会话ID
|
||||||
max_history: 最大历史消息数量
|
max_history: 最大历史消息数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
|
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
|
||||||
"""
|
"""
|
||||||
messages = self.get_messages(conversation_id, limit=max_history)
|
messages = self.get_messages(conversation_id, limit=max_history)
|
||||||
|
|
||||||
# 转换为字典格式
|
# 转换为字典格式
|
||||||
history = [
|
history = [
|
||||||
{
|
{
|
||||||
@@ -176,9 +179,9 @@ class ConversationService:
|
|||||||
}
|
}
|
||||||
for msg in messages
|
for msg in messages
|
||||||
]
|
]
|
||||||
|
|
||||||
return history
|
return history
|
||||||
|
|
||||||
def save_conversation_messages(
|
def save_conversation_messages(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -192,14 +195,14 @@ class ConversationService:
|
|||||||
role="user",
|
role="user",
|
||||||
content=user_message
|
content=user_message
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加助手消息
|
# 添加助手消息
|
||||||
self.add_message(
|
self.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message
|
content=assistant_message
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"保存会话消息成功",
|
"保存会话消息成功",
|
||||||
extra={
|
extra={
|
||||||
@@ -208,7 +211,7 @@ class ConversationService:
|
|||||||
"assistant_message_length": len(assistant_message)
|
"assistant_message_length": len(assistant_message)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_conversation(
|
def delete_conversation(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -217,9 +220,9 @@ class ConversationService:
|
|||||||
"""删除会话(软删除)"""
|
"""删除会话(软删除)"""
|
||||||
conversation = self.get_conversation(conversation_id, workspace_id)
|
conversation = self.get_conversation(conversation_id, workspace_id)
|
||||||
conversation.is_active = False
|
conversation.is_active = False
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"删除会话成功",
|
"删除会话成功",
|
||||||
extra={
|
extra={
|
||||||
@@ -227,3 +230,53 @@ class ConversationService:
|
|||||||
"workspace_id": str(workspace_id)
|
"workspace_id": str(workspace_id)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_or_get_conversation(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
is_draft: bool = False,
|
||||||
|
conversation_id: Optional[uuid.UUID] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
) -> Conversation:
|
||||||
|
"""创建或获取会话"""
|
||||||
|
|
||||||
|
# 如果提供了 conversation_id,尝试获取现有会话
|
||||||
|
if conversation_id:
|
||||||
|
try:
|
||||||
|
conversation = self.get_conversation(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证会话是否属于该应用
|
||||||
|
if conversation.app_id != app_id:
|
||||||
|
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||||
|
return conversation
|
||||||
|
except ResourceNotFoundException:
|
||||||
|
logger.warning(
|
||||||
|
"会话不存在,将创建新会话",
|
||||||
|
extra={"conversation_id": str(conversation_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建新会话(使用发布版本的配置)
|
||||||
|
conversation = self.create_conversation(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
user_id=user_id,
|
||||||
|
is_draft=is_draft
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"为分享链接创建新会话"
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation
|
||||||
|
|
||||||
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
def get_conversation_service(
|
||||||
|
db: Annotated[Session, Depends(get_db)]
|
||||||
|
) -> ConversationService:
|
||||||
|
"""获取工作流服务(依赖注入)"""
|
||||||
|
return ConversationService(db)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class ModelConfigService:
|
|||||||
"""获取模型配置列表"""
|
"""获取模型配置列表"""
|
||||||
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
|
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
|
||||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||||
|
|
||||||
return PageData(
|
return PageData(
|
||||||
page=PageMeta(
|
page=PageMeta(
|
||||||
page=query.page,
|
page=query.page,
|
||||||
@@ -72,7 +72,7 @@ class ModelConfigService:
|
|||||||
test_message: str = "Hello"
|
test_message: str = "Hello"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""验证模型配置是否有效
|
"""验证模型配置是否有效
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
model_name: 模型名称
|
model_name: 模型名称
|
||||||
@@ -81,7 +81,7 @@ class ModelConfigService:
|
|||||||
api_base: API基础URL
|
api_base: API基础URL
|
||||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||||
test_message: 测试消息
|
test_message: 测试消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 验证结果
|
Dict: 验证结果
|
||||||
"""
|
"""
|
||||||
@@ -89,10 +89,10 @@ class ModelConfigService:
|
|||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.core.models.embedding import RedBearEmbeddings
|
from app.core.models.embedding import RedBearEmbeddings
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
model_config = RedBearModelConfig(
|
model_config = RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@@ -101,16 +101,16 @@ class ModelConfigService:
|
|||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=100
|
max_tokens=100
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据模型类型选择不同的验证方式
|
# 根据模型类型选择不同的验证方式
|
||||||
model_type_lower = model_type.lower()
|
model_type_lower = model_type.lower()
|
||||||
|
|
||||||
if model_type_lower in ["llm", "chat"]:
|
if model_type_lower in ["llm", "chat"]:
|
||||||
# LLM/Chat 模型验证 - 统一使用字符串输入
|
# LLM/Chat 模型验证 - 统一使用字符串输入
|
||||||
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
|
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
|
||||||
response = await llm.ainvoke(test_message)
|
response = await llm.ainvoke(test_message)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
content = response.content if hasattr(response, 'content') else str(response)
|
content = response.content if hasattr(response, 'content') else str(response)
|
||||||
usage = None
|
usage = None
|
||||||
if hasattr(response, 'usage_metadata'):
|
if hasattr(response, 'usage_metadata'):
|
||||||
@@ -119,7 +119,7 @@ class ModelConfigService:
|
|||||||
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
|
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
|
||||||
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
|
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": f"{model_type.upper()} 模型配置验证成功",
|
"message": f"{model_type.upper()} 模型配置验证成功",
|
||||||
@@ -128,14 +128,14 @@ class ModelConfigService:
|
|||||||
"usage": usage,
|
"usage": usage,
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
elif model_type_lower == "embedding":
|
elif model_type_lower == "embedding":
|
||||||
# Embedding 模型验证(在线程中运行同步方法)
|
# Embedding 模型验证(在线程中运行同步方法)
|
||||||
embedding = RedBearEmbeddings(model_config)
|
embedding = RedBearEmbeddings(model_config)
|
||||||
test_texts = [test_message, "测试文本"]
|
test_texts = [test_message, "测试文本"]
|
||||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": "Embedding 模型配置验证成功",
|
"message": "Embedding 模型配置验证成功",
|
||||||
@@ -148,7 +148,7 @@ class ModelConfigService:
|
|||||||
},
|
},
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
elif model_type_lower == "rerank":
|
elif model_type_lower == "rerank":
|
||||||
# Rerank 模型验证(在线程中运行同步方法)
|
# Rerank 模型验证(在线程中运行同步方法)
|
||||||
rerank = RedBearRerank(model_config)
|
rerank = RedBearRerank(model_config)
|
||||||
@@ -156,7 +156,7 @@ class ModelConfigService:
|
|||||||
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
|
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
|
||||||
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
|
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": "Rerank 模型配置验证成功",
|
"message": "Rerank 模型配置验证成功",
|
||||||
@@ -169,7 +169,7 @@ class ModelConfigService:
|
|||||||
},
|
},
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"valid": False,
|
"valid": False,
|
||||||
@@ -179,7 +179,7 @@ class ModelConfigService:
|
|||||||
"usage": None,
|
"usage": None,
|
||||||
"error": f"不支持的模型类型: {model_type}"
|
"error": f"不支持的模型类型: {model_type}"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 提取详细的错误信息
|
# 提取详细的错误信息
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
@@ -203,12 +203,12 @@ class ModelConfigService:
|
|||||||
error_message = f"无效请求: {error_message}"
|
error_message = f"无效请求: {error_message}"
|
||||||
elif "model_copy" in error_message:
|
elif "model_copy" in error_message:
|
||||||
error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)"
|
error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)"
|
||||||
|
|
||||||
# 记录详细错误日志
|
# 记录详细错误日志
|
||||||
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
|
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
|
||||||
logger.error(f"错误详情: {error_message}")
|
logger.error(f"错误详情: {error_message}")
|
||||||
logger.debug(f"完整堆栈: {traceback.format_exc()}")
|
logger.debug(f"完整堆栈: {traceback.format_exc()}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": False,
|
"valid": False,
|
||||||
"message": f"{model_type.upper()} 模型配置验证失败",
|
"message": f"{model_type.upper()} 模型配置验证失败",
|
||||||
@@ -249,7 +249,7 @@ class ModelConfigService:
|
|||||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||||
# 添加租户ID
|
# 添加租户ID
|
||||||
model_config_data["tenant_id"] = tenant_id
|
model_config_data["tenant_id"] = tenant_id
|
||||||
|
|
||||||
model = ModelConfigRepository.create(db, model_config_data)
|
model = ModelConfigRepository.create(db, model_config_data)
|
||||||
db.flush() # 获取生成的 ID
|
db.flush() # 获取生成的 ID
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ class ModelConfigService:
|
|||||||
**api_key_data.dict()
|
**api_key_data.dict()
|
||||||
)
|
)
|
||||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
return model
|
return model
|
||||||
@@ -270,11 +270,11 @@ class ModelConfigService:
|
|||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if model_data.name and model_data.name != existing_model.name:
|
if model_data.name and model_data.name != existing_model.name:
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
@@ -285,7 +285,7 @@ class ModelConfigService:
|
|||||||
"""删除模型配置"""
|
"""删除模型配置"""
|
||||||
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
|
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
|
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
return success
|
return success
|
||||||
@@ -316,20 +316,20 @@ class ModelApiKeyService:
|
|||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
|
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]:
|
||||||
"""根据模型配置ID获取API Key列表"""
|
"""根据模型配置ID获取API Key列表"""
|
||||||
if not ModelConfigRepository.get_by_id(db, model_config_id):
|
if not ModelConfigRepository.get_by_id(db, model_config_id):
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||||
"""创建API Key"""
|
"""创建API Key"""
|
||||||
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
model_name=api_key_data.model_name,
|
model_name=api_key_data.model_name,
|
||||||
@@ -345,7 +345,7 @@ class ModelApiKeyService:
|
|||||||
f"模型配置验证失败: {validation_result['error']}",
|
f"模型配置验证失败: {validation_result['error']}",
|
||||||
BizCode.INVALID_PARAMETER
|
BizCode.INVALID_PARAMETER
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(api_key)
|
db.refresh(api_key)
|
||||||
@@ -357,12 +357,12 @@ class ModelApiKeyService:
|
|||||||
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
if not existing_api_key:
|
if not existing_api_key:
|
||||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
# 获取关联的模型配置以获取模型类型
|
# 获取关联的模型配置以获取模型类型
|
||||||
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
model_name=api_key_data.model_name,
|
model_name=api_key_data.model_name,
|
||||||
@@ -378,7 +378,7 @@ class ModelApiKeyService:
|
|||||||
f"模型配置验证失败: {validation_result['error']}",
|
f"模型配置验证失败: {validation_result['error']}",
|
||||||
BizCode.INVALID_PARAMETER
|
BizCode.INVALID_PARAMETER
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
|
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(api_key)
|
db.refresh(api_key)
|
||||||
@@ -389,7 +389,7 @@ class ModelApiKeyService:
|
|||||||
"""删除API Key"""
|
"""删除API Key"""
|
||||||
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
||||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
success = ModelApiKeyRepository.delete(db, api_key_id)
|
success = ModelApiKeyRepository.delete(db, api_key_id)
|
||||||
db.commit()
|
db.commit()
|
||||||
return success
|
return success
|
||||||
@@ -409,3 +409,11 @@ class ModelApiKeyService:
|
|||||||
if success:
|
if success:
|
||||||
db.commit()
|
db.commit()
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
|
||||||
|
|
||||||
|
api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id)
|
||||||
|
if api_kes and len(api_kes) > 0:
|
||||||
|
return api_kes[0]
|
||||||
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|||||||
205
api/app/services/order_service.py
Normal file
205
api/app/services/order_service.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""
|
||||||
|
Order Service
|
||||||
|
|
||||||
|
Handles order operations including forwarding requests to external APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import httpx
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from app.schemas.order_schema import CreateOrderRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OrderService:
|
||||||
|
"""Order service for handling order operations"""
|
||||||
|
|
||||||
|
def __init__(self, external_api_url: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
|
"""Initialize order service
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_api_url: External API base URL
|
||||||
|
api_key: API key for authentication
|
||||||
|
"""
|
||||||
|
# Default external API URL (replace with actual URL)
|
||||||
|
self.external_api_url = external_api_url or "https://api.example.com/v1"
|
||||||
|
self.api_key = api_key
|
||||||
|
self.timeout = 30.0 # 30 seconds timeout
|
||||||
|
|
||||||
|
async def create_order(
|
||||||
|
self,
|
||||||
|
order_data: CreateOrderRequest,
|
||||||
|
user_id: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create order by forwarding request to external API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_data: Order creation data
|
||||||
|
user_id: Current user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Order response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
httpx.HTTPError: If external API request fails
|
||||||
|
Exception: For other errors
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prepare request payload
|
||||||
|
payload = {
|
||||||
|
"product_id": order_data.product_id,
|
||||||
|
"quantity": order_data.quantity,
|
||||||
|
"customer_name": order_data.customer_name,
|
||||||
|
"customer_email": order_data.customer_email,
|
||||||
|
"notes": order_data.notes,
|
||||||
|
"user_id": user_id # Include user ID for tracking
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"User-Agent": "MemoryBear-OrderService/1.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add API key if configured
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
logger.info(f"Forwarding order creation request to external API: {self.external_api_url}/orders")
|
||||||
|
logger.debug(f"Request payload: {payload}")
|
||||||
|
|
||||||
|
# Make async HTTP request to external API
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.external_api_url}/orders",
|
||||||
|
json=payload,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log response status
|
||||||
|
logger.info(f"External API response status: {response.status_code}")
|
||||||
|
|
||||||
|
# Raise exception for 4xx/5xx status codes
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response_data = response.json()
|
||||||
|
logger.debug(f"External API response data: {response_data}")
|
||||||
|
|
||||||
|
# Transform external API response to internal format
|
||||||
|
return self._transform_external_response(response_data)
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.error(f"External API returned error status: {e.response.status_code}")
|
||||||
|
logger.error(f"Error response: {e.response.text}")
|
||||||
|
|
||||||
|
# Try to parse error response
|
||||||
|
try:
|
||||||
|
error_data = e.response.json()
|
||||||
|
error_message = error_data.get("message") or error_data.get("error") or "External API error"
|
||||||
|
except Exception:
|
||||||
|
error_message = f"External API error: {e.response.status_code}"
|
||||||
|
|
||||||
|
raise Exception(f"Failed to create order: {error_message}")
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.error(f"External API request timeout after {self.timeout}s")
|
||||||
|
raise Exception("Order creation timeout - external service not responding")
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"External API request failed: {str(e)}")
|
||||||
|
raise Exception(f"Failed to connect to external order service: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during order creation: {str(e)}", exc_info=True)
|
||||||
|
raise Exception(f"Order creation failed: {str(e)}")
|
||||||
|
|
||||||
|
def _transform_external_response(self, external_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Transform external API response to internal format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_data: Response data from external API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed response data
|
||||||
|
"""
|
||||||
|
# Handle different response formats from external API
|
||||||
|
# Adjust this based on actual external API response structure
|
||||||
|
|
||||||
|
if "data" in external_data:
|
||||||
|
# Format 1: {"success": true, "data": {...}}
|
||||||
|
data = external_data["data"]
|
||||||
|
elif "order" in external_data:
|
||||||
|
# Format 2: {"order": {...}}
|
||||||
|
data = external_data["order"]
|
||||||
|
else:
|
||||||
|
# Format 3: Direct response
|
||||||
|
data = external_data
|
||||||
|
|
||||||
|
# Extract fields with fallbacks
|
||||||
|
return {
|
||||||
|
"order_id": data.get("order_id") or data.get("id") or "UNKNOWN",
|
||||||
|
"status": data.get("status") or "pending",
|
||||||
|
"product_id": data.get("product_id") or "",
|
||||||
|
"quantity": data.get("quantity") or 0,
|
||||||
|
"total_amount": data.get("total_amount") or data.get("amount"),
|
||||||
|
"created_at": data.get("created_at") or data.get("timestamp"),
|
||||||
|
"message": external_data.get("message") or "Order created successfully"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_order(self, order_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get order details from external API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
order_id: Order ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Order details
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
logger.info(f"Fetching order {order_id} from external API")
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.external_api_url}/orders/{order_id}",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch order {order_id}: {str(e)}")
|
||||||
|
raise Exception(f"Failed to fetch order: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_order_service_instance: Optional[OrderService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_order_service(
|
||||||
|
external_api_url: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
) -> OrderService:
|
||||||
|
"""Get order service instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
external_api_url: External API URL (optional, uses default if not provided)
|
||||||
|
api_key: API key (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OrderService instance
|
||||||
|
"""
|
||||||
|
global _order_service_instance
|
||||||
|
|
||||||
|
if _order_service_instance is None:
|
||||||
|
_order_service_instance = OrderService(
|
||||||
|
external_api_url=external_api_url,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
return _order_service_instance
|
||||||
@@ -1,35 +1,32 @@
|
|||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import List, Optional
|
|
||||||
import uuid
|
|
||||||
import secrets
|
|
||||||
import hashlib
|
|
||||||
import datetime
|
import datetime
|
||||||
from fastapi import HTTPException, status
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
import uuid
|
||||||
|
from os import getenv
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||||
from app.models.tenant_model import Tenants
|
from app.core.logging_config import get_business_logger
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.app_model import App
|
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
|
||||||
from app.models.end_user_model import EndUser
|
from app.repositories import workspace_repository
|
||||||
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
|
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||||
from app.schemas.workspace_schema import (
|
from app.schemas.workspace_schema import (
|
||||||
WorkspaceCreate,
|
WorkspaceCreate,
|
||||||
WorkspaceUpdate,
|
WorkspaceUpdate,
|
||||||
WorkspaceInviteCreate,
|
WorkspaceInviteCreate,
|
||||||
WorkspaceInviteResponse,
|
WorkspaceInviteResponse,
|
||||||
InviteValidateResponse,
|
InviteValidateResponse,
|
||||||
InviteAcceptRequest,
|
InviteAcceptRequest,
|
||||||
WorkspaceMemberUpdate
|
WorkspaceMemberUpdate
|
||||||
)
|
)
|
||||||
from app.repositories import workspace_repository
|
|
||||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.config import settings
|
|
||||||
from app.services import user_service
|
|
||||||
from os import getenv
|
|
||||||
# 获取业务逻辑专用日志器
|
# 获取业务逻辑专用日志器
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
import os #
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
def switch_workspace(
|
def switch_workspace(
|
||||||
@@ -39,10 +36,10 @@ def switch_workspace(
|
|||||||
):
|
):
|
||||||
"""切换工作空间"""
|
"""切换工作空间"""
|
||||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
|
|
||||||
# 检查用户是否为成员或超级管理员
|
# 检查用户是否为成员或超级管理员
|
||||||
_check_workspace_member_permission(db, workspace_id, user)
|
_check_workspace_member_permission(db, workspace_id, user)
|
||||||
|
|
||||||
# 更新当前用户的工作空间上下文
|
# 更新当前用户的工作空间上下文
|
||||||
try:
|
try:
|
||||||
user.current_workspace_id = workspace_id
|
user.current_workspace_id = workspace_id
|
||||||
@@ -63,22 +60,22 @@ def delete_workspace_member(
|
|||||||
):
|
):
|
||||||
"""删除工作空间成员"""
|
"""删除工作空间成员"""
|
||||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||||
if not workspace_member:
|
if not workspace_member:
|
||||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||||
|
|
||||||
if workspace_member.workspace_id != workspace_id:
|
if workspace_member.workspace_id != workspace_id:
|
||||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace_member.is_active = False
|
workspace_member.is_active = False
|
||||||
workspace_member.user.current_workspace_id = None
|
workspace_member.user.current_workspace_id = None
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||||
|
|
||||||
|
|
||||||
@@ -94,7 +91,7 @@ def _create_workspace_only(
|
|||||||
db: Session, workspace: WorkspaceCreate, owner: User
|
db: Session, workspace: WorkspaceCreate, owner: User
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Create the workspace without adding any members
|
# Create the workspace without adding any members
|
||||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||||
@@ -126,7 +123,7 @@ def create_workspace(
|
|||||||
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
|
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_workspace)
|
db.refresh(db_workspace)
|
||||||
|
|
||||||
# 如果 storage_type 是 "rag",自动创建知识库
|
# 如果 storage_type 是 "rag",自动创建知识库
|
||||||
if workspace.storage_type == "rag":
|
if workspace.storage_type == "rag":
|
||||||
business_logger.info(
|
business_logger.info(
|
||||||
@@ -138,7 +135,7 @@ def create_workspace(
|
|||||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
|
|
||||||
# 创建知识库数据
|
# 创建知识库数据
|
||||||
knowledge_data = KnowledgeCreate(
|
knowledge_data = KnowledgeCreate(
|
||||||
workspace_id=db_workspace.id,
|
workspace_id=db_workspace.id,
|
||||||
@@ -162,10 +159,10 @@ def create_workspace(
|
|||||||
"html4excel": False
|
"html4excel": False
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
|
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
|
||||||
db_knowledge = knowledge_repository.create_knowledge(
|
db_knowledge = knowledge_repository.create_knowledge(
|
||||||
db=db,
|
db=db,
|
||||||
knowledge=knowledge_data
|
knowledge=knowledge_data
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -179,12 +176,12 @@ def create_workspace(
|
|||||||
)
|
)
|
||||||
db.rollback()
|
db.rollback()
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
|
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
|
||||||
BizCode.INTERNAL_ERROR
|
BizCode.INTERNAL_ERROR
|
||||||
)
|
)
|
||||||
|
|
||||||
return db_workspace
|
return db_workspace
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
|
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
|
||||||
db.rollback()
|
db.rollback()
|
||||||
@@ -195,7 +192,7 @@ def update_workspace(
|
|||||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||||
) -> Workspace:
|
) -> Workspace:
|
||||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
||||||
try:
|
try:
|
||||||
# 更新工作空间
|
# 更新工作空间
|
||||||
@@ -219,8 +216,8 @@ def get_workspace_members(
|
|||||||
db: Session, workspace_id: uuid.UUID, user: User
|
db: Session, workspace_id: uuid.UUID, user: User
|
||||||
) -> List[WorkspaceMember]:
|
) -> List[WorkspaceMember]:
|
||||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
# 查找工作空间
|
# 查找工作空间
|
||||||
business_logger.debug(f"查找工作空间: {workspace_id}")
|
business_logger.debug(f"查找工作空间: {workspace_id}")
|
||||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||||
@@ -237,10 +234,10 @@ def get_workspace_members(
|
|||||||
db=db, user_id=user.id, workspace_id=workspace_id
|
db=db, user_id=user.id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
workspace_memberships = {workspace_id} if member else set()
|
workspace_memberships = {workspace_id} if member else set()
|
||||||
|
|
||||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||||
resource = Resource.from_workspace(workspace)
|
resource = Resource.from_workspace(workspace)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
permission_service.require_permission(
|
permission_service.require_permission(
|
||||||
subject,
|
subject,
|
||||||
@@ -265,7 +262,7 @@ def get_workspace_members(
|
|||||||
|
|
||||||
def _generate_invite_token() -> tuple[str, str]:
|
def _generate_invite_token() -> tuple[str, str]:
|
||||||
"""生成邀请令牌和其哈希值
|
"""生成邀请令牌和其哈希值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (原始令牌, 令牌哈希)
|
tuple: (原始令牌, 令牌哈希)
|
||||||
"""
|
"""
|
||||||
@@ -285,21 +282,21 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
|
|||||||
message="Workspace not found",
|
message="Workspace not found",
|
||||||
code=BizCode.WORKSPACE_NOT_FOUND
|
code=BizCode.WORKSPACE_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用统一权限服务检查访问权限
|
# 使用统一权限服务检查访问权限
|
||||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||||
|
|
||||||
# 获取用户的工作空间成员关系
|
# 获取用户的工作空间成员关系
|
||||||
member = workspace_repository.get_member_in_workspace(
|
member = workspace_repository.get_member_in_workspace(
|
||||||
db=db, user_id=user.id, workspace_id=workspace_id
|
db=db, user_id=user.id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 任何成员都有访问权限
|
# 任何成员都有访问权限
|
||||||
workspace_memberships = {workspace_id} if member else set()
|
workspace_memberships = {workspace_id} if member else set()
|
||||||
|
|
||||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||||
resource = Resource.from_workspace(db_workspace)
|
resource = Resource.from_workspace(db_workspace)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
permission_service.require_permission(
|
permission_service.require_permission(
|
||||||
subject,
|
subject,
|
||||||
@@ -323,21 +320,21 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
|||||||
message="Workspace not found",
|
message="Workspace not found",
|
||||||
code=BizCode.WORKSPACE_NOT_FOUND
|
code=BizCode.WORKSPACE_NOT_FOUND
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用统一权限服务检查管理权限
|
# 使用统一权限服务检查管理权限
|
||||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||||
|
|
||||||
# 获取用户的工作空间成员关系
|
# 获取用户的工作空间成员关系
|
||||||
member = workspace_repository.get_member_in_workspace(
|
member = workspace_repository.get_member_in_workspace(
|
||||||
db=db, user_id=user.id, workspace_id=workspace_id
|
db=db, user_id=user.id, workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只有 manager 才有管理权限
|
# 只有 manager 才有管理权限
|
||||||
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
|
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
|
||||||
|
|
||||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||||
resource = Resource.from_workspace(db_workspace)
|
resource = Resource.from_workspace(db_workspace)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
permission_service.require_permission(
|
permission_service.require_permission(
|
||||||
subject,
|
subject,
|
||||||
@@ -353,14 +350,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
|||||||
|
|
||||||
|
|
||||||
def create_workspace_invite(
|
def create_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
invite_data: WorkspaceInviteCreate,
|
invite_data: WorkspaceInviteCreate,
|
||||||
user: User
|
user: User
|
||||||
) -> WorkspaceInviteResponse:
|
) -> WorkspaceInviteResponse:
|
||||||
"""创建工作空间邀请"""
|
"""创建工作空间邀请"""
|
||||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查权限
|
# 检查权限
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
@@ -368,7 +365,7 @@ def create_workspace_invite(
|
|||||||
# 检查被邀请用户是否已经在工作空间中
|
# 检查被邀请用户是否已经在工作空间中
|
||||||
from app.repositories import user_repository
|
from app.repositories import user_repository
|
||||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||||
|
|
||||||
if invited_user:
|
if invited_user:
|
||||||
# 用户存在,检查是否已经是工作空间成员
|
# 用户存在,检查是否已经是工作空间成员
|
||||||
existing_member = workspace_repository.get_member_in_workspace(
|
existing_member = workspace_repository.get_member_in_workspace(
|
||||||
@@ -379,14 +376,14 @@ def create_workspace_invite(
|
|||||||
if existing_member:
|
if existing_member:
|
||||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||||
|
|
||||||
# 检查是否已有待处理的邀请
|
# 检查是否已有待处理的邀请
|
||||||
invite_repo = WorkspaceInviteRepository(db)
|
invite_repo = WorkspaceInviteRepository(db)
|
||||||
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
|
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
|
||||||
email=invite_data.email,
|
email=invite_data.email,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
invite_token = None
|
invite_token = None
|
||||||
if existing_invite:
|
if existing_invite:
|
||||||
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
|
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
|
||||||
@@ -409,17 +406,17 @@ def create_workspace_invite(
|
|||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_invite)
|
db.refresh(db_invite)
|
||||||
invite_token = token
|
invite_token = token
|
||||||
|
|
||||||
invite_obj = existing_invite or db_invite
|
invite_obj = existing_invite or db_invite
|
||||||
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
|
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
|
||||||
|
|
||||||
# 构造响应
|
# 构造响应
|
||||||
response = WorkspaceInviteResponse.model_validate(invite_obj)
|
response = WorkspaceInviteResponse.model_validate(invite_obj)
|
||||||
response.invite_token = invite_token
|
response.invite_token = invite_token
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||||
@@ -427,8 +424,8 @@ def create_workspace_invite(
|
|||||||
|
|
||||||
|
|
||||||
def get_workspace_invites(
|
def get_workspace_invites(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user: User,
|
user: User,
|
||||||
status: Optional[InviteStatus] = None,
|
status: Optional[InviteStatus] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
@@ -436,15 +433,15 @@ def get_workspace_invites(
|
|||||||
) -> List[WorkspaceInviteResponse]:
|
) -> List[WorkspaceInviteResponse]:
|
||||||
"""获取工作空间邀请列表"""
|
"""获取工作空间邀请列表"""
|
||||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
# 检查工作空间是否存在
|
# 检查工作空间是否存在
|
||||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||||
if not workspace:
|
if not workspace:
|
||||||
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
|
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||||
|
|
||||||
# 检查权限
|
# 检查权限
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
|
|
||||||
# 获取邀请列表
|
# 获取邀请列表
|
||||||
invite_repo = WorkspaceInviteRepository(db)
|
invite_repo = WorkspaceInviteRepository(db)
|
||||||
invites = invite_repo.get_workspace_invites(
|
invites = invite_repo.get_workspace_invites(
|
||||||
@@ -453,35 +450,35 @@ def get_workspace_invites(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset
|
offset=offset
|
||||||
)
|
)
|
||||||
|
|
||||||
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
|
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
|
||||||
|
|
||||||
|
|
||||||
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||||
"""验证邀请令牌"""
|
"""验证邀请令牌"""
|
||||||
business_logger.info("验证邀请令牌")
|
business_logger.info("验证邀请令牌")
|
||||||
|
|
||||||
# 生成令牌哈希
|
# 生成令牌哈希
|
||||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
# 查找邀请
|
# 查找邀请
|
||||||
invite_repo = WorkspaceInviteRepository(db)
|
invite_repo = WorkspaceInviteRepository(db)
|
||||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||||
|
|
||||||
if not invite:
|
if not invite:
|
||||||
business_logger.warning("邀请令牌无效")
|
business_logger.warning("邀请令牌无效")
|
||||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||||
|
|
||||||
# 检查邀请状态和过期时间
|
# 检查邀请状态和过期时间
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
|
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
|
||||||
is_valid = not is_expired
|
is_valid = not is_expired
|
||||||
|
|
||||||
# 获取工作空间信息
|
# 获取工作空间信息
|
||||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||||
|
|
||||||
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
|
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
|
||||||
|
|
||||||
return InviteValidateResponse(
|
return InviteValidateResponse(
|
||||||
workspace_name=workspace.name,
|
workspace_name=workspace.name,
|
||||||
workspace_id=invite.workspace_id,
|
workspace_id=invite.workspace_id,
|
||||||
@@ -493,32 +490,32 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
|||||||
|
|
||||||
|
|
||||||
def accept_workspace_invite(
|
def accept_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
accept_request: InviteAcceptRequest,
|
accept_request: InviteAcceptRequest,
|
||||||
user: User
|
user: User
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""接受工作空间邀请"""
|
"""接受工作空间邀请"""
|
||||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# 生成令牌哈希
|
# 生成令牌哈希
|
||||||
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
|
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
|
||||||
|
|
||||||
# 查找邀请
|
# 查找邀请
|
||||||
invite_repo = WorkspaceInviteRepository(db)
|
invite_repo = WorkspaceInviteRepository(db)
|
||||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||||
|
|
||||||
if not invite:
|
if not invite:
|
||||||
business_logger.warning("邀请令牌无效")
|
business_logger.warning("邀请令牌无效")
|
||||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||||
|
|
||||||
# 检查邀请状态
|
# 检查邀请状态
|
||||||
if invite.status != InviteStatus.pending:
|
if invite.status != InviteStatus.pending:
|
||||||
business_logger.warning(f"邀请已被处理: status={invite.status}")
|
business_logger.warning(f"邀请已被处理: status={invite.status}")
|
||||||
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
|
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
|
||||||
|
|
||||||
# 检查过期时间
|
# 检查过期时间
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
if invite.expires_at < now:
|
if invite.expires_at < now:
|
||||||
@@ -526,31 +523,31 @@ def accept_workspace_invite(
|
|||||||
# 标记为过期
|
# 标记为过期
|
||||||
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
|
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
|
||||||
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
|
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
|
||||||
|
|
||||||
# 检查邮箱是否匹配
|
# 检查邮箱是否匹配
|
||||||
if invite.email != user.email:
|
if invite.email != user.email:
|
||||||
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
|
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
|
||||||
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
|
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
|
||||||
|
|
||||||
# 如果启用单工作空间模式,检查用户是否已有工作空间
|
# 如果启用单工作空间模式,检查用户是否已有工作空间
|
||||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||||
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||||
if user_workspaces:
|
if user_workspaces:
|
||||||
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
|
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
|
||||||
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
|
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
|
||||||
|
|
||||||
# 检查用户是否已经是工作空间成员
|
# 检查用户是否已经是工作空间成员
|
||||||
existing_member = workspace_repository.get_member_in_workspace(
|
existing_member = workspace_repository.get_member_in_workspace(
|
||||||
db=db,
|
db=db,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
workspace_id=invite.workspace_id
|
workspace_id=invite.workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing_member:
|
if existing_member:
|
||||||
business_logger.info("用户已是工作空间成员,更新邀请状态")
|
business_logger.info("用户已是工作空间成员,更新邀请状态")
|
||||||
invite_repo.update_invite_status(
|
invite_repo.update_invite_status(
|
||||||
invite.id,
|
invite.id,
|
||||||
InviteStatus.accepted,
|
InviteStatus.accepted,
|
||||||
accepted_at=now
|
accepted_at=now
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -559,10 +556,10 @@ def accept_workspace_invite(
|
|||||||
"message": "You are already a member of this workspace",
|
"message": "You are already a member of this workspace",
|
||||||
"workspace": workspace
|
"workspace": workspace
|
||||||
}
|
}
|
||||||
|
|
||||||
# 将角色映射到工作空间角色(现在直接使用相同的角色)
|
# 将角色映射到工作空间角色(现在直接使用相同的角色)
|
||||||
workspace_role = invite.role
|
workspace_role = invite.role
|
||||||
|
|
||||||
# 添加用户到工作空间
|
# 添加用户到工作空间
|
||||||
workspace_repository.add_member_to_workspace(
|
workspace_repository.add_member_to_workspace(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -570,27 +567,27 @@ def accept_workspace_invite(
|
|||||||
workspace_id=invite.workspace_id,
|
workspace_id=invite.workspace_id,
|
||||||
role=workspace_role
|
role=workspace_role
|
||||||
)
|
)
|
||||||
|
|
||||||
# 标记邀请为已接受
|
# 标记邀请为已接受
|
||||||
invite_repo.update_invite_status(
|
invite_repo.update_invite_status(
|
||||||
invite.id,
|
invite.id,
|
||||||
InviteStatus.accepted,
|
InviteStatus.accepted,
|
||||||
accepted_at=now
|
accepted_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 获取工作空间信息
|
# 获取工作空间信息
|
||||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||||
|
|
||||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"message": "Successfully joined the workspace",
|
"message": "Successfully joined the workspace",
|
||||||
"workspace": workspace,
|
"workspace": workspace,
|
||||||
"role": workspace_role
|
"role": workspace_role
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
|
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
|
||||||
@@ -598,34 +595,34 @@ def accept_workspace_invite(
|
|||||||
|
|
||||||
|
|
||||||
def revoke_workspace_invite(
|
def revoke_workspace_invite(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
invite_id: uuid.UUID,
|
invite_id: uuid.UUID,
|
||||||
user: User
|
user: User
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""撤销工作空间邀请"""
|
"""撤销工作空间邀请"""
|
||||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查权限
|
# 检查权限
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
|
|
||||||
# 撤销邀请
|
# 撤销邀请
|
||||||
invite_repo = WorkspaceInviteRepository(db)
|
invite_repo = WorkspaceInviteRepository(db)
|
||||||
invite = invite_repo.revoke_invite(invite_id)
|
invite = invite_repo.revoke_invite(invite_id)
|
||||||
|
|
||||||
if not invite:
|
if not invite:
|
||||||
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
|
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
|
||||||
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||||
|
|
||||||
if invite.workspace_id != workspace_id:
|
if invite.workspace_id != workspace_id:
|
||||||
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
|
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
|
||||||
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
|
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
|
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
|
||||||
return {"message": "邀请撤销成功"}
|
return {"message": "邀请撤销成功"}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
|
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
|
||||||
@@ -640,48 +637,48 @@ def update_workspace_member_roles(
|
|||||||
) -> List[WorkspaceMember]:
|
) -> List[WorkspaceMember]:
|
||||||
"""更新工作空间成员角色"""
|
"""更新工作空间成员角色"""
|
||||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||||
|
|
||||||
# 检查管理员权限
|
# 检查管理员权限
|
||||||
_check_workspace_admin_permission(db, workspace_id, user)
|
_check_workspace_admin_permission(db, workspace_id, user)
|
||||||
|
|
||||||
# 获取所有当前成员
|
# 获取所有当前成员
|
||||||
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||||
member_map = {m.id: m for m in all_members}
|
member_map = {m.id: m for m in all_members}
|
||||||
|
|
||||||
# 验证和业务规则检查
|
# 验证和业务规则检查
|
||||||
update_ids = set()
|
update_ids = set()
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
# 检查成员是否存在
|
# 检查成员是否存在
|
||||||
if upd.id not in member_map:
|
if upd.id not in member_map:
|
||||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||||
|
|
||||||
member = member_map[upd.id]
|
member = member_map[upd.id]
|
||||||
|
|
||||||
# 检查成员是否属于该工作空间
|
# 检查成员是否属于该工作空间
|
||||||
if member.workspace_id != workspace_id:
|
if member.workspace_id != workspace_id:
|
||||||
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||||
|
|
||||||
# 不能修改自己的角色
|
# 不能修改自己的角色
|
||||||
if member.user_id == user.id:
|
if member.user_id == user.id:
|
||||||
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
|
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
update_ids.add(upd.id)
|
update_ids.add(upd.id)
|
||||||
|
|
||||||
# 检查是否至少保留一个 manager
|
# 检查是否至少保留一个 manager
|
||||||
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
|
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
|
||||||
managers_after_update = [
|
managers_after_update = [
|
||||||
m for m in all_members
|
m for m in all_members
|
||||||
if m.id not in update_ids and m.role == WorkspaceRole.manager
|
if m.id not in update_ids and m.role == WorkspaceRole.manager
|
||||||
]
|
]
|
||||||
|
|
||||||
# 添加更新后会成为 manager 的成员
|
# 添加更新后会成为 manager 的成员
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
if upd.role == WorkspaceRole.manager:
|
if upd.role == WorkspaceRole.manager:
|
||||||
managers_after_update.append(member_map[upd.id])
|
managers_after_update.append(member_map[upd.id])
|
||||||
|
|
||||||
if len(managers_after_update) == 0:
|
if len(managers_after_update) == 0:
|
||||||
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
|
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
try:
|
try:
|
||||||
for upd in updates:
|
for upd in updates:
|
||||||
@@ -691,15 +688,15 @@ def update_workspace_member_roles(
|
|||||||
role=upd.role,
|
role=upd.role,
|
||||||
)
|
)
|
||||||
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
|
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
# 重新获取更新后的成员列表
|
# 重新获取更新后的成员列表
|
||||||
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||||
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
|
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
|
||||||
|
|
||||||
return updated_members
|
return updated_members
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
|
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
@@ -789,7 +786,7 @@ def get_workspace_models_configs(
|
|||||||
|
|
||||||
# 查询工作空间模型配置
|
# 查询工作空间模型配置
|
||||||
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
|
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
|
||||||
|
|
||||||
if configs is None:
|
if configs is None:
|
||||||
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
@@ -801,4 +798,5 @@ def get_workspace_models_configs(
|
|||||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user