Merge branch 'develop' into fix/memory-enduser-config

This commit is contained in:
Ke Sun
2026-02-06 11:56:21 +08:00
294 changed files with 9936 additions and 4180 deletions

View File

@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
VariableDefinition,
ToolConfig,
AgentConfigCreate,
AgentConfigUpdate, ToolOldConfig,
AgentConfigUpdate, ToolOldConfig, SkillConfig,
)
@@ -48,6 +48,9 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = [tool.model_dump() for tool in config.tools]
if hasattr(config, "skills") and config.skills:
result["skills"] = config.skills.model_dump()
return result
@@ -58,6 +61,7 @@ class AgentConfigConverter:
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Union[list, Dict[str, Any]]],
skills: Optional[dict]
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -68,6 +72,7 @@ class AgentConfigConverter:
memory: 记忆配置
variables: 变量配置
tools: 工具配置
skills: 技能列表
Returns:
包含 Pydantic 对象的字典
@@ -78,6 +83,7 @@ class AgentConfigConverter:
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": [],
"skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[])
}
# 1. 解析模型参数配置
@@ -117,5 +123,10 @@ class AgentConfigConverter:
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
if skills:
result["skills"] = SkillConfig(**skills)
else:
result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[])
return result

View File

@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
skills=agent_cfg.skills
)
# 将解析后的字段添加到对象上(用于序列化)
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
agent_cfg.skills = parsed["skills"]
return agent_cfg

View File

@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -63,7 +64,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -79,21 +80,55 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -113,22 +148,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -192,6 +211,8 @@ class AppChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
elapsed_time = time.time() - start_time
return {
@@ -230,7 +251,7 @@ class AppChatService:
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
@@ -246,20 +267,54 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
@@ -279,22 +334,6 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -374,6 +413,8 @@ class AppChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
@@ -618,6 +659,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
public=False
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
@@ -634,7 +676,8 @@ class AppChatService:
payload=payload,
config=config,
workspace_id=workspace_id,
release_id=release_id
release_id=release_id,
public=public
):
yield event

View File

