[modify] share chat
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, Query, Request
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -17,6 +17,8 @@ from app.services.auth_service import create_access_token
|
|||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
|
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -265,7 +267,8 @@ def get_conversation(
|
|||||||
async def chat(
|
async def chat(
|
||||||
payload: conversation_schema.ChatRequest,
|
payload: conversation_schema.ChatRequest,
|
||||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||||
):
|
):
|
||||||
"""发送消息并获取回复
|
"""发送消息并获取回复
|
||||||
|
|
||||||
@@ -307,7 +310,7 @@ async def chat(
|
|||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id # Save original user_id to other_id
|
original_user_id=user_id # Save original user_id to other_id
|
||||||
)
|
)
|
||||||
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
appid=share.app_id
|
appid=share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
@@ -390,15 +393,38 @@ async def chat(
|
|||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
|
# async def event_generator():
|
||||||
|
# async for event in service.chat_stream(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# ):
|
||||||
|
# yield event
|
||||||
|
|
||||||
|
# return StreamingResponse(
|
||||||
|
# event_generator(),
|
||||||
|
# media_type="text/event-stream",
|
||||||
|
# headers={
|
||||||
|
# "Cache-Control": "no-cache",
|
||||||
|
# "Connection": "keep-alive",
|
||||||
|
# "X-Accel-Buffering": "no"
|
||||||
|
# }
|
||||||
|
# )
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in service.chat_stream(
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
share_token=share_token,
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=str(new_end_user.id), # 转换为字符串
|
user_id= str(new_end_user.id), # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
password=password,
|
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
|
config=payload.agent_config,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
@@ -414,32 +440,43 @@ async def chat(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 非流式返回
|
# 非流式返回
|
||||||
result = await service.chat(
|
# result = await service.chat(
|
||||||
share_token=share_token,
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# )
|
||||||
|
# return success(data=conversation_schema.ChatResponse(**result))
|
||||||
|
result = await app_chat_service.agnet_chat(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=str(new_end_user.id), # 转换为字符串
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
password=password,
|
config= payload.agent_config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
# 多 Agent 流式返回
|
config = multi_agent_config_4_app_release(release)
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in service.multi_agent_chat_stream(
|
async for event in app_chat_service.multi_agent_chat_stream(
|
||||||
share_token=share_token,
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=str(new_end_user.id), # 转换为字符串
|
user_id=str(new_end_user.id), # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
password=password,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
@@ -458,20 +495,62 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await service.multi_agent_chat(
|
result = await app_chat_service.multi_agent_chat(
|
||||||
share_token=share_token,
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=str(new_end_user.id), # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
password=password,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
|
# 多 Agent 流式返回
|
||||||
|
# if payload.stream:
|
||||||
|
# async def event_generator():
|
||||||
|
# async for event in service.multi_agent_chat_stream(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# ):
|
||||||
|
# yield event
|
||||||
|
|
||||||
|
# return StreamingResponse(
|
||||||
|
# event_generator(),
|
||||||
|
# media_type="text/event-stream",
|
||||||
|
# headers={
|
||||||
|
# "Cache-Control": "no-cache",
|
||||||
|
# "Connection": "keep-alive",
|
||||||
|
# "X-Accel-Buffering": "no"
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # 多 Agent 非流式返回
|
||||||
|
# result = await service.multi_agent_chat(
|
||||||
|
# share_token=share_token,
|
||||||
|
# message=payload.message,
|
||||||
|
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
|
# user_id=str(new_end_user.id), # 转换为字符串
|
||||||
|
# variables=payload.variables,
|
||||||
|
# password=password,
|
||||||
|
# web_search=payload.web_search,
|
||||||
|
# memory=payload.memory,
|
||||||
|
# storage_type=storage_type,
|
||||||
|
# user_rag_memory_id=user_rag_memory_id
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return success(data=conversation_schema.ChatResponse(**result))
|
||||||
else:
|
else:
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
|
|||||||
@@ -137,10 +137,10 @@ async def chat(
|
|||||||
|
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
|
|
||||||
print("="*50)
|
# print("="*50)
|
||||||
print(app.current_release.default_model_config_id)
|
# print(app.current_release.default_model_config_id)
|
||||||
agent_config = agent_config_4_app_release(app.current_release)
|
agent_config = agent_config_4_app_release(app.current_release)
|
||||||
print(agent_config.default_model_config_id)
|
# print(agent_config.default_model_config_id)
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
|
|||||||
Reference in New Issue
Block a user