feat(app): add API to retrieve app configuration fields

This commit is contained in:
Eternity
2026-03-03 10:27:01 +08:00
parent 7cec966979
commit 07fea23dd0
3 changed files with 243 additions and 214 deletions

View File

@@ -2,25 +2,32 @@ import hashlib
import json import json
import uuid import uuid
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.response_utils import success from app.core.response_utils import success, fail
from app.db import get_db, get_db_read from app.db import get_db, get_db_read
from app.dependencies import get_share_user_id, ShareTokenData from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import WorkflowConfigRepository from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import release_share_schema, conversation_schema from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.auth_service import create_access_token from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \ from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"]) router = APIRouter(prefix="/public/share", tags=["Public Share"])
@@ -206,15 +213,13 @@ def list_conversations(
logger.debug(f"share_data:{share_data.user_id}") logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id other_id = share_data.user_id
service = SharedChatService(db) service = SharedChatService(db)
share, release = service._get_release_by_share_token(share_data.share_token, password) share, release = service.get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, app_id=share.app_id,
other_id=other_id other_id=other_id
) )
logger.debug(new_end_user.id) logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations( conversations, total = service.list_conversations(
share_token=share_data.share_token, share_token=share_data.share_token,
user_id=str(new_end_user.id), user_id=str(new_end_user.id),
@@ -293,19 +298,15 @@ async def chat(
# 提前验证和准备(在流式响应开始前完成) # 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错 # 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try: try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码 # 验证分享链接和密码
share, release = service._get_release_by_share_token(share_token, password) share, release = service.get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id # # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}" # end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id # Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id, app_id=share.app_id,
@@ -318,7 +319,6 @@ async def chat(
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app仅查询未删除的应用 # 直接通过 SQLAlchemy 查询 app仅查询未删除的应用
from app.models.app_model import App
app = db.query(App).filter( app = db.query(App).filter(
App.id == appid, App.id == appid,
App.is_active.is_(True) App.is_active.is_(True)
@@ -359,12 +359,12 @@ async def chat(
app_type = release.app.type if release.app else None app_type = release.app.type if release.app else None
# 根据应用类型验证配置 # 根据应用类型验证配置
if app_type == "agent": if app_type == AppType.AGENT:
# Agent 类型:验证模型配置 # Agent 类型:验证模型配置
model_config_id = release.default_model_config_id model_config_id = release.default_model_config_id
if not model_config_id: if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == "multi_agent": elif app_type == AppType.MULTI_AGENT:
# Multi-Agent 类型:验证多 Agent 配置 # Multi-Agent 类型:验证多 Agent 配置
config = release.config or {} config = release.config or {}
if not config.get("sub_agents"): if not config.get("sub_agents"):
@@ -638,6 +638,34 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) # return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else: else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@router.get("/config", summary="获取应用启动配置")
async def config_query(
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
share_service = SharedChatService(db)
share_token = share_data.share_token
share, release = share_service.get_release_by_share_token(share_token, password)
if release.app.type == AppType.WORKFLOW:
workflow_service = WorkflowService(db)
content = {
"app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config)
}
elif release.app.type == AppType.AGENT:
content = {
"app_type": release.app.type,
"variables": release.config.get("variables")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {
"app_type": release.app.type,
"variables": []
}
else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
return success(data=content)

View File

@@ -21,63 +21,64 @@ from app.repositories import knowledge_repository
import json import json
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
logger = get_business_logger() logger = get_business_logger()
class SharedChatService: class SharedChatService:
"""基于分享链接的聊天服务""" """基于分享链接的聊天服务"""
def __init__(self, db: Session): def __init__(self, db: Session):
self.db = db self.db = db
self.conversation_service = ConversationService(db) self.conversation_service = ConversationService(db)
self.share_service = ReleaseShareService(db) self.share_service = ReleaseShareService(db)
def _get_release_by_share_token( def get_release_by_share_token(
self, self,
share_token: str, share_token: str,
password: Optional[str] = None password: Optional[str] = None
) -> tuple[ReleaseShare, AppRelease]: ) -> tuple[ReleaseShare, AppRelease]:
"""通过 share_token 获取发布版本""" """通过 share_token 获取发布版本"""
# 获取分享配置 # 获取分享配置
share = self.share_service.repo.get_by_share_token(share_token) share = self.share_service.repo.get_by_share_token(share_token)
if not share: if not share:
raise ResourceNotFoundException("分享链接", share_token) raise ResourceNotFoundException("分享链接", share_token)
# 验证分享是否启用 # 验证分享是否启用
if not share.is_enabled: if not share.is_enabled:
raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED) raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED)
# 验证密码 # 验证密码
if share.require_password: if share.require_password:
if not password: if not password:
raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED) raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED)
if not self.share_service.verify_password(share_token, password): if not self.share_service.verify_password(share_token, password):
raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD) raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD)
# 获取发布版本 # 获取发布版本
release = self.db.get(AppRelease, share.release_id) release = self.db.get(AppRelease, share.release_id)
if not release: if not release:
raise ResourceNotFoundException("发布版本", str(share.release_id)) raise ResourceNotFoundException("发布版本", str(share.release_id))
# 更新访问统计 # 更新访问统计
try: try:
self.share_service.repo.increment_view_count(share.id) self.share_service.repo.increment_view_count(share.id)
except Exception as e: except Exception as e:
logger.warning(f"更新访问统计失败: {str(e)}") logger.warning(f"更新访问统计失败: {str(e)}")
return share, release return share, release
def create_or_get_conversation( def create_or_get_conversation(
self, self,
share_token: str, share_token: str,
conversation_id: Optional[uuid.UUID] = None, conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
password: Optional[str] = None password: Optional[str] = None
) -> Conversation: ) -> Conversation:
"""创建或获取会话""" """创建或获取会话"""
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 如果提供了 conversation_id尝试获取现有会话 # 如果提供了 conversation_id尝试获取现有会话
if conversation_id: if conversation_id:
try: try:
@@ -85,18 +86,18 @@ class SharedChatService:
conversation_id=conversation_id, conversation_id=conversation_id,
workspace_id=release.app.workspace_id workspace_id=release.app.workspace_id
) )
# 验证会话是否属于该应用 # 验证会话是否属于该应用
if conversation.app_id != release.app_id: if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation return conversation
except ResourceNotFoundException: except ResourceNotFoundException:
logger.warning( logger.warning(
"会话不存在,将创建新会话", "会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)} extra={"conversation_id": str(conversation_id)}
) )
# 创建新会话(使用发布版本的配置) # 创建新会话(使用发布版本的配置)
conversation = self.conversation_service.create_conversation( conversation = self.conversation_service.create_conversation(
app_id=release.app_id, app_id=release.app_id,
@@ -105,7 +106,7 @@ class SharedChatService:
is_draft=False, # 分享链接使用发布版本 is_draft=False, # 分享链接使用发布版本
config_snapshot=release.config config_snapshot=release.config
) )
logger.info( logger.info(
"为分享链接创建新会话", "为分享链接创建新会话",
extra={ extra={
@@ -114,25 +115,25 @@ class SharedChatService:
"release_id": str(release.id) "release_id": str(release.id)
} }
) )
return conversation return conversation
async def chat( async def chat(
self, self,
share_token: str, share_token: str,
message: str, message: str,
conversation_id: Optional[uuid.UUID] = None, conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None, password: Optional[str] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
actual_config_id = None actual_config_id = None
config_id=actual_config_id config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_parameter_merger import ModelParameterMerger
@@ -140,32 +141,30 @@ class SharedChatService:
from sqlalchemy import select from sqlalchemy import select
from app.models import ModelApiKey from app.models import ModelApiKey
start_time = time.time() start_time = time.time()
actual_config_id=None actual_config_id = None
config_id=actual_config_id config_id = actual_config_id
if variables is None: if variables is None:
variables = {} variables = {}
# 获取发布版本和配置 # 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 获取 Agent 配置 # 获取 Agent 配置
config = release.config or {} config = release.config or {}
# 获取模型配置ID # 获取模型配置ID
model_config_id = release.default_model_config_id model_config_id = release.default_model_config_id
if not model_config_id: if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置 # 获取模型配置
from app.models import ModelConfig from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id) model_config = self.db.get(ModelConfig, model_config_id)
if not model_config: if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id)) raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key # 获取 API Key
# stmt = ( # stmt = (
# select(ModelApiKey).join( # select(ModelApiKey).join(
@@ -184,7 +183,7 @@ class SharedChatService:
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj: if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话 # 获取或创建会话
conversation = self.create_or_get_conversation( conversation = self.create_or_get_conversation(
share_token=share_token, share_token=share_token,
@@ -192,7 +191,7 @@ class SharedChatService:
user_id=user_id, user_id=user_id,
password=password password=password
) )
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手") system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables: if variables:
@@ -202,31 +201,31 @@ class SharedChatService:
variables variables
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表 # 准备工具列表
tools = [] tools = []
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval") knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval: if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids: if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool) tools.append(kb_tool)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag=False memory_flag = False
if memory: if memory:
memory_config = config.get("memory", {}) memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id: if memory_config.get("enabled") and user_id:
memory_flag=True memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
web_tools=config.get("tools") web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {}) web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled",False) web_search_enable = web_search_choice.get("enabled", False)
if web_search: if web_search:
if web_search_enable: if web_search_enable:
search_tool = create_web_search_tool({}) search_tool = create_web_search_tool({})
@@ -238,10 +237,10 @@ class SharedChatService:
"tool_count": len(tools) "tool_count": len(tools)
} }
) )
# 获取模型参数 # 获取模型参数
model_parameters = config.get("model_parameters", {}) model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent # 创建 LangChain Agent
agent = LangChainAgent( agent = LangChainAgent(
model_name=api_key_obj.model_name, model_name=api_key_obj.model_name,
@@ -254,10 +253,10 @@ class SharedChatService:
tools=tools, tools=tools,
) )
# 加载历史消息 # 加载历史消息
history = [] history = []
memory_config={"enabled":True,'max_history':10} memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"): if memory_config.get("enabled"):
messages = self.conversation_service.get_messages( messages = self.conversation_service.get_messages(
conversation_id=conversation.id, conversation_id=conversation.id,
@@ -267,7 +266,7 @@ class SharedChatService:
{"role": msg.role, "content": msg.content} {"role": msg.role, "content": msg.content}
for msg in messages for msg in messages
] ]
# 调用 Agent # 调用 Agent
result = await agent.chat( result = await agent.chat(
message=message, message=message,
@@ -279,7 +278,7 @@ class SharedChatService:
config_id=config_id, config_id=config_id,
memory_flag=memory_flag memory_flag=memory_flag
) )
# 保存消息 # 保存消息
self.conversation_service.save_conversation_messages( self.conversation_service.save_conversation_messages(
conversation_id=conversation.id, conversation_id=conversation.id,
@@ -298,7 +297,7 @@ class SharedChatService:
# role="user", # role="user",
# content=message # content=message
# ) # )
# self.conversation_service.add_message( # self.conversation_service.add_message(
# conversation_id=conversation.id, # conversation_id=conversation.id,
# role="assistant", # role="assistant",
@@ -308,12 +307,11 @@ class SharedChatService:
# "usage": result.get("usage", {}) # "usage": result.get("usage", {})
# } # }
# ) # )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
return { return {
"conversation_id": conversation.id, "conversation_id": conversation.id,
"message": result["content"], "message": result["content"],
@@ -324,19 +322,19 @@ class SharedChatService:
}), }),
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
async def chat_stream( async def chat_stream(
self, self,
share_token: str, share_token: str,
message: str, message: str,
conversation_id: Optional[uuid.UUID] = None, conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None, password: Optional[str] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type:Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
@@ -345,36 +343,35 @@ class SharedChatService:
from sqlalchemy import select from sqlalchemy import select
from app.models import ModelApiKey from app.models import ModelApiKey
import json import json
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None: if variables is None:
variables = {} variables = {}
# 兼容新旧字段名:使用 memory_config_id # 兼容新旧字段名:使用 memory_config_id
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10} memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
try: try:
# 获取发布版本和配置 # 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 获取 Agent 配置 # 获取 Agent 配置
config = release.config or {} config = release.config or {}
agent_config_data = config.get("agent_config", {}) agent_config_data = config.get("agent_config", {})
# 获取模型配置ID # 获取模型配置ID
model_config_id = release.default_model_config_id model_config_id = release.default_model_config_id
if not model_config_id: if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置 # 获取模型配置
from app.models import ModelConfig from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id) model_config = self.db.get(ModelConfig, model_config_id)
if not model_config: if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id)) raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key # 获取 API Key
# stmt = ( # stmt = (
# select(ModelApiKey).join( # select(ModelApiKey).join(
@@ -393,7 +390,7 @@ class SharedChatService:
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj: if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话 # 获取或创建会话
conversation = self.create_or_get_conversation( conversation = self.create_or_get_conversation(
share_token=share_token, share_token=share_token,
@@ -401,7 +398,7 @@ class SharedChatService:
user_id=user_id, user_id=user_id,
password=password password=password
) )
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手") system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables: if variables:
@@ -411,21 +408,21 @@ class SharedChatService:
variables variables
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表 # 准备工具列表
tools = [] tools = []
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval") knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval: if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids: if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool) tools.append(kb_tool)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag=False memory_flag = False
if memory: if memory:
memory_config = config.get("memory", {}) memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id: if memory_config.get("enabled") and user_id:
@@ -450,7 +447,7 @@ class SharedChatService:
# 获取模型参数 # 获取模型参数
model_parameters = config.get("model_parameters", {}) model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent # 创建 LangChain Agent
agent = LangChainAgent( agent = LangChainAgent(
model_name=api_key_obj.model_name, model_name=api_key_obj.model_name,
@@ -463,7 +460,7 @@ class SharedChatService:
tools=tools, tools=tools,
streaming=True streaming=True
) )
# 加载历史消息 # 加载历史消息
history = [] history = []
memory_config = {"enabled": True, 'max_history': 10} memory_config = {"enabled": True, 'max_history': 10}
@@ -476,22 +473,22 @@ class SharedChatService:
{"role": msg.role, "content": msg.content} {"role": msg.role, "content": msg.content}
for msg in messages for msg in messages
] ]
# 发送开始事件 # 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent # 流式调用 Agent
full_content = "" full_content = ""
total_tokens = 0 total_tokens = 0
async for chunk in agent.chat_stream( async for chunk in agent.chat_stream(
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id, end_user_id=user_id,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id, user_rag_memory_id=user_rag_memory_id,
config_id=config_id, config_id=config_id,
memory_flag=memory_flag memory_flag=memory_flag
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
total_tokens = chunk total_tokens = chunk
@@ -499,16 +496,16 @@ class SharedChatService:
full_content += chunk full_content += chunk
# 发送消息块事件 # 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 保存消息 # 保存消息
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="user", role="user",
content=message content=message
) )
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="assistant", role="assistant",
@@ -524,7 +521,7 @@ class SharedChatService:
# 发送结束事件 # 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info( logger.info(
"流式聊天完成", "流式聊天完成",
extra={ extra={
@@ -533,7 +530,7 @@ class SharedChatService:
"message_length": len(full_content) "message_length": len(full_content)
} }
) )
except (GeneratorExit, asyncio.CancelledError): except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出 # 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断") logger.debug("流式聊天被中断")
@@ -542,39 +539,39 @@ class SharedChatService:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True) logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件 # 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_conversation_messages( def get_conversation_messages(
self, self,
share_token: str, share_token: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
password: Optional[str] = None password: Optional[str] = None
) -> Conversation: ) -> Conversation:
"""获取会话消息""" """获取会话消息"""
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 获取会话 # 获取会话
conversation = self.conversation_service.get_conversation( conversation = self.conversation_service.get_conversation(
conversation_id=conversation_id, conversation_id=conversation_id,
workspace_id=release.app.workspace_id workspace_id=release.app.workspace_id
) )
# 验证会话是否属于该应用 # 验证会话是否属于该应用
if conversation.app_id != release.app_id: if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation return conversation
def list_conversations( def list_conversations(
self, self,
share_token: str, share_token: str,
user_id: Optional[str] = None, user_id: Optional[str] = None,
password: Optional[str] = None, password: Optional[str] = None,
page: int = 1, page: int = 1,
pagesize: int = 20 pagesize: int = 20
) -> tuple[list[Conversation], int]: ) -> tuple[list[Conversation], int]:
"""列出会话""" """列出会话"""
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
conversations, total = self.conversation_service.list_conversations( conversations, total = self.conversation_service.list_conversations(
app_id=release.app_id, app_id=release.app_id,
workspace_id=release.app.workspace_id, workspace_id=release.app.workspace_id,
@@ -583,19 +580,19 @@ class SharedChatService:
page=page, page=page,
pagesize=pagesize pagesize=pagesize
) )
return conversations, total return conversations, total
async def multi_agent_chat( async def multi_agent_chat(
self, self,
share_token: str, share_token: str,
message: str, message: str,
conversation_id: Optional[uuid.UUID] = None, conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None, password: Optional[str] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@@ -603,18 +600,16 @@ class SharedChatService:
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig from app.models import MultiAgentConfig
start_time = time.time() start_time = time.time()
actual_config_id=None actual_config_id = None
config_id=actual_config_id config_id = actual_config_id
if variables is None: if variables is None:
variables = {} variables = {}
# 获取发布版本和配置 # 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 获取或创建会话 # 获取或创建会话
conversation = self.create_or_get_conversation( conversation = self.create_or_get_conversation(
share_token=share_token, share_token=share_token,
@@ -622,19 +617,19 @@ class SharedChatService:
user_id=user_id, user_id=user_id,
password=password password=password
) )
# 获取多 Agent 配置 # 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter( multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active.is_(True) MultiAgentConfig.is_active.is_(True)
).first() ).first()
if not multi_agent_config: if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 构建多 Agent 运行请求 # 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest( multi_agent_request = MultiAgentRunRequest(
message=message, message=message,
conversation_id=conversation.id, conversation_id=conversation.id,
@@ -644,23 +639,23 @@ class SharedChatService:
web_search=web_search, web_search=web_search,
memory=memory memory=memory
) )
# 使用多 Agent 服务执行 # 使用多 Agent 服务执行
multi_agent_service = MultiAgentService(self.db) multi_agent_service = MultiAgentService(self.db)
result = await multi_agent_service.run( result = await multi_agent_service.run(
app_id=release.app_id, app_id=release.app_id,
request=multi_agent_request request=multi_agent_request
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 保存消息 # 保存消息
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="user", role="user",
content=message content=message
) )
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="assistant", role="assistant",
@@ -672,8 +667,6 @@ class SharedChatService:
} }
) )
return { return {
"conversation_id": conversation.id, "conversation_id": conversation.id,
"message": result.get("message", ""), "message": result.get("message", ""),
@@ -684,34 +677,33 @@ class SharedChatService:
}, },
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
async def multi_agent_chat_stream( async def multi_agent_chat_stream(
self, self,
share_token: str, share_token: str,
message: str, message: str,
conversation_id: Optional[uuid.UUID] = None, conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None, password: Optional[str] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id:Optional[str] = None user_rag_memory_id: Optional[str] = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)""" """多 Agent 聊天(流式)"""
start_time = time.time() start_time = time.time()
actual_config_id=None actual_config_id = None
config_id=actual_config_id config_id = actual_config_id
if variables is None: if variables is None:
variables = {} variables = {}
try: try:
# 获取发布版本和配置 # 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password) share, release = self.get_release_by_share_token(share_token, password)
# 获取或创建会话 # 获取或创建会话
conversation = self.create_or_get_conversation( conversation = self.create_or_get_conversation(
share_token=share_token, share_token=share_token,
@@ -719,28 +711,28 @@ class SharedChatService:
user_id=user_id, user_id=user_id,
password=password password=password
) )
# 获取多 Agent 配置 # 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter( multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active.is_(True) MultiAgentConfig.is_active.is_(True)
).first() ).first()
if not multi_agent_config: if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 获取 storage_type 和 user_rag_memory_id # 获取 storage_type 和 user_rag_memory_id
workspace_id = release.app.workspace_id workspace_id = release.app.workspace_id
storage_type = 'neo4j' # 默认值 storage_type = 'neo4j' # 默认值
user_rag_memory_id = '' user_rag_memory_id = ''
try: try:
# 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享) # 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享)
from app.models import Workspace from app.models import Workspace
workspace = self.db.get(Workspace, workspace_id) workspace = self.db.get(Workspace, workspace_id)
if workspace and workspace.storage_type: if workspace and workspace.storage_type:
storage_type = workspace.storage_type storage_type = workspace.storage_type
# 获取 USER_RAG_MERORY 知识库 ID # 获取 USER_RAG_MERORY 知识库 ID
knowledge = knowledge_repository.get_knowledge_by_name( knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db, db=self.db,
@@ -751,13 +743,13 @@ class SharedChatService:
user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
except Exception as e: except Exception as e:
logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}") logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}")
# 发送开始事件 # 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 构建多 Agent 运行请求 # 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest( multi_agent_request = MultiAgentRunRequest(
message=message, message=message,
conversation_id=conversation.id, conversation_id=conversation.id,
@@ -767,20 +759,20 @@ class SharedChatService:
web_search=web_search, web_search=web_search,
memory=memory memory=memory
) )
# 使用多 Agent 服务流式执行 # 使用多 Agent 服务流式执行
multi_agent_service = MultiAgentService(self.db) multi_agent_service = MultiAgentService(self.db)
full_content = "" full_content = ""
async for event in multi_agent_service.run_stream( async for event in multi_agent_service.run_stream(
app_id=release.app_id, app_id=release.app_id,
request=multi_agent_request, request=multi_agent_request,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
# 直接转发事件 # 直接转发事件
yield event yield event
# 尝试提取内容(用于保存) # 尝试提取内容(用于保存)
if "data:" in event: if "data:" in event:
try: try:
@@ -790,16 +782,16 @@ class SharedChatService:
full_content += data["content"] full_content += data["content"]
except: except:
pass pass
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 保存消息 # 保存消息
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="user", role="user",
content=message content=message
) )
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation.id, conversation_id=conversation.id,
role="assistant", role="assistant",
@@ -808,7 +800,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
) )
logger.info( logger.info(
"多 Agent 流式聊天完成", "多 Agent 流式聊天完成",
extra={ extra={
@@ -818,7 +810,6 @@ class SharedChatService:
} }
) )
except (GeneratorExit, asyncio.CancelledError): except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出 # 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断") logger.debug("多 Agent 流式聊天被中断")

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config from app.core.workflow.validator import validate_workflow_config
from app.db import get_db from app.db import get_db
from app.models import App from app.models import App
@@ -617,7 +618,8 @@ class WorkflowService:
"event": "end", "event": "end",
"data": { "data": {
"elapsed_time": payload.get("elapsed_time"), "elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")) "message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
} }
} }
case "node_start" | "node_end" | "node_error" | "cycle_item": case "node_start" | "node_end" | "node_error" | "cycle_item":
@@ -779,6 +781,14 @@ class WorkflowService:
} }
} }
@staticmethod
def get_start_node_variables(config: dict) -> list:
nodes = config.get("nodes", [])
for node in nodes:
if node.get("type") == NodeType.START:
return node.get("config", {}).get("variables", [])
raise BusinessException("workflow config error - start node not found")
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象 """清理事件数据,移除不可序列化的对象