@@ -313,6 +313,7 @@ class AppService:
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", []),
skills=storage_data.get("skills", {}),
is_active=True,
created_at=now,
updated_at=now,
@@ -916,6 +917,7 @@ class AppService:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.skills = storage_data.get("skills", {})
agent_cfg.updated_at = now
@@ -1003,11 +1005,12 @@ class AppService:
},
memory={
"enabled": True,
"memory_content": None,
"memory_config_id": None,
"max_history": 10
},
variables=[],
tools=[],
skills=[],
is_active=True,
created_at=now,
updated_at=now,
@@ -1403,6 +1406,7 @@ class AppService:
"memory": agent_cfg.memory,
"variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or [],
"skills": agent_cfg.skills or {},
}
# config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -1,15 +1,13 @@
"""应用统计服务"""
from datetime import datetime, timedelta
from typing import Dict, Any, List
from typing import Dict, Any
import uuid
from sqlalchemy import func, and_, cast, Date
from sqlalchemy.orm import Session
from app.models.conversation_model import Conversation, Message
from app.models.end_user_model import EndUser
from app.models.api_key_model import ApiKey, ApiKeyLog
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.api_key_model import ApiKey, ApiKeyLog, ApiKeyType
class AppStatisticsService:
@@ -146,7 +144,6 @@ class AppStatisticsService:
end_dt: datetime
) -> Dict[str, Any]:
"""获取Token消耗统计从Message的meta_data中提取"""
from sqlalchemy import text
# 查询所有相关消息的token使用情况
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
@@ -187,7 +184,80 @@ class AppStatisticsService:
daily_tokens[date_str] = 0
daily_tokens[date_str] += int(tokens)
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["tokens"] for row in daily_data)
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
total = sum(row["count"] for row in daily_data)
return {"daily": daily_data, "total": total}
def get_workspace_api_statistics(
self,
workspace_id: uuid.UUID,
start_date: int,
end_date: int
) -> list[Any]:
"""获取工作空间API调用统计
Args:
workspace_id: 工作空间ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
Returns:
每日统计数据列表
"""
# 将毫秒时间戳转换为 datetime
start_time = datetime.fromtimestamp(start_date / 1000)
end_time = datetime.fromtimestamp(end_date / 1000)
# 应用类型agent, multi_agent, workflow
app_types = [ApiKeyType.AGENT, ApiKeyType.CLUSTER, ApiKeyType.WORKFLOW]
# 每日应用类型调用次数
daily_app_calls = self.db.query(
cast(ApiKeyLog.created_at, Date).label('date'),
func.count(ApiKeyLog.id).label('count')
).join(
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
).filter(
and_(
ApiKey.workspace_id == workspace_id,
ApiKey.type.in_(app_types),
ApiKeyLog.created_at >= start_time,
ApiKeyLog.created_at <= end_time
)
).group_by(cast(ApiKeyLog.created_at, Date)).all()
# 每日服务类型调用次数
daily_service_calls = self.db.query(
cast(ApiKeyLog.created_at, Date).label('date'),
func.count(ApiKeyLog.id).label('count')
).join(
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
).filter(
and_(
ApiKey.workspace_id == workspace_id,
ApiKey.type == ApiKeyType.SERVICE,
ApiKeyLog.created_at >= start_time,
ApiKeyLog.created_at <= end_time
)
).group_by(cast(ApiKeyLog.created_at, Date)).all()
# 构建每日数据
app_calls_dict = {str(row.date): row.count for row in daily_app_calls}
service_calls_dict = {str(row.date): row.count for row in daily_service_calls}
# 合并所有日期
all_dates = sorted(set(app_calls_dict.keys()) | set(service_calls_dict.keys()))
daily_data = []
for date in all_dates:
app_count = app_calls_dict.get(date, 0)
service_count = service_calls_dict.get(date, 0)
daily_data.append({
"date": date,
"total_calls": app_count + service_count,
"app_calls": app_count,
"service_calls": service_count
})
return daily_data

View File

@@ -24,6 +24,7 @@ from app.core.error_codes import BizCode
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelType
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -357,6 +358,8 @@ class CollaborativeOrchestrator:
"usage": response.get("usage", {"total_tokens": 0}),
"is_final_answer": True
}
ModelApiKeyService.record_api_key_usage(self.db, agent_config.get("api_key_id"))
# 检查是否有工具调用handoff
tool_calls = response.get("tool_calls", [])
@@ -427,7 +430,7 @@ class CollaborativeOrchestrator:
)
# 获取 API Key
api_key_config = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
api_key_config = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_config:
raise BusinessException(
f"Agent 模型没有可用的 API Key: {agent_id}",
@@ -442,7 +445,8 @@ class CollaborativeOrchestrator:
"provider": api_key_config.provider,
"api_key": api_key_config.api_key,
"api_base": api_key_config.api_base,
"model_parameters": config_data.get("model_parameters", {})
"model_parameters": config_data.get("model_parameters", {}),
"api_key_id": api_key_config.id
}
except ValueError:

View File

