diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index e2e5a250..a7a6203d 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -1,6 +1,6 @@ import hashlib import uuid - +from typing import Annotated from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -17,6 +17,8 @@ from app.services.auth_service import create_access_token from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService +from app.services.app_chat_service import AppChatService, get_app_chat_service +from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) logger = get_business_logger() @@ -265,7 +267,8 @@ def get_conversation( async def chat( payload: conversation_schema.ChatRequest, share_data: ShareTokenData = Depends(get_share_user_id), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None, ): """发送消息并获取回复 @@ -285,7 +288,7 @@ async def chat( password = None # Token 认证不需要密码 # end_user_id = user_id other_id = user_id - + # 提前验证和准备(在流式响应开始前完成) # 这样可以确保错误能正确返回,而不是在流式响应中间出错 from app.models.app_model import AppType @@ -307,7 +310,7 @@ async def chat( other_id=other_id, original_user_id=user_id # Save original user_id to other_id ) - + end_user_id = str(new_end_user.id) appid=share.app_id """获取存储类型和工作空间的ID""" @@ -390,15 +393,38 @@ async def chat( if app_type == AppType.AGENT: # 流式返回 if payload.stream: + # async def event_generator(): + # async for event in service.chat_stream( + # share_token=share_token, + # message=payload.message, + # conversation_id=conversation.id, # 使用已创建的会话 ID + # user_id=str(new_end_user.id), # 转换为字符串 + # variables=payload.variables, + # password=password, + # web_search=payload.web_search, + # memory=payload.memory, + # storage_type=storage_type, + # user_rag_memory_id=user_rag_memory_id + # ): + # yield event + + # return StreamingResponse( + # event_generator(), + # media_type="text/event-stream", + # headers={ + # "Cache-Control": "no-cache", + # "Connection": "keep-alive", + # "X-Accel-Buffering": "no" + # } + # ) async def event_generator(): - async for event in service.chat_stream( - share_token=share_token, + async for event in app_chat_service.agnet_chat_stream( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 + user_id= str(new_end_user.id), # 转换为字符串 variables=payload.variables, - password=password, web_search=payload.web_search, + config=payload.agent_config, memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id @@ -414,32 +440,43 @@ async def chat( "X-Accel-Buffering": "no" } ) - # 非流式返回 - result = await service.chat( - share_token=share_token, + # result = await service.chat( + # share_token=share_token, + # message=payload.message, + # conversation_id=conversation.id, # 使用已创建的会话 ID + # user_id=str(new_end_user.id), # 转换为字符串 + # variables=payload.variables, + # password=password, + # web_search=payload.web_search, + # memory=payload.memory, + # storage_type=storage_type, + # user_rag_memory_id=user_rag_memory_id + # ) + # return success(data=conversation_schema.ChatResponse(**result)) + result = await app_chat_service.agnet_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID user_id=str(new_end_user.id), # 转换为字符串 variables=payload.variables, - password=password, + config= payload.agent_config, web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ) - return success(data=conversation_schema.ChatResponse(**result)) + return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: - # 多 Agent 流式返回 + config = multi_agent_config_4_app_release(release) if payload.stream: async def event_generator(): - async for event in service.multi_agent_chat_stream( - share_token=share_token, + async for event in app_chat_service.multi_agent_chat_stream( + message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID user_id=str(new_end_user.id), # 转换为字符串 variables=payload.variables, - password=password, + config=config, web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, @@ -458,20 +495,62 @@ async def chat( ) # 多 Agent 非流式返回 - result = await service.multi_agent_chat( - share_token=share_token, + result = await app_chat_service.multi_agent_chat( + message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID - user_id=str(new_end_user.id), # 转换为字符串 + user_id=end_user_id, # 转换为字符串 variables=payload.variables, - password=password, + config=config, web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ) - return success(data=conversation_schema.ChatResponse(**result)) + return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) + # 多 Agent 流式返回 + # if payload.stream: + # async def event_generator(): + # async for event in service.multi_agent_chat_stream( + # share_token=share_token, + # message=payload.message, + # conversation_id=conversation.id, # 使用已创建的会话 ID + # user_id=str(new_end_user.id), # 转换为字符串 + # variables=payload.variables, + # password=password, + # web_search=payload.web_search, + # memory=payload.memory, + # storage_type=storage_type, + # user_rag_memory_id=user_rag_memory_id + # ): + # yield event + + # return StreamingResponse( + # event_generator(), + # media_type="text/event-stream", + # headers={ + # "Cache-Control": "no-cache", + # "Connection": "keep-alive", + # "X-Accel-Buffering": "no" + # } + # ) + + # # 多 Agent 非流式返回 + # result = await service.multi_agent_chat( + # share_token=share_token, + # message=payload.message, + # conversation_id=conversation.id, # 使用已创建的会话 ID + # user_id=str(new_end_user.id), # 转换为字符串 + # variables=payload.variables, + # password=password, + # web_search=payload.web_search, + # memory=payload.memory, + # storage_type=storage_type, + # user_rag_memory_id=user_rag_memory_id + # ) + + # return success(data=conversation_schema.ChatResponse(**result)) else: from app.core.exceptions import BusinessException from app.core.error_codes import BizCode diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index e190115a..5a78a28b 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -137,10 +137,10 @@ async def chat( if app_type == AppType.AGENT: - print("="*50) - print(app.current_release.default_model_config_id) + # print("="*50) + # print(app.current_release.default_model_config_id) agent_config = agent_config_4_app_release(app.current_release) - print(agent_config.default_model_config_id) + # print(agent_config.default_model_config_id) # 流式返回 if payload.stream: async def event_generator():