[MODIFY] Code optimization

This commit is contained in:
Mark
2025-12-15 14:09:43 +08:00
parent d2a630addb
commit a4e276ab27
157 changed files with 15976 additions and 3601 deletions

View File

@@ -5,9 +5,10 @@ import uuid
import hashlib
import time
import jwt
from app.services import task_service, workspace_service
from typing import Optional, Dict
from functools import wraps
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
from app.db import get_db
from app.core.response_utils import success
from app.core.logging_config import get_business_logger
@@ -21,8 +22,10 @@ from app.services.shared_chat_service import SharedChatService
from app.services.conversation_service import ConversationService
from app.services.auth_service import create_access_token
from app.dependencies import get_share_user_id, ShareTokenData
from app.models.user_model import User
from app.repositories.app_repository import AppRepository
from app.repositories.workspace_repository import WorkspaceRepository
from app.repositories import knowledge_repository
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -95,7 +98,7 @@ def get_access_token(
access_token = create_access_token(user_id, share_token)
logger.info(
f"生成访问 token",
"生成访问 token",
extra={
"share_token": share_token,
"user_id": user_id
@@ -270,7 +273,7 @@ def get_conversation(
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
db: Session = Depends(get_db)
):
"""发送消息并获取回复
@@ -313,6 +316,45 @@ async def chat(
original_user_id=user_id # Save original user_id to other_id
)
appid=share.app_id
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app
from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first()
if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
workspace_id = app.workspace_id
# 直接从 workspace 获取 storage_type公开分享场景无需权限检查
storage_type = workspace_service.get_workspace_storage_type_without_auth(
db=db,
workspace_id=workspace_id
)
if storage_type is None:
storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
# 获取应用类型
app_type = release.app.type if release.app else None
@@ -339,7 +381,7 @@ async def chat(
)
logger.debug(
f"参数验证完成",
"参数验证完成",
extra={
"share_token": share_token,
"app_type": app_type,
@@ -365,7 +407,9 @@ async def chat(
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -388,7 +432,9 @@ async def chat(
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.MULTI_AGENT:
@@ -403,7 +449,9 @@ async def chat(
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -426,7 +474,9 @@ async def chat(
variables=payload.variables,
password=password,
web_search=payload.web_search,
memory=payload.memory
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
return success(data=conversation_schema.ChatResponse(**result))