@@ -10,6 +10,11 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
@@ -24,12 +29,11 @@ from app.services import task_service
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService
from app.services.tool_service import ToolService
from app.services.multimodal_service import MultimodalService
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -59,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
长期记忆工具
"""
# search_switch = memory_config.get("search_switch", "2")
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
@tool(args_schema=LongTermMemoryInput)
def long_term_memory(question: str) -> str:
@@ -310,6 +315,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -320,9 +326,7 @@ class DraftRunService:
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -345,6 +349,25 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
@@ -433,7 +456,8 @@ class DraftRunService:
)
memory_config_= agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 8. 调用 Agent支持多模态
result = await agent.chat(
@@ -450,6 +474,8 @@ class DraftRunService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
await self._save_conversation_message(
@@ -558,6 +584,7 @@ class DraftRunService:
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
@@ -567,9 +594,7 @@ class DraftRunService:
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
@@ -592,6 +617,25 @@ class DraftRunService:
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -628,7 +672,6 @@ class DraftRunService:
}
)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_config["model_name"],
@@ -677,7 +720,8 @@ class DraftRunService:
})
memory_config_ = agent_config.memory
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
# 9. 流式调用 Agent支持多模态
full_content = ""
@@ -704,6 +748,8 @@ class DraftRunService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
if sub_agent:
yield self._format_sse_event("sub_usage", {
"total_tokens": total_tokens
@@ -770,7 +816,7 @@ class DraftRunService:
Raises:
BusinessException: 当没有可用的 API Key 时
"""
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# stmt = (
# select(ModelApiKey).join(
# ModelConfig, ModelApiKey.model_configs
@@ -784,7 +830,8 @@ class DraftRunService:
# )
#
# api_key = self.db.scalars(stmt).first()
api_key = api_keys[0] if api_keys else None
# api_key = api_keys[0] if api_keys else None
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -793,7 +840,8 @@ class DraftRunService:
"model_name": api_key.model_name,
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base
"api_base": api_key.api_base,
"api_key_id": api_key.id
}
async def _ensure_conversation(
@@ -1051,7 +1099,7 @@ class DraftRunService:
except Exception as e:
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
return {}
def _replace_variables(

View File

@@ -537,7 +537,7 @@ def convert_multi_agent_config_to_handoffs(
# 获取该 Agent 的模型配置
if release.default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(db, release.default_model_config_id)
if model_api_key:
model_config = RedBearModelConfig(
model_name=model_api_key.model_name,
@@ -551,6 +551,7 @@ def convert_multi_agent_config_to_handoffs(
}
)
logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}")
ModelApiKeyService.record_api_key_usage(db, model_api_key.id)
else:
logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}")
else:

View File

@@ -382,6 +382,7 @@ class LLMRouter:
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelApiKey, ModelType
from app.services.model_service import ModelApiKeyService
# 获取 API Key 配置(通过关联关系)
# api_key_config = self.db.query(ModelApiKey).join(
@@ -389,8 +390,9 @@ class LLMRouter:
# ).filter(ModelConfig.id == self.routing_model_config.id,
# ModelApiKey.is_active == True
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.routing_model_config.id)
if not api_key_config:
raise Exception("路由模型没有可用的 API Key")
@@ -424,7 +426,6 @@ class LLMRouter:
# 调用模型
response = await llm.ainvoke(prompt)
from app.services.model_service import ModelApiKeyService
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容

View File

@@ -349,7 +349,7 @@ class MasterAgentRouter:
from app.models import ModelApiKey, ModelType
# 获取 API Key 配置
api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id)
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.master_model_config.id)
if not api_key_config:
raise Exception("Master Agent 模型没有可用的 API Key")
@@ -400,6 +400,7 @@ class MasterAgentRouter:
# 调用模型
response = await llm.ainvoke(prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容
if hasattr(response, 'content'):

View File

@@ -1194,7 +1194,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
workspace_id=app.workspace_id
)
memory_config_id = str(memory_config.config_id) if memory_config else None
memory_obj = config.get('memory', {})
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
@@ -1284,7 +1286,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
if release:
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
if memory_config_id:
# 判断是否为UUID格式
if len(str(memory_config_id))>=5:
@@ -1330,7 +1333,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 从 config 中提取 memory_config_id
config = release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 获取配置名称使用字符串形式的ID进行查找兼容新旧格式
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None

View File

@@ -53,7 +53,10 @@ def get_workspace_end_users(
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 updated_at 从新到旧排序NULL 值排在最后)
"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
@@ -68,9 +71,14 @@ def get_workspace_end_users(
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
# 按 updated_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).order_by(
nullslast(desc(EndUserModel.updated_at)),
desc(EndUserModel.id)
).all()
# 转换为 Pydantic 模型(只在需要时转换)

View File

