[fix] public share agent run

This commit is contained in:
Mark
2026-01-06 14:43:23 +08:00
parent 85c7e531e4
commit f59f508c4d
3 changed files with 17 additions and 11 deletions

View File

@@ -392,8 +392,8 @@ async def chat(
if app_type == AppType.AGENT: if app_type == AppType.AGENT:
# 流式返回 # 流式返回
agent_config = agent_config_4_app_release(app.current_release) agent_config = agent_config_4_app_release(release)
if payload.stream: if payload.stream:
# async def event_generator(): # async def event_generator():
# async for event in service.chat_stream( # async for event in service.chat_stream(

View File

@@ -4,6 +4,8 @@ from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from app.db import Base from app.db import Base
from app.models.multi_agent_model import PydanticType
from app.schemas import ModelParameters
class AgentConfig(Base): class AgentConfig(Base):
@@ -17,14 +19,17 @@ class AgentConfig(Base):
# Agent 行为配置 # Agent 行为配置
system_prompt = Column(Text, nullable=True, comment="系统提示词") system_prompt = Column(Text, nullable=True, comment="系统提示词")
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID") default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
# 结构化配置(直接存储 JSON # 结构化配置(直接存储 JSON
model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等") # model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等")
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
comment="模型参数配置temperature、max_tokens等")
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置") knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
memory = Column(JSON, nullable=True, comment="记忆配置") memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置") variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置") tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
# 多 Agent 相关字段 # 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone") agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等") agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
@@ -41,4 +46,4 @@ class AgentConfig(Base):
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents") parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
def __repr__(self): def __repr__(self):
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>" return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"

View File

@@ -184,7 +184,7 @@ class AppChatService:
model_config_id = config.default_model_config_id model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id) api_key_obj = ModelApiKeyService.get_a_api_key(self.db ,model_config_id)
# 处理系统提示词(支持变量替换) # 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "") system_prompt = config.system_prompt
if variables: if variables:
system_prompt_rendered = render_prompt_message( system_prompt_rendered = render_prompt_message(
system_prompt, system_prompt,
@@ -197,7 +197,7 @@ class AppChatService:
tools = [] tools = []
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval") knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval: if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
@@ -208,13 +208,13 @@ class AppChatService:
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
memory_config = config.get("memory", {}) memory_config = config.memory
if memory_config.get("enabled") and user_id: if memory_config.get("enabled") and user_id:
memory_flag = True memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
web_tools = config.get("tools") web_tools = config.tools
web_search_choice = web_tools.get("web_search", {}) web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False) web_search_enable = web_search_choice.get("enabled", False)
if web_search == True: if web_search == True:
@@ -230,7 +230,7 @@ class AppChatService:
) )
# 获取模型参数 # 获取模型参数
model_parameters = config.get("model_parameters", {}) model_parameters = config.model_parameters
# 创建 LangChain Agent # 创建 LangChain Agent
agent = LangChainAgent( agent = LangChainAgent(
@@ -763,6 +763,7 @@ class AppChatService:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True) logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件 # 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================
def get_app_chat_service( def get_app_chat_service(