[modify] share chat

This commit is contained in:
Mark
2026-01-04 20:51:37 +08:00
parent 3a3cd59d8e
commit 3fe2ef6611
2 changed files with 104 additions and 25 deletions

View File

@@ -1,6 +1,6 @@
import hashlib import hashlib
import uuid import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -17,6 +17,8 @@ from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"]) router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger() logger = get_business_logger()
@@ -265,7 +267,8 @@ def get_conversation(
async def chat( async def chat(
payload: conversation_schema.ChatRequest, payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id), share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db) db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
): ):
"""发送消息并获取回复 """发送消息并获取回复
@@ -285,7 +288,7 @@ async def chat(
password = None # Token 认证不需要密码 password = None # Token 认证不需要密码
# end_user_id = user_id # end_user_id = user_id
other_id = user_id other_id = user_id
# 提前验证和准备(在流式响应开始前完成) # 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错 # 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType from app.models.app_model import AppType
@@ -307,7 +310,7 @@ async def chat(
other_id=other_id, other_id=other_id,
original_user_id=user_id # Save original user_id to other_id original_user_id=user_id # Save original user_id to other_id
) )
end_user_id = str(new_end_user.id)
appid=share.app_id appid=share.app_id
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
@@ -390,15 +393,38 @@ async def chat(
if app_type == AppType.AGENT: if app_type == AppType.AGENT:
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator(): async def event_generator():
async for event in service.chat_stream( async for event in app_chat_service.agnet_chat_stream(
share_token=share_token,
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串 user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables, variables=payload.variables,
password=password,
web_search=payload.web_search, web_search=payload.web_search,
config=payload.agent_config,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
@@ -414,32 +440,43 @@ async def chat(
"X-Accel-Buffering": "no" "X-Accel-Buffering": "no"
} }
) )
# 非流式返回 # 非流式返回
result = await service.chat( # result = await service.chat(
share_token=share_token, # share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串 user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables, variables=payload.variables,
password=password, config= payload.agent_config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
return success(data=conversation_schema.ChatResponse(**result)) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT: elif app_type == AppType.MULTI_AGENT:
# 多 Agent 流式返回 config = multi_agent_config_4_app_release(release)
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():
async for event in service.multi_agent_chat_stream( async for event in app_chat_service.multi_agent_chat_stream(
share_token=share_token,
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串 user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables, variables=payload.variables,
password=password, config=config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
@@ -458,20 +495,62 @@ async def chat(
) )
# 多 Agent 非流式返回 # 多 Agent 非流式返回
result = await service.multi_agent_chat( result = await app_chat_service.multi_agent_chat(
share_token=share_token,
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串 user_id=end_user_id, # 转换为字符串
variables=payload.variables, variables=payload.variables,
password=password, config=config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
) )
return success(data=conversation_schema.ChatResponse(**result)) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
else: else:
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode

View File

@@ -137,10 +137,10 @@ async def chat(
if app_type == AppType.AGENT: if app_type == AppType.AGENT:
print("="*50) # print("="*50)
print(app.current_release.default_model_config_id) # print(app.current_release.default_model_config_id)
agent_config = agent_config_4_app_release(app.current_release) agent_config = agent_config_4_app_release(app.current_release)
print(agent_config.default_model_config_id) # print(agent_config.default_model_config_id)
# 流式返回 # 流式返回
if payload.stream: if payload.stream:
async def event_generator(): async def event_generator():