@@ -108,13 +108,14 @@ class WorkspaceAppService:
app_info["releases"].append(release_info)
def _extract_memory_content(self, config: Any) -> str:
"""Extract memory_comtent from config"""
"""Extract memory_config_id from config (兼容新旧字段名)"""
if not config or not isinstance(config, dict):
return None
memory_obj = config.get('memory')
if memory_obj and isinstance(memory_obj, dict):
return memory_obj.get('memory_content')
# 兼容新旧字段名:优先使用 memory_config_id回退到 memory_content
return memory_obj.get('memory_config_id') or memory_obj.get('memory_content')
return None

View File

@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
if not params.reflection_model_id:
params.reflection_model_id = params.llm_id
if not params.emotion_model_id:
params.emotion_model_id = params.llm_id
config = MemoryConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"end_user_id": config.end_user_id,
"config_id_old": config_id_old,
"apply_id": config.apply_id,
"scene_id": config.scene_id,
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,

View File

@@ -6,7 +6,7 @@ import math
import time
import asyncio
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
from app.schemas import model_schema
from app.schemas.model_schema import (
@@ -633,19 +633,31 @@ class ModelApiKeyService:
@staticmethod
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
"""获取可用的API Key按优先级和负载均衡)"""
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
"""获取可用的API Key根据负载均衡策略"""
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
return None
api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys:
return None
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个
return api_keys[0]
@staticmethod
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool:
"""记录API Key使用"""
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success
if api_key_id:
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success
return False
@staticmethod
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:

View File

@@ -14,6 +14,7 @@ from app.services.conversation_state_manager import ConversationStateManager
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -2569,8 +2570,9 @@ class MultiAgentOrchestrator:
# ModelConfig.id == default_model_config_id,
# ModelApiKey.is_active.is_(True)
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
if not api_key_config:
logger.warning("Master Agent 没有可用的 API Key使用简单整合")
@@ -2601,6 +2603,8 @@ class MultiAgentOrchestrator:
# 调用模型进行整合
response = await llm.ainvoke(merge_prompt)
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
# 提取响应内容
if hasattr(response, 'content'):
merged_response = response.content
@@ -2730,8 +2734,9 @@ class MultiAgentOrchestrator:
# ModelConfig.id == default_model_config_id,
# ModelApiKey.is_active.is_(True)
# ).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
api_key_config = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
# api_key_config = api_keys[0] if api_keys else None
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
if not api_key_config:
logger.warning("Master Agent 没有可用的 API Key使用简单整合")
@@ -2790,6 +2795,8 @@ class MultiAgentOrchestrator:
logger.debug(f"收到流式 chunk #{chunk_count}: {content[:30]}...")
yield self._format_sse_event("message", {"content": content})
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
logger.info(f"Master Agent 流式整合完成,共 {chunk_count} 个 chunks")
except AttributeError as e:

View File

@@ -23,7 +23,7 @@ logger = get_business_logger()
class ImageFormatStrategy(Protocol):
"""图片格式策略接口"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式"""
...
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
class DashScopeImageStrategy:
"""通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}"""
return {
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
import httpx
import base64
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
else:
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return {
"type": "image",
"source": {
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
class OpenAIImageStrategy:
"""OpenAI 图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
return {
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"):
"""
初始化多模态服务
@@ -120,10 +120,10 @@ class MultimodalService:
"""
self.db = db
self.provider = provider.lower()
async def process_files(
self,
files: Optional[List[FileInput]]
self,
files: Optional[List[FileInput]]
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
@@ -136,7 +136,7 @@ class MultimodalService:
"""
if not files:
return []
result = []
for idx, file in enumerate(files):
try:
@@ -168,10 +168,10 @@ class MultimodalService:
"type": "text",
"text": f"[文件处理失败: {str(e)}]"
})
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
"""
处理图片文件
@@ -184,14 +184,10 @@ class MultimodalService:
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"}
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
url = file.url
else:
# 本地文件,获取访问 URL
url = await self._get_file_url(file.upload_file_id)
url = await self.get_file_url(file)
logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
@@ -223,7 +219,7 @@ class MultimodalService:
"type": "image",
"image": url
}
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
@@ -237,15 +233,15 @@ class MultimodalService:
import httpx
import base64
from mimetypes import guess_type
# 下载图片
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
# 获取图片数据
image_data = response.content
# 确定 media type
content_type = response.headers.get("content-type")
if content_type and content_type.startswith("image/"):
@@ -254,14 +250,14 @@ class MultimodalService:
# 从 URL 推断
guessed_type, _ = guess_type(url)
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
# 转换为 base64
base64_data = base64.b64encode(image_data).decode("utf-8")
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
@@ -284,14 +280,14 @@ class MultimodalService:
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id
).first()
file_name = generic_file.file_name if generic_file else "unknown"
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
"""
处理音频文件
@@ -307,7 +303,7 @@ class MultimodalService:
"type": "text",
"text": "[音频文件,暂不支持处理]"
}
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
"""
处理视频文件
@@ -323,13 +319,13 @@ class MultimodalService:
"type": "text",
"text": "[视频文件,暂不支持处理]"
}
async def _get_file_url(self, file_id: uuid.UUID) -> str:
async def get_file_url(self, file: FileInput) -> str:
"""
获取文件的访问 URL
Args:
file_id: 文件ID
file: File Input Struct
Returns:
str: 文件访问 URL
@@ -337,26 +333,31 @@ class MultimodalService:
Raises:
BusinessException: 文件不存在
"""
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 如果有 access_url直接返回
if generic_file.access_url:
return generic_file.access_url
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等)
# 这里暂时返回一个占位 URL
return f"/api/files/{file_id}/download"
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
# 本地文件,获取访问 URL
file_id = file.upload_file_id
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file.upload_file_id}",
BizCode.NOT_FOUND
)
# 如果有 access_url直接返回
if generic_file.access_url:
return generic_file.access_url
# 否则,根据 storage_path 生成 URL
# TODO: 根据实际存储方式生成 URL本地存储、OSS 等)
# 这里暂时返回一个占位 URL
return f"/api/files/{file_id}/download"
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
"""
提取文档文本内容
@@ -371,20 +372,20 @@ class MultimodalService:
GenericFile.id == file_id,
GenericFile.status == "active"
).first()
if not generic_file:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx
# - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower()
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path)
elif file_ext == '.pdf':
@@ -393,7 +394,7 @@ class MultimodalService:
return await self._extract_word_text(generic_file.storage_path)
else:
return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str:
"""读取纯文本文件"""
try:
@@ -402,7 +403,7 @@ class MultimodalService:
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str:
"""提取 PDF 文本"""
try:
@@ -412,7 +413,7 @@ class MultimodalService:
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str:
"""提取 Word 文档文本"""
try:

