[fix] public share agent run
This commit is contained in:
@@ -392,7 +392,7 @@ async def chat(
|
|||||||
|
|
||||||
if app_type == AppType.AGENT:
|
if app_type == AppType.AGENT:
|
||||||
# 流式返回
|
# 流式返回
|
||||||
agent_config = agent_config_4_app_release(app.current_release)
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
# async def event_generator():
|
# async def event_generator():
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
|||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
from app.models.multi_agent_model import PydanticType
|
||||||
|
from app.schemas import ModelParameters
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(Base):
|
class AgentConfig(Base):
|
||||||
@@ -19,7 +21,10 @@ class AgentConfig(Base):
|
|||||||
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
|
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
|
||||||
|
|
||||||
# 结构化配置(直接存储 JSON)
|
# 结构化配置(直接存储 JSON)
|
||||||
model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
# model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||||
|
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
|
||||||
|
comment="模型参数配置(temperature、max_tokens等)")
|
||||||
|
|
||||||
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
||||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ class AppChatService:
|
|||||||
model_config_id = config.default_model_config_id
|
model_config_id = config.default_model_config_id
|
||||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
|
||||||
# 处理系统提示词(支持变量替换)
|
# 处理系统提示词(支持变量替换)
|
||||||
system_prompt = config.get("system_prompt", "")
|
system_prompt = config.system_prompt
|
||||||
if variables:
|
if variables:
|
||||||
system_prompt_rendered = render_prompt_message(
|
system_prompt_rendered = render_prompt_message(
|
||||||
system_prompt,
|
system_prompt,
|
||||||
@@ -197,7 +197,7 @@ class AppChatService:
|
|||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
if knowledge_retrieval:
|
if knowledge_retrieval:
|
||||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||||
@@ -208,13 +208,13 @@ class AppChatService:
|
|||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
memory_config = config.get("memory", {})
|
memory_config = config.memory
|
||||||
if memory_config.get("enabled") and user_id:
|
if memory_config.get("enabled") and user_id:
|
||||||
memory_flag = True
|
memory_flag = True
|
||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
web_tools = config.get("tools")
|
web_tools = config.tools
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
if web_search == True:
|
if web_search == True:
|
||||||
@@ -230,7 +230,7 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.get("model_parameters", {})
|
model_parameters = config.model_parameters
|
||||||
|
|
||||||
# 创建 LangChain Agent
|
# 创建 LangChain Agent
|
||||||
agent = LangChainAgent(
|
agent = LangChainAgent(
|
||||||
@@ -763,6 +763,7 @@ class AppChatService:
|
|||||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||||
# 发送错误事件
|
# 发送错误事件
|
||||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
def get_app_chat_service(
|
def get_app_chat_service(
|
||||||
|
|||||||
Reference in New Issue
Block a user