* [fix]Fix the interface for statistics of recent activities and applications * [changes]Modify the code based on the AI review 1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter. 2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls. * [fix]Fix the interface for statistics of recent activities and applications * [changes]Modify the code based on the AI review 1.Use the boolean auxiliary methods provided by SQLAlchemy instead of using == True in the is_active filter. 2.The calculation of the "PROJECT_ROOT" has now been hardcoded with five levels of nested os.path.dirname calls.
192 lines
5.6 KiB
Python
192 lines
5.6 KiB
Python
"""Agent 注册表服务"""
|
||
import uuid
|
||
from typing import Optional, List, Dict, Any
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import select, or_, and_
|
||
|
||
from app.models import AgentConfig, App
|
||
from app.core.logging_config import get_business_logger
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class AgentRegistry:
|
||
"""Agent 注册表 - 管理所有可用的 Agent"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
self._cache: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def register_agent(self, agent: AgentConfig) -> None:
|
||
"""注册 Agent 到系统
|
||
|
||
Args:
|
||
agent: Agent 配置对象
|
||
"""
|
||
agent_info = self._to_agent_info(agent)
|
||
self._cache[str(agent.id)] = agent_info
|
||
|
||
logger.info(
|
||
"Agent 注册成功",
|
||
extra={
|
||
"agent_id": str(agent.id),
|
||
"name": agent.app.name,
|
||
"domain": agent.agent_domain
|
||
}
|
||
)
|
||
|
||
def discover_agents(
|
||
self,
|
||
query: Optional[str] = None,
|
||
domain: Optional[str] = None,
|
||
capabilities: Optional[List[str]] = None,
|
||
workspace_id: Optional[uuid.UUID] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""发现可用的 Agent
|
||
|
||
Args:
|
||
query: 搜索关键词
|
||
domain: 专业领域
|
||
capabilities: 所需能力列表
|
||
workspace_id: 工作空间ID(权限过滤)
|
||
|
||
Returns:
|
||
匹配的 Agent 列表
|
||
"""
|
||
# 构建查询
|
||
stmt = select(AgentConfig).join(App).where(
|
||
AgentConfig.is_active.is_(True),
|
||
App.is_active.is_(True)
|
||
)
|
||
|
||
# 工作空间过滤(同工作空间或公开)
|
||
if workspace_id:
|
||
stmt = stmt.where(
|
||
or_(
|
||
App.workspace_id == workspace_id,
|
||
App.visibility == "public"
|
||
)
|
||
)
|
||
|
||
# 领域过滤
|
||
if domain:
|
||
stmt = stmt.where(AgentConfig.agent_domain == domain)
|
||
|
||
# 能力过滤
|
||
if capabilities:
|
||
# PostgreSQL JSON 数组包含查询
|
||
for cap in capabilities:
|
||
stmt = stmt.where(
|
||
AgentConfig.capabilities.contains([cap])
|
||
)
|
||
|
||
# 关键词搜索
|
||
if query:
|
||
stmt = stmt.where(
|
||
or_(
|
||
App.name.ilike(f"%{query}%"),
|
||
App.description.ilike(f"%{query}%")
|
||
)
|
||
)
|
||
|
||
agents = self.db.scalars(stmt).all()
|
||
|
||
logger.debug(
|
||
"Agent 发现",
|
||
extra={
|
||
"query": query,
|
||
"domain": domain,
|
||
"capabilities": capabilities,
|
||
"found_count": len(agents)
|
||
}
|
||
)
|
||
|
||
return [self._to_agent_info(agent) for agent in agents]
|
||
|
||
def get_agent(self, agent_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||
"""获取 Agent 信息
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
|
||
Returns:
|
||
Agent 信息字典,如果不存在返回 None
|
||
"""
|
||
agent_id_str = str(agent_id)
|
||
|
||
# 先查缓存
|
||
if agent_id_str in self._cache:
|
||
return self._cache[agent_id_str]
|
||
|
||
# 查数据库
|
||
agent = self.db.get(AgentConfig, agent_id)
|
||
if agent and agent.is_active:
|
||
agent_info = self._to_agent_info(agent)
|
||
self._cache[agent_id_str] = agent_info
|
||
return agent_info
|
||
|
||
return None
|
||
|
||
def _to_agent_info(self, agent: AgentConfig) -> Dict[str, Any]:
|
||
"""转换为 Agent 信息字典
|
||
|
||
Args:
|
||
agent: Agent 配置对象
|
||
|
||
Returns:
|
||
Agent 信息字典
|
||
"""
|
||
return {
|
||
"id": str(agent.id),
|
||
"name": agent.app.name,
|
||
"description": agent.app.description,
|
||
"domain": agent.agent_domain,
|
||
"role": agent.agent_role,
|
||
"capabilities": agent.capabilities or [],
|
||
"tools": list(agent.tools.keys()) if agent.tools else [],
|
||
"knowledge_bases": self._extract_kb_ids(agent),
|
||
"system_prompt": self._truncate_prompt(agent.system_prompt),
|
||
"status": "active" if agent.is_active else "inactive",
|
||
"workspace_id": str(agent.app.workspace_id),
|
||
"visibility": agent.app.visibility
|
||
}
|
||
|
||
def _extract_kb_ids(self, agent: AgentConfig) -> List[str]:
|
||
"""提取知识库 ID 列表
|
||
|
||
Args:
|
||
agent: Agent 配置对象
|
||
|
||
Returns:
|
||
知识库 ID 列表
|
||
"""
|
||
if not agent.knowledge_retrieval:
|
||
return []
|
||
|
||
kb_config = agent.knowledge_retrieval
|
||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||
return [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||
|
||
def _truncate_prompt(self, prompt: Optional[str], max_length: int = 200) -> Optional[str]:
|
||
"""截断提示词
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
max_length: 最大长度
|
||
|
||
Returns:
|
||
截断后的提示词
|
||
"""
|
||
if not prompt:
|
||
return None
|
||
|
||
if len(prompt) <= max_length:
|
||
return prompt
|
||
|
||
return prompt[:max_length] + "..."
|
||
|
||
def clear_cache(self) -> None:
|
||
"""清空缓存"""
|
||
self._cache.clear()
|
||
logger.debug("Agent 注册表缓存已清空")
|