Merge branch 'develop' into fix/memory-enduser-config
This commit is contained in:
@@ -9,7 +9,7 @@ from app.schemas.app_schema import (
|
||||
VariableDefinition,
|
||||
ToolConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate, ToolOldConfig,
|
||||
AgentConfigUpdate, ToolOldConfig, SkillConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +48,9 @@ class AgentConfigConverter:
|
||||
# 5. 工具配置
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||
|
||||
if hasattr(config, "skills") and config.skills:
|
||||
result["skills"] = config.skills.model_dump()
|
||||
|
||||
return result
|
||||
|
||||
@@ -58,6 +61,7 @@ class AgentConfigConverter:
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Union[list, Dict[str, Any]]],
|
||||
skills: Optional[dict]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
@@ -68,6 +72,7 @@ class AgentConfigConverter:
|
||||
memory: 记忆配置
|
||||
variables: 变量配置
|
||||
tools: 工具配置
|
||||
skills: 技能列表
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
@@ -78,6 +83,7 @@ class AgentConfigConverter:
|
||||
"memory": MemoryConfig(enabled=True),
|
||||
"variables": [],
|
||||
"tools": [],
|
||||
"skills": SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||
}
|
||||
|
||||
# 1. 解析模型参数配置
|
||||
@@ -117,5 +123,10 @@ class AgentConfigConverter:
|
||||
name: ToolOldConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
if skills:
|
||||
result["skills"] = SkillConfig(**skills)
|
||||
else:
|
||||
result["skills"] = SkillConfig(enabled=False, all_skills=False, skill_ids=[])
|
||||
|
||||
return result
|
||||
|
||||
@@ -26,6 +26,7 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
memory=agent_cfg.memory,
|
||||
variables=agent_cfg.variables,
|
||||
tools=agent_cfg.tools,
|
||||
skills=agent_cfg.skills
|
||||
)
|
||||
|
||||
# 将解析后的字段添加到对象上(用于序列化)
|
||||
@@ -34,5 +35,6 @@ def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
agent_cfg.memory = parsed["memory"]
|
||||
agent_cfg.variables = parsed["variables"]
|
||||
agent_cfg.tools = parsed["tools"]
|
||||
agent_cfg.skills = parsed["skills"]
|
||||
|
||||
return agent_cfg
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -63,7 +64,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -79,21 +80,55 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -113,22 +148,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -192,6 +211,8 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -230,7 +251,7 @@ class AppChatService:
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.system_prompt
|
||||
if variables:
|
||||
@@ -246,20 +267,54 @@ class AppChatService:
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
@@ -279,22 +334,6 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -374,6 +413,8 @@ class AppChatService:
|
||||
}
|
||||
)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
@@ -618,6 +659,7 @@ class AppChatService:
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
public=False
|
||||
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
@@ -634,7 +676,8 @@ class AppChatService:
|
||||
payload=payload,
|
||||
config=config,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release_id
|
||||
release_id=release_id,
|
||||
public=public
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
@@ -313,6 +313,7 @@ class AppService:
|
||||
memory=storage_data.get("memory"),
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", []),
|
||||
skills=storage_data.get("skills", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -916,6 +917,7 @@ class AppService:
|
||||
agent_cfg.variables = storage_data.get("variables", [])
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
agent_cfg.skills = storage_data.get("skills", {})
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
@@ -1003,11 +1005,12 @@ class AppService:
|
||||
},
|
||||
memory={
|
||||
"enabled": True,
|
||||
"memory_content": None,
|
||||
"memory_config_id": None,
|
||||
"max_history": 10
|
||||
},
|
||||
variables=[],
|
||||
tools=[],
|
||||
skills=[],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1403,6 +1406,7 @@ class AppService:
|
||||
"memory": agent_cfg.memory,
|
||||
"variables": agent_cfg.variables or [],
|
||||
"tools": agent_cfg.tools or [],
|
||||
"skills": agent_cfg.skills or {},
|
||||
}
|
||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||
default_model_config_id = agent_cfg.default_model_config_id
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"""应用统计服务"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, Any
|
||||
import uuid
|
||||
from sqlalchemy import func, and_, cast, Date
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.models.api_key_model import ApiKey, ApiKeyLog, ApiKeyType
|
||||
|
||||
|
||||
class AppStatisticsService:
|
||||
@@ -146,7 +144,6 @@ class AppStatisticsService:
|
||||
end_dt: datetime
|
||||
) -> Dict[str, Any]:
|
||||
"""获取Token消耗统计(从Message的meta_data中提取)"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 查询所有相关消息的token使用情况
|
||||
# meta_data中可能包含: {"usage": {"total_tokens": 100}} 或 {"tokens": 100}
|
||||
@@ -187,7 +184,80 @@ class AppStatisticsService:
|
||||
daily_tokens[date_str] = 0
|
||||
daily_tokens[date_str] += int(tokens)
|
||||
|
||||
daily_data = [{"date": date, "tokens": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["tokens"] for row in daily_data)
|
||||
daily_data = [{"date": date, "count": tokens} for date, tokens in sorted(daily_tokens.items()) if tokens != 0]
|
||||
total = sum(row["count"] for row in daily_data)
|
||||
|
||||
return {"daily": daily_data, "total": total}
|
||||
|
||||
def get_workspace_api_statistics(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int
|
||||
) -> list[Any]:
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
每日统计数据列表
|
||||
"""
|
||||
# 将毫秒时间戳转换为 datetime
|
||||
start_time = datetime.fromtimestamp(start_date / 1000)
|
||||
end_time = datetime.fromtimestamp(end_date / 1000)
|
||||
|
||||
# 应用类型(agent, multi_agent, workflow)
|
||||
app_types = [ApiKeyType.AGENT, ApiKeyType.CLUSTER, ApiKeyType.WORKFLOW]
|
||||
|
||||
# 每日应用类型调用次数
|
||||
daily_app_calls = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.type.in_(app_types),
|
||||
ApiKeyLog.created_at >= start_time,
|
||||
ApiKeyLog.created_at <= end_time
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
# 每日服务类型调用次数
|
||||
daily_service_calls = self.db.query(
|
||||
cast(ApiKeyLog.created_at, Date).label('date'),
|
||||
func.count(ApiKeyLog.id).label('count')
|
||||
).join(
|
||||
ApiKey, ApiKeyLog.api_key_id == ApiKey.id
|
||||
).filter(
|
||||
and_(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.type == ApiKeyType.SERVICE,
|
||||
ApiKeyLog.created_at >= start_time,
|
||||
ApiKeyLog.created_at <= end_time
|
||||
)
|
||||
).group_by(cast(ApiKeyLog.created_at, Date)).all()
|
||||
|
||||
# 构建每日数据
|
||||
app_calls_dict = {str(row.date): row.count for row in daily_app_calls}
|
||||
service_calls_dict = {str(row.date): row.count for row in daily_service_calls}
|
||||
|
||||
# 合并所有日期
|
||||
all_dates = sorted(set(app_calls_dict.keys()) | set(service_calls_dict.keys()))
|
||||
|
||||
daily_data = []
|
||||
for date in all_dates:
|
||||
app_count = app_calls_dict.get(date, 0)
|
||||
service_count = service_calls_dict.get(date, 0)
|
||||
daily_data.append({
|
||||
"date": date,
|
||||
"total_calls": app_count + service_count,
|
||||
"app_calls": app_count,
|
||||
"service_calls": service_count
|
||||
})
|
||||
|
||||
return daily_data
|
||||
|
||||
@@ -24,6 +24,7 @@ from app.core.error_codes import BizCode
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -357,6 +358,8 @@ class CollaborativeOrchestrator:
|
||||
"usage": response.get("usage", {"total_tokens": 0}),
|
||||
"is_final_answer": True
|
||||
}
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, agent_config.get("api_key_id"))
|
||||
|
||||
# 检查是否有工具调用(handoff)
|
||||
tool_calls = response.get("tool_calls", [])
|
||||
@@ -427,7 +430,7 @@ class CollaborativeOrchestrator:
|
||||
)
|
||||
|
||||
# 获取 API Key
|
||||
api_key_config = ModelApiKeyService.get_a_api_key(self.db, model_config_id)
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_config:
|
||||
raise BusinessException(
|
||||
f"Agent 模型没有可用的 API Key: {agent_id}",
|
||||
@@ -442,7 +445,8 @@ class CollaborativeOrchestrator:
|
||||
"provider": api_key_config.provider,
|
||||
"api_key": api_key_config.api_key,
|
||||
"api_base": api_key_config.api_base,
|
||||
"model_parameters": config_data.get("model_parameters", {})
|
||||
"model_parameters": config_data.get("model_parameters", {}),
|
||||
"api_key_id": api_key_config.id
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
|
||||
@@ -10,6 +10,11 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -24,12 +29,11 @@ from app.services import task_service
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -59,7 +63,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
长期记忆工具
|
||||
"""
|
||||
# search_switch = memory_config.get("search_switch", "2")
|
||||
config_id= memory_config.get("memory_content") or memory_config.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config.get("memory_config_id") or memory_config.get("memory_content", None)
|
||||
logger.info(f"创建长期记忆工具,配置: end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}")
|
||||
@tool(args_schema=LongTermMemoryInput)
|
||||
def long_term_memory(question: str) -> str:
|
||||
@@ -310,6 +315,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -320,9 +326,7 @@ class DraftRunService:
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -345,6 +349,25 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
@@ -433,7 +456,8 @@ class DraftRunService:
|
||||
)
|
||||
|
||||
memory_config_= agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||
|
||||
# 8. 调用 Agent(支持多模态)
|
||||
result = await agent.chat(
|
||||
@@ -450,6 +474,8 @@ class DraftRunService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
@@ -558,6 +584,7 @@ class DraftRunService:
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
@@ -567,9 +594,7 @@ class DraftRunService:
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
@@ -592,6 +617,25 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -628,7 +672,6 @@ class DraftRunService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_config["model_name"],
|
||||
@@ -677,7 +720,8 @@ class DraftRunService:
|
||||
})
|
||||
|
||||
memory_config_ = agent_config.memory
|
||||
config_id = memory_config_.get("memory_content") or memory_config_.get("memory_config",None)
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||
|
||||
# 9. 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
@@ -704,6 +748,8 @@ class DraftRunService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
if sub_agent:
|
||||
yield self._format_sse_event("sub_usage", {
|
||||
"total_tokens": total_tokens
|
||||
@@ -770,7 +816,7 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当没有可用的 API Key 时
|
||||
"""
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
# ModelConfig, ModelApiKey.model_configs
|
||||
@@ -784,7 +830,8 @@ class DraftRunService:
|
||||
# )
|
||||
#
|
||||
# api_key = self.db.scalars(stmt).first()
|
||||
api_key = api_keys[0] if api_keys else None
|
||||
# api_key = api_keys[0] if api_keys else None
|
||||
api_key = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
|
||||
if not api_key:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
@@ -793,7 +840,8 @@ class DraftRunService:
|
||||
"model_name": api_key.model_name,
|
||||
"provider": api_key.provider,
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -1051,7 +1099,7 @@ class DraftRunService:
|
||||
|
||||
except Exception as e:
|
||||
# 对于多 Agent 应用,没有直接的 AgentConfig 是正常的
|
||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", extra={"error": str(e)})
|
||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||||
return {}
|
||||
|
||||
def _replace_variables(
|
||||
|
||||
@@ -537,7 +537,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
|
||||
# 获取该 Agent 的模型配置
|
||||
if release.default_model_config_id:
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(db, release.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(db, release.default_model_config_id)
|
||||
if model_api_key:
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_api_key.model_name,
|
||||
@@ -551,6 +551,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
}
|
||||
)
|
||||
logger.debug(f"Agent {agent_name} 使用模型: {model_api_key.model_name}")
|
||||
ModelApiKeyService.record_api_key_usage(db, model_api_key.id)
|
||||
else:
|
||||
logger.warning(f"Agent {agent_name} 模型配置无效: {release.default_model_config_id}")
|
||||
else:
|
||||
|
||||
@@ -382,6 +382,7 @@ class LLMRouter:
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# 获取 API Key 配置(通过关联关系)
|
||||
# api_key_config = self.db.query(ModelApiKey).join(
|
||||
@@ -389,8 +390,9 @@ class LLMRouter:
|
||||
# ).filter(ModelConfig.id == self.routing_model_config.id,
|
||||
# ModelApiKey.is_active == True
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, self.routing_model_config.id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.routing_model_config.id)
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("路由模型没有可用的 API Key")
|
||||
@@ -424,7 +426,6 @@ class LLMRouter:
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
|
||||
@@ -349,7 +349,7 @@ class MasterAgentRouter:
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = ModelApiKeyService.get_a_api_key(self.db, self.master_model_config.id)
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, self.master_model_config.id)
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("Master Agent 模型没有可用的 API Key")
|
||||
@@ -400,6 +400,7 @@ class MasterAgentRouter:
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
|
||||
@@ -1194,7 +1194,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
workspace_id=app.workspace_id
|
||||
)
|
||||
|
||||
memory_config_id = str(memory_config.config_id) if memory_config else None
|
||||
memory_obj = config.get('memory', {})
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
@@ -1284,7 +1286,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
if release:
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
if memory_config_id:
|
||||
# 判断是否为UUID格式
|
||||
if len(str(memory_config_id))>=5:
|
||||
@@ -1330,7 +1333,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 从 config 中提取 memory_config_id
|
||||
config = release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
memory_config_id = memory_obj.get('memory_config_id') or memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 获取配置名称(使用字符串形式的ID进行查找,兼容新旧格式)
|
||||
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
|
||||
|
||||
@@ -53,7 +53,10 @@ def get_workspace_end_users(
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 updated_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
@@ -68,9 +71,14 @@ def get_workspace_end_users(
|
||||
app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 updated_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.app_id.in_(app_ids)
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.updated_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
|
||||
@@ -108,13 +108,14 @@ class WorkspaceAppService:
|
||||
app_info["releases"].append(release_info)
|
||||
|
||||
def _extract_memory_content(self, config: Any) -> str:
|
||||
"""Extract memory_comtent from config"""
|
||||
"""Extract memory_config_id from config (兼容新旧字段名)"""
|
||||
if not config or not isinstance(config, dict):
|
||||
return None
|
||||
|
||||
memory_obj = config.get('memory')
|
||||
if memory_obj and isinstance(memory_obj, dict):
|
||||
return memory_obj.get('memory_content')
|
||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||
return memory_obj.get('memory_config_id') or memory_obj.get('memory_content')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -129,6 +129,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
# reflection_model_id 和 emotion_model_id 默认与 llm_id 一致
|
||||
if not params.reflection_model_id:
|
||||
params.reflection_model_id = params.llm_id
|
||||
if not params.emotion_model_id:
|
||||
params.emotion_model_id = params.llm_id
|
||||
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
@@ -203,6 +209,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"end_user_id": config.end_user_id,
|
||||
"config_id_old": config_id_old,
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": config.scene_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
@@ -633,19 +633,31 @@ class ModelApiKeyService:
|
||||
|
||||
@staticmethod
|
||||
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
|
||||
"""获取可用的API Key(按优先级和负载均衡)"""
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
|
||||
"""获取可用的API Key(根据负载均衡策略)"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@staticmethod
|
||||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID | None) -> bool:
|
||||
"""记录API Key使用"""
|
||||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
if api_key_id:
|
||||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_a_api_key(db: Session, model_config_id: uuid.UUID) -> ModelApiKey:
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -2569,8 +2570,9 @@ class MultiAgentOrchestrator:
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2601,6 +2603,8 @@ class MultiAgentOrchestrator:
|
||||
# 调用模型进行整合
|
||||
response = await llm.ainvoke(merge_prompt)
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
merged_response = response.content
|
||||
@@ -2730,8 +2734,9 @@ class MultiAgentOrchestrator:
|
||||
# ModelConfig.id == default_model_config_id,
|
||||
# ModelApiKey.is_active.is_(True)
|
||||
# ).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
api_key_config = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, default_model_config_id)
|
||||
# api_key_config = api_keys[0] if api_keys else None
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(self.db, default_model_config_id)
|
||||
|
||||
if not api_key_config:
|
||||
logger.warning("Master Agent 没有可用的 API Key,使用简单整合")
|
||||
@@ -2790,6 +2795,8 @@ class MultiAgentOrchestrator:
|
||||
logger.debug(f"收到流式 chunk #{chunk_count}: {content[:30]}...")
|
||||
yield self._format_sse_event("message", {"content": content})
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id)
|
||||
|
||||
logger.info(f"Master Agent 流式整合完成,共 {chunk_count} 个 chunks")
|
||||
|
||||
except AttributeError as e:
|
||||
|
||||
@@ -23,7 +23,7 @@ logger = get_business_logger()
|
||||
|
||||
class ImageFormatStrategy(Protocol):
|
||||
"""图片格式策略接口"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""将图片 URL 转换为特定 provider 的格式"""
|
||||
...
|
||||
@@ -31,7 +31,7 @@ class ImageFormatStrategy(Protocol):
|
||||
|
||||
class DashScopeImageStrategy:
|
||||
"""通义千问图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||
return {
|
||||
@@ -42,7 +42,7 @@ class DashScopeImageStrategy:
|
||||
|
||||
class BedrockImageStrategy:
|
||||
"""Bedrock/Anthropic 图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
@@ -51,17 +51,17 @@ class BedrockImageStrategy:
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
@@ -69,12 +69,12 @@ class BedrockImageStrategy:
|
||||
else:
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
@@ -87,7 +87,7 @@ class BedrockImageStrategy:
|
||||
|
||||
class OpenAIImageStrategy:
|
||||
"""OpenAI 图片格式策略"""
|
||||
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
return {
|
||||
@@ -109,7 +109,7 @@ PROVIDER_STRATEGIES = {
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||
"""
|
||||
初始化多模态服务
|
||||
@@ -120,10 +120,10 @@ class MultimodalService:
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
self,
|
||||
files: Optional[List[FileInput]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
@@ -136,7 +136,7 @@ class MultimodalService:
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
@@ -168,10 +168,10 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -184,14 +184,10 @@ class MultimodalService:
|
||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
- 通义千问: {"type": "image", "image": "url"}
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
url = file.url
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
url = await self._get_file_url(file.upload_file_id)
|
||||
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||
|
||||
|
||||
# 根据 provider 返回不同格式
|
||||
if self.provider in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||
@@ -223,7 +219,7 @@ class MultimodalService:
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
|
||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
@@ -237,15 +233,15 @@ class MultimodalService:
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
|
||||
# 下载图片
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
# 获取图片数据
|
||||
image_data = response.content
|
||||
|
||||
|
||||
# 确定 media type
|
||||
content_type = response.headers.get("content-type")
|
||||
if content_type and content_type.startswith("image/"):
|
||||
@@ -254,14 +250,14 @@ class MultimodalService:
|
||||
# 从 URL 推断
|
||||
guessed_type, _ = guess_type(url)
|
||||
media_type = guessed_type if guessed_type and guessed_type.startswith("image/") else "image/jpeg"
|
||||
|
||||
|
||||
# 转换为 base64
|
||||
base64_data = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
|
||||
logger.debug(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
@@ -284,14 +280,14 @@ class MultimodalService:
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
@@ -307,7 +303,7 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
@@ -323,13 +319,13 @@ class MultimodalService:
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
|
||||
async def _get_file_url(self, file_id: uuid.UUID) -> str:
|
||||
|
||||
async def get_file_url(self, file: FileInput) -> str:
|
||||
"""
|
||||
获取文件的访问 URL
|
||||
|
||||
Args:
|
||||
file_id: 文件ID
|
||||
file: File Input Struct
|
||||
|
||||
Returns:
|
||||
str: 文件访问 URL
|
||||
@@ -337,26 +333,31 @@ class MultimodalService:
|
||||
Raises:
|
||||
BusinessException: 文件不存在
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return file.url
|
||||
else:
|
||||
# 本地文件,获取访问 URL
|
||||
file_id = file.upload_file_id
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file.upload_file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# 如果有 access_url,直接返回
|
||||
if generic_file.access_url:
|
||||
return generic_file.access_url
|
||||
|
||||
# 否则,根据 storage_path 生成 URL
|
||||
# TODO: 根据实际存储方式生成 URL(本地存储、OSS 等)
|
||||
# 这里暂时返回一个占位 URL
|
||||
return f"/api/files/{file_id}/download"
|
||||
|
||||
async def _extract_document_text(self, file_id: uuid.UUID) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
@@ -371,20 +372,20 @@ class MultimodalService:
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
).first()
|
||||
|
||||
|
||||
if not generic_file:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
elif file_ext == '.pdf':
|
||||
@@ -393,7 +394,7 @@ class MultimodalService:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
@@ -402,7 +403,7 @@ class MultimodalService:
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
@@ -412,7 +413,7 @@ class MultimodalService:
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
{% raw %}
|
||||
Role: AI Prompt Optimization Expert
|
||||
|
||||
Profile
|
||||
@@ -12,11 +11,11 @@ Skills
|
||||
Core Optimization Skills
|
||||
Requirement Analysis: Accurately understand the relationship between the user’s current needs and the original prompt.
|
||||
Structural Reconstruction: Transform vague requirements into clear, block-structured instructions.
|
||||
Variable Handling: Identify and standardize dynamic variables in prompts.
|
||||
{% if skill != true %}Variable Handling: Identify and standardize dynamic variables in prompts.{% endif %}
|
||||
Conflict Resolution: Prioritize current requirements when historical requirements conflict with current needs.
|
||||
|
||||
Auxiliary Generation Skills
|
||||
Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.
|
||||
{% if skill != true %}Completeness Check: Ensure all necessary elements (input, output, constraints, etc.) are explicitly defined.{% endif %}
|
||||
Language Consistency: Maintain consistency between label language and user input language.
|
||||
Executability Verification: Ensure optimized prompts can be directly used in AI tools.
|
||||
Format Standardization: Strictly adhere to specified output format requirements.
|
||||
@@ -25,30 +24,30 @@ Rules
|
||||
Basic Principles
|
||||
Priority Rule: When historical requirements conflict with current requirements, unconditionally prioritize current requirements.
|
||||
Completeness Rule: If the original prompt is empty, generate a complete prompt based on the current requirements.
|
||||
Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints
|
||||
{% if skill != true %}Structure Rule: Use a clear block structure, and the contents of each block are roles, tasks, requirements, inputs, outputs, and constraints{% endif %}
|
||||
Language Rule: All label languages must fully match the user input language.
|
||||
|
||||
Behavior Guidelines
|
||||
Precision Guideline: All instructions must be precise and directly executable, avoiding ambiguity.
|
||||
Readability Guideline: Ensure optimized prompts have good readability and logical flow.
|
||||
Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.
|
||||
{% if skill != true %}{% raw %}Variable Handling Guideline: Use lowercase English variable names wrapped in {{}} when variables are needed.
|
||||
Constraint Handling Guideline: Do not mention variable-related limitations under the [Constraints] label.{% endraw %}{% endif %}
|
||||
|
||||
Constraints
|
||||
Output Constraint: Must output in JSON format including the fields "prompt" and "desc".
|
||||
Content Constraint: Must not include any explanations, analyses, or additional comments.
|
||||
Language Constraint: Must use clear and concise language.
|
||||
Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).
|
||||
{% if skill != true %}Completeness Constraint: Must fully define all missing elements (input details, output format, constraints, etc.).{% endif %}
|
||||
|
||||
Workflows
|
||||
Goal: Optimize or generate AI prompts that can be directly used according to user requirements.
|
||||
Step 1: Receive the user’s current requirement description {{user_require}} and the original prompt {{original_prompt}}.
|
||||
Step 2: Analyze requirements, identify conflicts, and prioritize current requirements.
|
||||
Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
||||
{% if skill != true %}Step 3: Optimize or generate the prompt in a block-structured format, ensuring all elements are fully defined.
|
||||
Step 4: Generate a JSON output containing the optimized prompt and its description.
|
||||
{% else %}Step 3: Generate a JSON output containing the optimized prompt and its description.{% endif %}
|
||||
|
||||
Expected Outcome: Obtain a clear, directly executable AI prompt accompanied by an optimization description.
|
||||
|
||||
Initialization
|
||||
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
||||
{% endraw %}
|
||||
As an AI Prompt Optimization Expert, you must follow the above Rules and execute tasks according to the Workflows.
|
||||
@@ -23,6 +23,7 @@ from app.repositories.prompt_optimizer_repository import (
|
||||
PromptReleaseRepository
|
||||
)
|
||||
from app.schemas.prompt_optimizer_schema import OptimizePromptResult
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -128,7 +129,8 @@ class PromptOptimizerService:
|
||||
session_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
current_prompt: str,
|
||||
user_require: str
|
||||
user_require: str,
|
||||
skill: bool = False
|
||||
) -> AsyncGenerator[dict[str, str | Any], Any]:
|
||||
"""
|
||||
Optimize a user-provided prompt using a configured prompt optimizer LLM.
|
||||
@@ -157,6 +159,7 @@ class PromptOptimizerService:
|
||||
user_id (uuid.UUID): Identifier of the user associated with the session.
|
||||
current_prompt (str): Original prompt to optimize.
|
||||
user_require (str): User's requirements or instructions for optimization.
|
||||
skill(bool): Is skill required
|
||||
|
||||
Returns:
|
||||
OptimizePromptResult: An object containing:
|
||||
@@ -174,8 +177,9 @@ class PromptOptimizerService:
|
||||
logger.info(f"Prompt optimization started, user_id={user_id}, session_id={session_id}")
|
||||
|
||||
# Create LLM instance
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
||||
api_config: ModelApiKey = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config.id)
|
||||
# api_config: ModelApiKey = api_keys[0] if api_keys else None
|
||||
api_config: ModelApiKey = ModelApiKeyService.get_available_api_key(self.db, model_config.id)
|
||||
llm = RedBearLLM(RedBearModelConfig(
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
@@ -186,7 +190,7 @@ class PromptOptimizerService:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render()
|
||||
rendered_system_message = Template(opt_system_prompt).render(skill=skill)
|
||||
|
||||
with open(os.path.join(prompt_path, 'prompt_optimizer_user.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_user_prompt = f.read()
|
||||
@@ -250,6 +254,7 @@ class PromptOptimizerService:
|
||||
optim_result = json_repair.repair_json(buffer, return_objects=True)
|
||||
# prompt = optim_result.get("prompt")
|
||||
desc = optim_result.get("desc")
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_config.id)
|
||||
self.create_message(
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -178,8 +179,9 @@ class SharedChatService:
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_key_obj = api_keys[0] if api_keys else None
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -309,6 +311,8 @@ class SharedChatService:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -349,7 +353,8 @@ class SharedChatService:
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
|
||||
# 兼容新旧字段名:使用 memory_config_id
|
||||
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
@@ -383,8 +388,9 @@ class SharedChatService:
|
||||
# .limit(1)
|
||||
# )
|
||||
# api_key_obj = self.db.scalars(stmt).first()
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
api_key_obj = api_keys[0] if api_keys else None
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(self.db, model_config_id)
|
||||
# api_key_obj = api_keys[0] if api_keys else None
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
@@ -513,7 +519,8 @@ class SharedChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
133
api/app/services/skill_service.py
Normal file
133
api/app/services/skill_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Skill Service"""
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
from app.schemas.skill_schema import SkillCreate, SkillUpdate
|
||||
from app.models.skill_model import Skill
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
|
||||
class SkillService:
|
||||
"""Skill 业务逻辑层"""
|
||||
|
||||
@staticmethod
|
||||
def create_skill(db: Session, data: SkillCreate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""创建技能"""
|
||||
# 检查同名技能
|
||||
existing = db.query(Skill).filter(
|
||||
Skill.tenant_id == tenant_id,
|
||||
Skill.name == data.name
|
||||
).first()
|
||||
if existing:
|
||||
raise BusinessException(f"技能名称'{data.name}'已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
skill = SkillRepository.create(db, data, tenant_id)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
def get_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> Skill:
|
||||
"""获取技能"""
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, skill_id, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 填充工具详情
|
||||
tool_service = ToolService(db)
|
||||
enriched_tools = []
|
||||
for tool_config in skill.tools:
|
||||
tool_id = tool_config.get("tool_id")
|
||||
if tool_id:
|
||||
tool_info = tool_service.get_tool_info(tool_id, tenant_id)
|
||||
if tool_info:
|
||||
enriched_tools.append({
|
||||
"tool_id": tool_id,
|
||||
"operation": tool_config.get("operation"),
|
||||
"tool_info": tool_info
|
||||
})
|
||||
skill.tools = enriched_tools
|
||||
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def list_skills(
|
||||
db: Session,
|
||||
tenant_id: uuid.UUID,
|
||||
search: str = None,
|
||||
is_active: bool = None,
|
||||
is_public: bool = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> tuple[list[type[Skill]], int]:
|
||||
"""列出技能"""
|
||||
return SkillRepository.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_skill(db: Session, skill_id: uuid.UUID, data: SkillUpdate, tenant_id: uuid.UUID) -> Skill:
|
||||
"""更新技能"""
|
||||
try:
|
||||
skill = SkillRepository.update(db, skill_id, data, tenant_id)
|
||||
if not skill:
|
||||
raise BusinessException(f"技能{skill_id}不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
db.refresh(skill)
|
||||
return skill
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def delete_skill(db: Session, skill_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除技能"""
|
||||
try:
|
||||
success = SkillRepository.delete(db, skill_id, tenant_id)
|
||||
if not success:
|
||||
raise BusinessException("技能不存在或无权限", BizCode.NOT_FOUND)
|
||||
db.commit()
|
||||
return True
|
||||
except (BusinessException, SQLAlchemyError) as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def load_skill_tools(db: Session, skill_ids: List[str], tenant_id: uuid.UUID) -> tuple[List, dict[str, str]]:
|
||||
"""加载技能关联的工具
|
||||
|
||||
Returns:
|
||||
(tools, tool_to_skill_map) - 工具列表和工具到技能的映射
|
||||
"""
|
||||
tools = []
|
||||
tool_to_skill_map = {} # {tool_name: skill_id}
|
||||
tool_service = ToolService(db)
|
||||
|
||||
for skill_id in skill_ids:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 加载技能关联的工具
|
||||
for tool_config in skill.tools:
|
||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool:
|
||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
# 建立工具到技能的映射
|
||||
tool_name = getattr(langchain_tool, 'name', str(id(langchain_tool)))
|
||||
tool_to_skill_map[tool_name] = skill_id
|
||||
except Exception as e:
|
||||
print(f"加载技能 {skill_id} 的工具时出错: {e}")
|
||||
continue
|
||||
|
||||
return tools, tool_to_skill_map
|
||||
@@ -4,9 +4,8 @@
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Annotated, AsyncGenerator, Optional
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from deprecated import deprecated
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -23,6 +22,7 @@ from app.repositories.workflow_repository import (
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,6 +36,7 @@ class WorkflowService:
|
||||
self.execution_repo = WorkflowExecutionRepository(db)
|
||||
self.node_execution_repo = WorkflowNodeExecutionRepository(db)
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.multimodal_service = MultimodalService(db)
|
||||
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
@@ -445,24 +446,22 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id, "files": files}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
@@ -544,10 +543,10 @@ class WorkflowService:
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"variables": result.get("variables"),
|
||||
"messages": result.get("messages"),
|
||||
# "variables": result.get("variables"),
|
||||
# "messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
@@ -566,6 +565,41 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
decide
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -573,6 +607,7 @@ class WorkflowService:
|
||||
config: WorkflowConfig,
|
||||
workspace_id: uuid.UUID,
|
||||
release_id: Optional[uuid.UUID] = None,
|
||||
public: bool = False
|
||||
):
|
||||
"""运行工作流(流式)
|
||||
|
||||
@@ -582,6 +617,7 @@ class WorkflowService:
|
||||
app_id: 应用 ID
|
||||
payload: 请求对象(包含 message, variables, conversation_id 等)
|
||||
config: 存储类型(可选)
|
||||
public: 是否发布
|
||||
|
||||
Yields:
|
||||
SSE 格式的流式事件
|
||||
@@ -597,24 +633,23 @@ class WorkflowService:
|
||||
code=BizCode.CONFIG_MISSING,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id}
|
||||
|
||||
# 转换 user_id 为 UUID
|
||||
triggered_by_uuid = None
|
||||
if payload.user_id:
|
||||
try:
|
||||
triggered_by_uuid = uuid.UUID(payload.user_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 user_id 格式: {payload.user_id}")
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id, "files": files}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = None
|
||||
if payload.conversation_id:
|
||||
try:
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id)
|
||||
except (ValueError, AttributeError):
|
||||
logger.warning(f"无效的 conversation_id 格式: {payload.conversation_id}")
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
@@ -661,7 +696,7 @@ class WorkflowService:
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(workspace_id),
|
||||
user_id=payload.user_id
|
||||
user_id=payload.user_id,
|
||||
):
|
||||
if event.get("event") == "workflow_end":
|
||||
|
||||
@@ -692,7 +727,9 @@ class WorkflowService:
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
event = self._emit(public, event)
|
||||
if event:
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
@@ -710,134 +747,6 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
@deprecated(reason="This method is deprecated. "
|
||||
"Please use WorkflowService.run / run_stream instead.")
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
triggered_by: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
stream: bool = False
|
||||
) -> AsyncGenerator | dict:
|
||||
"""运行工作流
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
input_data: 输入数据(包含 message 和 variables)
|
||||
triggered_by: 触发用户 ID
|
||||
conversation_id: 会话 ID(可选)
|
||||
stream: 是否流式返回
|
||||
|
||||
Returns:
|
||||
执行结果(非流式)或生成器(流式)
|
||||
|
||||
Raises:
|
||||
BusinessException: 配置不存在或执行失败时抛出
|
||||
"""
|
||||
# 1. 获取工作流配置
|
||||
config = self.get_workflow_config(app_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 2. 创建执行记录
|
||||
execution = self.create_execution(
|
||||
workflow_config_id=config.id,
|
||||
app_id=app_id,
|
||||
trigger_type="manual",
|
||||
triggered_by=triggered_by,
|
||||
conversation_id=conversation_id,
|
||||
input_data=input_data
|
||||
)
|
||||
|
||||
# 3. 构建工作流配置字典
|
||||
workflow_config_dict = {
|
||||
"nodes": config.nodes,
|
||||
"edges": config.edges,
|
||||
"variables": config.variables,
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
from app.models import App
|
||||
app = self.db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException(
|
||||
code=BizCode.NOT_FOUND,
|
||||
message=f"应用不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
if stream:
|
||||
# 流式执行
|
||||
return self._run_workflow_stream(
|
||||
workflow_config_dict,
|
||||
input_data,
|
||||
execution.execution_id,
|
||||
str(app.workspace_id),
|
||||
str(triggered_by)
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await execute_workflow(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id=str(app.workspace_id),
|
||||
user_id=str(triggered_by)
|
||||
)
|
||||
|
||||
# 更新执行结果
|
||||
if result.get("status") == "completed":
|
||||
token_usage = result.get("data").get("token_usage", {}) or {}
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"completed",
|
||||
output_data=result.get("node_outputs", {}),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
else:
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=result.get("error")
|
||||
)
|
||||
|
||||
# 返回增强的响应结构
|
||||
return {
|
||||
"execution_id": execution.execution_id,
|
||||
"status": result.get("status"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"error_message": result.get("error"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"token_usage": result.get("token_usage")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution.execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
raise BusinessException(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""清理事件数据,移除不可序列化的对象
|
||||
|
||||
@@ -869,72 +778,6 @@ class WorkflowService:
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
@deprecated(reason="This method is deprecated. Please use WorkflowService.run_stream instead.")
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str):
|
||||
"""运行工作流(流式,内部方法)
|
||||
|
||||
Args:
|
||||
workflow_config: 工作流配置
|
||||
input_data: 输入数据
|
||||
execution_id: 执行 ID
|
||||
workspace_id: 工作空间 ID
|
||||
user_id: 用户 ID
|
||||
|
||||
Yields:
|
||||
流式事件(格式:{"event": "<type>", "data": {...}})
|
||||
"""
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
try:
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
if event.get("event") == "workflow_end":
|
||||
token_usage = event.get("data").get("token_usage", {}) or {}
|
||||
status = event.get("data", {}).get("status")
|
||||
if status == "completed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"completed",
|
||||
output_data=event.get("data"),
|
||||
token_usage=token_usage.get("total_tokens", None)
|
||||
)
|
||||
elif status == "failed":
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
output_data=event.get("data")
|
||||
)
|
||||
else:
|
||||
logger.error(f"unexpect workflow run status, status: {status}")
|
||||
yield event
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作流流式执行失败: execution_id={execution_id}, error={e}", exc_info=True)
|
||||
self.update_execution_status(
|
||||
execution_id,
|
||||
"failed",
|
||||
error_message=str(e)
|
||||
)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {
|
||||
"execution_id": execution_id,
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
|
||||
Reference in New Issue
Block a user