View File

@@ -1,4 +1,3 @@
{% raw %}
Role: AI Prompt Optimization Expert
Profile
@@ -12,11 +11,11 @@ Skills
Core Optimization Skills
Requirement Analysis: Accurately understand the relationship between the users current needs and the original prompt.
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
Variable Handling: Identify and standardize dynamic variables in prompts.
{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
Auxiliary Generation Skills
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.
{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
Language Consistency: Maintain consistency between label language and user input language.
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
Format Standardization: Strictly adhere to specified output format requirements.
@@ -25,30 +24,30 @@ Rules
Basic Principles
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
Language Rule: All label languages must fully match the user input language.
Behavior Guidelines
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.
{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
Constraints
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
Content Constraint: Must not include any explanations, analyses, or additional comments.
Language Constraint: Must use clear and concise language.
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
Workflows
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
Step 1: Receive the users current requirement description {{user_require}} and the original prompt {{original_prompt}}.
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
Step 4: Generate a JSON output containing the optimized prompt and its description.
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
Initialization
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
{% endraw %}
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.

View File

@@ -23,6 +23,7 @@ from app.repositories.prompt_optimizer_repository import (
PromptReleaseRepository
)
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
from app.services.model_service import ModelApiKeyService
logger = get_business_logger()
@@ -128,7 +129,8 @@ class PromptOptimizerService:
session_id: uuid.UUID,
user_id: uuid.UUID,
current_prompt: str,
user_require: str
user_require: str,
skill: bool = False
) -> AsyncGenerator[dict[str, str | Any], Any]:
"""
Optimize a user-provided prompt using a configured prompt optimizer LLM.
@@ -157,6 +159,7 @@ class PromptOptimizerService:
user_id (uuid.UUID): Identifier of the user associated with the session.
current_prompt (str): Original prompt to optimize.
user_require (str): User's requirements or instructions for optimization.
skill(bool): Is skill required
Returns:
OptimizePromptResult: An object containing:
@@ -174,8 +177,9 @@ class PromptOptimizerService:
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
# Create LLM instance
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
api_config: ModelApiKey = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
# api_config: ModelApiKey = api_keys[0] if api_keys else None
api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id)
llm = RedBearLLM(RedBearModelConfig(
model_name=api_config.model_name,
provider=api_config.provider,
@@ -186,7 +190,7 @@ class PromptOptimizerService:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render()
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
opt_user_prompt = f.read()
@@ -250,6 +254,7 @@ class PromptOptimizerService:
optim_result = json_repair.repair_json(buffer, return_objects=True)
# prompt = optim_result.get("prompt")
desc = optim_result.get("desc")
ModelApiKeyService.record_api_key_usage(self.db, api_config.id)
self.create_message(
tenant_id=tenant_id,
session_id=session_id,

View File

@@ -10,6 +10,7 @@ from app.services.memory_konwledges_server import write_rag
from app.models import ReleaseShare, AppRelease, Conversation
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
@@ -178,8 +179,9 @@ class SharedChatService:
# .limit(1)
# )
# api_key_obj = self.db.scalars(stmt).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
api_key_obj = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_key_obj = api_keys[0] if api_keys else None
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -309,6 +311,8 @@ class SharedChatService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
return {
"conversation_id": conversation.id,
@@ -349,7 +353,8 @@ class SharedChatService:
if variables is None:
variables = {}
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
# 兼容新旧字段名:使用 memory_config_id
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
try:
# 获取发布版本和配置
@@ -383,8 +388,9 @@ class SharedChatService:
# .limit(1)
# )
# api_key_obj = self.db.scalars(stmt).first()
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
api_key_obj = api_keys[0] if api_keys else None
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
# api_key_obj = api_keys[0] if api_keys else None
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
@@ -513,7 +519,8 @@ class SharedChatService:
}
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"

View File

@@ -0,0 +1,133 @@
"""Skill Service"""
import uuid
from typing import List
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.repositories.skill_repository import SkillRepository
from app.schemas.skill_schema import SkillCreate, SkillUpdate
from app.models.skill_model import Skill
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.tool_service import ToolService
class SkillService:
"""Skill 业务逻辑层"""
@staticmethod
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
"""创建技能"""
# 检查同名技能
existing = db.query(Skill).filter(
Skill.tenant_id == tenant_id,
Skill.name == data.name
).first()
if existing:
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
skill = SkillRepository.create(db, data, tenant_id)
db.commit()
db.refresh(skill)
return skill
@staticmethod
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
"""获取技能"""
try:
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
# 填充工具详情
tool_service = ToolService(db)
enriched_tools = []
for tool_config in skill.tools:
tool_id = tool_config.get("tool_id")
if tool_id:
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
if tool_info:
enriched_tools.append({
"tool_id": tool_id,
"operation": tool_config.get("operation"),
"tool_info": tool_info
})
skill.tools = enriched_tools
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def list_skills(
db: Session,
tenant_id: uuid.UUID,
search: str = None,
is_active: bool = None,
is_public: bool = None,
page: int = 1,
pagesize: int = 10
) -> tuple[list[type[Skill]], int]:
"""列出技能"""
return SkillRepository.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
@staticmethod
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
"""更新技能"""
try:
skill = SkillRepository.update(db, skill_id, data, tenant_id)
if not skill:
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
db.commit()
db.refresh(skill)
return skill
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""删除技能"""
try:
success = SkillRepository.delete(db, skill_id, tenant_id)
if not success:
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
db.commit()
return True
except (BusinessException, SQLAlchemyError) as e:
db.rollback()
raise e
@staticmethod
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
"""加载技能关联的工具
Returns:
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
"""
tools = []
tool_to_skill_map = {} # {tool_name: skill_id}
tool_service = ToolService(db)
for skill_id in skill_ids:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 建立工具到技能的映射
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
tool_to_skill_map[tool_name] = skill_id
except Exception as e:
print(f"加载技能 {skill_id} 的工具时出错: {e}")
continue
return tools, tool_to_skill_map

View File

@@ -4,9 +4,8 @@
import datetime
import logging
import uuid
from typing import Any, Annotated, AsyncGenerator, Optional
from typing import Any, Annotated, Optional
from deprecated import deprecated
from fastapi import Depends
from sqlalchemy.orm import Session
@@ -23,6 +22,7 @@ from app.repositories.workflow_repository import (
from app.schemas import DraftRunRequest
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -36,6 +36,7 @@ class WorkflowService:
self.execution_repo = WorkflowExecutionRepository(db)
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
self.conversation_service = ConversationService(db)
self.multimodal_service = MultimodalService(db)
# ==================== 配置管理 ====================
@@ -445,24 +446,22 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id, "files": files}
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
# 2. 创建执行记录
execution = self.create_execution(
@@ -544,10 +543,10 @@ class WorkflowService:
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"variables": result.get("variables"),
"messages": result.get("messages"),
# "variables": result.get("variables"),
# "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
@@ -566,6 +565,41 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
@staticmethod
def _map_public_event(event: dict) -> dict | None:
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", ""))
}
}
case "node_start" | "node_end" | "node_error":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
decide
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
async def run_stream(
self,
app_id: uuid.UUID,
@@ -573,6 +607,7 @@ class WorkflowService:
config: WorkflowConfig,
workspace_id: uuid.UUID,
release_id: Optional[uuid.UUID] = None,
public: bool = False
):
"""运行工作流(流式)
@@ -582,6 +617,7 @@ class WorkflowService:
app_id: 应用 ID
payload: 请求对象(包含 message, variables, conversation_id 等)
config: 存储类型(可选)
public: 是否发布
Yields:
SSE 格式的流式事件
@@ -597,24 +633,23 @@ class WorkflowService:
code=BizCode.CONFIG_MISSING,
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id}
# 转换 user_id 为 UUID
triggered_by_uuid = None
if payload.user_id:
try:
triggered_by_uuid = uuid.UUID(payload.user_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id, "files": files}
# 转换 conversation_id 为 UUID
conversation_id_uuid = None
if payload.conversation_id:
try:
conversation_id_uuid = uuid.UUID(payload.conversation_id)
except (ValueError, AttributeError):
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
# 2. 创建执行记录
execution = self.create_execution(
@@ -661,7 +696,7 @@ class WorkflowService:
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(workspace_id),
user_id=payload.user_id
user_id=payload.user_id,
):
if event.get("event") == "workflow_end":
@@ -692,7 +727,9 @@ class WorkflowService:
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
event = self._emit(public, event)
if event:
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
@@ -710,134 +747,6 @@ class WorkflowService:
}
}
@deprecated(reason="This method is deprecated. "
"Please use WorkflowService.run / run_stream instead.")
async def run_workflow(
self,
app_id: uuid.UUID,
input_data: dict[str, Any],
triggered_by: uuid.UUID,
conversation_id: uuid.UUID | None = None,
stream: bool = False
) -> AsyncGenerator | dict:
"""运行工作流
Args:
app_id: 应用 ID
input_data: 输入数据(包含 message 和 variables
triggered_by: 触发用户 ID
conversation_id: 会话 ID可选
stream: 是否流式返回
Returns:
执行结果(非流式)或生成器(流式)
Raises:
BusinessException: 配置不存在或执行失败时抛出
"""
# 1. 获取工作流配置
config = self.get_workflow_config(app_id)
if not config:
raise BusinessException(
code=BizCode.NOT_FOUND,
message=f"工作流配置不存在: app_id={app_id}"
)
# 2. 创建执行记录
execution = self.create_execution(
workflow_config_id=config.id,
app_id=app_id,
trigger_type="manual",
triggered_by=triggered_by,
conversation_id=conversation_id,
input_data=input_data
)
# 3. 构建工作流配置字典
workflow_config_dict = {
"nodes": config.nodes,
"edges": config.edges,
"variables": config.variables,
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
from app.models import App
app = self.db.query(App).filter(
App.id == app_id,
App.is_active.is_(True)
).first()
if not app:
raise BusinessException(
code=BizCode.NOT_FOUND,
message=f"应用不存在: app_id={app_id}"
)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
if stream:
# 流式执行
return self._run_workflow_stream(
workflow_config_dict,
input_data,
execution.execution_id,
str(app.workspace_id),
str(triggered_by)
)
else:
# 非流式执行
result = await execute_workflow(
workflow_config=workflow_config_dict,
input_data=input_data,
execution_id=execution.execution_id,
workspace_id=str(app.workspace_id),
user_id=str(triggered_by)
)
# 更新执行结果
if result.get("status") == "completed":
token_usage = result.get("data").get("token_usage", {}) or {}
self.update_execution_status(
execution.execution_id,
"completed",
output_data=result.get("node_outputs", {}),
token_usage=token_usage.get("total_tokens", None)
)
else:
self.update_execution_status(
execution.execution_id,
"failed",
error_message=result.get("error")
)
# 返回增强的响应结构
return {
"execution_id": execution.execution_id,
"status": result.get("status"),
"output": result.get("output"), # 最终输出(字符串)
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"error_message": result.get("error"),
"elapsed_time": result.get("elapsed_time"),
"token_usage": result.get("token_usage")
}
except Exception as e:
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution.execution_id,
"failed",
error_message=str(e)
)
raise BusinessException(
code=BizCode.INTERNAL_ERROR,
message=f"工作流执行失败: {str(e)}"
)
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
@@ -869,72 +778,6 @@ class WorkflowService:
return clean_value(event)
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
async def _run_workflow_stream(
self,
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str):
"""运行工作流(流式,内部方法)
Args:
workflow_config: 工作流配置
input_data: 输入数据
execution_id: 执行 ID
workspace_id: 工作空间 ID
user_id: 用户 ID
Yields:
流式事件(格式:{"event": "<type>", "data": {...}}
"""
from app.core.workflow.executor import execute_workflow_stream
try:
async for event in execute_workflow_stream(
workflow_config=workflow_config,
input_data=input_data,
execution_id=execution_id,
workspace_id=workspace_id,
user_id=user_id
):
# 直接转发事件executor 已经返回正确格式)
if event.get("event") == "workflow_end":
token_usage = event.get("data").get("token_usage", {}) or {}
status = event.get("data", {}).get("status")
if status == "completed":
self.update_execution_status(
execution_id,
"completed",
output_data=event.get("data"),
token_usage=token_usage.get("total_tokens", None)
)
elif status == "failed":
self.update_execution_status(
execution_id,
"failed",
output_data=event.get("data")
)
else:
logger.error(f"unexpect workflow run status, status: {status}")
yield event
except Exception as e:
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
self.update_execution_status(
execution_id,
"failed",
error_message=str(e)
)
yield {
"event": "error",
"data": {
"execution_id": execution_id,
"error": str(e)
}
}
# ==================== 依赖注入函数 ====================