[add] app chat v1

This commit is contained in:
Mark
2025-12-24 20:35:04 +08:00
parent 63d5047d21
commit bbd73d5e95
14 changed files with 1497 additions and 264 deletions

View File

@@ -361,7 +361,8 @@ async def draft_run(
workspace_id=workspace_id,
user=current_user
)
if storage_type is None: storage_type = 'neo4j'
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = ''
if workspace_id:
@@ -370,7 +371,8 @@ async def draft_run(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge: user_rag_memory_id = str(knowledge.id)
if knowledge:
user_rag_memory_id = str(knowledge.id)
# 提前验证和准备(在流式响应开始前完成)

View File

@@ -0,0 +1,97 @@
from fastapi import APIRouter, Depends, status
from sqlalchemy.orm import Session
import os
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.order_schema import CreateOrderRequest
from app.schemas.response_schema import ApiResponse
from app.services.order_service import get_order_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, error
# Get API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/order",
tags=["Order"],
)
@router.post("", response_model=ApiResponse)
async def create_order(
order_data: CreateOrderRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
try:
api_logger.info(f"User {current_user.id} creating order for product {order_data.product_id}")
# Get external API configuration from environment
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
# Get order service instance
order_service = get_order_service(
external_api_url=external_api_url,
api_key=api_key
)
# Forward request to external API
result = await order_service.create_order(
order_data=order_data,
user_id=str(current_user.id)
)
api_logger.info(f"Order created successfully: {result.get('order_id')}")
return success(data=result, msg="Order created successfully")
except Exception as e:
api_logger.error(f"Failed to create order: {str(e)}", exc_info=True)
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)
@router.get("/{order_id}", response_model=ApiResponse)
async def get_order(
order_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""Get order details from external API
Args:
order_id: Order ID
db: Database session
current_user: Current authenticated user
Returns:
API response with order details
"""
try:
api_logger.info(f"User {current_user.id} fetching order {order_id}")
# Get external API configuration
external_api_url = os.getenv("EXTERNAL_ORDER_API_URL")
api_key = os.getenv("EXTERNAL_ORDER_API_KEY")
# Get order service instance
order_service = get_order_service(
external_api_url=external_api_url,
api_key=api_key
)
# Fetch order from external API
result = await order_service.get_order(order_id)
api_logger.info(f"Order {order_id} fetched successfully")
return success(data=result, msg="Order fetched successfully")
except Exception as e:
api_logger.error(f"Failed to fetch order {order_id}: {str(e)}", exc_info=True)
return error(msg=str(e), code=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@@ -2,14 +2,30 @@
import uuid
from fastapi import APIRouter, Depends, Request, Body
from sqlalchemy.orm import Session
from typing import Optional, Annotated
from starlette.responses import StreamingResponse
from app.core.api_key_auth import require_api_key
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
from app.core.api_key_auth import require_api_key
from app.schemas.api_key_schema import ApiKeyAuth
from app.dependencies import get_app_or_workspace
from app.repositories import knowledge_repository
from app.schemas import AppChatRequest, conversation_schema
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories.end_user_repository import EndUserRepository
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import AppService
from app.services.conversation_service import ConversationService, get_conversation_service
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.utils.app_config_utils import dict_to_multi_agent_config,dict_to_agent_config,dict_to_workflow_config
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
router = APIRouter(prefix="/app", tags=["V1 - App API"])
logger = get_business_logger()
@@ -19,28 +35,232 @@ async def list_apps():
return success(data=[], msg="App API - Coming Soon")
# /v1/apps/{resource_id}/chat
@router.post("/{resource_id}/chat")
@require_api_key(scopes=["app"])
async def chat_with_agent_demo(
resource_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
# async def chat(
# request: Request,
# api_key_auth: ApiKeyAuth = None,
# db: Session = Depends(get_db),
# message: str = Body(..., description="聊天消息内容"),
# ):
# """
# Agent 聊天接口demo
# scopes: 所需的权限范围列表["app", "rag", "memory"]
# Args:
# resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
# message: 请求参数
# request: 声明请求
# api_key_auth: 包含验证后的API Key 信息
# db: db_session
# """
# logger.info(f"API Key Auth: {api_key_auth}")
# logger.info(f"Resource ID: {resource_id}")
# logger.info(f"Message: {message}")
# return success(data={"received": True}, msg="消息已接收")
def _checkAppConfig(app: App):
if app.type == AppType.AGENT:
if not app.current_release.config:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app.type == AppType.MULTI_AGENT:
if not app.current_release.config:
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app.type == AppType.WORKFLOW:
if not app.current_release.config:
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
else:
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
@router.post("/chat")
# @require_api_key(scopes=["app"])
async def chat(
payload: AppChatRequest,
app: App = Depends(get_app_or_workspace),
db: Session = Depends(get_db),
message: str = Body(..., description="聊天消息内容"),
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
):
"""
Agent 聊天接口demo
other_id = payload.user_id
workspace_id = app.workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id,
other_id=other_id,
original_user_id=other_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
scopes: 所需的权限范围列表["app", "rag", "memory"]
# 提前验证和准备(在流式响应开始前完成)
storage_type = workspace_service.get_workspace_storage_type_without_auth(
db=db,
workspace_id=workspace_id
)
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = ''
if storage_type == 'rag':
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
app_type = app.type
# check app config
_checkAppConfig(app)
Args:
resource_id: 如果是应用的apikey传的是应用id; 如果是服务的apikey传的是工作空间id
message: 请求参数
request: 声明请求
api_key_auth: 包含验证后的API Key 信息
db: db_session
"""
logger.info(f"API Key Auth: {api_key_auth}")
logger.info(f"Resource ID: {resource_id}")
logger.info(f"Message: {message}")
return success(data={"received": True}, msg="消息已接收")
# 获取或创建会话(提前验证)
conversation = conversation_service.create_or_get_conversation(
app_id=app.id,
workspace_id=workspace_id,
user_id=end_user_id,
is_draft=False
)
if app_type == AppType.AGENT:
agent_config = dict_to_agent_config(app.current_release.config)
# 流式返回
if payload.stream:
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= end_user_id, # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=app.current_release.config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
result = await app_chat_service.agnet_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config= agent_config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回
config = dict_to_multi_agent_config(app.current_release.config)
if payload.stream:
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
# 多 Agent 流式返回
config = dict_to_workflow_config(app.current_release.config)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, status
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import uuid

View File

@@ -1,12 +1,13 @@
import uuid
from functools import wraps
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from jose import jwt, JWTError
from app.db import get_db, SessionLocal
from app.models import App
from app.schemas import token_schema
from app.core.config import settings
from app.core.security import get_token_id
@@ -27,6 +28,51 @@ security_logger = get_security_logger()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class APIKeyExtractor:
"""
Custom dependency to extract API Key from request headers
Supports two formats:
1. Authorization: Bearer <api_key>
2. X-API-Key: <api_key>
"""
async def __call__(self, request: Request) -> str:
"""Extract API Key from request headers
Args:
request: FastAPI Request object
Returns:
API Key string
Raises:
HTTPException: If API Key is not found
"""
# Try Authorization header first
auth_header = request.headers.get("Authorization")
if auth_header and " " in auth_header:
auth_scheme, auth_token = auth_header.split(" ", 1)
if auth_scheme.lower() == "bearer":
return auth_token
# Try X-API-Key header
api_key = request.headers.get("X-API-Key")
if api_key:
return api_key
# No API Key found
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API Key not found in request headers",
headers={"WWW-Authenticate": "Bearer"},
)
api_key_extractor = APIKeyExtractor()
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
@@ -304,7 +350,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
- db: Session = Depends(get_db)
- user 或 current_user: User = Depends(get_current_user)
- workspace_id: uuid.UUID query/path 参数)或 payload: AppCreatebody含 workspace_id
支持同步和异步函数。
"""
import asyncio
@@ -360,7 +406,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
def get_uow() -> IUnitOfWork:
"""
获取工作单元实例
Returns:
IUnitOfWork: 工作单元实例
"""
@@ -373,7 +419,7 @@ def cur_workspace_access_guard():
要求端点函数签名包含:
- db: Session = Depends(get_db)
- current_user: User = Depends(get_current_user)
支持同步和异步函数。
"""
import asyncio
@@ -423,10 +469,10 @@ async def get_share_user_id(
) -> ShareTokenData:
"""
从分享访问 token 中获取用户 ID 和 share_token
这个函数用于公开分享的接口,验证访问 token 并返回用户信息
不需要验证用户是否存在或激活,只需要验证 token 的有效性和 share_token 是否有效
Returns:
ShareTokenData: 包含 user_id 和 share_token
"""
@@ -469,4 +515,75 @@ async def get_share_user_id(
raise credentials_exception
async def get_app_or_workspace(
api_key: str = Depends(api_key_extractor),
db: Session = Depends(get_db)
) -> App | Workspace:
"""
Get App or Workspace from API Key
Supports two API Key formats:
1. Authorization: Bearer <api_key>
2. X-API-Key: <api_key>
Args:
api_key: API Key extracted from request headers
db: Database session
Returns:
App or Workspace object based on API Key
Raises:
HTTPException: If API Key is invalid or not found
"""
from app.services.api_key_service import ApiKeyAuthService
from app.repositories.app_repository import get_apps_by_id
from app.repositories.workspace_repository import get_workspace_by_id
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate API Key",
headers={"WWW-Authenticate": "Bearer"},
)
try:
auth_logger.debug(f"Validating API Key: {api_key[:10]}...")
# Validate API Key
api_key_obj = ApiKeyAuthService.validate_api_key(db, api_key)
if not api_key_obj:
auth_logger.warning(f"Invalid or expired API Key: {api_key[:10]}...")
raise credentials_exception
auth_logger.debug(f"API Key validated successfully, type: {api_key_obj.type}")
# Return App or Workspace based on API Key type
if (api_key_obj.type == "agent" or api_key.type == "multi_agent") and api_key_obj.resource_id:
# App API Key
app = get_apps_by_id(db, api_key_obj.resource_id)
if not app:
auth_logger.warning(f"App not found for API Key: {api_key_obj.resource_id}")
raise credentials_exception
auth_logger.info(f"App access granted: {app.id}")
return app
elif api_key_obj.type == "service":
# Workspace API Key
workspace = get_workspace_by_id(db, api_key_obj.workspace_id)
if not workspace:
auth_logger.warning(f"Workspace not found for API Key: {api_key_obj.workspace_id}")
raise credentials_exception
auth_logger.info(f"Workspace access granted: {workspace.id}")
return workspace
else:
auth_logger.warning(f"Unsupported API Key type: {api_key_obj.type}")
raise credentials_exception
except HTTPException:
raise
except Exception as e:
auth_logger.error(f"Error validating API Key: {str(e)}", exc_info=True)
raise credentials_exception

