[MODIFY] Code optimization
This commit is contained in:
@@ -5,9 +5,10 @@ import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import jwt
|
||||
from app.services import task_service, workspace_service
|
||||
from typing import Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -21,8 +22,10 @@ from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.repositories.app_repository import AppRepository
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -95,7 +98,7 @@ def get_access_token(
|
||||
access_token = create_access_token(user_id, share_token)
|
||||
|
||||
logger.info(
|
||||
f"生成访问 token",
|
||||
"生成访问 token",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"user_id": user_id
|
||||
@@ -270,7 +273,7 @@ def get_conversation(
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -313,6 +316,45 @@ async def chat(
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
|
||||
|
||||
appid=share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
# 获取应用类型
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
@@ -339,7 +381,7 @@ async def chat(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"参数验证完成",
|
||||
"参数验证完成",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"app_type": app_type,
|
||||
@@ -365,7 +407,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -388,7 +432,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -403,7 +449,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -426,7 +474,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
|
||||
Reference in New Issue
Block a user