View File

@@ -2,38 +2,10 @@ import os
import subprocess
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, APIRouter
from fastapi import HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.core.response_utils import fail
from app.core.logging_config import LoggingConfig, get_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode, HTTP_MAPPING
from app.controllers import (
model_controller,
task_controller,
test_controller,
user_controller,
auth_controller,
workspace_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
app_controller,
upload_controller,
memory_agent_controller,
memory_storage_controller,
memory_dashboard_controller,
multi_agent_controller,
)
from fastapi import FastAPI, APIRouter
app = FastAPI(title="Data Config API", version="1.0.0")
router = APIRouter(prefix="/memory", tags=["Memory"])
# 管理端 API (JWT 认证)
from app.controllers import manager_router

View File

@@ -8,7 +8,9 @@ from .file_schema import File, FileCreate, FileUpdate
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
from .app_schema import (
AppChatRequest,
DraftRunRequest,
DraftRunResponse,
DraftRunStreamChunk,
@@ -73,6 +75,10 @@ __all__ = [
"ChunkRetrieve",
"KnowledgeShare",
"KnowledgeShareCreate",
"CreateOrderRequest",
"OrderResponse",
"ExternalOrderResponse",
"AppChatRequest",
"DraftRunRequest",
"DraftRunResponse",
"DraftRunStreamChunk",

View File

@@ -334,6 +334,13 @@ class AppShare(BaseModel):
# ---------- Draft Run Schemas ----------
class AppChatRequest(BaseModel):
message: str = Field(..., description="用户消息")
conversation_id: Optional[str] = Field(default=None, description="会话ID用于多轮对话")
user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
class DraftRunRequest(BaseModel):
"""试运行请求"""
message: str = Field(..., description="用户消息")

View File

@@ -0,0 +1,63 @@
"""
Order Schema
Defines request and response models for order operations.
"""
from pydantic import BaseModel, Field
from typing import Any, Optional
class CreateOrderRequest(BaseModel):
"""Create order request model"""
product_id: str = Field(..., description="Product ID")
quantity: int = Field(..., gt=0, description="Order quantity")
customer_name: Optional[str] = Field(None, description="Customer name")
customer_email: Optional[str] = Field(None, description="Customer email")
notes: Optional[str] = Field(None, description="Order notes")
class Config:
json_schema_extra = {
"example": {
"product_id": "PROD-001",
"quantity": 2,
"customer_name": "John Doe",
"customer_email": "john@example.com",
"notes": "Please deliver before 5pm"
}
}
class OrderResponse(BaseModel):
"""Order response model"""
order_id: str = Field(..., description="Order ID")
status: str = Field(..., description="Order status")
product_id: str = Field(..., description="Product ID")
quantity: int = Field(..., description="Order quantity")
total_amount: Optional[float] = Field(None, description="Total amount")
created_at: Optional[str] = Field(None, description="Creation timestamp")
message: Optional[str] = Field(None, description="Response message")
class Config:
json_schema_extra = {
"example": {
"order_id": "ORD-20231224-001",
"status": "pending",
"product_id": "PROD-001",
"quantity": 2,
"total_amount": 199.99,
"created_at": "2023-12-24T10:30:00Z",
"message": "Order created successfully"
}
}
class ExternalOrderResponse(BaseModel):
"""External API response model (flexible structure)"""
success: bool = Field(default=True, description="Request success status")
data: Optional[Any] = Field(None, description="Response data")
error: Optional[str] = Field(None, description="Error message")
code: Optional[int] = Field(None, description="Response code")

View File

@@ -0,0 +1,485 @@
"""基于分享链接的聊天服务"""
import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator, Annotated
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
logger = get_business_logger()
class AppChatService:
"""基于分享链接的聊天服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
async def agnet_chat(
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
if variables is None:
variables = {}
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag = False
if memory == True:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 调用 Agent
result = await agent.chat(
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
self.conversation_service.save_conversation_messages(
conversation_id=conversation_id,
user_message=message,
assistant_message=result["content"]
)
elapsed_time = time.time() - start_time
return {
"conversation_id": conversation_id,
"message": result["content"],
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
}
async def agnet_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
try:
start_time = time.time()
config_id = None
if variables is None:
variables = {}
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag = False
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent
full_content = ""
async for chunk in agent.chat_stream(
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
}
)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
"流式聊天完成",
extra={
"conversation_id": str(conversation_id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断")
raise
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def multi_agent_chat(
self,
message: str,
conversation_id: uuid.UUID,
config: MultiAgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""多 Agent 聊天(非流式)"""
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 执行任务
result = await orchestrator.execute(
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
use_llm_routing=True, # 默认启用 LLM 路由
web_search=web_search, # 网络搜索参数
memory=memory # 记忆功能参数
)
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=result.get("message", ""),
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"sub_results": result.get("sub_results")
}
)
return {
"conversation_id": conversation_id,
"message": result.get("message", ""),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
},
"elapsed_time": elapsed_time
}
async def multi_agent_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: MultiAgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)"""
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
try:
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
full_content = ""
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
use_llm_routing=True,
web_search=web_search, # 网络搜索参数
memory=memory, # 记忆功能参数
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=full_content,
meta_data={
"elapsed_time": elapsed_time
}
)
logger.info(
"多 Agent 流式聊天完成",
extra={
"conversation_id": str(conversation_id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")
raise
except Exception as e:
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
# ==================== 依赖注入函数 ====================
def get_app_chat_service(
db: Annotated[Session, Depends(get_db)]
) -> ChatService:
"""获取工作流服务(依赖注入)"""
return ChatService(db)

View File

@@ -1,9 +1,12 @@
"""会话服务"""
import uuid
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Annotated
from fastapi import Depends
from sqlalchemy.orm import Session
from sqlalchemy import select, desc
from app.db import get_db
from app.models import Conversation, Message
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
@@ -14,10 +17,10 @@ logger = get_business_logger()
class ConversationService:
"""会话服务"""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
app_id: uuid.UUID,
@@ -36,11 +39,11 @@ class ConversationService:
is_draft=is_draft,
config_snapshot=config_snapshot
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info(
"创建会话成功",
extra={
@@ -50,9 +53,9 @@ class ConversationService:
"is_draft": is_draft
}
)
return conversation
def get_conversation(
self,
conversation_id: uuid.UUID,
@@ -60,17 +63,17 @@ class ConversationService:
) -> Conversation:
"""获取会话"""
stmt = select(Conversation).where(Conversation.id == conversation_id)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
conversation = self.db.scalars(stmt).first()
if not conversation:
raise ResourceNotFoundException("会话", str(conversation_id))
return conversation
def list_conversations(
self,
app_id: uuid.UUID,
@@ -86,25 +89,25 @@ class ConversationService:
Conversation.workspace_id == workspace_id,
Conversation.is_active == True
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# 总数
count_stmt = stmt.with_only_columns(Conversation.id)
total = len(self.db.execute(count_stmt).all())
# 分页
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
return conversations, total
def add_message(
self,
conversation_id: uuid.UUID,
@@ -119,22 +122,22 @@ class ConversationService:
content=content,
meta_data=meta_data
)
self.db.add(message)
# 更新会话的消息计数和更新时间
conversation = self.get_conversation(conversation_id)
conversation.message_count += 1
# 如果是第一条用户消息,可以用它作为标题
if conversation.message_count == 1 and role == "user":
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
self.db.commit()
self.db.refresh(message)
return message
def get_messages(
self,
conversation_id: uuid.UUID,
@@ -144,30 +147,30 @@ class ConversationService:
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
if limit:
stmt = stmt.limit(limit)
messages = list(self.db.scalars(stmt).all())
return messages
def get_conversation_history(
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None
) -> List[dict]:
"""获取会话历史消息
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
Returns:
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
"""
messages = self.get_messages(conversation_id, limit=max_history)
# 转换为字典格式
history = [
{
@@ -176,9 +179,9 @@ class ConversationService:
}
for msg in messages
]
return history
def save_conversation_messages(
self,
conversation_id: uuid.UUID,
@@ -192,14 +195,14 @@ class ConversationService:
role="user",
content=user_message
)
# 添加助手消息
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
)
logger.debug(
"保存会话消息成功",
extra={
@@ -208,7 +211,7 @@ class ConversationService:
"assistant_message_length": len(assistant_message)
}
)
def delete_conversation(
self,
conversation_id: uuid.UUID,
@@ -217,9 +220,9 @@ class ConversationService:
"""删除会话(软删除)"""
conversation = self.get_conversation(conversation_id, workspace_id)
conversation.is_active = False
self.db.commit()
logger.info(
"删除会话成功",
extra={
@@ -227,3 +230,53 @@ class ConversationService:
"workspace_id": str(workspace_id)
}
)
def create_or_get_conversation(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
is_draft: bool = False,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
) -> Conversation:
"""创建或获取会话"""
# 如果提供了 conversation_id尝试获取现有会话
if conversation_id:
try:
conversation = self.get_conversation(
conversation_id=conversation_id,
workspace_id=workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
except ResourceNotFoundException:
logger.warning(
"会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)}
)
# 创建新会话(使用发布版本的配置)
conversation = self.create_conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
is_draft=is_draft
)
logger.info(
"为分享链接创建新会话"
)
return conversation
# ==================== 依赖注入函数 ====================
def get_conversation_service(
db: Annotated[Session, Depends(get_db)]
) -> ConversationService:
"""获取工作流服务(依赖注入)"""
return ConversationService(db)

View File

@@ -36,7 +36,7 @@ class ModelConfigService:
"""获取模型配置列表"""
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=query.page,
@@ -72,7 +72,7 @@ class ModelConfigService:
test_message: str = "Hello"
) -> Dict[str, Any]:
"""验证模型配置是否有效
Args:
db: 数据库会话
model_name: 模型名称
@@ -81,7 +81,7 @@ class ModelConfigService:
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
Returns:
Dict: 验证结果
"""
@@ -89,10 +89,10 @@ class ModelConfigService:
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
import traceback
try:
start_time = time.time()
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
@@ -101,16 +101,16 @@ class ModelConfigService:
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower()
if model_type_lower in ["llm", "chat"]:
# LLM/Chat 模型验证 - 统一使用字符串输入
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
response = await llm.ainvoke(test_message)
elapsed_time = time.time() - start_time
content = response.content if hasattr(response, 'content') else str(response)
usage = None
if hasattr(response, 'usage_metadata'):
@@ -119,7 +119,7 @@ class ModelConfigService:
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
}
return {
"valid": True,
"message": f"{model_type.upper()} 模型配置验证成功",
@@ -128,14 +128,14 @@ class ModelConfigService:
"usage": usage,
"error": None
}
elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法)
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Embedding 模型配置验证成功",
@@ -148,7 +148,7 @@ class ModelConfigService:
},
"error": None
}
elif model_type_lower == "rerank":
# Rerank 模型验证(在线程中运行同步方法)
rerank = RedBearRerank(model_config)
@@ -156,7 +156,7 @@ class ModelConfigService:
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Rerank 模型配置验证成功",
@@ -169,7 +169,7 @@ class ModelConfigService:
},
"error": None
}
else:
return {
"valid": False,
@@ -179,7 +179,7 @@ class ModelConfigService:
"usage": None,
"error": f"不支持的模型类型: {model_type}"
}
except Exception as e:
# 提取详细的错误信息
error_message = str(e)
@@ -203,12 +203,12 @@ class ModelConfigService:
error_message = f"无效请求: {error_message}"
elif "model_copy" in error_message:
error_message = "模型消息格式错误: 请确保使用正确的模型类型LLM/Chat"
# 记录详细错误日志
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
logger.error(f"错误详情: {error_message}")
logger.debug(f"完整堆栈: {traceback.format_exc()}")
return {
"valid": False,
"message": f"{model_type.upper()} 模型配置验证失败",
@@ -249,7 +249,7 @@ class ModelConfigService:
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
# 添加租户ID
model_config_data["tenant_id"] = tenant_id
model = ModelConfigRepository.create(db, model_config_data)
db.flush() # 获取生成的 ID
@@ -259,7 +259,7 @@ class ModelConfigService:
**api_key_data.dict()
)
ModelApiKeyRepository.create(db, api_key_create_schema)
db.commit()
db.refresh(model)
return model
@@ -270,11 +270,11 @@ class ModelConfigService:
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
db.commit()
db.refresh(model)
@@ -285,7 +285,7 @@ class ModelConfigService:
"""删除模型配置"""
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
db.commit()
return success
@@ -316,20 +316,20 @@ class ModelApiKeyService:
return api_key
@staticmethod
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> list[ModelApiKey]:
"""根据模型配置ID获取API Key列表"""
if not ModelConfigRepository.get_by_id(db, model_config_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
@staticmethod
@staticmethod
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
"""创建API Key"""
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
@@ -345,7 +345,7 @@ class ModelApiKeyService:
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.create(db, api_key_data)
db.commit()
db.refresh(api_key)
@@ -357,12 +357,12 @@ class ModelApiKeyService:
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not existing_api_key:
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
# 获取关联的模型配置以获取模型类型
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
if not model_config:
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
@@ -378,7 +378,7 @@ class ModelApiKeyService:
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
db.commit()
db.refresh(api_key)
@@ -389,7 +389,7 @@ class ModelApiKeyService:
"""删除API Key"""
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
success = ModelApiKeyRepository.delete(db, api_key_id)
db.commit()
return success
@@ -409,3 +409,11 @@ class ModelApiKeyService:
if success:
db.commit()
return success
@staticmethod
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
api_kes = ModelApiKeyService.get_api_keys_by_model(db, model_config_id)
if api_kes and len(api_kes) > 0:
return api_kes[0]
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)

View File

@@ -0,0 +1,205 @@
"""
Order Service
Handles order operations including forwarding requests to external APIs.
"""
import logging
import httpx
from typing import Dict, Any, Optional
from app.schemas.order_schema import CreateOrderRequest
logger = logging.getLogger(__name__)
class OrderService:
"""Order service for handling order operations"""
def __init__(self, external_api_url: Optional[str] = None, api_key: Optional[str] = None):
"""Initialize order service
Args:
external_api_url: External API base URL
api_key: API key for authentication
"""
# Default external API URL (replace with actual URL)
self.external_api_url = external_api_url or "https://api.example.com/v1"
self.api_key = api_key
self.timeout = 30.0 # 30 seconds timeout
async def create_order(
self,
order_data: CreateOrderRequest,
user_id: str
) -> Dict[str, Any]:
"""Create order by forwarding request to external API
Args:
order_data: Order creation data
user_id: Current user ID
Returns:
Order response data
Raises:
httpx.HTTPError: If external API request fails
Exception: For other errors
"""
try:
# Prepare request payload
payload = {
"product_id": order_data.product_id,
"quantity": order_data.quantity,
"customer_name": order_data.customer_name,
"customer_email": order_data.customer_email,
"notes": order_data.notes,
"user_id": user_id # Include user ID for tracking
}
# Prepare headers
headers = {
"Content-Type": "application/json",
"User-Agent": "MemoryBear-OrderService/1.0"
}
# Add API key if configured
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(f"Forwarding order creation request to external API: {self.external_api_url}/orders")
logger.debug(f"Request payload: {payload}")
# Make async HTTP request to external API
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.external_api_url}/orders",
json=payload,
headers=headers
)
# Log response status
logger.info(f"External API response status: {response.status_code}")
# Raise exception for 4xx/5xx status codes
response.raise_for_status()
# Parse response
response_data = response.json()
logger.debug(f"External API response data: {response_data}")
# Transform external API response to internal format
return self._transform_external_response(response_data)
except httpx.HTTPStatusError as e:
logger.error(f"External API returned error status: {e.response.status_code}")
logger.error(f"Error response: {e.response.text}")
# Try to parse error response
try:
error_data = e.response.json()
error_message = error_data.get("message") or error_data.get("error") or "External API error"
except Exception:
error_message = f"External API error: {e.response.status_code}"
raise Exception(f"Failed to create order: {error_message}")
except httpx.TimeoutException:
logger.error(f"External API request timeout after {self.timeout}s")
raise Exception("Order creation timeout - external service not responding")
except httpx.RequestError as e:
logger.error(f"External API request failed: {str(e)}")
raise Exception(f"Failed to connect to external order service: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error during order creation: {str(e)}", exc_info=True)
raise Exception(f"Order creation failed: {str(e)}")
def _transform_external_response(self, external_data: Dict[str, Any]) -> Dict[str, Any]:
"""Transform external API response to internal format
Args:
external_data: Response data from external API
Returns:
Transformed response data
"""
# Handle different response formats from external API
# Adjust this based on actual external API response structure
if "data" in external_data:
# Format 1: {"success": true, "data": {...}}
data = external_data["data"]
elif "order" in external_data:
# Format 2: {"order": {...}}
data = external_data["order"]
else:
# Format 3: Direct response
data = external_data
# Extract fields with fallbacks
return {
"order_id": data.get("order_id") or data.get("id") or "UNKNOWN",
"status": data.get("status") or "pending",
"product_id": data.get("product_id") or "",
"quantity": data.get("quantity") or 0,
"total_amount": data.get("total_amount") or data.get("amount"),
"created_at": data.get("created_at") or data.get("timestamp"),
"message": external_data.get("message") or "Order created successfully"
}
async def get_order(self, order_id: str) -> Dict[str, Any]:
"""Get order details from external API
Args:
order_id: Order ID
Returns:
Order details
"""
try:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
logger.info(f"Fetching order {order_id} from external API")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(
f"{self.external_api_url}/orders/{order_id}",
headers=headers
)
response.raise_for_status()
return response.json()
except Exception as e:
logger.error(f"Failed to fetch order {order_id}: {str(e)}")
raise Exception(f"Failed to fetch order: {str(e)}")
# Singleton instance
_order_service_instance: Optional[OrderService] = None
def get_order_service(
external_api_url: Optional[str] = None,
api_key: Optional[str] = None
) -> OrderService:
"""Get order service instance
Args:
external_api_url: External API URL (optional, uses default if not provided)
api_key: API key (optional)
Returns:
OrderService instance
"""
global _order_service_instance
if _order_service_instance is None:
_order_service_instance = OrderService(
external_api_url=external_api_url,
api_key=api_key
)
return _order_service_instance

View File

@@ -1,35 +1,32 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import secrets
import hashlib
import datetime
from fastapi import HTTPException, status
import hashlib
import secrets
import uuid
from os import getenv
from typing import List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.models.tenant_model import Tenants
from app.core.logging_config import get_business_logger
from app.models.user_model import User
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.schemas.workspace_schema import (
WorkspaceCreate,
WorkspaceUpdate,
WorkspaceInviteCreate,
WorkspaceCreate,
WorkspaceUpdate,
WorkspaceInviteCreate,
WorkspaceInviteResponse,
InviteValidateResponse,
InviteAcceptRequest,
WorkspaceMemberUpdate
)
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.services import user_service
from os import getenv
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
import os #
from dotenv import load_dotenv
load_dotenv()
def switch_workspace(
@@ -39,10 +36,10 @@ def switch_workspace(
):
"""切换工作空间"""
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
# 检查用户是否为成员或超级管理员
_check_workspace_member_permission(db, workspace_id, user)
# 更新当前用户的工作空间上下文
try:
user.current_workspace_id = workspace_id
@@ -63,22 +60,22 @@ def delete_workspace_member(
):
"""删除工作空间成员"""
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
_check_workspace_admin_permission(db, workspace_id, user)
_check_workspace_admin_permission(db, workspace_id, user)
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
if not workspace_member:
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
if workspace_member.workspace_id != workspace_id:
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
try:
try:
workspace_member.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
@@ -94,7 +91,7 @@ def _create_workspace_only(
db: Session, workspace: WorkspaceCreate, owner: User
) -> Workspace:
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
try:
# Create the workspace without adding any members
business_logger.debug(f"创建工作空间: {workspace.name}")
@@ -126,7 +123,7 @@ def create_workspace(
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
db.commit()
db.refresh(db_workspace)
# 如果 storage_type 是 "rag",自动创建知识库
if workspace.storage_type == "rag":
business_logger.info(
@@ -138,7 +135,7 @@ def create_workspace(
from app.schemas.knowledge_schema import KnowledgeCreate
from app.models.knowledge_model import KnowledgeType, PermissionType
from app.repositories import knowledge_repository
# 创建知识库数据
knowledge_data = KnowledgeCreate(
workspace_id=db_workspace.id,
@@ -162,10 +159,10 @@ def create_workspace(
"html4excel": False
}
)
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
db_knowledge = knowledge_repository.create_knowledge(
db=db,
db=db,
knowledge=knowledge_data
)
db.commit()
@@ -179,12 +176,12 @@ def create_workspace(
)
db.rollback()
raise BusinessException(
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
BizCode.INTERNAL_ERROR
)
return db_workspace
except Exception as e:
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
db.rollback()
@@ -195,7 +192,7 @@ def update_workspace(
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
) -> Workspace:
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
try:
# 更新工作空间
@@ -219,8 +216,8 @@ def get_workspace_members(
db: Session, workspace_id: uuid.UUID, user: User
) -> List[WorkspaceMember]:
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
# 查找工作空间
business_logger.debug(f"查找工作空间: {workspace_id}")
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
@@ -237,10 +234,10 @@ def get_workspace_members(
db=db, user_id=user.id, workspace_id=workspace_id
)
workspace_memberships = {workspace_id} if member else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(workspace)
try:
permission_service.require_permission(
subject,
@@ -265,7 +262,7 @@ def get_workspace_members(
def _generate_invite_token() -> tuple[str, str]:
"""生成邀请令牌和其哈希值
Returns:
tuple: (原始令牌, 令牌哈希)
"""
@@ -285,21 +282,21 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
message="Workspace not found",
code=BizCode.WORKSPACE_NOT_FOUND
)
# 使用统一权限服务检查访问权限
from app.core.permissions import permission_service, Subject, Resource, Action
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
# 任何成员都有访问权限
workspace_memberships = {workspace_id} if member else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(db_workspace)
try:
permission_service.require_permission(
subject,
@@ -323,21 +320,21 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
message="Workspace not found",
code=BizCode.WORKSPACE_NOT_FOUND
)
# 使用统一权限服务检查管理权限
from app.core.permissions import permission_service, Subject, Resource, Action
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
# 只有 manager 才有管理权限
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(db_workspace)
try:
permission_service.require_permission(
subject,
@@ -353,14 +350,14 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
def create_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
db: Session,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
user: User
) -> WorkspaceInviteResponse:
"""创建工作空间邀请"""
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
try:
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
@@ -368,7 +365,7 @@ def create_workspace_invite(
# 检查被邀请用户是否已经在工作空间中
from app.repositories import user_repository
invited_user = user_repository.get_user_by_email(db, invite_data.email)
if invited_user:
# 用户存在,检查是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
@@ -379,14 +376,14 @@ def create_workspace_invite(
if existing_member:
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
# 检查是否已有待处理的邀请
invite_repo = WorkspaceInviteRepository(db)
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
email=invite_data.email,
email=invite_data.email,
workspace_id=workspace_id
)
invite_token = None
if existing_invite:
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
@@ -409,17 +406,17 @@ def create_workspace_invite(
)
db.commit()
db.refresh(db_invite)
invite_token = token
invite_token = token
invite_obj = existing_invite or db_invite
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
# 构造响应
response = WorkspaceInviteResponse.model_validate(invite_obj)
response.invite_token = invite_token
return response
except Exception as e:
db.rollback()
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
@@ -427,8 +424,8 @@ def create_workspace_invite(
def get_workspace_invites(
db: Session,
workspace_id: uuid.UUID,
db: Session,
workspace_id: uuid.UUID,
user: User,
status: Optional[InviteStatus] = None,
limit: int = 50,
@@ -436,15 +433,15 @@ def get_workspace_invites(
) -> List[WorkspaceInviteResponse]:
"""获取工作空间邀请列表"""
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
# 检查工作空间是否存在
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not workspace:
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
# 获取邀请列表
invite_repo = WorkspaceInviteRepository(db)
invites = invite_repo.get_workspace_invites(
@@ -453,35 +450,35 @@ def get_workspace_invites(
limit=limit,
offset=offset
)
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
"""验证邀请令牌"""
business_logger.info("验证邀请令牌")
# 生成令牌哈希
token_hash = hashlib.sha256(token.encode()).hexdigest()
# 查找邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning("邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态和过期时间
now = datetime.datetime.now()
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
is_valid = not is_expired
# 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
return InviteValidateResponse(
workspace_name=workspace.name,
workspace_id=invite.workspace_id,
@@ -493,32 +490,32 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
def accept_workspace_invite(
db: Session,
accept_request: InviteAcceptRequest,
db: Session,
accept_request: InviteAcceptRequest,
user: User
) -> dict:
"""接受工作空间邀请"""
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
try:
from app.core.config import settings
# 生成令牌哈希
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
# 查找邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning("邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态
if invite.status != InviteStatus.pending:
business_logger.warning(f"邀请已被处理: status={invite.status}")
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
# 检查过期时间
now = datetime.datetime.now()
if invite.expires_at < now:
@@ -526,31 +523,31 @@ def accept_workspace_invite(
# 标记为过期
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
# 检查邮箱是否匹配
if invite.email != user.email:
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
# 如果启用单工作空间模式,检查用户是否已有工作空间
if settings.ENABLE_SINGLE_WORKSPACE:
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
if user_workspaces:
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
# 检查用户是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
db=db,
user_id=user.id,
db=db,
user_id=user.id,
workspace_id=invite.workspace_id
)
if existing_member:
business_logger.info("用户已是工作空间成员,更新邀请状态")
invite_repo.update_invite_status(
invite.id,
InviteStatus.accepted,
invite.id,
InviteStatus.accepted,
accepted_at=now
)
db.commit()
@@ -559,10 +556,10 @@ def accept_workspace_invite(
"message": "You are already a member of this workspace",
"workspace": workspace
}
# 将角色映射到工作空间角色(现在直接使用相同的角色)
workspace_role = invite.role
# 添加用户到工作空间
workspace_repository.add_member_to_workspace(
db=db,
@@ -570,27 +567,27 @@ def accept_workspace_invite(
workspace_id=invite.workspace_id,
role=workspace_role
)
# 标记邀请为已接受
invite_repo.update_invite_status(
invite.id,
InviteStatus.accepted,
invite.id,
InviteStatus.accepted,
accepted_at=now
)
db.commit()
# 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
return {
"message": "Successfully joined the workspace",
"workspace": workspace,
"role": workspace_role
}
except Exception as e:
db.rollback()
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
@@ -598,34 +595,34 @@ def accept_workspace_invite(
def revoke_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_id: uuid.UUID,
db: Session,
workspace_id: uuid.UUID,
invite_id: uuid.UUID,
user: User
) -> dict:
"""撤销工作空间邀请"""
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
try:
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
# 撤销邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.revoke_invite(invite_id)
if not invite:
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
if invite.workspace_id != workspace_id:
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
db.commit()
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
return {"message": "邀请撤销成功"}
except Exception as e:
db.rollback()
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
@@ -640,48 +637,48 @@ def update_workspace_member_roles(
) -> List[WorkspaceMember]:
"""更新工作空间成员角色"""
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
# 检查管理员权限
_check_workspace_admin_permission(db, workspace_id, user)
# 获取所有当前成员
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
member_map = {m.id: m for m in all_members}
# 验证和业务规则检查
update_ids = set()
for upd in updates:
# 检查成员是否存在
if upd.id not in member_map:
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
member = member_map[upd.id]
# 检查成员是否属于该工作空间
if member.workspace_id != workspace_id:
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
# 不能修改自己的角色
if member.user_id == user.id:
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
update_ids.add(upd.id)
# 检查是否至少保留一个 manager
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
managers_after_update = [
m for m in all_members
m for m in all_members
if m.id not in update_ids and m.role == WorkspaceRole.manager
]
# 添加更新后会成为 manager 的成员
for upd in updates:
if upd.role == WorkspaceRole.manager:
managers_after_update.append(member_map[upd.id])
if len(managers_after_update) == 0:
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
# 执行更新
try:
for upd in updates:
@@ -691,15 +688,15 @@ def update_workspace_member_roles(
role=upd.role,
)
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
db.commit()
# 重新获取更新后的成员列表
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
return updated_members
except Exception as e:
db.rollback()
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
@@ -789,7 +786,7 @@ def get_workspace_models_configs(
# 查询工作空间模型配置
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
if configs is None:
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
raise BusinessException(
@@ -801,4 +798,5 @@ def get_workspace_models_configs(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return configs
return configs