feat: Add base project structure with API and web components
This commit is contained in:
0
api/app/services/__init__.py
Normal file
0
api/app/services/__init__.py
Normal file
116
api/app/services/agent_config_converter.py
Normal file
116
api/app/services/agent_config_converter.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Agent 配置格式转换器
|
||||
用于将 Pydantic 模型转换为数据库存储格式
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from app.schemas.app_schema import (
|
||||
KnowledgeRetrievalConfig,
|
||||
MemoryConfig,
|
||||
VariableDefinition,
|
||||
ToolConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate,
|
||||
)
|
||||
|
||||
|
||||
class AgentConfigConverter:
|
||||
"""Agent 配置格式转换器"""
|
||||
|
||||
@staticmethod
|
||||
def to_storage_format(config: AgentConfigCreate | AgentConfigUpdate) -> Dict[str, Any]:
|
||||
"""
|
||||
将配置对象转换为数据库存储格式
|
||||
|
||||
Args:
|
||||
config: AgentConfigCreate 或 AgentConfigUpdate 对象
|
||||
|
||||
Returns:
|
||||
包含数据库字段的字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. 模型参数配置
|
||||
if hasattr(config, 'model_parameters') and config.model_parameters:
|
||||
result["model_parameters"] = config.model_parameters.model_dump()
|
||||
|
||||
# 2. 知识库检索配置
|
||||
if config.knowledge_retrieval:
|
||||
result["knowledge_retrieval"] = config.knowledge_retrieval.model_dump()
|
||||
|
||||
# 3. 记忆配置
|
||||
if hasattr(config, 'memory') and config.memory:
|
||||
result["memory"] = config.memory.model_dump()
|
||||
|
||||
# 4. 变量配置
|
||||
if hasattr(config, 'variables') and config.variables:
|
||||
result["variables"] = [var.model_dump() for var in config.variables]
|
||||
|
||||
# 5. 工具配置
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = {
|
||||
name: tool.model_dump()
|
||||
for name, tool in config.tools.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_storage_format(
|
||||
model_parameters: Optional[Dict[str, Any]],
|
||||
knowledge_retrieval: Optional[Dict[str, Any]],
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
|
||||
Args:
|
||||
model_parameters: 模型参数配置
|
||||
knowledge_retrieval: 知识库检索配置
|
||||
memory: 记忆配置
|
||||
variables: 变量配置
|
||||
tools: 工具配置
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
"""
|
||||
result = {
|
||||
"model_parameters": None,
|
||||
"knowledge_retrieval": None,
|
||||
"memory": MemoryConfig(enabled=True),
|
||||
"variables": [],
|
||||
"tools": {},
|
||||
}
|
||||
|
||||
# 1. 解析模型参数配置
|
||||
if model_parameters:
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
result["model_parameters"] = ModelParameters(**model_parameters)
|
||||
|
||||
# 2. 解析知识库检索配置
|
||||
if knowledge_retrieval:
|
||||
result["knowledge_retrieval"] = KnowledgeRetrievalConfig(**knowledge_retrieval)
|
||||
else:
|
||||
# 提供默认的知识库配置(空列表)
|
||||
result["knowledge_retrieval"] = KnowledgeRetrievalConfig(
|
||||
knowledge_bases=[],
|
||||
merge_strategy="weighted"
|
||||
)
|
||||
|
||||
# 3. 解析记忆配置
|
||||
if memory:
|
||||
result["memory"] = MemoryConfig(**memory)
|
||||
|
||||
# 4. 解析变量配置
|
||||
if variables:
|
||||
result["variables"] = [VariableDefinition(**var) for var in variables]
|
||||
|
||||
# 5. 解析工具配置
|
||||
if tools:
|
||||
result["tools"] = {
|
||||
name: ToolConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
return result
|
||||
38
api/app/services/agent_config_helper.py
Normal file
38
api/app/services/agent_config_helper.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Agent 配置辅助函数
|
||||
用于增强 AgentConfig 对象,添加解析后的字段
|
||||
"""
|
||||
from app.models import AgentConfig
|
||||
from app.services.agent_config_converter import AgentConfigConverter
|
||||
|
||||
|
||||
def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
|
||||
"""
|
||||
增强 AgentConfig 对象,添加解析后的配置字段
|
||||
|
||||
Args:
|
||||
agent_cfg: AgentConfig ORM 对象
|
||||
|
||||
Returns:
|
||||
增强后的 AgentConfig 对象(添加了解析字段)
|
||||
"""
|
||||
if not agent_cfg:
|
||||
return agent_cfg
|
||||
|
||||
# 解析数据库存储格式
|
||||
parsed = AgentConfigConverter.from_storage_format(
|
||||
model_parameters=agent_cfg.model_parameters,
|
||||
knowledge_retrieval=agent_cfg.knowledge_retrieval,
|
||||
memory=agent_cfg.memory,
|
||||
variables=agent_cfg.variables,
|
||||
tools=agent_cfg.tools,
|
||||
)
|
||||
|
||||
# 将解析后的字段添加到对象上(用于序列化)
|
||||
agent_cfg.model_parameters = parsed["model_parameters"]
|
||||
agent_cfg.knowledge_retrieval = parsed["knowledge_retrieval"]
|
||||
agent_cfg.memory = parsed["memory"]
|
||||
agent_cfg.variables = parsed["variables"]
|
||||
agent_cfg.tools = parsed["tools"]
|
||||
|
||||
return agent_cfg
|
||||
0
api/app/services/agent_invocation_service.py
Normal file
0
api/app/services/agent_invocation_service.py
Normal file
191
api/app/services/agent_registry.py
Normal file
191
api/app/services/agent_registry.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""Agent 注册表服务"""
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, or_, and_
|
||||
|
||||
from app.models import AgentConfig, App
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Agent 注册表 - 管理所有可用的 Agent"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def register_agent(self, agent: AgentConfig) -> None:
|
||||
"""注册 Agent 到系统
|
||||
|
||||
Args:
|
||||
agent: Agent 配置对象
|
||||
"""
|
||||
agent_info = self._to_agent_info(agent)
|
||||
self._cache[str(agent.id)] = agent_info
|
||||
|
||||
logger.info(
|
||||
f"Agent 注册成功",
|
||||
extra={
|
||||
"agent_id": str(agent.id),
|
||||
"name": agent.app.name,
|
||||
"domain": agent.agent_domain
|
||||
}
|
||||
)
|
||||
|
||||
def discover_agents(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
domain: Optional[str] = None,
|
||||
capabilities: Optional[List[str]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""发现可用的 Agent
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
domain: 专业领域
|
||||
capabilities: 所需能力列表
|
||||
workspace_id: 工作空间ID(权限过滤)
|
||||
|
||||
Returns:
|
||||
匹配的 Agent 列表
|
||||
"""
|
||||
# 构建查询
|
||||
stmt = select(AgentConfig).join(App).where(
|
||||
AgentConfig.is_active == True,
|
||||
App.is_active == True
|
||||
)
|
||||
|
||||
# 工作空间过滤(同工作空间或公开)
|
||||
if workspace_id:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
App.workspace_id == workspace_id,
|
||||
App.visibility == "public"
|
||||
)
|
||||
)
|
||||
|
||||
# 领域过滤
|
||||
if domain:
|
||||
stmt = stmt.where(AgentConfig.agent_domain == domain)
|
||||
|
||||
# 能力过滤
|
||||
if capabilities:
|
||||
# PostgreSQL JSON 数组包含查询
|
||||
for cap in capabilities:
|
||||
stmt = stmt.where(
|
||||
AgentConfig.capabilities.contains([cap])
|
||||
)
|
||||
|
||||
# 关键词搜索
|
||||
if query:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
App.name.ilike(f"%{query}%"),
|
||||
App.description.ilike(f"%{query}%")
|
||||
)
|
||||
)
|
||||
|
||||
agents = self.db.scalars(stmt).all()
|
||||
|
||||
logger.debug(
|
||||
f"Agent 发现",
|
||||
extra={
|
||||
"query": query,
|
||||
"domain": domain,
|
||||
"capabilities": capabilities,
|
||||
"found_count": len(agents)
|
||||
}
|
||||
)
|
||||
|
||||
return [self._to_agent_info(agent) for agent in agents]
|
||||
|
||||
def get_agent(self, agent_id: uuid.UUID) -> Optional[Dict[str, Any]]:
|
||||
"""获取 Agent 信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent 信息字典,如果不存在返回 None
|
||||
"""
|
||||
agent_id_str = str(agent_id)
|
||||
|
||||
# 先查缓存
|
||||
if agent_id_str in self._cache:
|
||||
return self._cache[agent_id_str]
|
||||
|
||||
# 查数据库
|
||||
agent = self.db.get(AgentConfig, agent_id)
|
||||
if agent and agent.is_active:
|
||||
agent_info = self._to_agent_info(agent)
|
||||
self._cache[agent_id_str] = agent_info
|
||||
return agent_info
|
||||
|
||||
return None
|
||||
|
||||
def _to_agent_info(self, agent: AgentConfig) -> Dict[str, Any]:
|
||||
"""转换为 Agent 信息字典
|
||||
|
||||
Args:
|
||||
agent: Agent 配置对象
|
||||
|
||||
Returns:
|
||||
Agent 信息字典
|
||||
"""
|
||||
return {
|
||||
"id": str(agent.id),
|
||||
"name": agent.app.name,
|
||||
"description": agent.app.description,
|
||||
"domain": agent.agent_domain,
|
||||
"role": agent.agent_role,
|
||||
"capabilities": agent.capabilities or [],
|
||||
"tools": list(agent.tools.keys()) if agent.tools else [],
|
||||
"knowledge_bases": self._extract_kb_ids(agent),
|
||||
"system_prompt": self._truncate_prompt(agent.system_prompt),
|
||||
"status": "active" if agent.is_active else "inactive",
|
||||
"workspace_id": str(agent.app.workspace_id),
|
||||
"visibility": agent.app.visibility
|
||||
}
|
||||
|
||||
def _extract_kb_ids(self, agent: AgentConfig) -> List[str]:
|
||||
"""提取知识库 ID 列表
|
||||
|
||||
Args:
|
||||
agent: Agent 配置对象
|
||||
|
||||
Returns:
|
||||
知识库 ID 列表
|
||||
"""
|
||||
if not agent.knowledge_retrieval:
|
||||
return []
|
||||
|
||||
kb_config = agent.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
return [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
|
||||
def _truncate_prompt(self, prompt: Optional[str], max_length: int = 200) -> Optional[str]:
|
||||
"""截断提示词
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
截断后的提示词
|
||||
"""
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
if len(prompt) <= max_length:
|
||||
return prompt
|
||||
|
||||
return prompt[:max_length] + "..."
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清空缓存"""
|
||||
self._cache.clear()
|
||||
logger.debug("Agent 注册表缓存已清空")
|
||||
130
api/app/services/agent_server.py
Normal file
130
api/app/services/agent_server.py
Normal file
@@ -0,0 +1,130 @@
|
||||
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import create_agent, AgentState
|
||||
from langchain.agents.middleware import before_model
|
||||
from langchain.tools import tool
|
||||
from langchain_core.messages import RemoveMessage
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.services.api_resquests_server import send_message, model, retrieval
|
||||
|
||||
|
||||
class config(BaseModel):
|
||||
template_str:str
|
||||
params:dict
|
||||
model_configs: List[dict] = []
|
||||
history_memory:bool
|
||||
knowledge_base:bool
|
||||
|
||||
class RemoryInput(BaseModel):
|
||||
question: str
|
||||
end_user_id: str
|
||||
search_switch:str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
end_user_id: str
|
||||
message: str
|
||||
search_switch:str
|
||||
kb_ids: List[str] = []
|
||||
similarity_threshold:float
|
||||
vector_similarity_weight:float
|
||||
top_k:int
|
||||
hybrid:bool
|
||||
token:str
|
||||
|
||||
class RetrievalInput(BaseModel):
|
||||
message: str
|
||||
kb_ids: List[str] = []
|
||||
similarity_threshold: float
|
||||
vector_similarity_weight: float
|
||||
top_k: int
|
||||
hybrid: bool
|
||||
token: str
|
||||
|
||||
async def tool_Retrieval(req):
|
||||
tool_result = retrieval_search.invoke({
|
||||
"message":req.message, "kb_ids":req.kb_ids,
|
||||
"similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight,
|
||||
"top_k":req.top_k, "hybrid":req.hybrid, "token":req.token
|
||||
})
|
||||
return tool_result
|
||||
async def tool_memory(req):
|
||||
tool_result = remory_sk.invoke({
|
||||
"question": req.message,
|
||||
"end_user_id": req.end_user_id,
|
||||
"search_switch": req.search_switch
|
||||
})
|
||||
return tool_result
|
||||
|
||||
|
||||
@before_model
|
||||
# ========== 消息剪枝中间件 ==========
|
||||
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""保留前1条 + 最近3~4条消息"""
|
||||
messages = state["messages"]
|
||||
if len(messages) <= 10:
|
||||
return None
|
||||
first_msg = messages[0]
|
||||
recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:]
|
||||
new_messages = [first_msg] + recent_messages
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages
|
||||
]
|
||||
}
|
||||
|
||||
##-----------历史记忆------------
|
||||
@ tool(args_schema=RemoryInput)
|
||||
def remory_sk(question: str, end_user_id: str, search_switch: str):
|
||||
"""
|
||||
条件调用工具:
|
||||
- 仅当 question 是疑问句时调用 send_message
|
||||
- 根据 end_user_id 和 search_switch 参数决定是否执行检索
|
||||
|
||||
Args:
|
||||
question: 用户的提问内容
|
||||
end_user_id: 用户唯一标识符
|
||||
search_switch: 搜索开关,控制是否执行检索
|
||||
|
||||
Returns:
|
||||
检索结果或空字符串
|
||||
"""
|
||||
# 移除关于 configurable 的描述,避免混淆
|
||||
if not end_user_id or end_user_id == '123':
|
||||
print("警告: 无效的 user_id 参数")
|
||||
return ''
|
||||
|
||||
if search_switch in ['on', 'off'] or not search_switch:
|
||||
print("警告: 无效的 search_switch 参数")
|
||||
return ''
|
||||
return send_message(end_user_id, question, '[]', search_switch)
|
||||
|
||||
#-------------检索------------
|
||||
|
||||
|
||||
@ tool(args_schema=RetrievalInput)
|
||||
def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token):
|
||||
'''检索'''
|
||||
search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token)
|
||||
return search
|
||||
async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str):
|
||||
"""根据模型名动态创建代理"""
|
||||
model_name, api_key, api_base=await model(model_id,token)
|
||||
llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key)
|
||||
memory = InMemorySaver()
|
||||
return create_agent(
|
||||
llm,
|
||||
tools=[remory_sk,retrieval_search],
|
||||
middleware=[trim_messages],
|
||||
checkpointer=memory,
|
||||
system_prompt=PROMPT
|
||||
)
|
||||
331
api/app/services/agent_tools.py
Normal file
331
api/app/services/agent_tools.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Agent 发现和调用工具"""
|
||||
import uuid
|
||||
import time
|
||||
import datetime
|
||||
from typing import Optional, Dict, Any, List
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import AgentConfig, ModelConfig, AgentInvocation
|
||||
from app.services.agent_registry import AgentRegistry
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories import workspace_repository, knowledge_repository
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
# ==================== Agent 发现工具 ====================
|
||||
|
||||
class AgentDiscoveryInput(BaseModel):
|
||||
"""Agent 发现工具输入参数"""
|
||||
query: Optional[str] = Field(None, description="搜索关键词,如:'客服'、'技术支持'")
|
||||
domain: Optional[str] = Field(None, description="专业领域,如:'customer_service'、'technical_support'")
|
||||
capabilities: Optional[List[str]] = Field(None, description="所需能力列表,如:['退货处理', '订单查询']")
|
||||
|
||||
|
||||
def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID):
|
||||
"""创建 Agent 发现工具
|
||||
|
||||
Args:
|
||||
registry: Agent 注册表
|
||||
workspace_id: 当前工作空间 ID
|
||||
|
||||
Returns:
|
||||
Agent 发现工具
|
||||
"""
|
||||
|
||||
@tool(args_schema=AgentDiscoveryInput)
|
||||
def discover_agents(
|
||||
query: Optional[str] = None,
|
||||
domain: Optional[str] = None,
|
||||
capabilities: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""发现系统中可用的 Agent。当需要找到能够处理特定任务的 Agent 时使用此工具。
|
||||
|
||||
Args:
|
||||
query: 搜索关键词(如:"客服"、"技术支持")
|
||||
domain: 专业领域(如:"customer_service"、"technical_support")
|
||||
capabilities: 所需能力(如:["退货处理", "订单查询"])
|
||||
|
||||
Returns:
|
||||
可用 Agent 的列表和描述
|
||||
"""
|
||||
try:
|
||||
agents = registry.discover_agents(
|
||||
query=query,
|
||||
domain=domain,
|
||||
capabilities=capabilities,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return "未找到匹配的 Agent"
|
||||
|
||||
# 格式化输出
|
||||
result = f"找到 {len(agents)} 个可用的 Agent:\n\n"
|
||||
for i, agent in enumerate(agents, 1):
|
||||
result += f"{i}. {agent['name']}\n"
|
||||
result += f" ID: {agent['id']}\n"
|
||||
if agent['description']:
|
||||
result += f" 描述: {agent['description']}\n"
|
||||
if agent['domain']:
|
||||
result += f" 领域: {agent['domain']}\n"
|
||||
if agent['capabilities']:
|
||||
result += f" 能力: {', '.join(agent['capabilities'])}\n"
|
||||
if agent['tools']:
|
||||
result += f" 工具: {', '.join(agent['tools'])}\n"
|
||||
result += "\n"
|
||||
|
||||
logger.info(
|
||||
f"Agent 发现成功",
|
||||
extra={
|
||||
"query": query,
|
||||
"domain": domain,
|
||||
"found_count": len(agents)
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 发现失败", extra={"error": str(e)})
|
||||
return f"发现 Agent 失败: {str(e)}"
|
||||
|
||||
return discover_agents
|
||||
|
||||
|
||||
# ==================== Agent 调用工具 ====================
|
||||
|
||||
class AgentInvocationInput(BaseModel):
|
||||
"""Agent 调用工具输入参数"""
|
||||
agent_id: str = Field(..., description="要调用的 Agent ID(通过 discover_agents 工具获取)")
|
||||
message: str = Field(..., description="发送给 Agent 的消息或任务描述")
|
||||
context: Optional[Dict[str, Any]] = Field(None, description="可选的上下文信息(如:用户信息、历史记录等)")
|
||||
|
||||
|
||||
def create_agent_invocation_tool(
|
||||
db: Session,
|
||||
registry: AgentRegistry,
|
||||
workspace_id: uuid.UUID,
|
||||
current_agent_id: uuid.UUID,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
parent_invocation_id: Optional[uuid.UUID] = None,
|
||||
invocation_chain: Optional[List[uuid.UUID]] = None
|
||||
):
|
||||
"""创建 Agent 调用工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
registry: Agent 注册表
|
||||
workspace_id: 当前工作空间 ID
|
||||
current_agent_id: 当前 Agent ID
|
||||
conversation_id: 会话 ID
|
||||
parent_invocation_id: 父调用 ID
|
||||
invocation_chain: 调用链(用于检测循环调用)
|
||||
|
||||
Returns:
|
||||
Agent 调用工具
|
||||
"""
|
||||
# 1. 获取工作空间的 storage_type
|
||||
storage_type = 'neo4j' # 默认值
|
||||
user_rag_memory_id = None
|
||||
|
||||
try:
|
||||
workspace = workspace_repository.get_workspace_by_id(db, workspace_id)
|
||||
if workspace and workspace.storage_type:
|
||||
storage_type = workspace.storage_type
|
||||
logger.debug(
|
||||
f"获取工作空间存储类型成功",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"storage_type": storage_type
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取工作空间存储类型失败,使用默认值 neo4j",
|
||||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||||
)
|
||||
|
||||
# 2. 如果 storage_type 是 rag,获取知识库 ID
|
||||
if storage_type == 'rag':
|
||||
try:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MEMORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
logger.debug(
|
||||
f"获取 RAG 知识库成功",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"knowledge_id": user_rag_memory_id
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取 RAG 知识库失败,将使用 neo4j 存储",
|
||||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
|
||||
if invocation_chain is None:
|
||||
invocation_chain = []
|
||||
|
||||
@tool(args_schema=AgentInvocationInput)
|
||||
async def invoke_agent(
|
||||
agent_id: str,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""调用另一个 Agent 来处理任务。当当前 Agent 无法处理某个任务,需要其他专业 Agent 帮助时使用。
|
||||
|
||||
Args:
|
||||
agent_id: 要调用的 Agent ID(通过 discover_agents 工具获取)
|
||||
message: 发送给 Agent 的消息或任务描述
|
||||
context: 可选的上下文信息(如:用户信息、历史记录等)
|
||||
|
||||
Returns:
|
||||
被调用 Agent 的响应结果
|
||||
"""
|
||||
try:
|
||||
# 1. 验证 Agent 存在
|
||||
agent_uuid = uuid.UUID(agent_id)
|
||||
agent_info = registry.get_agent(agent_uuid)
|
||||
if not agent_info:
|
||||
return f"Agent {agent_id} 不存在"
|
||||
|
||||
# 2. 验证权限(同工作空间或公开)
|
||||
if agent_info["workspace_id"] != str(workspace_id) and agent_info["visibility"] != "public":
|
||||
return f"无权访问 Agent {agent_info['name']}"
|
||||
|
||||
# 3. 防止自己调用自己
|
||||
if agent_id == str(current_agent_id):
|
||||
return "不能调用自己"
|
||||
|
||||
# 4. 防止循环调用
|
||||
if agent_uuid in invocation_chain:
|
||||
return f"检测到循环调用:{agent_info['name']} 已在调用链中"
|
||||
|
||||
# 5. 检查调用深度
|
||||
max_depth = 5
|
||||
if len(invocation_chain) >= max_depth:
|
||||
return f"调用深度超过限制(最大 {max_depth} 层)"
|
||||
|
||||
# 6. 获取 Agent 配置
|
||||
agent_config = db.get(AgentConfig, agent_uuid)
|
||||
if not agent_config:
|
||||
return f"Agent 配置不存在"
|
||||
|
||||
# 7. 获取模型配置
|
||||
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
if not model_config:
|
||||
return f"Agent 模型配置不存在"
|
||||
|
||||
# 8. 创建调用记录
|
||||
invocation = AgentInvocation(
|
||||
caller_agent_id=current_agent_id,
|
||||
callee_agent_id=agent_uuid,
|
||||
conversation_id=conversation_id,
|
||||
parent_invocation_id=parent_invocation_id,
|
||||
input_message=message,
|
||||
context=context,
|
||||
status="running",
|
||||
started_at=datetime.datetime.now()
|
||||
)
|
||||
db.add(invocation)
|
||||
db.commit()
|
||||
db.refresh(invocation)
|
||||
|
||||
logger.info(
|
||||
f"Agent 调用开始",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
"callee_agent_id": agent_id,
|
||||
"depth": len(invocation_chain)
|
||||
}
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 9. 调用 Agent
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
variables=context or {},
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 10. 更新调用记录
|
||||
invocation.status = "completed"
|
||||
invocation.output_message = result["message"]
|
||||
invocation.completed_at = datetime.datetime.now()
|
||||
invocation.elapsed_time = elapsed_time
|
||||
invocation.token_usage = result.get("usage", {})
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Agent 调用成功",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
"callee_agent_id": agent_id,
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
)
|
||||
|
||||
return result["message"]
|
||||
|
||||
except Exception as e:
|
||||
# 更新调用记录为失败
|
||||
invocation.status = "failed"
|
||||
invocation.error_message = str(e)
|
||||
invocation.completed_at = datetime.datetime.now()
|
||||
invocation.elapsed_time = time.time() - start_time
|
||||
db.commit()
|
||||
|
||||
logger.error(
|
||||
f"Agent 调用失败",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
"callee_agent_id": agent_id,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Agent 调用异常",
|
||||
extra={
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
"callee_agent_id": agent_id,
|
||||
"error": str(e)
|
||||
}
|
||||
)
|
||||
return f"调用 Agent 失败: {str(e)}"
|
||||
|
||||
return invoke_agent
|
||||
173
api/app/services/api_key_service.py
Normal file
173
api/app/services/api_key_service.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""API Key Service"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Tuple, List
|
||||
import uuid
|
||||
import datetime
|
||||
import math
|
||||
|
||||
from app.models.api_key_model import ApiKey, ApiKeyType
|
||||
from app.repositories.api_key_repository import ApiKeyRepository
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.core.api_key_utils import generate_api_key
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ApiKeyService:
|
||||
"""API Key 业务逻辑服务"""
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(
|
||||
db: Session,
|
||||
*,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyCreate
|
||||
) -> Tuple[ApiKey, str]:
|
||||
"""创建 API Key
|
||||
|
||||
Returns:
|
||||
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
|
||||
"""
|
||||
# 生成 API Key
|
||||
api_key, key_hash, key_prefix = generate_api_key(data.type)
|
||||
|
||||
# 创建数据
|
||||
api_key_data = {
|
||||
"id": uuid.uuid4(),
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash,
|
||||
"type": data.type,
|
||||
"scopes": data.scopes,
|
||||
"workspace_id": workspace_id,
|
||||
"resource_id": data.resource_id,
|
||||
"resource_type": data.resource_type,
|
||||
"rate_limit": data.rate_limit,
|
||||
"quota_limit": data.quota_limit,
|
||||
"expires_at": data.expires_at,
|
||||
"created_by": user_id,
|
||||
"created_at": datetime.datetime.now(),
|
||||
"updated_at": datetime.datetime.now(),
|
||||
}
|
||||
|
||||
api_key_obj = ApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key_obj)
|
||||
|
||||
logger.info(f"API Key 创建成功", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"name": data.name,
|
||||
"type": data.type
|
||||
})
|
||||
|
||||
return api_key_obj, api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> ApiKey:
|
||||
"""获取 API Key"""
|
||||
api_key = ApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException("API Key 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
if api_key.workspace_id != workspace_id:
|
||||
raise BusinessException("无权访问此 API Key", BizCode.FORBIDDEN)
|
||||
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def list_api_keys(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
) -> PageData:
|
||||
"""列出 API Keys"""
|
||||
items, total = ApiKeyRepository.list_by_workspace(db, workspace_id, query)
|
||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||
|
||||
return PageData(
|
||||
page=PageMeta(
|
||||
page=query.page,
|
||||
pagesize=query.pagesize,
|
||||
total=total,
|
||||
hasnext=query.page < pages
|
||||
),
|
||||
items=[api_key_schema.ApiKey.model_validate(item) for item in items]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate
|
||||
) -> ApiKey:
|
||||
"""更新 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"API Key 更新成功", extra={"api_key_id": str(api_key_id)})
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> bool:
|
||||
"""删除 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
ApiKeyRepository.delete(db, api_key_id)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"API Key 删除成功", extra={"api_key_id": str(api_key_id)})
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def regenerate_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Tuple[ApiKey, str]:
|
||||
"""重新生成 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 生成新的 API Key
|
||||
new_api_key, key_hash, key_prefix = generate_api_key(ApiKeyType(api_key.type))
|
||||
|
||||
# 更新
|
||||
ApiKeyRepository.update(db, api_key_id, {
|
||||
"key_hash": key_hash,
|
||||
"key_prefix": key_prefix
|
||||
})
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
||||
return api_key, new_api_key
|
||||
|
||||
@staticmethod
|
||||
def get_stats(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> api_key_schema.ApiKeyStats:
|
||||
"""获取使用统计"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
||||
return api_key_schema.ApiKeyStats(**stats_data)
|
||||
1903
api/app/services/app_service.py
Normal file
1903
api/app/services/app_service.py
Normal file
File diff suppressed because it is too large
Load Diff
262
api/app/services/auth_service.py
Normal file
262
api/app/services/auth_service.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Tuple, Union
|
||||
import jwt
|
||||
import time
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.repositories import user_repository
|
||||
from app.core.security import verify_password
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
# Token 配置
|
||||
TOKEN_SECRET_KEY = settings.SECRET_KEY
|
||||
TOKEN_ALGORITHM = "HS256"
|
||||
|
||||
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticates a user.
|
||||
|
||||
:param db: The database session.
|
||||
:param email: The email.
|
||||
:param password: The password.
|
||||
:return: The user object if authentication is successful, otherwise None.
|
||||
"""
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
return None # User not found
|
||||
if not user.is_active:
|
||||
return None # User is inactive
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None # Incorrect password
|
||||
return user # Authentication successful
|
||||
|
||||
|
||||
def authenticate_user_with_status(db: Session, email: str, password: str) -> Tuple[bool, Optional[User], str]:
|
||||
"""
|
||||
认证用户并返回详细状态(用于需要区分不同失败原因的场景)
|
||||
|
||||
:param db: 数据库会话
|
||||
:param email: 用户邮箱
|
||||
:param password: 用户密码
|
||||
:return: (认证成功, 用户对象, 状态消息)
|
||||
状态消息: "success", "user_not_found", "user_inactive", "password_incorrect"
|
||||
"""
|
||||
from app.core.logging_config import get_auth_logger
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
# 查找用户
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
logger.warning(f"用户不存在: {email}")
|
||||
return (False, None, "user_not_found")
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
logger.warning(f"用户未激活: {email}")
|
||||
return (False, user, "user_inactive")
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
logger.warning(f"密码错误: {email}")
|
||||
return (False, user, "password_incorrect")
|
||||
|
||||
logger.info(f"用户认证成功: {email}")
|
||||
return (True, user, "success")
|
||||
|
||||
|
||||
def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
|
||||
"""
|
||||
认证用户,失败时抛出异常(推荐使用)
|
||||
|
||||
:param db: 数据库会话
|
||||
:param email: 用户邮箱
|
||||
:param password: 用户密码
|
||||
:return: 用户对象
|
||||
:raises BusinessException: 认证失败时抛出
|
||||
"""
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_auth_logger
|
||||
|
||||
logger = get_auth_logger()
|
||||
|
||||
# 查找用户
|
||||
user = user_repository.get_user_by_email(db, email=email)
|
||||
if not user:
|
||||
logger.warning(f"用户不存在: {email}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
logger.warning(f"用户未激活: {email}")
|
||||
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
logger.warning(f"密码错误: {email}")
|
||||
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
|
||||
|
||||
logger.info(f"用户认证成功: {email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_user_by_username(db: Session, username: str) -> Optional[User]:
|
||||
"""
|
||||
Get a user by username.
|
||||
|
||||
:param db: The database session.
|
||||
:param username: The username.
|
||||
:return: The user object if found, otherwise None.
|
||||
"""
|
||||
return user_repository.get_user_by_username(db, username=username)
|
||||
|
||||
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
|
||||
"""
|
||||
Get a user by user_id.
|
||||
|
||||
:param db: The database session.
|
||||
:param user_id: The user id (UUID string).
|
||||
:return: The user object if found, otherwise None.
|
||||
"""
|
||||
return user_repository.get_user_by_id(db, user_id=user_id)
|
||||
|
||||
|
||||
def register_user_with_invite(
|
||||
db: Session,
|
||||
email: str,
|
||||
password: str,
|
||||
invite_token: str,
|
||||
workspace_id: str
|
||||
) -> User:
|
||||
"""
|
||||
使用邀请码注册新用户并加入工作空间
|
||||
|
||||
:param db: 数据库会话
|
||||
:param email: 用户邮箱
|
||||
:param password: 用户密码
|
||||
:param invite_token: 邀请令牌
|
||||
:param workspace_id: 工作空间ID
|
||||
:return: 创建的用户对象
|
||||
"""
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.workspace_schema import InviteAcceptRequest
|
||||
from app.services import user_service, workspace_service
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
logger.info(f"使用邀请码注册用户: {email}")
|
||||
|
||||
try:
|
||||
# 创建用户
|
||||
user_create = UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
username=email.split('@')[0]
|
||||
)
|
||||
user = user_service.create_user(db=db, user=user_create)
|
||||
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
|
||||
|
||||
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit)
|
||||
invite_accept = InviteAcceptRequest(token=invite_token)
|
||||
workspace_service.accept_workspace_invite(db, invite_accept, user)
|
||||
logger.info(f"用户接受邀请成功")
|
||||
|
||||
# 重新查询用户对象以确保获取最新状态
|
||||
from app.repositories import user_repository
|
||||
user = user_repository.get_user_by_id(db, str(user.id))
|
||||
|
||||
# 设置当前工作空间
|
||||
user.current_workspace_id = workspace_id
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
logger.info(f"用户注册并加入工作空间成功: {user.email}, workspace_id: {user.current_workspace_id}")
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"注册用户失败: {email} - {str(e)}")
|
||||
raise
|
||||
|
||||
def bind_workspace_with_invite(
|
||||
db: Session,
|
||||
user: User,
|
||||
invite_token: str,
|
||||
workspace_id: str
|
||||
) -> User:
|
||||
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.workspace_schema import InviteAcceptRequest
|
||||
from app.services import user_service, workspace_service
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
try:
|
||||
|
||||
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit)
|
||||
invite_accept = InviteAcceptRequest(token=invite_token)
|
||||
workspace_service.accept_workspace_invite(db, invite_accept, user)
|
||||
logger.info(f"用户接受邀请成功")
|
||||
|
||||
# 重新查询用户对象以确保获取最新状态
|
||||
from app.repositories import user_repository
|
||||
user = user_repository.get_user_by_id(db, str(user.id))
|
||||
|
||||
# 设置当前工作空间
|
||||
user.current_workspace_id = workspace_id
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"绑定工作空间失败: user={user.email} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_access_token(user_id: str, share_token: str) -> str:
|
||||
"""创建访问 token
|
||||
|
||||
Token 不设置过期时间,只要 share_token 有效,token 就有效
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
share_token: 分享 token
|
||||
|
||||
Returns:
|
||||
JWT token
|
||||
"""
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"share_token": share_token,
|
||||
"iat": int(time.time()) # 签发时间
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, TOKEN_SECRET_KEY, algorithm=TOKEN_ALGORITHM)
|
||||
return token
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> dict:
|
||||
"""解码访问 token
|
||||
|
||||
Args:
|
||||
token: JWT token
|
||||
|
||||
Returns:
|
||||
包含 user_id 和 share_token 的字典
|
||||
|
||||
Raises:
|
||||
BusinessException: token 无效
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
|
||||
return {
|
||||
"user_id": payload["user_id"],
|
||||
"share_token": payload["share_token"]
|
||||
}
|
||||
except jwt.InvalidTokenError:
|
||||
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)
|
||||
229
api/app/services/conversation_service.py
Normal file
229
api/app/services/conversation_service.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""会话服务"""
|
||||
import uuid
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
from app.models import Conversation, Message
|
||||
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ConversationService:
|
||||
"""会话服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
is_draft: bool = False,
|
||||
config_snapshot: Optional[dict] = None
|
||||
) -> Conversation:
|
||||
"""创建会话"""
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
title=title or "新会话",
|
||||
is_draft=is_draft,
|
||||
config_snapshot=config_snapshot
|
||||
)
|
||||
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
|
||||
logger.info(
|
||||
f"创建会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def get_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Conversation:
|
||||
"""获取会话"""
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
|
||||
if not conversation:
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
return conversation
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> Tuple[List[Conversation], int]:
|
||||
"""列出会话"""
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active == True
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
# 总数
|
||||
count_stmt = stmt.with_only_columns(Conversation.id)
|
||||
total = len(self.db.execute(count_stmt).all())
|
||||
|
||||
# 分页
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
|
||||
return conversations, total
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
meta_data: Optional[dict] = None
|
||||
) -> Message:
|
||||
"""添加消息"""
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
meta_data=meta_data
|
||||
)
|
||||
|
||||
self.db.add(message)
|
||||
|
||||
# 更新会话的消息计数和更新时间
|
||||
conversation = self.get_conversation(conversation_id)
|
||||
conversation.message_count += 1
|
||||
|
||||
# 如果是第一条用户消息,可以用它作为标题
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
return message
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Message]:
|
||||
"""获取会话消息"""
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
|
||||
return messages
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""获取会话历史消息
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
|
||||
Returns:
|
||||
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
|
||||
"""
|
||||
messages = self.get_messages(conversation_id, limit=max_history)
|
||||
|
||||
# 转换为字典格式
|
||||
history = [
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
return history
|
||||
|
||||
def save_conversation_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str
|
||||
):
|
||||
"""保存会话消息(用户消息和助手回复)"""
|
||||
# 添加用户消息
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=user_message
|
||||
)
|
||||
|
||||
# 添加助手消息
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=assistant_message
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"保存会话消息成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"user_message_length": len(user_message),
|
||||
"assistant_message_length": len(assistant_message)
|
||||
}
|
||||
)
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
):
|
||||
"""删除会话(软删除)"""
|
||||
conversation = self.get_conversation(conversation_id, workspace_id)
|
||||
conversation.is_active = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"删除会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
261
api/app/services/conversation_state_manager.py
Normal file
261
api/app/services/conversation_state_manager.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""会话状态管理器 - 解决多轮对话路由错乱"""
|
||||
import json
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ConversationStateManager:
|
||||
"""会话状态管理器
|
||||
|
||||
用于管理多轮对话中的会话状态,包括:
|
||||
- 当前使用的 Agent
|
||||
- 路由历史
|
||||
- 主题追踪
|
||||
- Agent 切换统计
|
||||
"""
|
||||
|
||||
def __init__(self, storage_backend: Optional[Any] = None):
|
||||
"""初始化状态管理器
|
||||
|
||||
Args:
|
||||
storage_backend: 存储后端(Redis/内存等)
|
||||
"""
|
||||
self.storage = storage_backend or InMemoryStorage()
|
||||
self.ttl = 3600 # 1小时过期
|
||||
|
||||
def get_state(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""获取会话状态
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
会话状态字典
|
||||
"""
|
||||
state = self.storage.get(f"conv_state:{conversation_id}")
|
||||
|
||||
if not state:
|
||||
logger.info(f"创建新会话状态: {conversation_id}")
|
||||
return self._create_new_state(conversation_id)
|
||||
|
||||
return state
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
conversation_id: str,
|
||||
agent_id: str,
|
||||
message: str,
|
||||
topic: Optional[str] = None,
|
||||
confidence: float = 1.0
|
||||
) -> Dict[str, Any]:
|
||||
"""更新会话状态
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
agent_id: 当前 Agent ID
|
||||
message: 用户消息
|
||||
topic: 消息主题
|
||||
confidence: 路由置信度
|
||||
|
||||
Returns:
|
||||
更新后的状态
|
||||
"""
|
||||
state = self.get_state(conversation_id)
|
||||
|
||||
# 检测 Agent 切换
|
||||
agent_changed = False
|
||||
if state["current_agent_id"] and state["current_agent_id"] != agent_id:
|
||||
agent_changed = True
|
||||
state["switch_count"] += 1
|
||||
state["previous_agent_id"] = state["current_agent_id"]
|
||||
state["same_agent_turns"] = 0
|
||||
|
||||
logger.info(
|
||||
f"Agent 切换",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"from": state["current_agent_id"],
|
||||
"to": agent_id,
|
||||
"switch_count": state["switch_count"]
|
||||
}
|
||||
)
|
||||
else:
|
||||
state["same_agent_turns"] += 1
|
||||
|
||||
# 更新当前 Agent
|
||||
state["current_agent_id"] = agent_id
|
||||
state["last_message"] = message
|
||||
state["last_topic"] = topic
|
||||
state["updated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加到历史
|
||||
history_item = {
|
||||
"message": message[:100], # 截断长消息
|
||||
"agent_id": agent_id,
|
||||
"topic": topic,
|
||||
"confidence": confidence,
|
||||
"agent_changed": agent_changed,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
state["routing_history"].append(history_item)
|
||||
|
||||
# 保持最近 10 条历史
|
||||
if len(state["routing_history"]) > 10:
|
||||
state["routing_history"] = state["routing_history"][-10:]
|
||||
|
||||
# 保存状态
|
||||
self.storage.set(
|
||||
f"conv_state:{conversation_id}",
|
||||
state,
|
||||
ttl=self.ttl
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
def clear_state(self, conversation_id: str) -> None:
|
||||
"""清除会话状态
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
"""
|
||||
self.storage.delete(f"conv_state:{conversation_id}")
|
||||
logger.info(f"清除会话状态: {conversation_id}")
|
||||
|
||||
def get_routing_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取路由历史
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
路由历史列表
|
||||
"""
|
||||
state = self.get_state(conversation_id)
|
||||
history = state.get("routing_history", [])
|
||||
return history[-limit:] if history else []
|
||||
|
||||
def get_statistics(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""获取会话统计信息
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
state = self.get_state(conversation_id)
|
||||
history = state.get("routing_history", [])
|
||||
|
||||
# 统计各 Agent 使用次数
|
||||
agent_usage = {}
|
||||
for item in history:
|
||||
agent_id = item["agent_id"]
|
||||
agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1
|
||||
|
||||
# 统计主题分布
|
||||
topic_distribution = {}
|
||||
for item in history:
|
||||
topic = item.get("topic", "未知")
|
||||
topic_distribution[topic] = topic_distribution.get(topic, 0) + 1
|
||||
|
||||
return {
|
||||
"conversation_id": conversation_id,
|
||||
"total_turns": len(history),
|
||||
"switch_count": state.get("switch_count", 0),
|
||||
"current_agent_id": state.get("current_agent_id"),
|
||||
"same_agent_turns": state.get("same_agent_turns", 0),
|
||||
"agent_usage": agent_usage,
|
||||
"topic_distribution": topic_distribution,
|
||||
"created_at": state.get("created_at"),
|
||||
"updated_at": state.get("updated_at")
|
||||
}
|
||||
|
||||
def _create_new_state(self, conversation_id: str) -> Dict[str, Any]:
|
||||
"""创建新的会话状态
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
新的状态字典
|
||||
"""
|
||||
state = {
|
||||
"conversation_id": conversation_id,
|
||||
"current_agent_id": None,
|
||||
"previous_agent_id": None,
|
||||
"routing_history": [],
|
||||
"last_message": None,
|
||||
"last_topic": None,
|
||||
"switch_count": 0,
|
||||
"same_agent_turns": 0,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"updated_at": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 保存初始状态
|
||||
self.storage.set(
|
||||
f"conv_state:{conversation_id}",
|
||||
state,
|
||||
ttl=self.ttl
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
class InMemoryStorage:
|
||||
"""内存存储后端(用于开发和测试)"""
|
||||
|
||||
def __init__(self):
|
||||
self._storage: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def get(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取数据"""
|
||||
return self._storage.get(key)
|
||||
|
||||
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
|
||||
"""设置数据"""
|
||||
self._storage[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""删除数据"""
|
||||
if key in self._storage:
|
||||
del self._storage[key]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空所有数据"""
|
||||
self._storage.clear()
|
||||
|
||||
|
||||
class RedisStorage:
|
||||
"""Redis 存储后端(用于生产环境)"""
|
||||
|
||||
def __init__(self, redis_client):
|
||||
"""初始化 Redis 存储
|
||||
|
||||
Args:
|
||||
redis_client: Redis 客户端实例
|
||||
"""
|
||||
self.redis = redis_client
|
||||
|
||||
def get(self, key: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取数据"""
|
||||
data = self.redis.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
|
||||
"""设置数据"""
|
||||
self.redis.setex(key, ttl, json.dumps(value))
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""删除数据"""
|
||||
self.redis.delete(key)
|
||||
85
api/app/services/document_service.py
Normal file
85
api/app/services/document_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.document_model import Document
|
||||
from app.schemas.document_schema import DocumentCreate, DocumentUpdate
|
||||
from app.repositories import document_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_documents_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(f"Query document in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = document_repository.get_documents_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(f"The document paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying document pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_document(
|
||||
db: Session, document: DocumentCreate, current_user: User
|
||||
) -> Document:
|
||||
business_logger.info(f"Create a document: {document.file_name}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
document.created_by = current_user.id
|
||||
db_document = document_repository.create_document(
|
||||
db=db, document=document
|
||||
)
|
||||
business_logger.info(f"The document has been successfully created: {document.file_name} (ID: {db_document.id}), creator: {current_user.username}")
|
||||
return db_document
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a document: {document.file_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> Document | None:
|
||||
business_logger.debug(f"Query document based on ID: document_id={document_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
document = document_repository.get_document_by_id(db=db, document_id=document_id)
|
||||
if document:
|
||||
business_logger.info(f"document query successful: {document.file_name} (ID: {document_id})")
|
||||
else:
|
||||
business_logger.warning(f"document does not exist: document_id={document_id}")
|
||||
return document
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID, current_user: User) -> int:
|
||||
business_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id=={kb_id}, username: {current_user.username}")
|
||||
return document_repository.reset_documents_progress_by_kb_id(db=db, kb_id=kb_id)
|
||||
|
||||
|
||||
def delete_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete document: document_id={document_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
document_repository.delete_document_by_id(db=db, document_id=document_id)
|
||||
business_logger.info(f"document deleted successfully: document_id={document_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete document: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
1630
api/app/services/draft_run_service.py
Normal file
1630
api/app/services/draft_run_service.py
Normal file
File diff suppressed because it is too large
Load Diff
87
api/app/services/file_service.py
Normal file
87
api/app/services/file_service.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.file_model import File
|
||||
from app.schemas.file_schema import FileCreate, FileUpdate
|
||||
from app.repositories import file_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_files_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(f"Query file in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = file_repository.get_files_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(f"The file paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying file pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_file(
|
||||
db: Session, file: FileCreate, current_user: User
|
||||
) -> File:
|
||||
business_logger.info(f"Create a file: {file.file_name}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
file.created_by = current_user.id
|
||||
if file.parent_id is None:
|
||||
file.parent_id = file.kb_id
|
||||
db_file = file_repository.create_file(
|
||||
db=db, file=file
|
||||
)
|
||||
business_logger.info(f"The file has been successfully created: {file.file_name} (ID: {db_file.id}), creator: {current_user.username}")
|
||||
return db_file
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a file: {file.file_name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None:
|
||||
business_logger.debug(f"Query file based on ID: file_id={file_id}")
|
||||
|
||||
try:
|
||||
file = file_repository.get_file_by_id(db=db, file_id=file_id)
|
||||
if file:
|
||||
business_logger.info(f"file query successful: {file.file_name} (ID: {file_id})")
|
||||
else:
|
||||
business_logger.warning(f"file does not exist: file_id={file_id}")
|
||||
return file
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None, current_user: User) -> list | None:
|
||||
business_logger.debug(f"Query file based on folder ID: parent_id={parent_id}, username: {current_user.username}")
|
||||
return file_repository.get_files_by_parent_id(db=db, parent_id=parent_id)
|
||||
|
||||
|
||||
def delete_file_by_id(db: Session, file_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete file: file_id={file_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
file_repository.delete_file_by_id(db=db, file_id=file_id)
|
||||
business_logger.info(f"file_id deleted successfully: file_id={file_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}")
|
||||
raise
|
||||
126
api/app/services/knowledge_service.py
Normal file
126
api/app/services/knowledge_service.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
|
||||
from app.repositories import knowledge_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_knowledges_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(f"Query knowledge base in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = knowledge_repository.get_knowledges_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(f"The knowledge base paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying knowledge base pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_chunded_knowledgeids(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list
|
||||
) -> list:
|
||||
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
|
||||
|
||||
try:
|
||||
items = knowledge_repository.get_chunded_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
business_logger.info(f"Querying the vectorized knowledge base id list succeeded: username={current_user.username} count={len(items)}")
|
||||
return items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying the vectorized knowledge base id list failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_knowledge(
|
||||
db: Session, knowledge: KnowledgeCreate, current_user: User
|
||||
) -> Knowledge:
|
||||
business_logger.info(f"Create a knowledge base: {knowledge.name}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
knowledge.created_by = current_user.id
|
||||
if knowledge.workspace_id is None:
|
||||
knowledge.workspace_id = current_user.current_workspace_id
|
||||
if knowledge.parent_id is None:
|
||||
knowledge.parent_id = knowledge.workspace_id
|
||||
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
|
||||
db_knowledge = knowledge_repository.create_knowledge(
|
||||
db=db, knowledge=knowledge
|
||||
)
|
||||
business_logger.info(f"The knowledge base has been successfully created: {knowledge.name} (ID: {db_knowledge.id}), creator: {current_user.username}")
|
||||
return db_knowledge
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a knowledge base: {knowledge.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> Knowledge | None:
|
||||
business_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id)
|
||||
if knowledge:
|
||||
business_logger.info(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})")
|
||||
else:
|
||||
business_logger.warning(f"knowledge base does not exist: knowledge_id={knowledge_id}")
|
||||
return knowledge
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledge_by_name(db: Session, name: str, current_user: User) -> Knowledge | None:
|
||||
business_logger.debug(f"Query knowledge base based on name: name={name}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(db=db, name=name, workspace_id=current_user.current_workspace_id)
|
||||
if knowledge:
|
||||
business_logger.info(f"knowledge base query successful: {name} (ID: {knowledge.id})")
|
||||
else:
|
||||
business_logger.warning(f"knowledge base does not exist: name={name}")
|
||||
return knowledge
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the knowledge base based on the name: name={name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete knowledge base: knowledge_id={knowledge_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the knowledge base information for logging purposes
|
||||
knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id)
|
||||
if knowledge:
|
||||
business_logger.debug(f"Execute knowledge base deletion: {knowledge.name} (ID: {knowledge_id})")
|
||||
else:
|
||||
business_logger.warning(f"The knowledge base to be deleted does not exist: knowledge_id={knowledge_id}")
|
||||
|
||||
knowledge_repository.delete_knowledge_by_id(db=db, knowledge_id=knowledge_id)
|
||||
business_logger.info(f"knowledge base record deleted successfully: knowledge_id={knowledge_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete knowledge base: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
108
api/app/services/knowledgeshare_service.py
Normal file
108
api/app/services/knowledgeshare_service.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.knowledgeshare_model import KnowledgeShare
|
||||
from app.schemas.knowledgeshare_schema import KnowledgeShareCreate
|
||||
from app.repositories import knowledgeshare_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_knowledgeshares_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(f"Query knowledge base sharing in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = knowledgeshare_repository.get_knowledgeshares_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(f"The knowledge base sharing paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying knowledge base sharing pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_source_kb_ids_by_target_kb_id(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list
|
||||
) -> list:
|
||||
business_logger.debug(f"Query the original knowledge base id list by sharing the knowledge base: username={current_user.username}")
|
||||
|
||||
try:
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
business_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: username={current_user.username} count={len(items)}")
|
||||
return items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_knowledgeshare(
|
||||
db: Session, knowledgeshare: KnowledgeShareCreate, current_user: User
|
||||
) -> KnowledgeShare:
|
||||
business_logger.info(f"Create a knowledge base sharing: creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
knowledgeshare.source_workspace_id = current_user.current_workspace_id
|
||||
knowledgeshare.shared_by = current_user.id
|
||||
business_logger.debug("Start creating a knowledge base sharing")
|
||||
db_knowledgeshare = knowledgeshare_repository.create_knowledgeshare(
|
||||
db=db, knowledgeshare=knowledgeshare
|
||||
)
|
||||
business_logger.info(f"knowledge base sharing created successfully: (ID: {db_knowledgeshare.id}), creator: {current_user.username}")
|
||||
return db_knowledgeshare
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a knowledge base sharing - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> KnowledgeShare | None:
|
||||
business_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
|
||||
if knowledgeshare:
|
||||
business_logger.info(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})")
|
||||
else:
|
||||
business_logger.warning(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
|
||||
return knowledgeshare
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the knowledge base sharing information for logging purposes
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
|
||||
if knowledgeshare:
|
||||
business_logger.debug(f"Execute knowledge base sharing deletion: (ID: {knowledgeshare_id})")
|
||||
else:
|
||||
business_logger.warning(f"The knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
|
||||
|
||||
knowledgeshare_repository.delete_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
|
||||
business_logger.info(f"knowledge base sharing deleted successfully: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
|
||||
raise
|
||||
51
api/app/services/langchain_tool_server.py
Normal file
51
api/app/services/langchain_tool_server.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
# 加载.env文件
|
||||
load_dotenv()
|
||||
|
||||
# 读取web_search环境变量
|
||||
web_search_value = os.getenv('web_search')
|
||||
def Search(query):
|
||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||
api_key = web_search_value
|
||||
payload = json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
], #搜索输入
|
||||
"edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||
"search_source": "baidu_search_v2", #使用的搜索引擎版本
|
||||
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||
"search_filter": {
|
||||
"range": {
|
||||
"page_time": {
|
||||
"gte": "now-1w/d", #时间查询参数,大于或等于
|
||||
"lt": "now/d", #时间查询参数,小于
|
||||
"gt": "", #时间查询参数,大于
|
||||
"lte": "" #时间查询参数,小于或等于
|
||||
}
|
||||
}
|
||||
},
|
||||
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
|
||||
"search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||
"enable_full_content":True #是否输出网页完整原文
|
||||
}, ensure_ascii=False)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
|
||||
content=[]
|
||||
for i in response['references']:
|
||||
title=i['title']
|
||||
snippet=i['snippet']
|
||||
content.append(title+';'+snippet)
|
||||
content='。'.join(content)
|
||||
return content
|
||||
340
api/app/services/llm_client.py
Normal file
340
api/app/services/llm_client.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""LLM 客户端适配器 - 支持多种 LLM 提供商"""
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM 客户端基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLM 响应文本
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
"""OpenAI 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
base_url: Optional[str] = None
|
||||
):
|
||||
"""初始化 OpenAI 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
base_url: API 基础 URL(可选,用于兼容其他服务)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key 未配置")
|
||||
|
||||
try:
|
||||
from openai import AsyncOpenAI
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 库: pip install openai")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
**kwargs: 其他参数(temperature, max_tokens 等)
|
||||
|
||||
Returns:
|
||||
LLM 响应文本
|
||||
"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
max_tokens=kwargs.get("max_tokens", 500)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class AzureOpenAIClient(BaseLLMClient):
|
||||
"""Azure OpenAI 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
deployment_name: Optional[str] = None,
|
||||
api_version: str = "2024-02-15-preview"
|
||||
):
|
||||
"""初始化 Azure OpenAI 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
endpoint: Azure 端点
|
||||
deployment_name: 部署名称
|
||||
api_version: API 版本
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
||||
self.endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
self.deployment_name = deployment_name or os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
||||
self.api_version = api_version
|
||||
|
||||
if not all([self.api_key, self.endpoint, self.deployment_name]):
|
||||
raise ValueError("Azure OpenAI 配置不完整")
|
||||
|
||||
try:
|
||||
from openai import AsyncAzureOpenAI
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.api_key,
|
||||
azure_endpoint=self.endpoint,
|
||||
api_version=self.api_version
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 openai 库: pip install openai")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.deployment_name,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
max_tokens=kwargs.get("max_tokens", 500)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Azure OpenAI API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicClient(BaseLLMClient):
|
||||
"""Anthropic Claude 客户端"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "claude-3-sonnet-20240229"
|
||||
):
|
||||
"""初始化 Anthropic 客户端
|
||||
|
||||
Args:
|
||||
api_key: API 密钥
|
||||
model: 模型名称
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.model = model
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Anthropic API key 未配置")
|
||||
|
||||
try:
|
||||
from anthropic import AsyncAnthropic
|
||||
self.client = AsyncAnthropic(api_key=self.api_key)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 anthropic 库: pip install anthropic")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=kwargs.get("max_tokens", 500),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
|
||||
return response.content[0].text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Anthropic API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class LocalLLMClient(BaseLLMClient):
|
||||
"""本地 LLM 客户端(通过 HTTP API)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://localhost:8000",
|
||||
model: str = "local-model"
|
||||
):
|
||||
"""初始化本地 LLM 客户端
|
||||
|
||||
Args:
|
||||
base_url: API 基础 URL
|
||||
model: 模型名称
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
|
||||
try:
|
||||
import httpx
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
except ImportError:
|
||||
raise ImportError("请安装 httpx 库: pip install httpx")
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求"""
|
||||
try:
|
||||
response = await self.client.post(
|
||||
f"{self.base_url}/v1/chat/completions",
|
||||
json={
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": kwargs.get("temperature", 0.3),
|
||||
"max_tokens": kwargs.get("max_tokens", 500)
|
||||
}
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return data["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"本地 LLM API 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class MockLLMClient(BaseLLMClient):
|
||||
"""模拟 LLM 客户端(用于测试)"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化模拟客户端"""
|
||||
self.call_count = 0
|
||||
|
||||
async def chat(self, prompt: str, **kwargs) -> str:
|
||||
"""发送聊天请求(返回模拟结果)"""
|
||||
self.call_count += 1
|
||||
|
||||
logger.info(f"模拟 LLM 调用 (第 {self.call_count} 次)")
|
||||
|
||||
# 简单的规则匹配
|
||||
prompt_lower = prompt.lower()
|
||||
|
||||
if "数学" in prompt_lower or "方程" in prompt_lower or "计算" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "math-agent",
|
||||
"confidence": 0.9,
|
||||
"reason": "消息包含数学相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "化学" in prompt_lower or "反应" in prompt_lower or "元素" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "chemistry-agent",
|
||||
"confidence": 0.85,
|
||||
"reason": "消息包含化学相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "物理" in prompt_lower or "力" in prompt_lower or "速度" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "physics-agent",
|
||||
"confidence": 0.88,
|
||||
"reason": "消息包含物理相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "语文" in prompt_lower or "古诗" in prompt_lower or "作文" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "chinese-agent",
|
||||
"confidence": 0.87,
|
||||
"reason": "消息包含语文相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
elif "英语" in prompt_lower or "单词" in prompt_lower or "语法" in prompt_lower:
|
||||
return json.dumps({
|
||||
"agent_id": "english-agent",
|
||||
"confidence": 0.86,
|
||||
"reason": "消息包含英语相关内容"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
return json.dumps({
|
||||
"agent_id": "math-agent",
|
||||
"confidence": 0.5,
|
||||
"reason": "无法明确判断,使用默认 Agent"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
class LLMClientFactory:
|
||||
"""LLM 客户端工厂"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: str = "mock",
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""创建 LLM 客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 (openai, azure, anthropic, local, mock)
|
||||
**kwargs: 客户端配置参数
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例
|
||||
"""
|
||||
provider = provider.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return OpenAIClient(**kwargs)
|
||||
|
||||
elif provider == "azure":
|
||||
return AzureOpenAIClient(**kwargs)
|
||||
|
||||
elif provider == "anthropic":
|
||||
return AnthropicClient(**kwargs)
|
||||
|
||||
elif provider == "local":
|
||||
return LocalLLMClient(**kwargs)
|
||||
|
||||
elif provider == "mock":
|
||||
return MockLLMClient()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的 LLM 提供商: {provider}")
|
||||
|
||||
@staticmethod
|
||||
def create_from_env() -> BaseLLMClient:
|
||||
"""从环境变量创建 LLM 客户端
|
||||
|
||||
环境变量:
|
||||
- LLM_PROVIDER: 提供商名称
|
||||
- OPENAI_API_KEY: OpenAI API 密钥
|
||||
- AZURE_OPENAI_API_KEY: Azure OpenAI API 密钥
|
||||
- ANTHROPIC_API_KEY: Anthropic API 密钥
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例
|
||||
"""
|
||||
provider = os.getenv("LLM_PROVIDER", "mock")
|
||||
|
||||
logger.info(f"从环境变量创建 LLM 客户端: {provider}")
|
||||
|
||||
return LLMClientFactory.create(provider)
|
||||
685
api/app/services/llm_router.py
Normal file
685
api/app/services/llm_router.py
Normal file
@@ -0,0 +1,685 @@
|
||||
"""基于 LLM 的智能路由器 - 混合策略"""
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LLMRouter:
|
||||
"""基于 LLM 的智能路由器
|
||||
|
||||
混合策略:
|
||||
1. 先用关键词快速筛选(置信度 > 0.8 直接返回)
|
||||
2. 对于模糊情况(置信度 0.3-0.8),调用 LLM 辅助
|
||||
3. 对于完全不匹配(置信度 < 0.3),调用 LLM
|
||||
4. 缓存 LLM 结果,减少重复调用
|
||||
"""
|
||||
|
||||
# 主题切换信号
|
||||
SWITCH_SIGNALS = [
|
||||
"换个话题", "另外", "还有", "对了",
|
||||
"那这个呢", "再问一个", "顺便问下",
|
||||
"我想问", "帮我", "请问", "换一个"
|
||||
]
|
||||
|
||||
# 延续信号
|
||||
CONTINUATION_SIGNALS = [
|
||||
"继续", "还是", "也", "同样", "类似",
|
||||
"这个", "那个", "它", "他", "她", "呢"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
state_manager: ConversationStateManager,
|
||||
routing_rules: List[Dict[str, Any]],
|
||||
sub_agents: Dict[str, Any],
|
||||
routing_model_config: Optional[ModelConfig] = None,
|
||||
use_llm: bool = True
|
||||
):
|
||||
"""初始化 LLM 路由器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
state_manager: 会话状态管理器
|
||||
routing_rules: 路由规则列表
|
||||
sub_agents: 子 Agent 配置字典
|
||||
routing_model_config: 用于路由的模型配置(可选)
|
||||
use_llm: 是否启用 LLM(默认 True)
|
||||
"""
|
||||
self.db = db
|
||||
self.state_manager = state_manager
|
||||
self.routing_rules = routing_rules
|
||||
self.sub_agents = sub_agents
|
||||
self.routing_model_config = routing_model_config
|
||||
self.use_llm = use_llm and routing_model_config is not None
|
||||
|
||||
# 配置参数
|
||||
self.min_confidence_for_switch = 0.7
|
||||
self.max_same_agent_turns = 10
|
||||
self.keyword_high_confidence_threshold = 0.8 # 关键词高置信度阈值
|
||||
self.keyword_low_confidence_threshold = 0.3 # 关键词低置信度阈值
|
||||
|
||||
# 缓存配置
|
||||
self.cache_enabled = True
|
||||
self.cache_size = 1000
|
||||
|
||||
async def route(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
force_new: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""智能路由(混合策略)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
force_new: 是否强制重新路由
|
||||
|
||||
Returns:
|
||||
路由结果
|
||||
"""
|
||||
logger.info(
|
||||
f"开始 LLM 智能路由",
|
||||
extra={
|
||||
"message_length": len(message),
|
||||
"conversation_id": conversation_id,
|
||||
"use_llm": self.use_llm
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 获取会话状态
|
||||
state = None
|
||||
if conversation_id and not force_new:
|
||||
state = self.state_manager.get_state(conversation_id)
|
||||
|
||||
# 2. 检测主题切换
|
||||
topic_changed = self._detect_topic_change(message, state)
|
||||
|
||||
# 3. 提取当前主题
|
||||
topic = await self._extract_topic_with_llm(message) if self.use_llm else self._extract_topic(message)
|
||||
|
||||
# 4. 选择路由策略
|
||||
if force_new:
|
||||
agent_id, confidence, method = await self._route_with_hybrid(message)
|
||||
strategy = "force_new"
|
||||
reason = "用户强制重新路由"
|
||||
|
||||
elif not state or not state.get("current_agent_id"):
|
||||
agent_id, confidence, method = await self._route_with_hybrid(message)
|
||||
strategy = "new_conversation"
|
||||
reason = "新会话,首次路由"
|
||||
|
||||
elif topic_changed:
|
||||
agent_id, confidence, method = await self._route_with_hybrid(message)
|
||||
strategy = "topic_changed"
|
||||
reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}"
|
||||
|
||||
elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns:
|
||||
agent_id, confidence, method = await self._route_with_hybrid(message)
|
||||
strategy = "max_turns_reached"
|
||||
reason = f"同一 Agent 已使用 {state['same_agent_turns']} 轮"
|
||||
|
||||
else:
|
||||
current_agent_id = state["current_agent_id"]
|
||||
should_continue, continue_confidence = self._should_continue_current_agent(
|
||||
message,
|
||||
current_agent_id
|
||||
)
|
||||
|
||||
if should_continue:
|
||||
agent_id = current_agent_id
|
||||
confidence = continue_confidence
|
||||
method = "keyword"
|
||||
strategy = "continue_current"
|
||||
reason = "消息在当前 Agent 能力范围内"
|
||||
else:
|
||||
new_agent_id, new_confidence, method = await self._route_with_hybrid(message)
|
||||
|
||||
if new_confidence > continue_confidence + self.min_confidence_for_switch:
|
||||
agent_id = new_agent_id
|
||||
confidence = new_confidence
|
||||
strategy = "switch_agent"
|
||||
reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}"
|
||||
else:
|
||||
agent_id = current_agent_id
|
||||
confidence = continue_confidence
|
||||
method = "keyword"
|
||||
strategy = "keep_current"
|
||||
reason = "置信度差距不足以切换 Agent"
|
||||
|
||||
# 5. 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
conversation_id,
|
||||
agent_id,
|
||||
message,
|
||||
topic,
|
||||
confidence
|
||||
)
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"confidence": confidence,
|
||||
"strategy": strategy,
|
||||
"topic": topic,
|
||||
"topic_changed": topic_changed,
|
||||
"reason": reason,
|
||||
"routing_method": method # "keyword", "llm", "hybrid"
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"路由完成",
|
||||
extra={
|
||||
"agent_id": agent_id,
|
||||
"strategy": strategy,
|
||||
"confidence": confidence,
|
||||
"method": method
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _route_with_hybrid(self, message: str) -> Tuple[str, float, str]:
|
||||
"""混合路由策略
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
(agent_id, confidence, method)
|
||||
"""
|
||||
# 1. 先用关键词匹配
|
||||
keyword_agent_id, keyword_confidence = self._route_with_keywords(message)
|
||||
|
||||
# 2. 判断是否需要 LLM
|
||||
if not self.use_llm or not self.routing_model_config:
|
||||
# 不使用 LLM,直接返回关键词结果
|
||||
return keyword_agent_id, keyword_confidence, "keyword"
|
||||
|
||||
if keyword_confidence >= self.keyword_high_confidence_threshold:
|
||||
# 关键词置信度很高,直接返回
|
||||
logger.info(f"关键词置信度高 ({keyword_confidence:.2f}),跳过 LLM")
|
||||
return keyword_agent_id, keyword_confidence, "keyword"
|
||||
|
||||
# 3. 使用 LLM 辅助决策
|
||||
logger.info(f"关键词置信度较低 ({keyword_confidence:.2f}),调用 LLM")
|
||||
llm_agent_id, llm_confidence = await self._route_with_llm(message)
|
||||
|
||||
# 4. 综合决策
|
||||
if llm_confidence > keyword_confidence:
|
||||
# LLM 置信度更高
|
||||
final_confidence = llm_confidence * 0.7 + keyword_confidence * 0.3
|
||||
return llm_agent_id, final_confidence, "llm"
|
||||
else:
|
||||
# 关键词置信度更高或相当
|
||||
final_confidence = keyword_confidence * 0.7 + llm_confidence * 0.3
|
||||
return keyword_agent_id, final_confidence, "hybrid"
|
||||
|
||||
def _route_with_keywords(self, message: str) -> Tuple[str, float]:
|
||||
"""基于关键词的路由
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
(agent_id, confidence)
|
||||
"""
|
||||
best_agent_id = None
|
||||
best_score = 0.0
|
||||
|
||||
for rule in self.routing_rules:
|
||||
score = self._calculate_rule_score(message, rule)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_agent_id = rule.get("target_agent_id")
|
||||
|
||||
if not best_agent_id or best_score < 0.3:
|
||||
best_agent_id = self._get_default_agent_id()
|
||||
best_score = 0.5
|
||||
|
||||
return best_agent_id, best_score
|
||||
|
||||
async def _route_with_llm(self, message: str) -> Tuple[str, float]:
|
||||
"""基于 LLM 的路由
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
(agent_id, confidence)
|
||||
"""
|
||||
# 检查缓存
|
||||
if self.cache_enabled:
|
||||
cached_result = self._get_cached_llm_result(message)
|
||||
if cached_result:
|
||||
logger.info("使用缓存的 LLM 路由结果")
|
||||
return cached_result
|
||||
|
||||
# 构建 prompt
|
||||
prompt = self._build_routing_prompt(message)
|
||||
|
||||
try:
|
||||
# 调用 LLM
|
||||
response = await self._call_llm(prompt)
|
||||
|
||||
# 解析结果
|
||||
agent_id, confidence = self._parse_llm_response(response)
|
||||
|
||||
# 缓存结果
|
||||
if self.cache_enabled:
|
||||
self._cache_llm_result(message, agent_id, confidence)
|
||||
|
||||
return agent_id, confidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 路由失败: {str(e)}")
|
||||
# 降级到关键词路由
|
||||
return self._route_with_keywords(message)
|
||||
|
||||
def _build_routing_prompt(self, message: str) -> str:
|
||||
"""构建 LLM 路由 prompt
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
prompt 字符串
|
||||
"""
|
||||
# 构建 Agent 描述
|
||||
agent_descriptions = []
|
||||
for agent_id, agent_data in self.sub_agents.items():
|
||||
# 获取 Agent 信息
|
||||
agent_info = agent_data.get("info", {})
|
||||
agent_config = agent_data.get("config")
|
||||
|
||||
# 查找该 Agent 的路由规则
|
||||
rules = [r for r in self.routing_rules if r.get("target_agent_id") == agent_id]
|
||||
|
||||
# 构建描述
|
||||
name = agent_info.get("name", "未命名 Agent")
|
||||
role = agent_info.get("role", "")
|
||||
capabilities = agent_info.get("capabilities", [])
|
||||
|
||||
desc_parts = [f"- agent_id: {agent_id}", f" 名称: {name}"]
|
||||
|
||||
if role:
|
||||
desc_parts.append(f" 角色: {role}")
|
||||
|
||||
# 从路由规则获取关键词
|
||||
if rules:
|
||||
rule = rules[0]
|
||||
keywords = rule.get("keywords", [])
|
||||
if keywords:
|
||||
desc_parts.append(f" 关键词: {', '.join(keywords[:5])}")
|
||||
|
||||
# 从 Agent 信息获取能力
|
||||
if capabilities:
|
||||
desc_parts.append(f" 擅长: {', '.join(capabilities[:5])}")
|
||||
|
||||
agent_descriptions.append("\n".join(desc_parts))
|
||||
|
||||
agents_text = "\n\n".join(agent_descriptions)
|
||||
|
||||
# 如果没有 Agent 描述,添加警告
|
||||
if not agents_text:
|
||||
agents_text = "(警告:没有可用的 Agent 信息)"
|
||||
|
||||
# 提取所有可用的 agent_id
|
||||
available_agent_ids = list(self.sub_agents.keys())
|
||||
agent_ids_text = ", ".join(available_agent_ids)
|
||||
|
||||
prompt = f"""你是一个智能路由助手,需要根据用户的消息,选择最合适的 Agent 来处理。
|
||||
|
||||
可用的 Agent:
|
||||
{agents_text}
|
||||
|
||||
用户消息:"{message}"
|
||||
|
||||
**重要**:你必须从以下 agent_id 中选择一个:{agent_ids_text}
|
||||
|
||||
请分析这条消息,选择最合适的 Agent。
|
||||
|
||||
要求:
|
||||
1. 仔细理解消息的意图和主题
|
||||
2. 从上面列出的 agent_id 中选择最匹配的一个
|
||||
3. 给出置信度(0-1 之间的小数)
|
||||
4. agent_id 必须是上面列出的其中一个,不能自己编造
|
||||
|
||||
请以 JSON 格式返回:
|
||||
{{
|
||||
"agent_id": "从上面列表中选择的 agent_id",
|
||||
"confidence": 0.95,
|
||||
"reason": "选择理由"
|
||||
}}
|
||||
"""
|
||||
return prompt
|
||||
|
||||
async def _call_llm(self, prompt: str) -> str:
|
||||
"""调用 LLM API(使用系统的 RedBearLLM)
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
|
||||
Returns:
|
||||
LLM 响应
|
||||
"""
|
||||
if not self.routing_model_config:
|
||||
raise Exception("路由模型配置未设置")
|
||||
|
||||
try:
|
||||
# 使用系统的 RedBearLLM 来调用模型
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.routing_model_config.id,
|
||||
ModelApiKey.is_active == True
|
||||
).first()
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("路由模型没有可用的 API Key")
|
||||
|
||||
# 打印供应商信息
|
||||
logger.info(
|
||||
f"LLM 路由使用模型",
|
||||
extra={
|
||||
"provider": api_key_config.provider,
|
||||
"model_name": api_key_config.model_name,
|
||||
"api_base": api_key_config.api_base,
|
||||
"model_config_id": str(self.routing_model_config.id)
|
||||
}
|
||||
)
|
||||
|
||||
# 创建 RedBearModelConfig
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=api_key_config.model_name,
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}")
|
||||
|
||||
# 创建 LLM 实例
|
||||
llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 路由调用失败: {str(e)}")
|
||||
# 降级到关键词路由
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Tuple[str, float]:
|
||||
"""解析 LLM 响应
|
||||
|
||||
Args:
|
||||
response: LLM 响应文本
|
||||
|
||||
Returns:
|
||||
(agent_id, confidence)
|
||||
"""
|
||||
try:
|
||||
# 提取 JSON
|
||||
json_match = re.search(r'\{[^}]+\}', response)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
agent_id = result.get("agent_id")
|
||||
confidence = float(result.get("confidence", 0.5))
|
||||
|
||||
# 验证 agent_id 是否有效
|
||||
if agent_id not in self.sub_agents:
|
||||
logger.warning(f"LLM 返回的 agent_id 无效: {agent_id}")
|
||||
agent_id = self._get_default_agent_id()
|
||||
confidence = 0.5
|
||||
|
||||
return agent_id, confidence
|
||||
else:
|
||||
raise ValueError("无法从响应中提取 JSON")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 LLM 响应失败: {str(e)}")
|
||||
return self._get_default_agent_id(), 0.5
|
||||
|
||||
def _get_cached_llm_result(self, message: str) -> Optional[Tuple[str, float]]:
|
||||
"""获取缓存的 LLM 结果
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
缓存的结果或 None
|
||||
"""
|
||||
# TODO: 实现真正的缓存机制(使用 Redis 或内存字典)
|
||||
return None
|
||||
|
||||
def _cache_llm_result(self, message: str, agent_id: str, confidence: float):
|
||||
"""缓存 LLM 结果
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
agent_id: Agent ID
|
||||
confidence: 置信度
|
||||
"""
|
||||
# lru_cache 会自动处理缓存
|
||||
pass
|
||||
|
||||
async def _extract_topic_with_llm(self, message: str) -> str:
|
||||
"""使用 LLM 提取主题
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
主题名称
|
||||
"""
|
||||
if not self.routing_model_config:
|
||||
return self._extract_topic(message)
|
||||
|
||||
prompt = f"""请分析以下消息的主题,从这些选项中选择一个:
|
||||
数学、物理、化学、语文、英语、历史、作业、学习规划、订单、退款、账户、支付、其他
|
||||
|
||||
消息:"{message}"
|
||||
|
||||
只返回主题名称,不要其他内容。
|
||||
"""
|
||||
|
||||
try:
|
||||
response = await self._call_llm(prompt)
|
||||
topic = response.strip()
|
||||
|
||||
# 验证主题
|
||||
valid_topics = [
|
||||
"数学", "物理", "化学", "语文", "英语", "历史",
|
||||
"作业", "学习规划", "订单", "退款", "账户", "支付", "其他"
|
||||
]
|
||||
|
||||
if topic in valid_topics:
|
||||
return topic
|
||||
else:
|
||||
return self._extract_topic(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM 提取主题失败: {str(e)}")
|
||||
return self._extract_topic(message)
|
||||
|
||||
# 以下方法与 SmartRouter 相同
|
||||
|
||||
def _detect_topic_change(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""检测主题是否切换"""
|
||||
if not state or not state.get("last_topic"):
|
||||
return False
|
||||
|
||||
for signal in self.SWITCH_SIGNALS:
|
||||
if signal in message:
|
||||
logger.info(f"检测到主题切换信号: {signal}")
|
||||
return True
|
||||
|
||||
current_topic = self._extract_topic(message)
|
||||
last_topic = state.get("last_topic")
|
||||
|
||||
if current_topic != last_topic and current_topic != "其他":
|
||||
logger.info(f"主题变化: {last_topic} -> {current_topic}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_continue_current_agent(
|
||||
self,
|
||||
message: str,
|
||||
current_agent_id: str
|
||||
) -> Tuple[bool, float]:
|
||||
"""判断是否应该继续使用当前 Agent"""
|
||||
has_continuation_signal = any(
|
||||
signal in message
|
||||
for signal in self.CONTINUATION_SIGNALS
|
||||
)
|
||||
|
||||
current_score = self._calculate_agent_score(message, current_agent_id)
|
||||
|
||||
if has_continuation_signal and current_score > 0.3:
|
||||
return True, min(current_score + 0.2, 1.0)
|
||||
|
||||
if current_score > 0.6:
|
||||
return True, current_score
|
||||
|
||||
return False, current_score
|
||||
|
||||
def _calculate_rule_score(
|
||||
self,
|
||||
message: str,
|
||||
rule: Dict[str, Any]
|
||||
) -> float:
|
||||
"""计算规则匹配分数"""
|
||||
score = 0.0
|
||||
message_lower = message.lower()
|
||||
|
||||
keywords = rule.get("keywords", [])
|
||||
if keywords:
|
||||
matched_keywords = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword.lower() in message_lower
|
||||
)
|
||||
keyword_score = matched_keywords / len(keywords)
|
||||
score += keyword_score * 0.6
|
||||
|
||||
patterns = rule.get("patterns", [])
|
||||
if patterns:
|
||||
matched_patterns = sum(
|
||||
1 for pattern in patterns
|
||||
if re.search(pattern, message, re.IGNORECASE)
|
||||
)
|
||||
pattern_score = matched_patterns / len(patterns)
|
||||
score += pattern_score * 0.3
|
||||
|
||||
exclude_keywords = rule.get("exclude_keywords", [])
|
||||
if exclude_keywords:
|
||||
has_exclude = any(
|
||||
keyword.lower() in message_lower
|
||||
for keyword in exclude_keywords
|
||||
)
|
||||
if has_exclude:
|
||||
score *= 0.5
|
||||
|
||||
min_keyword_count = rule.get("min_keyword_count", 0)
|
||||
if keywords and min_keyword_count > 0:
|
||||
matched_count = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword.lower() in message_lower
|
||||
)
|
||||
if matched_count < min_keyword_count:
|
||||
score *= 0.7
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def _calculate_agent_score(
|
||||
self,
|
||||
message: str,
|
||||
agent_id: str
|
||||
) -> float:
|
||||
"""计算 Agent 对消息的匹配分数"""
|
||||
agent_rules = [
|
||||
rule for rule in self.routing_rules
|
||||
if rule.get("target_agent_id") == agent_id
|
||||
]
|
||||
|
||||
if not agent_rules:
|
||||
return 0.0
|
||||
|
||||
max_score = max(
|
||||
self._calculate_rule_score(message, rule)
|
||||
for rule in agent_rules
|
||||
)
|
||||
|
||||
return max_score
|
||||
|
||||
def _extract_topic(self, message: str) -> str:
|
||||
"""提取消息主题(关键词方式)"""
|
||||
topic_keywords = {
|
||||
"数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"],
|
||||
"物理": ["物理", "力", "速度", "加速度", "能量", "功率", "电路"],
|
||||
"化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"],
|
||||
"语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"],
|
||||
"英语": ["英语", "单词", "语法", "翻译", "时态", "句型"],
|
||||
"历史": ["历史", "朝代", "事件", "人物", "战争", "革命"],
|
||||
"作业": ["作业", "批改", "检查", "评分", "反馈"],
|
||||
"学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"],
|
||||
"订单": ["订单", "发货", "物流", "配送", "快递"],
|
||||
"退款": ["退款", "退货", "售后", "换货", "维修"],
|
||||
"账户": ["账户", "密码", "登录", "注册", "绑定"],
|
||||
"支付": ["支付", "付款", "充值", "余额", "优惠券"]
|
||||
}
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
topic_scores = {}
|
||||
for topic, keywords in topic_keywords.items():
|
||||
matched = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword in message_lower
|
||||
)
|
||||
if matched > 0:
|
||||
topic_scores[topic] = matched
|
||||
|
||||
if topic_scores:
|
||||
best_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
|
||||
return best_topic
|
||||
|
||||
return "其他"
|
||||
|
||||
def _get_default_agent_id(self) -> str:
|
||||
"""获取默认 Agent ID"""
|
||||
if self.routing_rules:
|
||||
return self.routing_rules[0].get("target_agent_id")
|
||||
|
||||
if self.sub_agents:
|
||||
return list(self.sub_agents.keys())[0]
|
||||
|
||||
return "default-agent"
|
||||
1035
api/app/services/memory_agent_service.py
Normal file
1035
api/app/services/memory_agent_service.py
Normal file
File diff suppressed because it is too large
Load Diff
595
api/app/services/memory_dashboard_service.py
Normal file
595
api/app/services/memory_dashboard_service.py
Normal file
@@ -0,0 +1,595 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.memory_increment_model import MemoryIncrement
|
||||
|
||||
from app.repositories import (
|
||||
app_repository,
|
||||
end_user_repository,
|
||||
memory_increment_repository,
|
||||
knowledge_repository
|
||||
)
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
from app.schemas.memory_increment_schema import MemoryIncrement as MemoryIncrementSchema
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_workspace_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查询应用(ORM)并转换为 Pydantic 模型
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
apps = [AppSchema.model_validate(h) for h in apps_orm]
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = []
|
||||
for app_id in app_ids:
|
||||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||||
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list])
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return end_users
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间宿主列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_memory_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
limit: int,
|
||||
current_user: User
|
||||
) -> List[MemoryIncrementSchema]:
|
||||
"""获取工作空间的记忆增量"""
|
||||
business_logger.info(f"获取工作空间记忆增量: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查询记忆增量
|
||||
memory_increment_orm_list = memory_increment_repository.get_memory_increments_by_workspace_id(db, workspace_id, limit)
|
||||
memory_increment = [MemoryIncrementSchema.model_validate(m) for m in memory_increment_orm_list]
|
||||
|
||||
business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
|
||||
return memory_increment
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间记忆增量失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_api_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""获取工作空间的API调用增量"""
|
||||
business_logger.info(f"获取工作空间API调用增量: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查询API调用增量
|
||||
api_increment = 856
|
||||
|
||||
business_logger.info(f"成功获取 {api_increment} API调用增量")
|
||||
return api_increment
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间API调用增量失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def write_workspace_total_memory(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""写入工作空间的记忆总量"""
|
||||
business_logger.info(f"写入工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 模拟记忆总量
|
||||
total_num = 1024
|
||||
|
||||
# 写入记忆总量
|
||||
memory_increment_repository.write_memory_increment(db, workspace_id, total_num)
|
||||
|
||||
business_logger.info(f"成功写入记忆总量 {total_num}")
|
||||
return total_num
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"写入工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_memory_list(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
limit: int = 7
|
||||
) -> dict:
|
||||
"""
|
||||
获取工作空间的记忆列表(整合接口)
|
||||
|
||||
整合以下三个接口的数据:
|
||||
1. total_memory - 工作空间记忆总量
|
||||
2. memory_increment - 工作空间记忆增量
|
||||
3. hosts - 工作空间宿主列表
|
||||
"""
|
||||
business_logger.info(f"获取工作空间记忆列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
result = {}
|
||||
|
||||
try:
|
||||
# 1. 获取记忆总量
|
||||
try:
|
||||
total_memory = write_workspace_total_memory(db, workspace_id, current_user)
|
||||
result["total_memory"] = total_memory
|
||||
business_logger.info(f"成功获取记忆总量: {total_memory}")
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
result["total_memory"] = 0.0
|
||||
|
||||
# 2. 获取记忆增量
|
||||
try:
|
||||
memory_increment = get_workspace_memory_increment(db, workspace_id, limit, current_user)
|
||||
result["memory_increment"] = memory_increment
|
||||
business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取记忆增量失败: {str(e)}")
|
||||
result["memory_increment"] = []
|
||||
|
||||
# 3. 获取宿主列表
|
||||
try:
|
||||
hosts = get_workspace_end_users(db, workspace_id, current_user)
|
||||
result["hosts"] = hosts
|
||||
business_logger.info(f"成功获取 {len(hosts)} 个宿主记录")
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取宿主列表失败: {str(e)}")
|
||||
result["hosts"] = []
|
||||
|
||||
business_logger.info(f"成功获取工作空间记忆列表")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间记忆列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_total_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取用户列表的总用户数
|
||||
"""
|
||||
business_logger.info(f"获取用户列表的总用户数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 复用原有的 get_workspace_end_users 逻辑
|
||||
end_users = get_workspace_end_users(db, workspace_id, current_user)
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return {
|
||||
"total_num": len(end_users),
|
||||
"online_num": len(end_users)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户列表失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_workspace_total_memory_count(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
end_user_id: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
获取工作空间的记忆总量(通过聚合所有host的记忆数)
|
||||
|
||||
逻辑:
|
||||
1. 从 memory_list 获取所有 host_id
|
||||
2. 对每个 host_id 调用 search_all 获取 total
|
||||
3. 将所有 total 求和返回
|
||||
"""
|
||||
business_logger.info(f"获取工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取所有 hosts
|
||||
hosts = get_workspace_end_users(db, workspace_id, current_user)
|
||||
business_logger.info(f"获取到 {len(hosts)} 个宿主")
|
||||
|
||||
if not hosts:
|
||||
business_logger.warning("未找到任何宿主,返回0")
|
||||
return {
|
||||
"total_memory_count": 0,
|
||||
"host_count": 0,
|
||||
"details": []
|
||||
}
|
||||
|
||||
# 2. 对每个 host_id 调用 search_all 获取 total
|
||||
from app.services import memory_storage_service
|
||||
|
||||
total_count = 0
|
||||
details = []
|
||||
|
||||
# 如果提供了 end_user_id,只查询该用户
|
||||
if end_user_id:
|
||||
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
|
||||
return {
|
||||
"total_memory_count": search_result.get("total", 0),
|
||||
"host_count": 1,
|
||||
"details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}]
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
try:
|
||||
end_user_id_str = str(host.id)
|
||||
|
||||
search_result = await memory_storage_service.search_all(
|
||||
end_user_id=end_user_id_str
|
||||
)
|
||||
|
||||
host_total = search_result.get("total", 0)
|
||||
total_count += host_total
|
||||
|
||||
details.append({
|
||||
"end_user_id": end_user_id_str,
|
||||
"count": host_total
|
||||
})
|
||||
|
||||
business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}")
|
||||
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
|
||||
# 失败的 host 记为 0
|
||||
details.append({
|
||||
"end_user_id": str(host.id),
|
||||
"count": 0
|
||||
})
|
||||
|
||||
result = {
|
||||
"total_memory_count": total_count,
|
||||
"host_count": len(hosts),
|
||||
"details": details
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取工作空间记忆总量: {total_count} (来自 {len(hosts)} 个宿主)")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
# ======== RAG 相关服务 ========
|
||||
def get_rag_total_doc(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有doc_num的总和
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总文档数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_doc = knowledge_repository.get_total_doc_num_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总文档数: {total_doc}")
|
||||
return total_doc
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总文档数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_rag_total_chunk(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有chunk_num的总和
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总chunk数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_chunk = knowledge_repository.get_total_chunk_num_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总chunk数: {total_chunk}")
|
||||
return total_chunk
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_rag_total_kb(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
||||
return total_kb
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_user_total_chunk(
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
计算documents表中file_name=='end_user_id'+'.txt'的所有记录chunk_num的总和
|
||||
"""
|
||||
business_logger.info(f"获取用户总chunk数: end_user_id={end_user_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from sqlalchemy import func
|
||||
|
||||
# 构造文件名
|
||||
file_name = f"{end_user_id}.txt"
|
||||
|
||||
# 查询并求和
|
||||
total_chunk = db.query(func.sum(Document.chunk_num)).filter(
|
||||
Document.file_name == file_name
|
||||
).scalar() or 0
|
||||
|
||||
business_logger.info(f"成功获取用户总chunk数: {total_chunk} (file_name={file_name})")
|
||||
return int(total_chunk)
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_rag_content(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id,
|
||||
然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容,
|
||||
接着对获取的内容进行提取,只要page_content的内容,
|
||||
最后返回数据
|
||||
"""
|
||||
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
|
||||
# 1. 构造文件名
|
||||
file_name = f"{end_user_id}.txt"
|
||||
|
||||
# 2. 查询documents表获取id和kb_id
|
||||
documents = db.query(Document).filter(
|
||||
Document.file_name == file_name
|
||||
).all()
|
||||
|
||||
if not documents:
|
||||
business_logger.warning(f"未找到文件: {file_name}")
|
||||
return {
|
||||
"total": 0,
|
||||
"contents": []
|
||||
}
|
||||
|
||||
business_logger.info(f"找到 {len(documents)} 个文档记录")
|
||||
|
||||
# 3. 获取所有chunks的page_content
|
||||
all_contents = []
|
||||
total_chunks = 0
|
||||
|
||||
for document in documents:
|
||||
try:
|
||||
# 获取知识库信息
|
||||
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
|
||||
if not kb:
|
||||
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
|
||||
continue
|
||||
|
||||
# 初始化向量服务
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb)
|
||||
|
||||
# 获取该文档的所有chunks(分页获取)
|
||||
page = 1
|
||||
pagesize = 100 # 每页100条
|
||||
|
||||
while True:
|
||||
total, items = vector_service.search_by_segment(
|
||||
document_id=str(document.id),
|
||||
query=None,
|
||||
pagesize=pagesize,
|
||||
page=page,
|
||||
asc=True
|
||||
)
|
||||
|
||||
if not items:
|
||||
break
|
||||
|
||||
# 提取page_content
|
||||
for item in items:
|
||||
all_contents.append(item.page_content)
|
||||
total_chunks += 1
|
||||
|
||||
# # 如果达到limit限制,直接返回
|
||||
# if limit > 0 and total_chunks >= limit:
|
||||
# business_logger.info(f"已达到limit限制: {limit}")
|
||||
# return {
|
||||
# "total": total_chunks,
|
||||
# "contents": all_contents[:limit]
|
||||
# }
|
||||
|
||||
# 检查是否还有下一页
|
||||
if page * pagesize >= total:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks")
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
result = {
|
||||
"total": total_chunks,
|
||||
"contents": all_contents[:limit] if limit > 0 else all_contents
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])} 条")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG内容失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_chunk_summary_and_tags(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
max_tags: int,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的总结、标签和人物形象
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
max_tags: 最大标签数量
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含summary、tags和personas的字典
|
||||
"""
|
||||
business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"summary": "暂无内容",
|
||||
"tags": [],
|
||||
"personas": []
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona
|
||||
|
||||
# 3. 并发生成摘要、提取标签和人物形象
|
||||
import asyncio
|
||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
|
||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
|
||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
|
||||
|
||||
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
||||
|
||||
# 4. 格式化标签数据
|
||||
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
|
||||
|
||||
result = {
|
||||
"summary": summary,
|
||||
"tags": tags,
|
||||
"personas": personas
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def get_chunk_insight(
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> dict:
|
||||
"""
|
||||
获取chunk的洞察分析
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID
|
||||
limit: 返回的chunk数量限制
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
包含insight的字典
|
||||
"""
|
||||
business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. 获取chunk内容
|
||||
rag_content = get_rag_content(end_user_id, limit, db, current_user)
|
||||
chunks = rag_content.get("contents", [])
|
||||
|
||||
if not chunks:
|
||||
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
|
||||
return {
|
||||
"insight": "暂无足够数据生成洞察报告"
|
||||
}
|
||||
|
||||
# 2. 导入RAG工具函数
|
||||
from app.core.rag_utils import generate_chunk_insight
|
||||
|
||||
# 3. 生成洞察
|
||||
insight = await generate_chunk_insight(chunks, max_chunks=limit)
|
||||
|
||||
result = {
|
||||
"insight": insight
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取chunk洞察")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
582
api/app/services/memory_konwledges_server.py
Normal file
582
api/app/services/memory_konwledges_server.py
Normal file
@@ -0,0 +1,582 @@
|
||||
# 修改 memory_konwledges_server.py 文件
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema, document_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from app.models.document_model import Document
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.user_model import User
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.db import get_db
|
||||
# 创建一个简单的用户类用于测试
|
||||
api_logger = get_api_logger()
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
content: str
|
||||
class SimpleUser:
|
||||
def __init__(self, user_id: str):
|
||||
# 确保ID是UUID类型
|
||||
self.id = user_id
|
||||
self.username = user_id
|
||||
|
||||
'''解析'''
|
||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||
"""
|
||||
解析指定文档
|
||||
|
||||
Args:
|
||||
document_id: 文档ID
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
dict: 包含任务ID的结果字典
|
||||
|
||||
Raises:
|
||||
HTTPException: 当文档、文件或知识库不存在时抛出异常
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 检查文档是否存在
|
||||
api_logger.debug(f"检查文档是否存在: {document_id}")
|
||||
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
|
||||
|
||||
if not db_document:
|
||||
api_logger.warning(f"文档不存在或无访问权限: document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文档不存在或无访问权限"
|
||||
)
|
||||
|
||||
# 2. 检查文件是否存在
|
||||
api_logger.debug(f"检查文件是否存在: {db_document.file_id}")
|
||||
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
|
||||
|
||||
if not db_file:
|
||||
api_logger.warning(f"文件不存在或无访问权限: file_id={db_document.file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件不存在或无访问权限"
|
||||
)
|
||||
|
||||
# 3. 构建文件路径:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
|
||||
file_path = os.path.join(
|
||||
settings.FILE_PATH,
|
||||
str(db_file.kb_id),
|
||||
str(db_file.parent_id),
|
||||
f"{db_file.id}{db_file.file_ext}"
|
||||
)
|
||||
|
||||
# 4. 检查文件是否存在于磁盘上
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"文件未找到(可能已被删除): file_path={file_path}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文件未找到(可能已被删除)"
|
||||
)
|
||||
|
||||
# 5. 获取知识库信息
|
||||
api_logger.info(f"获取知识库详情: knowledge_id={db_document.kb_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id,
|
||||
current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"知识库不存在或访问被拒绝: knowledge_id={db_document.kb_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="知识库不存在或访问被拒绝"
|
||||
)
|
||||
|
||||
# 6. 发送解析任务到Celery后台队列
|
||||
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
|
||||
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
|
||||
api_logger.info(f"文档解析任务已接受: document_id={document_id}, task_id={task.id}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
'''获取块ID'''
|
||||
async def get_document_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
keywords: Optional[str] = None,
|
||||
db: Session = None,
|
||||
current_user: User = None
|
||||
):
|
||||
"""
|
||||
分页查询文档块列表
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
document_id: 文档ID
|
||||
page: 页码,默认为1
|
||||
pagesize: 每页大小,默认为20
|
||||
keywords: 用于匹配块内容的关键字
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
dict: 包含分页数据的响应结果
|
||||
|
||||
Raises:
|
||||
HTTPException: 当知识库不存在或查询失败时抛出异常
|
||||
"""
|
||||
api_logger.info(
|
||||
f"分页查询文档块列表: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 参数验证
|
||||
if page < 1 or pagesize < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="分页参数必须大于0"
|
||||
)
|
||||
|
||||
# 获取知识库信息
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="知识库不存在或访问被拒绝"
|
||||
)
|
||||
|
||||
# 执行分页查询
|
||||
try:
|
||||
api_logger.debug(f"开始执行文档块查询")
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(
|
||||
document_id=str(document_id),
|
||||
query=keywords,
|
||||
pagesize=pagesize,
|
||||
page=page,
|
||||
asc=True
|
||||
)
|
||||
api_logger.info(f"文档块查询成功: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"文档块查询失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 构造响应结果
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
|
||||
return success(data=result, msg="文档块列表查询成功")
|
||||
|
||||
'''查找文档ID'''
|
||||
def find_document_id_by_kb_and_filename(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
file_name: str
|
||||
) -> str | None:
|
||||
"""
|
||||
通过 kb_id 和 file_name 在 documents 表中查找对应的 ID
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
kb_id: 知识库ID
|
||||
file_name: 文件名
|
||||
|
||||
Returns:
|
||||
str | None: 找到的 document ID,未找到返回 None
|
||||
"""
|
||||
try:
|
||||
# 查询 documents 表
|
||||
document = db.query(Document).filter(
|
||||
Document.kb_id == kb_id,
|
||||
Document.file_name == file_name
|
||||
).first()
|
||||
|
||||
if document:
|
||||
print(f"找到文档: ID={document.id}, kb_id={kb_id}, file_name={file_name}")
|
||||
return str(document.id)
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
'''获取知识库ID'''
|
||||
def find_documents_by_kb_id(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
limit: int = 10
|
||||
) -> list[dict]:
|
||||
"""
|
||||
通过 kb_id 查找所有相关文档
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
kb_id: 知识库ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
list[dict]: 文档列表,包含 id, file_name, created_at 等信息
|
||||
"""
|
||||
try:
|
||||
documents = db.query(Document).filter(
|
||||
Document.kb_id == kb_id
|
||||
).limit(limit).all()
|
||||
|
||||
result = []
|
||||
for doc in documents:
|
||||
result.append({
|
||||
"id": str(doc.id),
|
||||
"file_name": doc.file_name,
|
||||
"file_ext": doc.file_ext,
|
||||
"file_size": doc.file_size,
|
||||
"created_at": doc.created_at.isoformat() if doc.created_at else None,
|
||||
"status": getattr(doc, 'status', None)
|
||||
})
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
''''上传文件'''
|
||||
async def memory_konwledges_up(
|
||||
kb_id: str,
|
||||
parent_id: str,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
||||
):
|
||||
# 如果没有提供current_user,则创建一个默认的
|
||||
if current_user is None:
|
||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
||||
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The content is empty."
|
||||
)
|
||||
|
||||
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
|
||||
)
|
||||
|
||||
upload_file = file_schema.FileCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id, # 现在是UUID类型
|
||||
parent_id=parent_id,
|
||||
file_name=f"{create_data.title}.txt",
|
||||
file_ext=".txt",
|
||||
file_size=file_size,
|
||||
)
|
||||
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
|
||||
|
||||
# Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension}
|
||||
# 使用 settings.FILE_PATH 确保与 parse_document_by_id 一致
|
||||
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
|
||||
|
||||
# 确保目录存在
|
||||
Path(save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as f:
|
||||
f.write(content_bytes)
|
||||
|
||||
# Verify whether the file has been saved successfully
|
||||
if not os.path.exists(save_path):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="File save failed"
|
||||
)
|
||||
|
||||
# Create a document
|
||||
create_document_data = document_schema.DocumentCreate(
|
||||
kb_id=kb_id,
|
||||
created_by=current_user.id,
|
||||
file_id=db_file.id,
|
||||
file_name=db_file.file_name,
|
||||
file_ext=db_file.file_ext,
|
||||
file_size=db_file.file_size,
|
||||
file_meta={},
|
||||
parser_id="naive",
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 128,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": "false"
|
||||
}
|
||||
)
|
||||
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
|
||||
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
'''添加新块'''
|
||||
|
||||
|
||||
async def create_document_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
create_data: ChunkCreate,
|
||||
db: Session,
|
||||
current_user: User
|
||||
):
|
||||
"""
|
||||
创建文档块
|
||||
|
||||
Args:
|
||||
kb_id: 知识库ID
|
||||
document_id: 文档ID
|
||||
create_data: 创建数据
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
dict: 包含创建的文档块信息的成功响应
|
||||
|
||||
Raises:
|
||||
HTTPException: 当知识库或文档不存在时抛出相应异常
|
||||
"""
|
||||
api_logger.info(
|
||||
f"创建文档块请求: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||||
|
||||
# 1. 获取知识库信息
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="知识库不存在或访问被拒绝"
|
||||
)
|
||||
|
||||
# 2. 获取文档信息
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
if not db_document:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="文档不存在或您无访问权限"
|
||||
)
|
||||
|
||||
# 3. 初始化向量服务
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
# 4. 获取排序ID(处理索引不存在的情况)
|
||||
sort_id = 0
|
||||
try:
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
|
||||
if items:
|
||||
sort_id = items[0].metadata["sort_id"]
|
||||
except Exception as e:
|
||||
# 如果索引不存在,从 0 开始
|
||||
error_msg = str(e)
|
||||
if "index_not_found_exception" in error_msg or "no such index" in error_msg:
|
||||
api_logger.warning(f"索引不存在,将从 sort_id=0 开始: {error_msg}")
|
||||
sort_id = 0
|
||||
else:
|
||||
# 其他错误则抛出
|
||||
api_logger.error(f"查询文档块失败: {error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询文档块失败: {error_msg}"
|
||||
)
|
||||
|
||||
sort_id = sort_id + 1
|
||||
|
||||
# 5. 创建文档块
|
||||
doc_id = uuid.uuid4().hex
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(document_id),
|
||||
"knowledge_id": str(kb_id),
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||||
|
||||
# 6. 存储向量化的文档块(这会自动创建索引如果不存在)
|
||||
try:
|
||||
vector_service.add_chunks([chunk])
|
||||
except Exception as e:
|
||||
api_logger.error(f"添加文档块到向量库失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"添加文档块到向量库失败: {str(e)}"
|
||||
)
|
||||
|
||||
# 7. 更新 chunk_num
|
||||
db_document.chunk_num += 1
|
||||
db.commit()
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
|
||||
async def write_rag(group_id, message, user_rag_memory_id):
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
|
||||
Args:
|
||||
group_id: 组ID,用作文件标题
|
||||
message: 消息内容
|
||||
user_rag_memory_id: 知识库ID(必须是有效的UUID)
|
||||
|
||||
Returns:
|
||||
写入结果
|
||||
|
||||
Raises:
|
||||
HTTPException: 当参数无效或操作失败时
|
||||
"""
|
||||
# 验证 user_rag_memory_id 是否为有效的 UUID
|
||||
if not user_rag_memory_id:
|
||||
api_logger.error("user_rag_memory_id 为空,无法执行 RAG 写入操作")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="知识库ID不能为空"
|
||||
)
|
||||
|
||||
try:
|
||||
# 尝试将字符串转换为 UUID 以验证格式
|
||||
kb_uuid = uuid.UUID(user_rag_memory_id)
|
||||
except (ValueError, AttributeError) as e:
|
||||
api_logger.error(f"user_rag_memory_id 不是有效的UUID: {user_rag_memory_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||
)
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
create_data = CustomTextFileCreate(title=group_id, content=message)
|
||||
current_user = SimpleUser(user_rag_memory_id)
|
||||
# 检查文档是否已存在
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt")
|
||||
print('======',document)
|
||||
api_logger.info(f"查找文档结果: document_id={document}")
|
||||
if document is not None:
|
||||
# 文档已存在,直接添加新块
|
||||
api_logger.info(f"文档已存在,添加新块: document_id={document}")
|
||||
|
||||
create_chunks = ChunkCreate(content=message)
|
||||
result = await create_document_chunk(
|
||||
kb_id=kb_uuid,
|
||||
document_id=uuid.UUID(document),
|
||||
create_data=create_chunks,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 文档不存在,创建新文档
|
||||
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}")
|
||||
result = await memory_konwledges_up(
|
||||
kb_id=user_rag_memory_id,
|
||||
parent_id=user_rag_memory_id,
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
await parse_document_by_id(document, db=db, current_user=current_user)
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
db.close()
|
||||
# 在异步环境中调用示例
|
||||
|
||||
|
||||
async def example_usage():
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 创建 CustomTextFileCreate 对象
|
||||
title = '2f6ff1eb-50c7-4765-8e89-e4566be19122'
|
||||
create_data = CustomTextFileCreate(
|
||||
title=title,
|
||||
content="莫扎特在巴黎经历母亲去世后返回萨尔茨堡,他随后创作的交响曲主题是否与格鲁克在维也纳推动的“改革歌剧”理念存在共通之处?贝多芬早年曾师从海顿,而海顿又受雇于埃斯特哈齐家族——这种师承体系如何影响了当时欧洲宫廷音乐的传承结构?斯卡拉歌剧院选择萨列里的歌剧作为开幕演出,是否与当时米兰政治环境和奥地利宫廷影响有关?"
|
||||
)
|
||||
|
||||
# 创建用户对象
|
||||
current_user = SimpleUser("6243c125-9420-402c-bbb5-d1977811ac96")
|
||||
|
||||
# 上传文件
|
||||
result = await memory_konwledges_up(
|
||||
kb_id="c71df60a-36a6-4759-a2ce-101e3087b401",
|
||||
parent_id="c71df60a-36a6-4759-a2ce-101e3087b401",
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
print(result)
|
||||
#找到document_id
|
||||
|
||||
# 使用刚创建的文档ID进行解析
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", file_name=f"{title}.txt")
|
||||
print('====',document)
|
||||
res___=await parse_document_by_id(document, db=db, current_user=current_user)
|
||||
print(res___)
|
||||
|
||||
# result='e8cf9ace-d1a9-4af2-b0c4-3fc94f4f8042'
|
||||
# document_id='d22e8173-50d0-4e10-a7de-aa638ef893bc'
|
||||
#
|
||||
# '''更新块'''
|
||||
#
|
||||
# new_content = "这是新的 chunk 内容,用来覆盖原来的内容"
|
||||
# # 构造 ChunkUpdate 对象
|
||||
# update_data = ChunkCreate(content=new_content)
|
||||
# updated_chunk = await create_document_chunk(
|
||||
# kb_id= result,
|
||||
# document_id=document_id,
|
||||
# create_data= update_data,
|
||||
# db=db,
|
||||
# current_user=current_user
|
||||
# )
|
||||
# print(updated_chunk)
|
||||
return '','',''
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(example_usage())
|
||||
asyncio.run(write_rag('1111','22222',"c71df60a-36a6-4759-a2ce-101e3087b401"))
|
||||
568
api/app/services/memory_storage_service.py
Normal file
568
api/app/services/memory_storage_service.py
Normal file
@@ -0,0 +1,568 @@
|
||||
"""
|
||||
Memory Storage Service
|
||||
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
import json
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigPilotRun,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Load environment variables for Neo4j connector
|
||||
load_dotenv()
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
|
||||
|
||||
class MemoryStorageService:
|
||||
"""Service for memory storage operations"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("MemoryStorageService initialized")
|
||||
|
||||
async def get_storage_info(self) -> dict:
|
||||
"""
|
||||
Example wrapper method - retrieves storage information
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
Storage information dictionary
|
||||
"""
|
||||
logger.info(f"Getting storage info ")
|
||||
|
||||
# Empty wrapper - implement your logic here
|
||||
result = {
|
||||
"status": "active",
|
||||
"message": "This is an example wrapper"
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"""Service layer for config params CRUD.
|
||||
|
||||
The DB connection is optional; when absent, methods return a failure
|
||||
response containing an SQL preview to aid integration.
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn: Optional[Any] = None) -> None:
|
||||
self.db_conn = db_conn
|
||||
|
||||
# --- Driver compatibility helpers ---
|
||||
@staticmethod
|
||||
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
|
||||
mod = type(conn).__module__
|
||||
return ("psycopg2" in mod) or ("psycopg" in mod)
|
||||
|
||||
@staticmethod
|
||||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||||
from datetime import datetime
|
||||
|
||||
for item in data_list:
|
||||
for field in ['created_at', 'updated_at']:
|
||||
if field in item and item[field] is not None:
|
||||
value = item[field]
|
||||
dt = None
|
||||
|
||||
# 如果是 datetime 对象,直接使用
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
# 如果是字符串,先解析
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
pass # 保持原值
|
||||
|
||||
# 转换为 YYYYMMDDHHmmss 格式
|
||||
if dt:
|
||||
item[field] = dt.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
return data_list
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||
configs = self._get_workspace_configs(params.workspace_id)
|
||||
if configs is None:
|
||||
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
|
||||
|
||||
# 只在未指定时填充(允许手动覆盖)
|
||||
if not params.llm_id:
|
||||
params.llm_id = configs.get('llm')
|
||||
if not params.embedding_id:
|
||||
params.embedding_id = configs.get('embedding')
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
query, qparams = DataConfigRepository.build_insert(params)
|
||||
cur = self.db_conn.cursor()
|
||||
# PostgreSQL 使用 psycopg2 的命名参数格式
|
||||
cur.execute(query, qparams)
|
||||
self.db_conn.commit()
|
||||
return {"affected": getattr(cur, "rowcount", None)}
|
||||
|
||||
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
|
||||
"""获取工作空间模型配置(内部方法,便于测试)"""
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||||
|
||||
db_session = SessionLocal()
|
||||
try:
|
||||
return get_workspace_models_configs(db_session, workspace_id)
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称)
|
||||
query, qparams = DataConfigRepository.build_delete(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
# 如果没有任何行被删除,抛出异常
|
||||
if not affected:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
# --- Update ---
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
query, qparams = DataConfigRepository.build_update(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
|
||||
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
query, qparams = DataConfigRepository.build_update_extracted(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
|
||||
query, qparams = DataConfigRepository.build_update_forget(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_extracted(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_forget(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
query, qparams = DataConfigRepository.build_select_all(workspace_id)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
rows = cur.fetchall()
|
||||
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
|
||||
if not rows:
|
||||
return []
|
||||
# Map rows to list of dicts
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
data_list = [dict(zip(columns, row)) for row in rows]
|
||||
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
for item in data_list:
|
||||
if 'workspace_id' in item and item['workspace_id'] is not None:
|
||||
item['workspace_id'] = str(item['workspace_id'])
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
|
||||
"""
|
||||
选择策略与内存覆写与同步版保持一致:优先 payload.config_id,其次 dbrun.json;两者皆无时报错。
|
||||
支持 dialogue_text 参数用于试运行模式。
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
|
||||
# 导入并 await 主管线(使用当前 ASGI 事件循环)
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.core.memory.utils.self_reflexion_utils import reflexion
|
||||
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
# "data": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在谷歌工作。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "10",
|
||||
# "created_at": "2023-01-01",
|
||||
# "expired_at": "2023-01-02",
|
||||
# "valid_at": "2023-01-01",
|
||||
# "invalid_at": "2023-01-02",
|
||||
# "entity_ids": []
|
||||
# },
|
||||
# "conflict": True,
|
||||
# "conflict_memory": {
|
||||
# "id": "1",
|
||||
# "statement": "张明现在在清华大学当讲师。",
|
||||
# "group_id": "1",
|
||||
# "chunk_id": "1",
|
||||
# "created_at": "2019-12-01T19:15:05.213210",
|
||||
# "expired_at": None,
|
||||
# "valid_at": None,
|
||||
# "invalid_at": None,
|
||||
# "entity_ids": []
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
from app.core.memory.utils.config.get_example_data import get_example_data
|
||||
data = get_example_data()
|
||||
reflexion_result = await reflexion(data)
|
||||
|
||||
# 读取输出,使用全局配置路径
|
||||
from app.core.config import settings
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
"time_log": os.path.join(project_root, "time.log"),
|
||||
"extracted_result": extracted_result,
|
||||
}
|
||||
|
||||
|
||||
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
|
||||
# Ensure env for connector (e.g., NEO4J_PASSWORD)
|
||||
load_dotenv()
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
|
||||
|
||||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_DIALOGUE,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
||||
return data
|
||||
|
||||
|
||||
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_CHUNK,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "chunk", "num": result[0]["num"]}
|
||||
return data
|
||||
|
||||
|
||||
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_STATEMENT,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "statement", "num": result[0]["num"]}
|
||||
return data
|
||||
|
||||
|
||||
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ENTITY,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "entity", "num": result[0]["num"]}
|
||||
return data
|
||||
|
||||
|
||||
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
if not result or len(result) < 4:
|
||||
data = {
|
||||
"total": 0,
|
||||
"counts": {
|
||||
"dialogue": 0,
|
||||
"chunk": 0,
|
||||
"statement": 0,
|
||||
"entity": 0,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
data = {
|
||||
"total": result[-1]["Count"],
|
||||
"counts": {
|
||||
"dialogue": result[0]["Count"],
|
||||
"chunk": result[1]["Count"],
|
||||
"statement": result[2]["Count"],
|
||||
"entity": result[3]["Count"],
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""统一知识库类型分布接口。
|
||||
|
||||
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
if not result or len(result) < 4:
|
||||
data = {
|
||||
"total": 0,
|
||||
"distribution": [
|
||||
{"type": "dialogue", "count": 0},
|
||||
{"type": "chunk", "count": 0},
|
||||
{"type": "statement", "count": 0},
|
||||
{"type": "entity", "count": 0},
|
||||
]
|
||||
}
|
||||
return data
|
||||
|
||||
total = result[-1]["Count"]
|
||||
distribution = [
|
||||
{"type": "dialogue", "count": result[0]["Count"]},
|
||||
{"type": "chunk", "count": result[1]["Count"]},
|
||||
{"type": "statement", "count": result[2]["Count"]},
|
||||
{"type": "entity", "count": result[3]["Count"]},
|
||||
]
|
||||
|
||||
data = {"total": total, "distribution": distribution}
|
||||
return data
|
||||
|
||||
|
||||
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_DETIALS,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_EDGES,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""搜索所有实体之间的关系网络(group 维度)。"""
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
||||
for item in result:
|
||||
source_fact = item["sourceNode"]["fact_summary"]
|
||||
target_fact = item["targetNode"]["fact_summary"]
|
||||
# 截取前三条“来源”
|
||||
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
|
||||
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
|
||||
# 与现有返回风格保持一致,携带搜索类型、数量与详情
|
||||
data = {
|
||||
"search_for": "entity_graph",
|
||||
"num": len(result),
|
||||
"detials": result,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取热门记忆标签,按数量排序并返回前N个
|
||||
"""
|
||||
# 获取更多标签供LLM筛选(获取limit*4个标签)
|
||||
raw_limit = limit * 4
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
|
||||
|
||||
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
|
||||
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 只返回前limit个
|
||||
top_tags = sorted_tags[:limit]
|
||||
|
||||
return [{"name": t, "frequency": f} for t, f in top_tags]
|
||||
|
||||
|
||||
async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
insight = MemoryInsight(end_user_id)
|
||||
report = await insight.generate_insight_report()
|
||||
await insight.close()
|
||||
data = {"report": report}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
stats, _msg = get_recent_activity_stats()
|
||||
total = (
|
||||
stats.get("chunk_count", 0)
|
||||
+ stats.get("statements_count", 0)
|
||||
+ stats.get("triplet_entities_count", 0)
|
||||
+ stats.get("triplet_relations_count", 0)
|
||||
+ stats.get("temporal_count", 0)
|
||||
)
|
||||
# 精简:仅提供“最新一次活动多久前”
|
||||
latest_relative = None
|
||||
try:
|
||||
info = stats.get("log_path", "")
|
||||
idx = info.rfind("最新:")
|
||||
if idx != -1:
|
||||
latest_path = info[idx + 3 :].strip()
|
||||
if latest_path and os.path.exists(latest_path):
|
||||
import time
|
||||
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
|
||||
m = int(diff // 60)
|
||||
if m < 1:
|
||||
latest_relative = "刚刚"
|
||||
elif m < 60:
|
||||
latest_relative = f"{m}分钟前"
|
||||
else:
|
||||
h = int(m // 60)
|
||||
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
summary = await generate_user_summary(end_user_id)
|
||||
data = {"summary": summary}
|
||||
return data
|
||||
160
api/app/services/model_parameter_merger.py
Normal file
160
api/app/services/model_parameter_merger.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
模型参数合并器
|
||||
|
||||
用于合并 ModelConfig 和 AgentConfig 中的模型参数,
|
||||
AgentConfig 中的参数优先级更高,可以覆盖 ModelConfig 的默认参数。
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ModelParameterMerger:
|
||||
"""模型参数合并器"""
|
||||
|
||||
@staticmethod
|
||||
def merge_parameters(
|
||||
model_config_params: Optional[Dict[str, Any]],
|
||||
agent_config_params: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并模型配置参数和 Agent 配置参数
|
||||
|
||||
优先级:agent_config_params > model_config_params > 默认值
|
||||
|
||||
Args:
|
||||
model_config_params: ModelConfig.config 中的参数
|
||||
agent_config_params: AgentConfig.model_parameters 中的参数
|
||||
|
||||
Returns:
|
||||
合并后的参数字典
|
||||
|
||||
Example:
|
||||
>>> model_params = {"temperature": 0.5, "max_tokens": 1000}
|
||||
>>> agent_params = {"temperature": 0.8}
|
||||
>>> merged = ModelParameterMerger.merge_parameters(model_params, agent_params)
|
||||
>>> merged
|
||||
{"temperature": 0.8, "max_tokens": 1000}
|
||||
"""
|
||||
# 默认参数
|
||||
default_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
"top_p": 1.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"n": 1,
|
||||
"stop": None
|
||||
}
|
||||
|
||||
# 合并参数:默认值 -> 模型配置 -> Agent 配置
|
||||
merged = default_params.copy()
|
||||
|
||||
# 应用模型配置参数
|
||||
if model_config_params:
|
||||
for key in default_params.keys():
|
||||
if key in model_config_params:
|
||||
merged[key] = model_config_params[key]
|
||||
|
||||
# 应用 Agent 配置参数(优先级最高)
|
||||
if agent_config_params:
|
||||
for key in default_params.keys():
|
||||
if key in agent_config_params and agent_config_params[key] is not None:
|
||||
merged[key] = agent_config_params[key]
|
||||
|
||||
# 移除 None 值
|
||||
merged = {k: v for k, v in merged.items() if v is not None}
|
||||
|
||||
logger.debug(
|
||||
f"参数合并完成",
|
||||
extra={
|
||||
"model_params": model_config_params,
|
||||
"agent_params": agent_config_params,
|
||||
"merged": merged
|
||||
}
|
||||
)
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def get_effective_parameters(
|
||||
model_config: Optional[Any],
|
||||
agent_config: Optional[Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取有效的模型参数(从 ORM 对象中提取并合并)
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig ORM 对象
|
||||
agent_config: AgentConfig ORM 对象
|
||||
|
||||
Returns:
|
||||
合并后的参数字典
|
||||
"""
|
||||
# 提取模型配置参数
|
||||
model_params = None
|
||||
if model_config and hasattr(model_config, 'config'):
|
||||
model_params = model_config.config
|
||||
|
||||
# 提取 Agent 配置参数
|
||||
agent_params = None
|
||||
if agent_config and hasattr(agent_config, 'model_parameters'):
|
||||
agent_params = agent_config.model_parameters
|
||||
|
||||
return ModelParameterMerger.merge_parameters(model_params, agent_params)
|
||||
|
||||
@staticmethod
|
||||
def format_for_llm_call(parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化参数用于 LLM API 调用
|
||||
|
||||
不同的 LLM 提供商可能需要不同的参数格式,
|
||||
这个方法可以根据需要进行转换。
|
||||
|
||||
Args:
|
||||
parameters: 合并后的参数字典
|
||||
|
||||
Returns:
|
||||
格式化后的参数字典
|
||||
"""
|
||||
# 基本格式化(可以根据不同提供商扩展)
|
||||
formatted = parameters.copy()
|
||||
|
||||
# 确保参数在有效范围内
|
||||
if "temperature" in formatted:
|
||||
formatted["temperature"] = max(0.0, min(2.0, formatted["temperature"]))
|
||||
|
||||
if "max_tokens" in formatted:
|
||||
formatted["max_tokens"] = max(1, min(32000, formatted["max_tokens"]))
|
||||
|
||||
if "top_p" in formatted:
|
||||
formatted["top_p"] = max(0.0, min(1.0, formatted["top_p"]))
|
||||
|
||||
if "frequency_penalty" in formatted:
|
||||
formatted["frequency_penalty"] = max(-2.0, min(2.0, formatted["frequency_penalty"]))
|
||||
|
||||
if "presence_penalty" in formatted:
|
||||
formatted["presence_penalty"] = max(-2.0, min(2.0, formatted["presence_penalty"]))
|
||||
|
||||
if "n" in formatted:
|
||||
formatted["n"] = max(1, min(10, formatted["n"]))
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
def merge_model_parameters(
|
||||
model_config_params: Optional[Dict[str, Any]],
|
||||
agent_config_params: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
合并模型参数的便捷函数
|
||||
|
||||
Args:
|
||||
model_config_params: ModelConfig.config 中的参数
|
||||
agent_config_params: AgentConfig.model_parameters 中的参数
|
||||
|
||||
Returns:
|
||||
合并后的参数字典
|
||||
"""
|
||||
return ModelParameterMerger.merge_parameters(model_config_params, agent_config_params)
|
||||
409
api/app/services/model_service.py
Normal file
409
api/app/services/model_service.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import math
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
|
||||
ModelConfigQuery, ModelStats
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ModelConfigService:
|
||||
"""模型配置服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_id(db: Session, model_id: uuid.UUID) -> ModelConfig:
|
||||
"""根据ID获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_id(db, model_id)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(db: Session, query: ModelConfigQuery) -> PageData:
|
||||
"""获取模型配置列表"""
|
||||
models, total = ModelConfigRepository.get_list(db, query)
|
||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||
|
||||
return PageData(
|
||||
page=PageMeta(
|
||||
page=query.page,
|
||||
pagesize=query.pagesize,
|
||||
total=total,
|
||||
hasnext=query.page < pages
|
||||
),
|
||||
items=[model_schema.ModelConfig.model_validate(model) for model in models]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello"
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
model_name: 模型名称
|
||||
provider: 提供商
|
||||
api_key: API密钥
|
||||
api_base: API基础URL
|
||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||
test_message: 测试消息
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
"""
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
import traceback
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# 根据模型类型选择不同的验证方式
|
||||
model_type_lower = model_type.lower()
|
||||
|
||||
if model_type_lower in ["llm", "chat"]:
|
||||
# LLM/Chat 模型验证 - 统一使用字符串输入
|
||||
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
|
||||
response = await llm.ainvoke(test_message)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
content = response.content if hasattr(response, 'content') else str(response)
|
||||
usage = None
|
||||
if hasattr(response, 'usage_metadata'):
|
||||
usage = {
|
||||
"input_tokens": getattr(response.usage_metadata, 'input_tokens', 0),
|
||||
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
|
||||
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
|
||||
}
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": f"{model_type.upper()} 模型配置验证成功",
|
||||
"response": content,
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": usage,
|
||||
"error": None
|
||||
}
|
||||
|
||||
elif model_type_lower == "embedding":
|
||||
# Embedding 模型验证(在线程中运行同步方法)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Embedding 模型配置验证成功",
|
||||
"response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}",
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": {
|
||||
"input_tokens": len(test_message),
|
||||
"vector_count": len(vectors),
|
||||
"vector_dimension": len(vectors[0]) if vectors else 0
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
elif model_type_lower == "rerank":
|
||||
# Rerank 模型验证(在线程中运行同步方法)
|
||||
rerank = RedBearRerank(model_config)
|
||||
query = test_message
|
||||
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
|
||||
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "Rerank 模型配置验证成功",
|
||||
"response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果",
|
||||
"elapsed_time": elapsed_time,
|
||||
"usage": {
|
||||
"query_length": len(query),
|
||||
"document_count": len(documents),
|
||||
"result_count": len(results) if results else 0
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"message": "不支持的模型类型",
|
||||
"response": None,
|
||||
"elapsed_time": None,
|
||||
"usage": None,
|
||||
"error": f"不支持的模型类型: {model_type}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制"
|
||||
elif "ValidationException" in error_type or "ValidationException" in error_message:
|
||||
# 其他验证错误
|
||||
if "access denied" in error_message.lower():
|
||||
error_message = "访问被拒绝: 请检查 API 凭证和权限配置"
|
||||
else:
|
||||
error_message = f"验证失败: {error_message}"
|
||||
elif "AuthenticationError" in error_type or "authentication" in error_message.lower():
|
||||
error_message = "认证失败: API Key 无效或已过期"
|
||||
elif "RateLimitError" in error_type or "rate limit" in error_message.lower():
|
||||
error_message = "请求频率限制: 已超过 API 调用限制"
|
||||
elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower():
|
||||
error_message = f"无效请求: {error_message}"
|
||||
elif "model_copy" in error_message:
|
||||
error_message = "模型消息格式错误: 请确保使用正确的模型类型(LLM/Chat)"
|
||||
|
||||
# 记录详细错误日志
|
||||
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
|
||||
logger.error(f"错误详情: {error_message}")
|
||||
logger.debug(f"完整堆栈: {traceback.format_exc()}")
|
||||
|
||||
return {
|
||||
"valid": False,
|
||||
"message": f"{model_type.upper()} 模型配置验证失败",
|
||||
"response": None,
|
||||
"elapsed_time": None,
|
||||
"usage": None,
|
||||
"error": error_message,
|
||||
"error_type": error_type
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def create_model(db: Session, model_data: ModelConfigCreate) -> ModelConfig:
|
||||
"""创建模型配置"""
|
||||
# 检查名称是否已存在
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证配置
|
||||
if not model_data.skip_validation and model_data.api_keys:
|
||||
api_key_data = model_data.api_keys
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
# 事务处理
|
||||
api_key_data = model_data.api_keys
|
||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush() # 获取生成的 ID
|
||||
|
||||
if api_key_data:
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_id=model.id,
|
||||
**api_key_data.dict()
|
||||
)
|
||||
ModelApiKeyRepository.create(db, api_key_create_schema)
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(db: Session, model_id: uuid.UUID) -> bool:
|
||||
"""删除模型配置"""
|
||||
if not ModelConfigRepository.get_by_id(db, model_id):
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
success = ModelConfigRepository.delete(db, model_id)
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def get_model_stats(db: Session) -> ModelStats:
|
||||
"""获取模型统计信息"""
|
||||
stats_data = ModelConfigRepository.get_stats(db)
|
||||
return ModelStats(
|
||||
total_models=stats_data["total_models"],
|
||||
active_models=stats_data["active_models"],
|
||||
llm_count=stats_data["llm_count"],
|
||||
embedding_count=stats_data["embedding_count"],
|
||||
rerank_count=stats_data["rerank_count"],
|
||||
provider_stats=stats_data["provider_stats"]
|
||||
)
|
||||
|
||||
|
||||
class ModelApiKeyService:
|
||||
"""模型API Key服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey:
|
||||
"""根据ID获取API Key"""
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
|
||||
"""根据模型配置ID获取API Key列表"""
|
||||
if not ModelConfigRepository.get_by_id(db, model_config_id):
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
|
||||
|
||||
@staticmethod
|
||||
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
|
||||
"""创建API Key"""
|
||||
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
api_key = ModelApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey:
|
||||
"""更新API Key"""
|
||||
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not existing_api_key:
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
)
|
||||
print(validation_result)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
f"模型配置验证失败: {validation_result['error']}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""删除API Key"""
|
||||
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
|
||||
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
|
||||
|
||||
success = ModelApiKeyRepository.delete(db, api_key_id)
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
@staticmethod
|
||||
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
|
||||
"""获取可用的API Key(按优先级和负载均衡)"""
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
|
||||
if not api_keys:
|
||||
return None
|
||||
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
|
||||
|
||||
@staticmethod
|
||||
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
||||
"""记录API Key使用"""
|
||||
success = ModelApiKeyRepository.update_usage(db, api_key_id)
|
||||
if success:
|
||||
db.commit()
|
||||
return success
|
||||
191
api/app/services/multi_agent_config_converter.py
Normal file
191
api/app/services/multi_agent_config_converter.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
多智能体配置格式转换器
|
||||
用于将 Pydantic 模型转换为数据库存储格式
|
||||
"""
|
||||
from typing import Dict, Any, Optional, List
|
||||
import uuid
|
||||
from app.schemas.multi_agent_schema import (
|
||||
SubAgentConfig,
|
||||
RoutingRule,
|
||||
ExecutionConfig,
|
||||
MultiAgentConfigCreate,
|
||||
MultiAgentConfigUpdate,
|
||||
)
|
||||
|
||||
|
||||
class MultiAgentConfigConverter:
|
||||
"""多智能体配置格式转换器"""
|
||||
|
||||
@staticmethod
|
||||
def to_storage_format(config: MultiAgentConfigCreate | MultiAgentConfigUpdate) -> Dict[str, Any]:
|
||||
"""
|
||||
将配置对象转换为数据库存储格式
|
||||
|
||||
Args:
|
||||
config: MultiAgentConfigCreate 或 MultiAgentConfigUpdate 对象
|
||||
|
||||
Returns:
|
||||
包含数据库字段的字典
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. 子 Agent 配置
|
||||
if hasattr(config, 'sub_agents') and config.sub_agents:
|
||||
result["sub_agents"] = [
|
||||
MultiAgentConfigConverter._convert_uuid_to_str(agent.model_dump())
|
||||
for agent in config.sub_agents
|
||||
]
|
||||
|
||||
# 2. 路由规则配置
|
||||
if hasattr(config, 'routing_rules') and config.routing_rules:
|
||||
result["routing_rules"] = [
|
||||
MultiAgentConfigConverter._convert_uuid_to_str(rule.model_dump())
|
||||
for rule in config.routing_rules
|
||||
]
|
||||
|
||||
# 3. 执行配置
|
||||
if hasattr(config, 'execution_config') and config.execution_config:
|
||||
result["execution_config"] = MultiAgentConfigConverter._convert_uuid_to_str(
|
||||
config.execution_config.model_dump()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_storage_format(
|
||||
sub_agents: Optional[List[Dict[str, Any]]],
|
||||
routing_rules: Optional[List[Dict[str, Any]]],
|
||||
execution_config: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
|
||||
Args:
|
||||
sub_agents: 子 Agent 配置列表
|
||||
routing_rules: 路由规则配置列表
|
||||
execution_config: 执行配置
|
||||
|
||||
Returns:
|
||||
包含 Pydantic 对象的字典
|
||||
"""
|
||||
result = {
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": None,
|
||||
}
|
||||
|
||||
# 1. 解析子 Agent 配置
|
||||
if sub_agents:
|
||||
result["sub_agents"] = [
|
||||
SubAgentConfig(**agent_data)
|
||||
for agent_data in sub_agents
|
||||
]
|
||||
|
||||
# 2. 解析路由规则配置
|
||||
if routing_rules:
|
||||
result["routing_rules"] = [
|
||||
RoutingRule(**rule_data)
|
||||
for rule_data in routing_rules
|
||||
]
|
||||
else:
|
||||
# 提供默认的空路由规则
|
||||
result["routing_rules"] = []
|
||||
|
||||
# 3. 解析执行配置
|
||||
if execution_config:
|
||||
result["execution_config"] = ExecutionConfig(**execution_config)
|
||||
else:
|
||||
# 提供默认的执行配置
|
||||
result["execution_config"] = ExecutionConfig(
|
||||
max_iterations=10,
|
||||
timeout=300,
|
||||
enable_parallel=False,
|
||||
error_handling="stop"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _convert_uuid_to_str(obj: Any) -> Any:
|
||||
"""
|
||||
递归转换对象中的所有 UUID 为字符串
|
||||
|
||||
Args:
|
||||
obj: 要转换的对象(dict, list, UUID 等)
|
||||
|
||||
Returns:
|
||||
转换后的对象
|
||||
"""
|
||||
if isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: MultiAgentConfigConverter._convert_uuid_to_str(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [MultiAgentConfigConverter._convert_uuid_to_str(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def enrich_with_published_configs(
|
||||
sub_agents: List[Dict[str, Any]],
|
||||
get_published_config_func
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为子 Agent 配置添加发布的 config_id
|
||||
|
||||
Args:
|
||||
sub_agents: 子 Agent 配置列表
|
||||
get_published_config_func: 获取发布配置的函数
|
||||
|
||||
Returns:
|
||||
增强后的子 Agent 配置列表
|
||||
"""
|
||||
enriched_agents = []
|
||||
|
||||
for agent in sub_agents:
|
||||
agent_copy = agent.copy()
|
||||
|
||||
# 获取该 Agent 当前发布的配置
|
||||
if 'agent_id' in agent:
|
||||
try:
|
||||
agent_id = uuid.UUID(agent['agent_id']) if isinstance(agent['agent_id'], str) else agent['agent_id']
|
||||
published_config = get_published_config_func(agent_id)
|
||||
|
||||
if published_config:
|
||||
agent_copy['published_config_id'] = str(published_config.get('id')) if isinstance(published_config, dict) else None
|
||||
except Exception as e:
|
||||
# 如果获取失败,记录但不中断
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.warning(f"获取 Agent {agent.get('agent_id')} 的发布配置失败: {e}")
|
||||
|
||||
enriched_agents.append(agent_copy)
|
||||
|
||||
return enriched_agents
|
||||
|
||||
@staticmethod
|
||||
def create_default_template(app_id: uuid.UUID) -> Dict[str, Any]:
|
||||
"""
|
||||
创建默认的多智能体配置模板
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
默认配置模板
|
||||
"""
|
||||
return {
|
||||
"app_id": str(app_id),
|
||||
"master_agent_id": None,
|
||||
"orchestration_mode": "sequential",
|
||||
"sub_agents": [],
|
||||
"routing_rules": [],
|
||||
"execution_config": {
|
||||
"max_iterations": 10,
|
||||
"timeout": 300,
|
||||
"enable_parallel": False,
|
||||
"error_handling": "stop"
|
||||
},
|
||||
"aggregation_strategy": "last",
|
||||
"is_active": False
|
||||
}
|
||||
1116
api/app/services/multi_agent_orchestrator.py
Normal file
1116
api/app/services/multi_agent_orchestrator.py
Normal file
File diff suppressed because it is too large
Load Diff
630
api/app/services/multi_agent_service.py
Normal file
630
api/app/services/multi_agent_service.py
Normal file
@@ -0,0 +1,630 @@
|
||||
"""多 Agent 配置管理服务"""
|
||||
import uuid
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, desc
|
||||
|
||||
from app.models import MultiAgentConfig, App, AgentConfig
|
||||
from app.schemas.multi_agent_schema import (
|
||||
MultiAgentConfigCreate,
|
||||
MultiAgentConfigUpdate,
|
||||
MultiAgentRunRequest
|
||||
)
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import AppRelease
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
def convert_uuids_to_str(obj: Any) -> Any:
|
||||
"""递归转换对象中的所有 UUID 为字符串
|
||||
|
||||
Args:
|
||||
obj: 要转换的对象(dict, list, UUID 等)
|
||||
|
||||
Returns:
|
||||
转换后的对象
|
||||
"""
|
||||
if isinstance(obj, uuid.UUID):
|
||||
return str(obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_uuids_to_str(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_uuids_to_str(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class MultiAgentService:
|
||||
"""多 Agent 配置管理服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: MultiAgentConfigCreate,
|
||||
created_by: uuid.UUID
|
||||
) -> MultiAgentConfig:
|
||||
"""创建多 Agent 配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
data: 配置数据
|
||||
created_by: 创建者 ID
|
||||
|
||||
Returns:
|
||||
多 Agent 配置
|
||||
"""
|
||||
# 1. 验证应用存在
|
||||
app = self.db.get(App, app_id)
|
||||
if not app:
|
||||
raise ResourceNotFoundException("应用", str(app_id))
|
||||
|
||||
# 2. 检查是否已有有效配置
|
||||
existing = self.db.scalars(
|
||||
select(MultiAgentConfig)
|
||||
.where(
|
||||
MultiAgentConfig.app_id == app_id,
|
||||
MultiAgentConfig.is_active == True
|
||||
)
|
||||
.order_by(MultiAgentConfig.updated_at.desc())
|
||||
).first()
|
||||
if existing:
|
||||
raise BusinessException("应用已有多 Agent 配置", BizCode.DUPLICATE_RESOURCE)
|
||||
|
||||
# 3. 验证主 Agent 存在
|
||||
master_agent = self.db.get(AgentConfig, data.master_agent_id)
|
||||
if not master_agent:
|
||||
raise ResourceNotFoundException("主 Agent", str(data.master_agent_id))
|
||||
|
||||
# 4. 验证子 Agent 存在
|
||||
for sub_agent in data.sub_agents:
|
||||
agent = self.db.get(AgentConfig, sub_agent.agent_id)
|
||||
if not agent:
|
||||
raise ResourceNotFoundException("子 Agent", str(sub_agent.agent_id))
|
||||
|
||||
# 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化)
|
||||
sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents]
|
||||
routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
|
||||
|
||||
# 处理 execution_config(可能是 None、字典或 Pydantic 模型)
|
||||
if data.execution_config is None:
|
||||
execution_config_data = {}
|
||||
elif isinstance(data.execution_config, dict):
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config)
|
||||
else:
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
|
||||
|
||||
config = MultiAgentConfig(
|
||||
app_id=app_id,
|
||||
master_agent_id=data.master_agent_id,
|
||||
master_agent_name=data.master_agent_name,
|
||||
orchestration_mode=data.orchestration_mode,
|
||||
sub_agents=sub_agents_data,
|
||||
routing_rules=routing_rules_data,
|
||||
execution_config=execution_config_data,
|
||||
aggregation_strategy=data.aggregation_strategy
|
||||
)
|
||||
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"创建多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id),
|
||||
"mode": data.orchestration_mode,
|
||||
"sub_agent_count": len(data.sub_agents)
|
||||
}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_config(self, app_id: uuid.UUID) -> Optional[MultiAgentConfig]:
|
||||
"""获取多 Agent 配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
多 Agent 配置,如果不存在返回 None
|
||||
"""
|
||||
return self.db.scalars(
|
||||
select(MultiAgentConfig)
|
||||
.where(
|
||||
MultiAgentConfig.app_id == app_id,
|
||||
MultiAgentConfig.is_active == True
|
||||
)
|
||||
.order_by(MultiAgentConfig.updated_at.desc())
|
||||
).first()
|
||||
|
||||
def get_multi_agent_configs(self, app_id: uuid.UUID) -> Optional[dict]:
|
||||
"""通过 app_id 获取最新有效的多智能体配置,并将 agent_id 转换为 app_id
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
转换后的配置字典,如果不存在返回 None
|
||||
"""
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
return None
|
||||
|
||||
# 转换 master_agent_id (release_id) 为 app_id
|
||||
master_release = self.db.get(AppRelease, config.master_agent_id)
|
||||
master_app_id = master_release.app_id if master_release else config.master_agent_id
|
||||
|
||||
# 转换 sub_agents 中的 agent_id (release_id) 为 app_id
|
||||
converted_sub_agents = []
|
||||
for sub_agent in config.sub_agents:
|
||||
sub_agent_copy = sub_agent.copy()
|
||||
release_id = sub_agent.get("agent_id")
|
||||
if release_id:
|
||||
try:
|
||||
release_id_uuid = uuid.UUID(release_id) if isinstance(release_id, str) else release_id
|
||||
sub_release = self.db.get(AppRelease, release_id_uuid)
|
||||
if sub_release:
|
||||
sub_agent_copy["agent_id"] = str(sub_release.app_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换 sub_agent agent_id 失败: {release_id}, 错误: {str(e)}")
|
||||
converted_sub_agents.append(sub_agent_copy)
|
||||
|
||||
# 构建返回的配置字典
|
||||
return {
|
||||
"id": config.id,
|
||||
"app_id": config.app_id,
|
||||
"master_agent_id": master_app_id,
|
||||
"master_agent_name": config.master_agent_name,
|
||||
"orchestration_mode": config.orchestration_mode,
|
||||
"sub_agents": converted_sub_agents,
|
||||
"routing_rules": config.routing_rules,
|
||||
"execution_config": config.execution_config,
|
||||
"aggregation_strategy": config.aggregation_strategy,
|
||||
"is_active": config.is_active,
|
||||
"created_at": config.created_at,
|
||||
"updated_at": config.updated_at
|
||||
}
|
||||
|
||||
def get_published_config_by_agent_id(self, agent_id: uuid.UUID) -> Optional[dict]:
|
||||
"""通过 agent_id 获取当前发布版本的完整配置
|
||||
|
||||
Args:
|
||||
agent_id: Agent 配置 ID
|
||||
|
||||
Returns:
|
||||
当前发布版本的配置字典,如果没有发布版本则返回 None
|
||||
"""
|
||||
from app.models import AppRelease
|
||||
|
||||
# 查询 Agent 配置
|
||||
agent_config = self.db.get(AgentConfig, agent_id)
|
||||
if not agent_config:
|
||||
logger.warning(f"Agent 配置不存在: {agent_id}")
|
||||
return None
|
||||
|
||||
# 获取关联的应用
|
||||
app = self.db.get(App, agent_config.app_id)
|
||||
if not app or not app.current_release_id:
|
||||
logger.warning(f"应用未发布或不存在: app_id={agent_config.app_id}")
|
||||
return None
|
||||
|
||||
# 获取当前发布版本
|
||||
release = self.db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
logger.warning(f"发布版本不存在: release_id={app.current_release_id}")
|
||||
return None
|
||||
|
||||
# 从发布版本的 config 中获取完整配置
|
||||
# config 是一个 JSON 对象,包含了发布时的配置快照
|
||||
config_data = release.config
|
||||
if config_data and isinstance(config_data, dict):
|
||||
return config_data
|
||||
|
||||
return None
|
||||
|
||||
def get_published_by_agent_id(self, agent_id: uuid.UUID) -> Optional[AppRelease]:
|
||||
"""通过 agent_id 获取当前发布版本的完整配置
|
||||
|
||||
Args:
|
||||
agent_id: Agent 配置 ID
|
||||
|
||||
Returns:
|
||||
当前发布版本的配置字典,如果没有发布版本则返回 None
|
||||
"""
|
||||
|
||||
# 获取关联的应用
|
||||
app = self.db.get(App, agent_id)
|
||||
if not app or not app.current_release_id:
|
||||
logger.warning(f"应用未发布或不存在: app_id={agent_id}")
|
||||
return None
|
||||
|
||||
# 获取当前发布版本
|
||||
release = self.db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
logger.warning(f"发布版本不存在: release_id={app.current_release_id}")
|
||||
return None
|
||||
return release
|
||||
|
||||
def update_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
data: MultiAgentConfigUpdate
|
||||
) -> MultiAgentConfig:
|
||||
"""更新多 Agent 配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
data: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的配置
|
||||
"""
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
# 1. 验证应用存在
|
||||
app = self.db.get(App, app_id)
|
||||
if not app:
|
||||
raise ResourceNotFoundException("应用", str(app_id))
|
||||
|
||||
# 2. 验证主 Agent 存在并获取发布版本 ID
|
||||
master_app_release = self.get_published_by_agent_id(data.master_agent_id)
|
||||
if not master_app_release:
|
||||
raise ResourceNotFoundException("主 Agent 未发布或不存在", str(data.master_agent_id))
|
||||
|
||||
# 使用发布版本 ID
|
||||
data.master_agent_id = master_app_release.id
|
||||
|
||||
# 3. 验证子 Agent 存在并获取发布版本 ID
|
||||
for sub_agent in data.sub_agents:
|
||||
agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id)
|
||||
if not agent_app_release:
|
||||
raise ResourceNotFoundException("子 Agent 未发布或不存在", str(sub_agent.agent_id))
|
||||
|
||||
# 使用发布版本 ID
|
||||
sub_agent.agent_id = agent_app_release.id
|
||||
|
||||
|
||||
# 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化)
|
||||
sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents]
|
||||
# routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
|
||||
|
||||
# 处理 execution_config(可能是 None、字典或 Pydantic 模型)
|
||||
if data.execution_config is None:
|
||||
execution_config_data = {}
|
||||
elif isinstance(data.execution_config, dict):
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config)
|
||||
else:
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
|
||||
|
||||
config = MultiAgentConfig(
|
||||
app_id=app_id,
|
||||
master_agent_id=data.master_agent_id,
|
||||
master_agent_name=data.master_agent_name,
|
||||
orchestration_mode=data.orchestration_mode,
|
||||
sub_agents=sub_agents_data,
|
||||
# routing_rules=routing_rules_data,
|
||||
execution_config=execution_config_data,
|
||||
aggregation_strategy=data.aggregation_strategy
|
||||
)
|
||||
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"创建多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id),
|
||||
"mode": data.orchestration_mode,
|
||||
"sub_agent_count": len(data.sub_agents)
|
||||
}
|
||||
)
|
||||
return config
|
||||
# raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
# 更新字段
|
||||
if data.master_agent_id is not None:
|
||||
# 验证主 Agent 存在
|
||||
# 3. 验证主 Agent 存在并获取发布配置
|
||||
master_app_release = self.get_published_by_agent_id(data.master_agent_id)
|
||||
if not master_app_release:
|
||||
raise ResourceNotFoundException("主 Agent 未发布或", str(data.master_agent_id))
|
||||
|
||||
config.master_agent_id = master_app_release.id
|
||||
|
||||
if data.master_agent_name is not None:
|
||||
config.master_agent_name = data.master_agent_name
|
||||
|
||||
if data.orchestration_mode is not None:
|
||||
config.orchestration_mode = data.orchestration_mode
|
||||
|
||||
if data.sub_agents is not None:
|
||||
# 验证子 Agent 存在,并获取其发布的 config_id
|
||||
updated_sub_agents = []
|
||||
for sub_agent in data.sub_agents:
|
||||
agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id)
|
||||
if not agent_app_release:
|
||||
raise ResourceNotFoundException("子 Agent 未发布或", str(sub_agent.agent_id))
|
||||
sub_agent.agent_id = agent_app_release.id
|
||||
sub_agent_dict = convert_uuids_to_str(sub_agent.model_dump())
|
||||
updated_sub_agents.append(sub_agent_dict)
|
||||
|
||||
config.sub_agents = updated_sub_agents
|
||||
|
||||
# if data.routing_rules is not None:
|
||||
# config.routing_rules = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
|
||||
|
||||
if data.execution_config is None:
|
||||
execution_config_data = {}
|
||||
elif isinstance(data.execution_config, dict):
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config)
|
||||
else:
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
|
||||
|
||||
if data.aggregation_strategy is not None:
|
||||
config.aggregation_strategy = data.aggregation_strategy
|
||||
|
||||
if data.is_active is not None:
|
||||
config.is_active = data.is_active
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"更新多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id)
|
||||
}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def delete_config(self, app_id: uuid.UUID) -> None:
|
||||
"""删除多 Agent 配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
"""
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
self.db.delete(config)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"删除多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id)
|
||||
}
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
request: MultiAgentRunRequest
|
||||
) -> dict:
|
||||
"""运行多 Agent 任务
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
request: 运行请求
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
# 1. 获取配置
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
if not config.is_active:
|
||||
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
# 3. 执行任务
|
||||
result = await orchestrator.execute(
|
||||
message=request.message,
|
||||
conversation_id=request.conversation_id,
|
||||
user_id=request.user_id,
|
||||
variables=request.variables,
|
||||
use_llm_routing=getattr(request, 'use_llm_routing', True), # 默认启用 LLM 路由
|
||||
web_search=getattr(request, 'web_search', False), # 网络搜索参数
|
||||
memory=getattr(request, 'memory', True) # 记忆功能参数
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
request: MultiAgentRunRequest,
|
||||
storage_type :str,
|
||||
user_rag_memory_id :str
|
||||
):
|
||||
"""运行多 Agent 任务(流式返回)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
request: 运行请求
|
||||
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
# 1. 获取配置
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
if not config.is_active:
|
||||
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
|
||||
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=request.message,
|
||||
conversation_id=request.conversation_id,
|
||||
user_id=request.user_id,
|
||||
variables=request.variables,
|
||||
use_llm_routing=getattr(request, 'use_llm_routing', True),
|
||||
web_search=getattr(request, 'web_search', False), # 网络搜索参数
|
||||
memory=getattr(request, 'memory', True) , # 记忆功能参数
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
def add_sub_agent(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
name: str,
|
||||
role: Optional[str] = None,
|
||||
priority: int = 1,
|
||||
capabilities: Optional[List[str]] = None
|
||||
) -> MultiAgentConfig:
|
||||
"""添加子 Agent
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
agent_id: Agent ID
|
||||
name: Agent 名称
|
||||
role: 角色描述
|
||||
priority: 优先级
|
||||
capabilities: 能力列表
|
||||
|
||||
Returns:
|
||||
更新后的配置
|
||||
"""
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
# 验证 Agent 存在
|
||||
agent = self.db.get(AgentConfig, agent_id)
|
||||
if not agent:
|
||||
raise ResourceNotFoundException("Agent", str(agent_id))
|
||||
|
||||
# 检查是否已存在
|
||||
for sub_agent in config.sub_agents:
|
||||
if sub_agent["agent_id"] == str(agent_id):
|
||||
raise BusinessException("Agent 已存在于配置中", BizCode.DUPLICATE_RESOURCE)
|
||||
|
||||
# 添加子 Agent
|
||||
new_sub_agent = {
|
||||
"agent_id": str(agent_id),
|
||||
"name": name,
|
||||
"role": role,
|
||||
"priority": priority,
|
||||
"capabilities": capabilities or []
|
||||
}
|
||||
|
||||
config.sub_agents.append(new_sub_agent)
|
||||
|
||||
# 标记为已修改
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"添加子 Agent 成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"agent_id": str(agent_id),
|
||||
"agent_name": name
|
||||
}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def remove_sub_agent(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
agent_id: uuid.UUID
|
||||
) -> MultiAgentConfig:
|
||||
"""移除子 Agent
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
更新后的配置
|
||||
"""
|
||||
config = self.get_config(app_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
# 查找并移除
|
||||
original_count = len(config.sub_agents)
|
||||
config.sub_agents = [
|
||||
sub_agent for sub_agent in config.sub_agents
|
||||
if sub_agent["agent_id"] != str(agent_id)
|
||||
]
|
||||
|
||||
if len(config.sub_agents) == original_count:
|
||||
raise ResourceNotFoundException("子 Agent", str(agent_id))
|
||||
|
||||
# 标记为已修改
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"移除子 Agent 成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"agent_id": str(agent_id)
|
||||
}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def list_configs(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> Tuple[List[MultiAgentConfig], int]:
|
||||
"""列出多 Agent 配置
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间 ID
|
||||
page: 页码
|
||||
pagesize: 每页数量
|
||||
|
||||
Returns:
|
||||
配置列表和总数
|
||||
"""
|
||||
# 构建查询
|
||||
stmt = (
|
||||
select(MultiAgentConfig)
|
||||
.join(App)
|
||||
.where(App.workspace_id == workspace_id)
|
||||
.order_by(desc(MultiAgentConfig.created_at))
|
||||
)
|
||||
|
||||
# 总数
|
||||
count_stmt = stmt.with_only_columns(MultiAgentConfig.id)
|
||||
total = len(self.db.execute(count_stmt).all())
|
||||
|
||||
# 分页
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
configs = list(self.db.scalars(stmt).all())
|
||||
|
||||
return configs, total
|
||||
444
api/app/services/release_share_service.py
Normal file
444
api/app/services/release_share_service.py
Normal file
@@ -0,0 +1,444 @@
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import ReleaseShare, AppRelease, App, AgentConfig
|
||||
from app.repositories.release_share_repository import ReleaseShareRepository
|
||||
from app.core.share_utils import (
|
||||
generate_share_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
build_share_url,
|
||||
generate_embed_code
|
||||
)
|
||||
from app.core.exceptions import ResourceNotFoundException, BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas import release_share_schema
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ReleaseShareService:
|
||||
"""发布版本分享服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.repo = ReleaseShareRepository(db)
|
||||
|
||||
def create_or_update_share(
|
||||
self,
|
||||
release_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
data: release_share_schema.ReleaseShareCreate,
|
||||
base_url: Optional[str] = None
|
||||
) -> ReleaseShare:
|
||||
"""创建或更新分享配置
|
||||
|
||||
Args:
|
||||
release_id: 发布版本 ID
|
||||
user_id: 用户 ID
|
||||
workspace_id: 工作空间 ID
|
||||
data: 分享配置数据
|
||||
base_url: 基础 URL(用于生成完整的分享链接)
|
||||
|
||||
Returns:
|
||||
分享配置
|
||||
"""
|
||||
# 验证发布版本存在且属于该工作空间
|
||||
release = self._get_release_or_404(release_id)
|
||||
self._validate_release_access(release, workspace_id)
|
||||
|
||||
# 检查是否已存在分享配置
|
||||
existing_share = self.repo.get_by_release_id(release_id)
|
||||
|
||||
if existing_share:
|
||||
# 更新现有配置
|
||||
return self._update_share_internal(existing_share, data)
|
||||
else:
|
||||
# 创建新配置
|
||||
return self._create_share_internal(release, user_id, data)
|
||||
|
||||
def _create_share_internal(
|
||||
self,
|
||||
release: AppRelease,
|
||||
user_id: uuid.UUID,
|
||||
data: release_share_schema.ReleaseShareCreate
|
||||
) -> ReleaseShare:
|
||||
"""内部方法:创建分享配置"""
|
||||
# 生成唯一的 share_token
|
||||
share_token = self._generate_unique_token()
|
||||
|
||||
# 处理密码
|
||||
password_hash = None
|
||||
if data.require_password and data.password:
|
||||
password_hash = hash_password(data.password)
|
||||
|
||||
# 创建分享配置
|
||||
share = ReleaseShare(
|
||||
release_id=release.id,
|
||||
app_id=release.app_id,
|
||||
is_enabled=data.is_enabled,
|
||||
share_token=share_token,
|
||||
require_password=data.require_password,
|
||||
password_hash=password_hash,
|
||||
allow_embed=data.allow_embed,
|
||||
embed_domains=data.embed_domains or [],
|
||||
created_by=user_id
|
||||
)
|
||||
|
||||
share = self.repo.create(share)
|
||||
|
||||
logger.info(
|
||||
f"创建分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(release.id),
|
||||
"app_id": str(release.app_id),
|
||||
"share_token": share_token
|
||||
}
|
||||
)
|
||||
|
||||
return share
|
||||
|
||||
def _update_share_internal(
|
||||
self,
|
||||
share: ReleaseShare,
|
||||
data: release_share_schema.ReleaseShareUpdate
|
||||
) -> ReleaseShare:
|
||||
"""内部方法:更新分享配置"""
|
||||
if data.is_enabled is not None:
|
||||
share.is_enabled = data.is_enabled
|
||||
|
||||
if data.require_password is not None:
|
||||
share.require_password = data.require_password
|
||||
|
||||
if data.password is not None:
|
||||
if data.password:
|
||||
share.password_hash = hash_password(data.password)
|
||||
else:
|
||||
share.password_hash = None
|
||||
|
||||
if data.allow_embed is not None:
|
||||
share.allow_embed = data.allow_embed
|
||||
|
||||
if data.embed_domains is not None:
|
||||
share.embed_domains = data.embed_domains or []
|
||||
|
||||
share = self.repo.update(share)
|
||||
|
||||
logger.info(
|
||||
f"更新分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(share.release_id)
|
||||
}
|
||||
)
|
||||
|
||||
return share
|
||||
|
||||
def update_share(
|
||||
self,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
data: release_share_schema.ReleaseShareUpdate
|
||||
) -> ReleaseShare:
|
||||
"""更新分享配置
|
||||
|
||||
Args:
|
||||
release_id: 发布版本 ID
|
||||
workspace_id: 工作空间 ID
|
||||
data: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的分享配置
|
||||
"""
|
||||
# 验证发布版本
|
||||
release = self._get_release_or_404(release_id)
|
||||
self._validate_release_access(release, workspace_id)
|
||||
|
||||
# 获取分享配置
|
||||
share = self.repo.get_by_release_id(release_id)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享配置", str(release_id))
|
||||
|
||||
return self._update_share_internal(share, data)
|
||||
|
||||
def get_share(
|
||||
self,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
base_url: Optional[str] = None
|
||||
) -> Optional[release_share_schema.ReleaseShare]:
|
||||
"""获取分享配置
|
||||
|
||||
Args:
|
||||
release_id: 发布版本 ID
|
||||
workspace_id: 工作空间 ID
|
||||
base_url: 基础 URL
|
||||
|
||||
Returns:
|
||||
分享配置 Schema
|
||||
"""
|
||||
# 验证发布版本
|
||||
release = self._get_release_or_404(release_id)
|
||||
self._validate_release_access(release, workspace_id)
|
||||
|
||||
share = self.repo.get_by_release_id(release_id)
|
||||
if not share:
|
||||
return None
|
||||
|
||||
return self._convert_to_schema(share, base_url)
|
||||
|
||||
def delete_share(
|
||||
self,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> None:
|
||||
"""删除(禁用)分享配置
|
||||
|
||||
Args:
|
||||
release_id: 发布版本 ID
|
||||
workspace_id: 工作空间 ID
|
||||
"""
|
||||
# 验证发布版本
|
||||
release = self._get_release_or_404(release_id)
|
||||
self._validate_release_access(release, workspace_id)
|
||||
|
||||
share = self.repo.get_by_release_id(release_id)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享配置", str(release_id))
|
||||
|
||||
self.repo.delete(share)
|
||||
|
||||
logger.info(
|
||||
f"删除分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(release_id)
|
||||
}
|
||||
)
|
||||
|
||||
def regenerate_token(
|
||||
self,
|
||||
release_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> ReleaseShare:
|
||||
"""重新生成分享 token(旧链接失效)
|
||||
|
||||
Args:
|
||||
release_id: 发布版本 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
更新后的分享配置
|
||||
"""
|
||||
# 验证发布版本
|
||||
release = self._get_release_or_404(release_id)
|
||||
self._validate_release_access(release, workspace_id)
|
||||
|
||||
share = self.repo.get_by_release_id(release_id)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享配置", str(release_id))
|
||||
|
||||
# 生成新 token
|
||||
old_token = share.share_token
|
||||
share.share_token = self._generate_unique_token()
|
||||
share = self.repo.update(share)
|
||||
|
||||
logger.info(
|
||||
f"重新生成分享 token",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"old_token": old_token,
|
||||
"new_token": share.share_token
|
||||
}
|
||||
)
|
||||
|
||||
return share
|
||||
|
||||
def get_shared_release_info(
|
||||
self,
|
||||
share_token: str,
|
||||
password: Optional[str] = None
|
||||
) -> release_share_schema.SharedReleaseInfo:
|
||||
"""获取公开分享的发布版本信息
|
||||
|
||||
Args:
|
||||
share_token: 分享 token
|
||||
password: 访问密码(如果需要)
|
||||
|
||||
Returns:
|
||||
分享的发布版本信息
|
||||
"""
|
||||
# 获取分享配置
|
||||
share = self.repo.get_by_share_token(share_token)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享链接", share_token)
|
||||
|
||||
# 检查是否启用
|
||||
if not share.is_enabled:
|
||||
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
|
||||
|
||||
# 验证密码
|
||||
is_password_verified = False
|
||||
if share.require_password:
|
||||
if not password:
|
||||
# 需要密码但未提供,返回基本信息
|
||||
release = self.db.get(AppRelease, share.release_id)
|
||||
return release_share_schema.SharedReleaseInfo(
|
||||
app_name=release.name,
|
||||
app_description=release.description,
|
||||
app_icon=release.icon,
|
||||
app_type=release.type,
|
||||
version=release.version,
|
||||
release_notes=release.release_notes,
|
||||
published_at=int(release.published_at.timestamp() * 1000),
|
||||
config={},
|
||||
require_password=True,
|
||||
is_password_verified=False,
|
||||
allow_embed=share.allow_embed
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not share.password_hash or not verify_password(password, share.password_hash):
|
||||
raise BusinessException("密码错误", BizCode.INVALID_PASSWORD)
|
||||
|
||||
is_password_verified = True
|
||||
|
||||
# 获取发布版本详细信息
|
||||
release = self.db.get(AppRelease, share.release_id)
|
||||
if not release:
|
||||
raise ResourceNotFoundException("发布版本", str(share.release_id))
|
||||
|
||||
# 异步更新访问统计(不阻塞响应)
|
||||
try:
|
||||
self.repo.increment_view_count(share.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新访问统计失败: {str(e)}")
|
||||
|
||||
# 返回完整信息
|
||||
return release_share_schema.SharedReleaseInfo(
|
||||
app_name=release.name,
|
||||
app_description=release.description,
|
||||
app_icon=release.icon,
|
||||
app_type=release.type,
|
||||
version=release.version,
|
||||
release_notes=release.release_notes,
|
||||
published_at=int(release.published_at.timestamp() * 1000),
|
||||
config=release.config or {},
|
||||
require_password=share.require_password,
|
||||
is_password_verified=is_password_verified,
|
||||
allow_embed=share.allow_embed
|
||||
)
|
||||
|
||||
def verify_password(
|
||||
self,
|
||||
share_token: str,
|
||||
password: str
|
||||
) -> bool:
|
||||
"""验证分享密码
|
||||
|
||||
Args:
|
||||
share_token: 分享 token
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
是否验证成功
|
||||
"""
|
||||
share = self.repo.get_by_share_token(share_token)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享链接", share_token)
|
||||
|
||||
if not share.is_enabled:
|
||||
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
|
||||
|
||||
if not share.require_password:
|
||||
return True
|
||||
|
||||
if not share.password_hash:
|
||||
return False
|
||||
|
||||
return verify_password(password, share.password_hash)
|
||||
|
||||
def get_embed_code(
|
||||
self,
|
||||
share_token: str,
|
||||
width: str = "100%",
|
||||
height: str = "600px",
|
||||
base_url: Optional[str] = None
|
||||
) -> release_share_schema.EmbedCode:
|
||||
"""获取嵌入代码
|
||||
|
||||
Args:
|
||||
share_token: 分享 token
|
||||
width: 宽度
|
||||
height: 高度
|
||||
base_url: 基础 URL
|
||||
|
||||
Returns:
|
||||
嵌入代码
|
||||
"""
|
||||
share = self.repo.get_by_share_token(share_token)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享链接", share_token)
|
||||
|
||||
if not share.is_enabled:
|
||||
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
|
||||
|
||||
if not share.allow_embed:
|
||||
raise BusinessException("该分享不允许嵌入", BizCode.EMBED_NOT_ALLOWED)
|
||||
|
||||
embed_data = generate_embed_code(share_token, width, height, base_url)
|
||||
return release_share_schema.EmbedCode(**embed_data)
|
||||
|
||||
def _generate_unique_token(self, max_attempts: int = 10) -> str:
|
||||
"""生成唯一的分享 token"""
|
||||
for _ in range(max_attempts):
|
||||
token = generate_share_token()
|
||||
if not self.repo.token_exists(token):
|
||||
return token
|
||||
|
||||
raise BusinessException("生成唯一 token 失败,请重试", BizCode.INTERNAL_ERROR)
|
||||
|
||||
def _get_release_or_404(self, release_id: uuid.UUID) -> AppRelease:
|
||||
"""获取发布版本或抛出 404"""
|
||||
release = self.db.get(AppRelease, release_id)
|
||||
if not release:
|
||||
raise ResourceNotFoundException("发布版本", str(release_id))
|
||||
return release
|
||||
|
||||
def _validate_release_access(self, release: AppRelease, workspace_id: uuid.UUID) -> None:
|
||||
"""验证发布版本访问权限"""
|
||||
app = self.db.get(App, release.app_id)
|
||||
if not app:
|
||||
raise ResourceNotFoundException("应用", str(release.app_id))
|
||||
|
||||
if app.workspace_id != workspace_id:
|
||||
raise BusinessException("无权访问该发布版本", BizCode.PERMISSION_DENIED)
|
||||
|
||||
def _convert_to_schema(
|
||||
self,
|
||||
share: ReleaseShare,
|
||||
base_url: Optional[str] = None
|
||||
) -> release_share_schema.ReleaseShare:
|
||||
"""转换为 Schema"""
|
||||
share_url = build_share_url(share.share_token, base_url)
|
||||
|
||||
return release_share_schema.ReleaseShare(
|
||||
id=share.id,
|
||||
release_id=share.release_id,
|
||||
app_id=share.app_id,
|
||||
is_enabled=share.is_enabled,
|
||||
share_token=share.share_token,
|
||||
share_url=share_url,
|
||||
require_password=share.require_password,
|
||||
allow_embed=share.allow_embed,
|
||||
embed_domains=share.embed_domains or [],
|
||||
view_count=share.view_count,
|
||||
last_accessed_at=share.last_accessed_at,
|
||||
created_at=share.created_at,
|
||||
updated_at=share.updated_at
|
||||
)
|
||||
160
api/app/services/session_service.py
Normal file
160
api/app/services/session_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""用户会话管理服务"""
|
||||
|
||||
@staticmethod
|
||||
def _get_user_session_key(username: str) -> str:
|
||||
"""获取用户会话的Redis键"""
|
||||
return f"user_session:{username}"
|
||||
|
||||
@staticmethod
|
||||
def _get_token_blacklist_key(token_id: str) -> str:
|
||||
"""获取token黑名单的Redis键"""
|
||||
return f"token_blacklist:{token_id}"
|
||||
|
||||
@staticmethod
|
||||
async def set_user_active_session(username: str, token_id: str, expires_at: datetime) -> None:
|
||||
"""设置用户的活跃会话
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
token_id: token的唯一标识
|
||||
expires_at: token过期时间
|
||||
"""
|
||||
if not settings.ENABLE_SINGLE_SESSION:
|
||||
return
|
||||
|
||||
session_key = SessionService._get_user_session_key(username)
|
||||
session_data = {
|
||||
"token_id": token_id,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"expires_at": expires_at.isoformat()
|
||||
}
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
|
||||
if expire_seconds > 0:
|
||||
await aio_redis_set(session_key, session_data, expire=expire_seconds)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_active_session(username: str) -> Optional[dict]:
|
||||
"""获取用户的活跃会话
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
会话数据字典或None
|
||||
"""
|
||||
if not settings.ENABLE_SINGLE_SESSION:
|
||||
return None
|
||||
|
||||
session_key = SessionService._get_user_session_key(username)
|
||||
session_data = await aio_redis_get(session_key)
|
||||
|
||||
if session_data:
|
||||
try:
|
||||
return json.loads(session_data) if isinstance(session_data, str) else session_data
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_old_session(username: str, new_token_id: str) -> None:
|
||||
"""使旧会话失效
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
new_token_id: 新token的ID
|
||||
"""
|
||||
if not settings.ENABLE_SINGLE_SESSION:
|
||||
return
|
||||
|
||||
# 获取当前活跃会话
|
||||
current_session = await SessionService.get_user_active_session(username)
|
||||
|
||||
if current_session and current_session.get("token_id") != new_token_id:
|
||||
# 将旧token加入黑名单
|
||||
old_token_id = current_session.get("token_id")
|
||||
if old_token_id:
|
||||
await SessionService.blacklist_token(old_token_id)
|
||||
|
||||
@staticmethod
|
||||
async def blacklist_token(token_id: str, expire_seconds: int = None) -> None:
|
||||
"""将token加入黑名单
|
||||
|
||||
Args:
|
||||
token_id: token的唯一标识
|
||||
expire_seconds: 黑名单过期时间(秒),默认为refresh token的过期时间
|
||||
"""
|
||||
if expire_seconds is None:
|
||||
# 默认使用refresh token的过期时间
|
||||
expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
|
||||
blacklist_key = SessionService._get_token_blacklist_key(token_id)
|
||||
await aio_redis_set(blacklist_key, "blacklisted", expire=expire_seconds)
|
||||
|
||||
@staticmethod
|
||||
async def is_token_blacklisted(token_id: str) -> bool:
|
||||
"""检查token是否在黑名单中
|
||||
|
||||
Args:
|
||||
token_id: token的唯一标识
|
||||
|
||||
Returns:
|
||||
True如果token在黑名单中,否则False
|
||||
"""
|
||||
if not settings.ENABLE_SINGLE_SESSION:
|
||||
return False
|
||||
|
||||
blacklist_key = SessionService._get_token_blacklist_key(token_id)
|
||||
result = await aio_redis_get(blacklist_key)
|
||||
return result is not None
|
||||
|
||||
@staticmethod
|
||||
async def clear_user_session(username: str) -> None:
|
||||
"""清除用户会话
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
"""
|
||||
session_key = SessionService._get_user_session_key(username)
|
||||
await aio_redis_delete(session_key)
|
||||
|
||||
@staticmethod
|
||||
async def invalidate_all_user_tokens(user_id: str) -> None:
|
||||
"""使用户的所有 tokens 失效(用于密码重置等场景)
|
||||
|
||||
通过在 Redis 中设置一个用户级别的失效标记来实现。
|
||||
所有在此时间点之前签发的 tokens 都将被视为无效。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
"""
|
||||
invalidation_key = f"user_token_invalidation:{user_id}"
|
||||
current_time = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# 设置失效时间戳,过期时间为 refresh token 的最大有效期
|
||||
expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
|
||||
await aio_redis_set(invalidation_key, current_time, expire=expire_seconds)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_token_invalidation_time(user_id: str) -> Optional[str]:
|
||||
"""获取用户 token 失效时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
失效时间的 ISO 格式字符串,如果没有失效记录则返回 None
|
||||
"""
|
||||
invalidation_key = f"user_token_invalidation:{user_id}"
|
||||
result = await aio_redis_get(invalidation_key)
|
||||
return result if result else None
|
||||
759
api/app/services/shared_chat_service.py
Normal file
759
api/app/services/shared_chat_service.py
Normal file
@@ -0,0 +1,759 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class SharedChatService:
|
||||
"""基于分享链接的聊天服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.share_service = ReleaseShareService(db)
|
||||
|
||||
def _get_release_by_share_token(
|
||||
self,
|
||||
share_token: str,
|
||||
password: Optional[str] = None
|
||||
) -> tuple[ReleaseShare, AppRelease]:
|
||||
"""通过 share_token 获取发布版本"""
|
||||
# 获取分享配置
|
||||
share = self.share_service.repo.get_by_share_token(share_token)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享链接", share_token)
|
||||
|
||||
# 验证分享是否启用
|
||||
if not share.is_enabled:
|
||||
raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED)
|
||||
|
||||
# 验证密码
|
||||
if share.require_password:
|
||||
if not password:
|
||||
raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED)
|
||||
|
||||
if not self.share_service.verify_password(share_token, password):
|
||||
raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD)
|
||||
|
||||
# 获取发布版本
|
||||
release = self.db.get(AppRelease, share.release_id)
|
||||
if not release:
|
||||
raise ResourceNotFoundException("发布版本", str(share.release_id))
|
||||
|
||||
# 更新访问统计
|
||||
try:
|
||||
self.share_service.repo.increment_view_count(share.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新访问统计失败: {str(e)}")
|
||||
|
||||
return share, release
|
||||
|
||||
def create_or_get_conversation(
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None
|
||||
) -> Conversation:
|
||||
"""创建或获取会话"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 如果提供了 conversation_id,尝试获取现有会话
|
||||
if conversation_id:
|
||||
try:
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=release.app.workspace_id
|
||||
)
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != release.app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
f"会话不存在,将创建新会话",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
# 创建新会话(使用发布版本的配置)
|
||||
conversation = self.conversation_service.create_conversation(
|
||||
app_id=release.app_id,
|
||||
workspace_id=release.app.workspace_id,
|
||||
user_id=user_id,
|
||||
is_draft=False, # 分享链接使用发布版本
|
||||
config_snapshot=release.config
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"为分享链接创建新会话",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"share_token": share_token,
|
||||
"release_id": str(release.id)
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取 Agent 配置
|
||||
config = release.config or {}
|
||||
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 获取模型配置
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active == True
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
|
||||
if variables:
|
||||
system_prompt_rendered = render_prompt_message(
|
||||
system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
|
||||
if memory==True:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools=config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled",False)
|
||||
if web_search==True:
|
||||
if web_search_enable==True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config={"enabled":True,'max_history':10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation.id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 调用 Agent
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
user_message=message,
|
||||
assistant_message=result["content"]
|
||||
)
|
||||
# self.conversation_service.add_message(
|
||||
# conversation_id=conversation.id,
|
||||
# role="user",
|
||||
# content=message
|
||||
# )
|
||||
|
||||
# self.conversation_service.add_message(
|
||||
# conversation_id=conversation.id,
|
||||
# role="assistant",
|
||||
# content=result["content"],
|
||||
# meta_data={
|
||||
# "model": api_key_obj.model_name,
|
||||
# "usage": result.get("usage", {})
|
||||
# }
|
||||
# )
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
"message": result["content"],
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}),
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取 Agent 配置
|
||||
config = release.config or {}
|
||||
agent_config_data = config.get("agent_config", {})
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 获取模型配置
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
# 获取 API Key
|
||||
stmt = (
|
||||
select(ModelApiKey)
|
||||
.where(
|
||||
ModelApiKey.model_config_id == model_config_id,
|
||||
ModelApiKey.is_active == True
|
||||
)
|
||||
.order_by(ModelApiKey.priority.desc())
|
||||
.limit(1)
|
||||
)
|
||||
api_key_obj = self.db.scalars(stmt).first()
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
|
||||
if variables:
|
||||
system_prompt_rendered = render_prompt_message(
|
||||
system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation.id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent
|
||||
full_content = ""
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"model": api_key_obj.model_name,
|
||||
"usage": {}
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
f"流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("流式聊天被中断")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
def get_conversation_messages(
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: uuid.UUID,
|
||||
password: Optional[str] = None
|
||||
) -> Conversation:
|
||||
"""获取会话消息"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取会话
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=release.app.workspace_id
|
||||
)
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != release.app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
|
||||
return conversation
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
share_token: str,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""列出会话"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
conversations, total = self.conversation_service.list_conversations(
|
||||
app_id=release.app_id,
|
||||
workspace_id=release.app.workspace_id,
|
||||
user_id=user_id,
|
||||
is_draft=False, # 只显示发布版本的会话
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
return conversations, total
|
||||
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""多 Agent 聊天(非流式)"""
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
# 获取多 Agent 配置
|
||||
multi_agent_config = self.db.query(MultiAgentConfig).filter(
|
||||
MultiAgentConfig.app_id == release.app_id,
|
||||
MultiAgentConfig.is_active == True
|
||||
).first()
|
||||
|
||||
if not multi_agent_config:
|
||||
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 构建多 Agent 运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=message,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
use_llm_routing=True,
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# 使用多 Agent 服务执行
|
||||
multi_agent_service = MultiAgentService(self.db)
|
||||
result = await multi_agent_service.run(
|
||||
app_id=release.app_id,
|
||||
request=multi_agent_request
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content=result.get("message", ""),
|
||||
meta_data={
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"sub_results": result.get("sub_results")
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
"message": result.get("message", ""),
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
},
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
# 获取多 Agent 配置
|
||||
multi_agent_config = self.db.query(MultiAgentConfig).filter(
|
||||
MultiAgentConfig.app_id == release.app_id,
|
||||
MultiAgentConfig.is_active == True
|
||||
).first()
|
||||
|
||||
if not multi_agent_config:
|
||||
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 获取 storage_type 和 user_rag_memory_id
|
||||
workspace_id = release.app.workspace_id
|
||||
storage_type = 'neo4j' # 默认值
|
||||
user_rag_memory_id = ''
|
||||
|
||||
try:
|
||||
# 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享)
|
||||
from app.models import Workspace
|
||||
workspace = self.db.get(Workspace, workspace_id)
|
||||
if workspace and workspace.storage_type:
|
||||
storage_type = workspace.storage_type
|
||||
|
||||
# 获取 USER_RAG_MERORY 知识库 ID
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 构建多 Agent 运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=message,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
use_llm_routing=True,
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# 使用多 Agent 服务流式执行
|
||||
multi_agent_service = MultiAgentService(self.db)
|
||||
full_content = ""
|
||||
|
||||
async for event in multi_agent_service.run_stream(
|
||||
app_id=release.app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
# 直接转发事件
|
||||
yield event
|
||||
|
||||
# 尝试提取内容(用于保存)
|
||||
if "data:" in event:
|
||||
try:
|
||||
data_line = event.split("data: ", 1)[1].strip()
|
||||
data = json.loads(data_line)
|
||||
if "content" in data:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
content=full_content,
|
||||
meta_data={
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多 Agent 流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("多 Agent 流式聊天被中断")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
426
api/app/services/smart_router.py
Normal file
426
api/app/services/smart_router.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""智能路由器 - 解决多轮对话路由错乱"""
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class SmartRouter:
|
||||
"""智能路由器
|
||||
|
||||
核心功能:
|
||||
1. 检测主题切换
|
||||
2. 判断是否应该继续使用当前 Agent
|
||||
3. 智能选择最合适的 Agent
|
||||
4. 支持强制重新路由
|
||||
"""
|
||||
|
||||
# 主题切换信号
|
||||
SWITCH_SIGNALS = [
|
||||
"换个话题", "另外", "还有", "对了",
|
||||
"那这个呢", "再问一个", "顺便问下",
|
||||
"我想问", "帮我", "请问", "换一个"
|
||||
]
|
||||
|
||||
# 延续信号
|
||||
CONTINUATION_SIGNALS = [
|
||||
"继续", "还是", "也", "同样", "类似",
|
||||
"这个", "那个", "它", "他", "她", "呢"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_manager: ConversationStateManager,
|
||||
routing_rules: List[Dict[str, Any]],
|
||||
sub_agents: Dict[str, Any]
|
||||
):
|
||||
"""初始化智能路由器
|
||||
|
||||
Args:
|
||||
state_manager: 会话状态管理器
|
||||
routing_rules: 路由规则列表
|
||||
sub_agents: 子 Agent 配置字典
|
||||
"""
|
||||
self.state_manager = state_manager
|
||||
self.routing_rules = routing_rules
|
||||
self.sub_agents = sub_agents
|
||||
|
||||
# 配置参数
|
||||
self.min_confidence_for_switch = 0.7 # 切换 Agent 的最小置信度
|
||||
self.max_same_agent_turns = 10 # 同一 Agent 最大连续轮数
|
||||
|
||||
async def route(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
force_new: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""智能路由
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
force_new: 是否强制重新路由(忽略历史)
|
||||
|
||||
Returns:
|
||||
路由结果 {
|
||||
"agent_id": str,
|
||||
"confidence": float,
|
||||
"strategy": str,
|
||||
"topic": str,
|
||||
"topic_changed": bool,
|
||||
"reason": str
|
||||
}
|
||||
"""
|
||||
logger.info(
|
||||
f"开始智能路由",
|
||||
extra={
|
||||
"message_length": len(message),
|
||||
"conversation_id": conversation_id,
|
||||
"force_new": force_new
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 获取会话状态
|
||||
state = None
|
||||
if conversation_id and not force_new:
|
||||
state = self.state_manager.get_state(conversation_id)
|
||||
|
||||
# 2. 检测主题切换
|
||||
topic_changed = self._detect_topic_change(message, state)
|
||||
|
||||
# 3. 提取当前主题
|
||||
topic = self._extract_topic(message)
|
||||
|
||||
# 4. 选择路由策略
|
||||
if force_new:
|
||||
# 强制重新路由
|
||||
agent_id, confidence = self._route_from_scratch(message)
|
||||
strategy = "force_new"
|
||||
reason = "用户强制重新路由"
|
||||
|
||||
elif not state or not state.get("current_agent_id"):
|
||||
# 新会话,从头路由
|
||||
agent_id, confidence = self._route_from_scratch(message)
|
||||
strategy = "new_conversation"
|
||||
reason = "新会话,首次路由"
|
||||
|
||||
elif topic_changed:
|
||||
# 主题切换,重新路由
|
||||
agent_id, confidence = self._route_from_scratch(message)
|
||||
strategy = "topic_changed"
|
||||
reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}"
|
||||
|
||||
elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns:
|
||||
# 同一 Agent 使用太久,强制重新评估
|
||||
agent_id, confidence = self._route_from_scratch(message)
|
||||
strategy = "max_turns_reached"
|
||||
reason = f"同一 Agent 已使用 {state['same_agent_turns']} 轮"
|
||||
|
||||
else:
|
||||
# 检查是否应该继续使用当前 Agent
|
||||
current_agent_id = state["current_agent_id"]
|
||||
should_continue, continue_confidence = self._should_continue_current_agent(
|
||||
message,
|
||||
current_agent_id
|
||||
)
|
||||
|
||||
if should_continue:
|
||||
# 继续使用当前 Agent
|
||||
agent_id = current_agent_id
|
||||
confidence = continue_confidence
|
||||
strategy = "continue_current"
|
||||
reason = "消息在当前 Agent 能力范围内"
|
||||
else:
|
||||
# 重新路由
|
||||
new_agent_id, new_confidence = self._route_from_scratch(message)
|
||||
|
||||
# 只有新 Agent 的置信度明显更高时才切换
|
||||
if new_confidence > continue_confidence + self.min_confidence_for_switch:
|
||||
agent_id = new_agent_id
|
||||
confidence = new_confidence
|
||||
strategy = "switch_agent"
|
||||
reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}"
|
||||
else:
|
||||
# 置信度差距不大,继续使用当前 Agent
|
||||
agent_id = current_agent_id
|
||||
confidence = continue_confidence
|
||||
strategy = "keep_current"
|
||||
reason = "置信度差距不足以切换 Agent"
|
||||
|
||||
# 5. 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
conversation_id,
|
||||
agent_id,
|
||||
message,
|
||||
topic,
|
||||
confidence
|
||||
)
|
||||
|
||||
result = {
|
||||
"agent_id": agent_id,
|
||||
"confidence": confidence,
|
||||
"strategy": strategy,
|
||||
"topic": topic,
|
||||
"topic_changed": topic_changed,
|
||||
"reason": reason
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"路由完成",
|
||||
extra={
|
||||
"agent_id": agent_id,
|
||||
"strategy": strategy,
|
||||
"confidence": confidence,
|
||||
"topic": topic
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _detect_topic_change(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]]
|
||||
) -> bool:
|
||||
"""检测主题是否切换
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
|
||||
Returns:
|
||||
是否切换主题
|
||||
"""
|
||||
if not state or not state.get("last_topic"):
|
||||
return False
|
||||
|
||||
# 检查明确的切换信号
|
||||
for signal in self.SWITCH_SIGNALS:
|
||||
if signal in message:
|
||||
logger.info(f"检测到主题切换信号: {signal}")
|
||||
return True
|
||||
|
||||
# 比较主题
|
||||
current_topic = self._extract_topic(message)
|
||||
last_topic = state.get("last_topic")
|
||||
|
||||
if current_topic != last_topic and current_topic != "其他":
|
||||
logger.info(f"主题变化: {last_topic} -> {current_topic}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_continue_current_agent(
|
||||
self,
|
||||
message: str,
|
||||
current_agent_id: str
|
||||
) -> Tuple[bool, float]:
|
||||
"""判断是否应该继续使用当前 Agent
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
current_agent_id: 当前 Agent ID
|
||||
|
||||
Returns:
|
||||
(是否继续, 置信度)
|
||||
"""
|
||||
# 检查延续信号
|
||||
has_continuation_signal = any(
|
||||
signal in message
|
||||
for signal in self.CONTINUATION_SIGNALS
|
||||
)
|
||||
|
||||
# 计算当前 Agent 对消息的匹配度
|
||||
current_score = self._calculate_agent_score(message, current_agent_id)
|
||||
|
||||
# 如果有延续信号且匹配度不太低,继续使用
|
||||
if has_continuation_signal and current_score > 0.3:
|
||||
return True, min(current_score + 0.2, 1.0)
|
||||
|
||||
# 如果匹配度高,继续使用
|
||||
if current_score > 0.6:
|
||||
return True, current_score
|
||||
|
||||
return False, current_score
|
||||
|
||||
def _route_from_scratch(self, message: str) -> Tuple[str, float]:
|
||||
"""从头开始路由(不考虑历史)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
(Agent ID, 置信度)
|
||||
"""
|
||||
best_agent_id = None
|
||||
best_score = 0.0
|
||||
|
||||
# 遍历所有路由规则
|
||||
for rule in self.routing_rules:
|
||||
score = self._calculate_rule_score(message, rule)
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_agent_id = rule.get("target_agent_id")
|
||||
|
||||
# 如果没有匹配的规则,使用默认 Agent
|
||||
if not best_agent_id or best_score < 0.3:
|
||||
best_agent_id = self._get_default_agent_id()
|
||||
best_score = 0.5
|
||||
logger.warning(f"未找到匹配规则,使用默认 Agent: {best_agent_id}")
|
||||
|
||||
return best_agent_id, best_score
|
||||
|
||||
def _calculate_rule_score(
|
||||
self,
|
||||
message: str,
|
||||
rule: Dict[str, Any]
|
||||
) -> float:
|
||||
"""计算规则匹配分数
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
rule: 路由规则
|
||||
|
||||
Returns:
|
||||
匹配分数 (0-1)
|
||||
"""
|
||||
score = 0.0
|
||||
message_lower = message.lower()
|
||||
|
||||
# 1. 关键词匹配 (权重 0.6)
|
||||
keywords = rule.get("keywords", [])
|
||||
if keywords:
|
||||
matched_keywords = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword.lower() in message_lower
|
||||
)
|
||||
keyword_score = matched_keywords / len(keywords)
|
||||
score += keyword_score * 0.6
|
||||
|
||||
# 2. 正则匹配 (权重 0.3)
|
||||
patterns = rule.get("patterns", [])
|
||||
if patterns:
|
||||
matched_patterns = sum(
|
||||
1 for pattern in patterns
|
||||
if re.search(pattern, message, re.IGNORECASE)
|
||||
)
|
||||
pattern_score = matched_patterns / len(patterns)
|
||||
score += pattern_score * 0.3
|
||||
|
||||
# 3. 排除关键词 (负分)
|
||||
exclude_keywords = rule.get("exclude_keywords", [])
|
||||
if exclude_keywords:
|
||||
has_exclude = any(
|
||||
keyword.lower() in message_lower
|
||||
for keyword in exclude_keywords
|
||||
)
|
||||
if has_exclude:
|
||||
score *= 0.5 # 减半
|
||||
|
||||
# 4. 最小关键词数量要求
|
||||
min_keyword_count = rule.get("min_keyword_count", 0)
|
||||
if keywords and min_keyword_count > 0:
|
||||
matched_count = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword.lower() in message_lower
|
||||
)
|
||||
if matched_count < min_keyword_count:
|
||||
score *= 0.7 # 惩罚
|
||||
|
||||
return min(score, 1.0)
|
||||
|
||||
def _calculate_agent_score(
|
||||
self,
|
||||
message: str,
|
||||
agent_id: str
|
||||
) -> float:
|
||||
"""计算 Agent 对消息的匹配分数
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
匹配分数 (0-1)
|
||||
"""
|
||||
# 找到该 Agent 对应的所有规则
|
||||
agent_rules = [
|
||||
rule for rule in self.routing_rules
|
||||
if rule.get("target_agent_id") == agent_id
|
||||
]
|
||||
|
||||
if not agent_rules:
|
||||
return 0.0
|
||||
|
||||
# 返回最高分数
|
||||
max_score = max(
|
||||
self._calculate_rule_score(message, rule)
|
||||
for rule in agent_rules
|
||||
)
|
||||
|
||||
return max_score
|
||||
|
||||
def _extract_topic(self, message: str) -> str:
|
||||
"""提取消息主题
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
主题名称
|
||||
"""
|
||||
# 主题关键词映射
|
||||
topic_keywords = {
|
||||
"数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"],
|
||||
"物理": ["物理", "力", "速度", "加速度", "能量", "功率", "电路"],
|
||||
"化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"],
|
||||
"语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"],
|
||||
"英语": ["英语", "单词", "语法", "翻译", "时态", "句型"],
|
||||
"历史": ["历史", "朝代", "事件", "人物", "战争", "革命"],
|
||||
"作业": ["作业", "批改", "检查", "评分", "反馈"],
|
||||
"学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"],
|
||||
"订单": ["订单", "发货", "物流", "配送", "快递"],
|
||||
"退款": ["退款", "退货", "售后", "换货", "维修"],
|
||||
"账户": ["账户", "密码", "登录", "注册", "绑定"],
|
||||
"支付": ["支付", "付款", "充值", "余额", "优惠券"]
|
||||
}
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
# 统计每个主题的匹配度
|
||||
topic_scores = {}
|
||||
for topic, keywords in topic_keywords.items():
|
||||
matched = sum(
|
||||
1 for keyword in keywords
|
||||
if keyword in message_lower
|
||||
)
|
||||
if matched > 0:
|
||||
topic_scores[topic] = matched
|
||||
|
||||
# 返回匹配度最高的主题
|
||||
if topic_scores:
|
||||
best_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
|
||||
return best_topic
|
||||
|
||||
return "其他"
|
||||
|
||||
def _get_default_agent_id(self) -> str:
|
||||
"""获取默认 Agent ID
|
||||
|
||||
Returns:
|
||||
默认 Agent ID
|
||||
"""
|
||||
# 优先使用第一个路由规则的 Agent
|
||||
if self.routing_rules:
|
||||
return self.routing_rules[0].get("target_agent_id")
|
||||
|
||||
# 否则使用第一个子 Agent
|
||||
if self.sub_agents:
|
||||
return list(self.sub_agents.keys())[0]
|
||||
|
||||
return "default-agent"
|
||||
52
api/app/services/task_service.py
Normal file
52
api/app/services/task_service.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from app.celery_app import celery_app
|
||||
|
||||
def create_processing_task(item_data: dict) -> str:
|
||||
"""
|
||||
Sends a task to the Celery queue to process an item.
|
||||
|
||||
:param item_data: The dictionary representation of the item.
|
||||
:return: The ID of the created task.
|
||||
"""
|
||||
task = celery_app.send_task("tasks.process_item", args=[item_data])
|
||||
return task.id
|
||||
|
||||
def get_task_result(task_id: str) -> dict:
|
||||
"""
|
||||
Checks the status and result of a Celery task.
|
||||
|
||||
:param task_id: The ID of the task to check.
|
||||
:return: A dictionary with the task's status and result (if ready).
|
||||
"""
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
|
||||
if result.ready():
|
||||
return {"status": result.status, "result": result.get()}
|
||||
|
||||
return {"status": result.status}
|
||||
def get_task_memory_read_result(task_id: str) -> dict:
|
||||
"""
|
||||
Checks the status and result of a memory read task.
|
||||
|
||||
:param task_id: The ID of the task to check.
|
||||
:return: A dictionary with the task's status and result (if ready).
|
||||
"""
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
|
||||
if result.ready():
|
||||
return {"status": result.status, "result": result.get()}
|
||||
|
||||
return {"status": result.status}
|
||||
|
||||
def get_task_memory_write_result(task_id: str) -> dict:
|
||||
"""
|
||||
Checks the status and result of a memory write task.
|
||||
|
||||
:param task_id: The ID of the task to check.
|
||||
:return: A dictionary with the task's status and result (if ready).
|
||||
"""
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
|
||||
if result.ready():
|
||||
return {"status": result.status, "result": result.get()}
|
||||
|
||||
return {"status": result.status}
|
||||
220
api/app/services/tenant_service.py
Normal file
220
api/app/services/tenant_service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
from app.repositories.user_repository import UserRepository
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.tenant_schema import (
|
||||
TenantCreate, TenantUpdate, Tenant, TenantQuery, TenantList
|
||||
)
|
||||
from app.schemas.user_schema import User
|
||||
from app.schemas.workspace_schema import WorkspaceCreate
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User as UserModel
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
|
||||
class TenantService:
|
||||
"""租户业务逻辑层"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.tenant_repo = TenantRepository(db)
|
||||
self.user_repo = UserRepository(db)
|
||||
self.workspace_repo = WorkspaceRepository(db)
|
||||
|
||||
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
|
||||
"""创建租户"""
|
||||
# 检查租户名称是否已存在
|
||||
existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name)
|
||||
if existing_tenant:
|
||||
raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
try:
|
||||
tenant = self.tenant_repo.create_tenant(tenant_data)
|
||||
business_logger.info(f"创建租户成功: {tenant.name} (ID: {tenant.id})")
|
||||
return tenant
|
||||
except Exception as e:
|
||||
business_logger.error(f"创建租户失败: {str(e)}")
|
||||
raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def create_tenant_and_assign_user(self, tenant_data: TenantCreate, user_id: uuid.UUID) -> Tenants:
|
||||
"""创建租户并分配用户"""
|
||||
try:
|
||||
# 创建租户
|
||||
tenant = self.create_tenant(tenant_data)
|
||||
|
||||
# 将用户分配给租户
|
||||
success = self.user_repo.assign_user_to_tenant(user_id, tenant.id)
|
||||
if not success:
|
||||
raise BusinessException("分配用户到租户失败", code=BizCode.STATE_CONFLICT)
|
||||
|
||||
business_logger.info(f"创建租户并分配用户成功: {tenant.name}")
|
||||
return tenant
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"创建租户和分配用户失败: {str(e)}")
|
||||
self.db.rollback()
|
||||
raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def get_tenant(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""获取租户"""
|
||||
return self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
|
||||
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
|
||||
"""根据名称获取租户"""
|
||||
return self.tenant_repo.get_tenant_by_name(name)
|
||||
|
||||
def get_tenants(self, query: TenantQuery) -> TenantList:
|
||||
"""获取租户列表"""
|
||||
skip = (query.page - 1) * query.size
|
||||
|
||||
tenants = self.tenant_repo.get_tenants(
|
||||
skip=skip,
|
||||
limit=query.size,
|
||||
is_active=query.is_active,
|
||||
search=query.search
|
||||
)
|
||||
|
||||
total = self.tenant_repo.count_tenants(
|
||||
is_active=query.is_active,
|
||||
search=query.search
|
||||
)
|
||||
|
||||
pages = (total + query.size - 1) // query.size
|
||||
|
||||
return TenantList(
|
||||
items=[Tenant.model_validate(tenant) for tenant in tenants],
|
||||
total=total,
|
||||
page=query.page,
|
||||
size=query.size,
|
||||
pages=pages
|
||||
)
|
||||
|
||||
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
|
||||
"""更新租户"""
|
||||
# 如果更新名称,检查是否重复
|
||||
if tenant_data.name:
|
||||
existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name)
|
||||
if existing_tenant and existing_tenant.id != tenant_id:
|
||||
raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
try:
|
||||
tenant = self.tenant_repo.update_tenant(tenant_id, tenant_data)
|
||||
if tenant:
|
||||
business_logger.info(f"更新租户成功: {tenant.name} (ID: {tenant.id})")
|
||||
return tenant
|
||||
except Exception as e:
|
||||
business_logger.error(f"更新租户失败: {str(e)}")
|
||||
raise BusinessException(f"更新租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
|
||||
"""删除租户"""
|
||||
try:
|
||||
# 检查租户是否存在
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
return False
|
||||
|
||||
# 检查是否有关联的用户
|
||||
users = self.tenant_repo.get_tenant_users(tenant_id)
|
||||
if users:
|
||||
raise BusinessException("无法删除租户,存在关联的用户", code=BizCode.STATE_CONFLICT)
|
||||
|
||||
# 检查是否有关联的工作空间
|
||||
workspaces = self.workspace_repo.get_workspaces_by_tenant(tenant_id)
|
||||
if workspaces:
|
||||
raise BusinessException("无法删除租户,存在关联的工作空间", code=BizCode.STATE_CONFLICT)
|
||||
|
||||
success = self.tenant_repo.delete_tenant(tenant_id)
|
||||
if success:
|
||||
business_logger.info(f"删除租户成功: {tenant.name} (ID: {tenant.id})")
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"删除租户失败: {str(e)}")
|
||||
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
# 租户用户管理
|
||||
def get_tenant_users(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[UserModel]:
|
||||
"""获取租户下的用户列表"""
|
||||
return self.user_repo.get_users_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search
|
||||
)
|
||||
|
||||
def count_tenant_users(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
return self.user_repo.count_users_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
is_active=is_active,
|
||||
search=search
|
||||
)
|
||||
|
||||
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
|
||||
"""将用户分配给租户"""
|
||||
# 检查租户是否存在
|
||||
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
try:
|
||||
success = self.user_repo.assign_user_to_tenant(user_id, tenant_id)
|
||||
if success:
|
||||
business_logger.info(f"分配用户到租户成功: 用户ID {user_id}, 租户ID {tenant_id}")
|
||||
return success
|
||||
except Exception as e:
|
||||
business_logger.error(f"分配用户到租户失败: {str(e)}")
|
||||
raise BusinessException(f"分配用户到租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
|
||||
"""获取用户所属的租户"""
|
||||
return self.tenant_repo.get_user_tenant(user_id)
|
||||
|
||||
def remove_user_from_tenant(self, user_id: uuid.UUID) -> bool:
|
||||
"""将用户从租户中移除(设置tenant_id为None)"""
|
||||
try:
|
||||
user = self.user_repo.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
success = self.user_repo.assign_user_to_tenant(user_id, None)
|
||||
if success:
|
||||
business_logger.info(f"移除用户租户关联成功: 用户ID {user_id}")
|
||||
return success
|
||||
except Exception as e:
|
||||
business_logger.error(f"移除用户租户关联失败: {str(e)}")
|
||||
raise BusinessException(f"移除用户租户关联失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def get_users_without_tenant(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None
|
||||
) -> List[UserModel]:
|
||||
"""获取没有租户的用户列表"""
|
||||
return self.user_repo.get_users_without_tenant(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active
|
||||
)
|
||||
617
api/app/services/upload_service.py
Normal file
617
api/app/services/upload_service.py
Normal file
@@ -0,0 +1,617 @@
|
||||
"""
|
||||
Upload Service for Generic File Upload System
|
||||
Handles file upload, storage, access, deletion, and metadata updates.
|
||||
"""
|
||||
import os
|
||||
import uuid
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import UploadFile
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.generic_file_model import GenericFile
|
||||
from app.repositories.generic_file_repository import GenericFileRepository
|
||||
from app.core.upload_enums import UploadContext
|
||||
from app.core.storage_strategy import StrategyFactory
|
||||
from app.core.validators.file_validator import FileValidator
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.uow import IUnitOfWork
|
||||
from app.core.compensation import CompensationHandler
|
||||
|
||||
# Get logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FileNotFoundError(BusinessException):
|
||||
"""Exception raised when file is not found."""
|
||||
def __init__(self, file_id: uuid.UUID):
|
||||
super().__init__(
|
||||
f"文件 {file_id} 不存在",
|
||||
code=BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
class FileAccessDeniedError(BusinessException):
|
||||
"""Exception raised when file access is denied."""
|
||||
def __init__(self, file_id: uuid.UUID):
|
||||
super().__init__(
|
||||
f"无权访问文件 {file_id}",
|
||||
code=BizCode.FORBIDDEN
|
||||
)
|
||||
|
||||
|
||||
class FileStorageError(BusinessException):
|
||||
"""Exception raised when file storage fails."""
|
||||
def __init__(self, reason: str):
|
||||
super().__init__(
|
||||
f"文件存储失败: {reason}",
|
||||
code=BizCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
|
||||
class FileReferencedError(BusinessException):
|
||||
"""Exception raised when trying to delete a referenced file."""
|
||||
def __init__(self, file_id: uuid.UUID, reference_count: int):
|
||||
super().__init__(
|
||||
f"文件 {file_id} 被 {reference_count} 个资源引用,无法删除",
|
||||
code=BizCode.BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
class UploadResult:
|
||||
"""Result of a file upload operation."""
|
||||
def __init__(self, success: bool, file_id: Optional[uuid.UUID] = None,
|
||||
file_name: str = "", error: Optional[str] = None):
|
||||
self.success = success
|
||||
self.file_id = file_id
|
||||
self.file_name = file_name
|
||||
self.error = error
|
||||
|
||||
|
||||
class UploadService:
|
||||
"""
|
||||
Service for handling file uploads and management.
|
||||
Coordinates validation, storage, and database operations.
|
||||
Uses Unit of Work pattern for transaction management.
|
||||
"""
|
||||
|
||||
def __init__(self, uow: IUnitOfWork = None):
|
||||
self.validator = FileValidator()
|
||||
self.uow = uow
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file: UploadFile,
|
||||
context: UploadContext,
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
current_user: User,
|
||||
db: Session = None
|
||||
) -> GenericFile:
|
||||
"""
|
||||
Upload a single file using Unit of Work pattern with compensation transactions.
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
context: Upload context (avatar, app_icon, etc.)
|
||||
metadata: Additional metadata for the file
|
||||
current_user: The user uploading the file
|
||||
db: Optional database session (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
GenericFile: The created file record
|
||||
|
||||
Raises:
|
||||
FileSizeExceededError: If file size exceeds limit
|
||||
FileTypeNotAllowedError: If file type is not allowed
|
||||
EmptyFileError: If file is empty
|
||||
FileStorageError: If file storage fails
|
||||
"""
|
||||
logger.info(f"Starting file upload: filename={file.filename}, context={context}, user={current_user.id}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
# Get storage strategy for this context
|
||||
strategy = StrategyFactory.get_strategy(context)
|
||||
upload_policy = strategy.get_upload_policy()
|
||||
|
||||
# Validate file against upload policy
|
||||
logger.debug(f"Validating file: {file.filename}")
|
||||
self.validator.validate_and_raise(file, upload_policy)
|
||||
|
||||
# Generate file ID
|
||||
file_id = uuid.uuid4()
|
||||
|
||||
# Extract file information
|
||||
filename = file.filename or "unknown"
|
||||
file_extension = ""
|
||||
if "." in filename:
|
||||
file_extension = "." + filename.rsplit(".", 1)[1].lower()
|
||||
|
||||
# Get file size
|
||||
file.file.seek(0, 2)
|
||||
file_size = file.file.tell()
|
||||
file.file.seek(0)
|
||||
|
||||
# Get storage path
|
||||
storage_path = strategy.get_storage_path(
|
||||
tenant_id=current_user.tenant_id,
|
||||
file_id=file_id,
|
||||
file_extension=file_extension,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Storage path: {storage_path}")
|
||||
|
||||
# Use Unit of Work pattern with compensation handler
|
||||
compensation = CompensationHandler()
|
||||
|
||||
try:
|
||||
# Use provided UoW or create a new one for backward compatibility
|
||||
if self.uow:
|
||||
uow = self.uow
|
||||
should_manage_context = False
|
||||
else:
|
||||
# Backward compatibility: use provided db session
|
||||
if db:
|
||||
# Create a temporary UoW wrapper for the existing session
|
||||
from app.core.uow import SqlAlchemyUnitOfWork
|
||||
uow = SqlAlchemyUnitOfWork(lambda: db)
|
||||
uow._session = db
|
||||
uow.files = GenericFileRepository(db)
|
||||
should_manage_context = False
|
||||
else:
|
||||
raise FileStorageError("Either uow or db session must be provided")
|
||||
|
||||
# 1. Save physical file
|
||||
self._save_physical_file(file, storage_path)
|
||||
|
||||
# Register compensation: delete physical file if database operation fails
|
||||
compensation.register(lambda: self._delete_physical_file(storage_path))
|
||||
|
||||
# 2. Generate access URL
|
||||
access_url = None
|
||||
if context in [UploadContext.AVATAR, UploadContext.APP_ICON]:
|
||||
access_url = f"{settings.FILE_ACCESS_URL_PREFIX}/{file_id}"
|
||||
|
||||
# 3. Create file data
|
||||
file_data = {
|
||||
"id": file_id,
|
||||
"tenant_id": current_user.tenant_id,
|
||||
"created_by": current_user.id,
|
||||
"file_name": filename,
|
||||
"file_ext": file_extension,
|
||||
"file_size": file_size,
|
||||
"mime_type": file.content_type,
|
||||
"context": context.value,
|
||||
"storage_path": str(storage_path),
|
||||
"file_metadata": metadata,
|
||||
"status": "active",
|
||||
"is_public": metadata.get("is_public", False),
|
||||
"access_url": access_url,
|
||||
"reference_count": 0,
|
||||
}
|
||||
|
||||
# 4. Create database record
|
||||
db_file = uow.files.create_file(file_data)
|
||||
|
||||
# 5. Commit transaction (only if we're managing the session)
|
||||
if should_manage_context:
|
||||
uow.commit()
|
||||
elif db:
|
||||
db.commit()
|
||||
|
||||
# Success - clear compensation operations
|
||||
compensation.clear()
|
||||
|
||||
logger.info(f"File upload completed successfully: {filename} (ID: {file_id})")
|
||||
return db_file
|
||||
|
||||
except Exception as e:
|
||||
# Execute compensation operations
|
||||
compensation.execute()
|
||||
|
||||
# Rollback if we're managing the session
|
||||
if db:
|
||||
db.rollback()
|
||||
|
||||
logger.error(f"File upload failed: {str(e)}")
|
||||
raise FileStorageError(f"文件上传失败: {str(e)}")
|
||||
|
||||
def _save_physical_file(self, file: UploadFile, storage_path: Path):
|
||||
"""
|
||||
Save physical file to filesystem.
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
storage_path: Path where file should be saved
|
||||
|
||||
Raises:
|
||||
FileStorageError: If file save fails
|
||||
"""
|
||||
try:
|
||||
# Create directory if it doesn't exist
|
||||
storage_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save file
|
||||
with open(storage_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
logger.info(f"File saved to filesystem: {storage_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save file to filesystem: {str(e)}")
|
||||
raise FileStorageError(f"无法保存文件到磁盘: {str(e)}")
|
||||
|
||||
def _delete_physical_file(self, storage_path: Path):
|
||||
"""
|
||||
Delete physical file (compensation operation).
|
||||
|
||||
Args:
|
||||
storage_path: Path of file to delete
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(storage_path):
|
||||
os.remove(storage_path)
|
||||
logger.info(f"补偿操作:删除文件 {storage_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"删除文件失败: {e}")
|
||||
|
||||
def _restore_file_from_backup(self, backup_path: Path, original_path: Path):
|
||||
"""
|
||||
Restore file from backup (compensation operation).
|
||||
|
||||
Args:
|
||||
backup_path: Path of backup file
|
||||
original_path: Path where file should be restored
|
||||
"""
|
||||
try:
|
||||
if backup_path.exists():
|
||||
shutil.copy2(backup_path, original_path)
|
||||
logger.info(f"补偿操作:从备份恢复文件 {original_path}")
|
||||
# Clean up backup after restoration
|
||||
os.remove(backup_path)
|
||||
logger.debug(f"补偿操作:删除备份文件 {backup_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"恢复文件失败: {e}")
|
||||
|
||||
def upload_files_batch(
|
||||
self,
|
||||
files: List[UploadFile],
|
||||
context: UploadContext,
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
current_user: User,
|
||||
db: Session = None
|
||||
) -> List[UploadResult]:
|
||||
"""
|
||||
Upload multiple files in batch.
|
||||
Individual file failures do not affect other files.
|
||||
|
||||
Args:
|
||||
files: List of uploaded files
|
||||
context: Upload context (avatar, app_icon, etc.)
|
||||
metadata: Additional metadata for the files
|
||||
current_user: The user uploading the files
|
||||
db: Optional database session (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
List[UploadResult]: List of upload results for each file
|
||||
|
||||
Raises:
|
||||
BusinessException: If batch size exceeds limit
|
||||
"""
|
||||
logger.info(f"Starting batch upload: {len(files)} files, context={context}, user={current_user.id}")
|
||||
|
||||
# Validate batch size
|
||||
MAX_BATCH_SIZE = 20
|
||||
if len(files) > MAX_BATCH_SIZE:
|
||||
raise BusinessException(
|
||||
f"批量上传文件数量不能超过 {MAX_BATCH_SIZE} 个",
|
||||
code=BizCode.BAD_REQUEST,
|
||||
context={
|
||||
"file_count": len(files),
|
||||
"max_batch_size": MAX_BATCH_SIZE,
|
||||
"user_id": str(current_user.id),
|
||||
"tenant_id": str(current_user.tenant_id),
|
||||
"context": context
|
||||
}
|
||||
)
|
||||
|
||||
results = []
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# Upload each file independently
|
||||
db_file = self.upload_file(file, context, metadata, current_user, db)
|
||||
|
||||
results.append(UploadResult(
|
||||
success=True,
|
||||
file_id=db_file.id,
|
||||
file_name=file.filename or "unknown",
|
||||
error=None
|
||||
))
|
||||
|
||||
logger.info(f"Batch upload success: {file.filename}")
|
||||
|
||||
except Exception as e:
|
||||
# Log error but continue with other files
|
||||
logger.error(f"Batch upload failed for {file.filename}: {str(e)}")
|
||||
|
||||
results.append(UploadResult(
|
||||
success=False,
|
||||
file_id=None,
|
||||
file_name=file.filename or "unknown",
|
||||
error=str(e)
|
||||
))
|
||||
|
||||
logger.info(f"Batch upload completed: {sum(1 for r in results if r.success)}/{len(files)} successful")
|
||||
return results
|
||||
|
||||
def get_file(
|
||||
self,
|
||||
file_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: Session = None
|
||||
) -> GenericFile:
|
||||
"""
|
||||
Get a file by ID with permission validation.
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file
|
||||
current_user: The user requesting the file
|
||||
db: Optional database session (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
GenericFile: The file record
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
FileAccessDeniedError: If user doesn't have permission
|
||||
"""
|
||||
logger.debug(f"Getting file: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
# Use UoW or provided db session
|
||||
if self.uow:
|
||||
with self.uow:
|
||||
file = self.uow.files.get_file_by_id(file_id)
|
||||
elif db:
|
||||
repository = GenericFileRepository(db)
|
||||
file = repository.get_file_by_id(file_id)
|
||||
else:
|
||||
raise FileStorageError("Either uow or db session must be provided")
|
||||
|
||||
if not file:
|
||||
logger.warning(f"File not found: {file_id}")
|
||||
raise FileNotFoundError(file_id)
|
||||
|
||||
# Check permissions using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_file(file)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.READ,
|
||||
resource,
|
||||
error_message=f"无权访问文件 {file_id}"
|
||||
)
|
||||
except PermissionDeniedException:
|
||||
logger.warning(f"Access denied: file_id={file_id}, user={current_user.id}")
|
||||
raise FileAccessDeniedError(file_id)
|
||||
|
||||
logger.debug(f"File access granted: {file.file_name}")
|
||||
return file
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: uuid.UUID,
|
||||
current_user: User,
|
||||
db: Session = None
|
||||
) -> None:
|
||||
"""
|
||||
Delete a file (both physical file and database record) using UoW pattern with compensation.
|
||||
|
||||
This method uses compensation transactions to ensure data consistency:
|
||||
1. Delete physical file first
|
||||
2. Register compensation to restore file if DB deletion fails
|
||||
3. Delete database record
|
||||
4. Commit transaction
|
||||
5. Clear compensation on success
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to delete
|
||||
current_user: The user requesting deletion
|
||||
db: Optional database session (for backward compatibility)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
FileAccessDeniedError: If user doesn't have permission
|
||||
FileReferencedError: If file is still referenced
|
||||
FileStorageError: If deletion fails
|
||||
"""
|
||||
logger.info(f"Deleting file: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
# Get file and check permissions
|
||||
if self.uow:
|
||||
with self.uow:
|
||||
file = self.uow.files.get_file_by_id(file_id)
|
||||
elif db:
|
||||
repository = GenericFileRepository(db)
|
||||
file = repository.get_file_by_id(file_id)
|
||||
else:
|
||||
raise FileStorageError("Either uow or db session must be provided")
|
||||
|
||||
if not file:
|
||||
logger.warning(f"File not found for deletion: {file_id}")
|
||||
raise FileNotFoundError(file_id)
|
||||
|
||||
# Check permissions using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_file(file)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.DELETE,
|
||||
resource,
|
||||
error_message=f"无权删除文件 {file_id}"
|
||||
)
|
||||
except PermissionDeniedException:
|
||||
logger.warning(f"Delete access denied: file_id={file_id}, user={current_user.id}")
|
||||
raise FileAccessDeniedError(file_id)
|
||||
|
||||
# Check reference count
|
||||
if file.reference_count > 0:
|
||||
logger.warning(f"Cannot delete referenced file: file_id={file_id}, references={file.reference_count}")
|
||||
raise FileReferencedError(file_id, file.reference_count)
|
||||
|
||||
# Store storage path and file content for potential restoration
|
||||
storage_path = Path(file.storage_path)
|
||||
backup_path = None
|
||||
|
||||
# Use compensation handler for atomic deletion
|
||||
compensation = CompensationHandler()
|
||||
|
||||
try:
|
||||
# 1. Backup and delete physical file first
|
||||
if storage_path.exists():
|
||||
# Create backup in temp location
|
||||
backup_path = storage_path.parent / f".backup_{file_id}{storage_path.suffix}"
|
||||
shutil.copy2(storage_path, backup_path)
|
||||
logger.debug(f"Created backup: {backup_path}")
|
||||
|
||||
# Delete original file
|
||||
os.remove(storage_path)
|
||||
logger.info(f"Physical file deleted: {storage_path}")
|
||||
|
||||
# Register compensation: restore file from backup if DB deletion fails
|
||||
compensation.register(lambda: self._restore_file_from_backup(backup_path, storage_path))
|
||||
else:
|
||||
logger.warning(f"Physical file not found: {storage_path}")
|
||||
|
||||
# 2. Delete database record (soft delete)
|
||||
if self.uow:
|
||||
with self.uow:
|
||||
self.uow.files.delete_file(file_id)
|
||||
self.uow.commit()
|
||||
elif db:
|
||||
repository = GenericFileRepository(db)
|
||||
repository.delete_file(file_id)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"File record deleted successfully: {file.file_name} (ID: {file_id})")
|
||||
|
||||
# 3. Success - clear compensations and remove backup
|
||||
compensation.clear()
|
||||
if backup_path and backup_path.exists():
|
||||
os.remove(backup_path)
|
||||
logger.debug(f"Removed backup: {backup_path}")
|
||||
|
||||
except Exception as e:
|
||||
# Execute compensation to restore file
|
||||
compensation.execute()
|
||||
|
||||
# Rollback database if using db session
|
||||
if db:
|
||||
db.rollback()
|
||||
|
||||
logger.error(f"Failed to delete file: {str(e)}")
|
||||
raise FileStorageError(f"无法删除文件: {str(e)}")
|
||||
|
||||
def update_file_metadata(
|
||||
self,
|
||||
file_id: uuid.UUID,
|
||||
update_data: Dict[str, Any],
|
||||
current_user: User,
|
||||
db: Session = None
|
||||
) -> GenericFile:
|
||||
"""
|
||||
Update file metadata using UoW pattern.
|
||||
|
||||
Args:
|
||||
file_id: UUID of the file to update
|
||||
update_data: Dictionary containing fields to update
|
||||
current_user: The user requesting the update
|
||||
db: Optional database session (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
GenericFile: The updated file record
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
FileAccessDeniedError: If user doesn't have permission
|
||||
"""
|
||||
logger.info(f"Updating file metadata: file_id={file_id}, user={current_user.id}")
|
||||
|
||||
# Get file and check permissions
|
||||
if self.uow:
|
||||
with self.uow:
|
||||
file = self.uow.files.get_file_by_id(file_id)
|
||||
elif db:
|
||||
repository = GenericFileRepository(db)
|
||||
file = repository.get_file_by_id(file_id)
|
||||
else:
|
||||
raise FileStorageError("Either uow or db session must be provided")
|
||||
|
||||
if not file:
|
||||
logger.warning(f"File not found for update: {file_id}")
|
||||
raise FileNotFoundError(file_id)
|
||||
|
||||
# Check permissions using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_file(file)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.UPDATE,
|
||||
resource,
|
||||
error_message=f"无权更新文件 {file_id}"
|
||||
)
|
||||
except PermissionDeniedException:
|
||||
logger.warning(f"Update access denied: file_id={file_id}, user={current_user.id}")
|
||||
raise FileAccessDeniedError(file_id)
|
||||
|
||||
# Filter allowed fields for update
|
||||
# Users can only update: file_name, file_metadata, is_public
|
||||
allowed_fields = ["file_name", "file_metadata", "is_public"]
|
||||
filtered_update_data = {
|
||||
key: value for key, value in update_data.items()
|
||||
if key in allowed_fields
|
||||
}
|
||||
|
||||
if not filtered_update_data:
|
||||
logger.warning(f"No valid fields to update for file: {file_id}")
|
||||
return file
|
||||
|
||||
# Update file metadata
|
||||
try:
|
||||
if self.uow:
|
||||
with self.uow:
|
||||
updated_file = self.uow.files.update_file(file_id, filtered_update_data)
|
||||
self.uow.commit()
|
||||
elif db:
|
||||
repository = GenericFileRepository(db)
|
||||
updated_file = repository.update_file(file_id, filtered_update_data)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"File metadata updated successfully: {file.file_name} (ID: {file_id})")
|
||||
return updated_file
|
||||
|
||||
except Exception as e:
|
||||
if db:
|
||||
db.rollback()
|
||||
logger.error(f"Failed to update file metadata: {str(e)}")
|
||||
raise FileStorageError(f"无法更新文件元数据: {str(e)}")
|
||||
570
api/app/services/user_service.py
Normal file
570
api/app/services/user_service.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import datetime
|
||||
import secrets
|
||||
import string
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.repositories import user_repository
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.tenant_schema import TenantCreate
|
||||
from app.services.tenant_service import TenantService
|
||||
from app.services.session_service import SessionService
|
||||
from app.core.security import get_password_hash, verify_password
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.core.error_codes import BizCode
|
||||
# from app.services import workspace_service
|
||||
# from app.schemas.workspace_schema import WorkspaceCreate
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def create_initial_superuser(db: Session):
|
||||
business_logger.info("检查并创建初始超级用户")
|
||||
|
||||
superuser = user_repository.get_superuser(db)
|
||||
if superuser:
|
||||
business_logger.info("超级用户已存在,跳过创建")
|
||||
return
|
||||
|
||||
user_in = UserCreate(
|
||||
username=settings.FIRST_SUPERUSER_USERNAME,
|
||||
email=settings.FIRST_SUPERUSER_EMAIL,
|
||||
password=settings.FIRST_SUPERUSER_PASSWORD,
|
||||
)
|
||||
|
||||
try:
|
||||
business_logger.debug("开始创建初始租户")
|
||||
# Create a default tenant for the superuser
|
||||
default_tenant = TenantCreate(
|
||||
name=f"{user_in.username}'s Tenant",
|
||||
description=f"Default tenant for {user_in.username}",
|
||||
)
|
||||
# Create tenant service and create tenant with user assignment
|
||||
tenant_service = TenantService(db)
|
||||
tenant = tenant_service.create_tenant(default_tenant)
|
||||
db.flush()
|
||||
business_logger.debug("开始创建初始超级用户")
|
||||
|
||||
hashed_password = get_password_hash(user_in.password)
|
||||
superuser = user_repository.create_user(
|
||||
db=db, user=user_in, hashed_password=hashed_password, is_superuser=True,
|
||||
tenant_id=tenant.id
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(superuser)
|
||||
business_logger.info(f"初始超级用户创建成功: {superuser.username} (ID: {superuser.id})")
|
||||
return superuser
|
||||
except Exception as e:
|
||||
business_logger.error(f"初始超级用户创建失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(
|
||||
f"初始超级用户创建失败: {str(e)}",
|
||||
code=BizCode.DB_ERROR,
|
||||
context={"username": username, "email": email},
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
def create_user(db: Session, user: UserCreate) -> User:
|
||||
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={"username": user.username, "email": user.email}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
business_logger.warning(f"邮箱已注册: {user.email}")
|
||||
raise BusinessException(
|
||||
"邮箱已注册",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={"email": user.email, "username": user.username}
|
||||
)
|
||||
|
||||
# 创建普通用户,需要有默认租户
|
||||
business_logger.debug(f"开始创建用户: {user.username}")
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
# 获取默认租户(第一个活跃租户)
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
tenant_repo = TenantRepository(db)
|
||||
tenants = tenant_repo.get_tenants(skip=0, limit=1, is_active=True)
|
||||
|
||||
if not tenants:
|
||||
business_logger.error("系统中没有可用的租户")
|
||||
raise BusinessException(
|
||||
"系统配置错误:没有可用的租户",
|
||||
code=BizCode.TENANT_NOT_FOUND,
|
||||
context={"username": user.username, "email": user.email}
|
||||
)
|
||||
|
||||
default_tenant = tenants[0]
|
||||
|
||||
new_user = user_repository.create_user(
|
||||
db=db, user=user, hashed_password=hashed_password,
|
||||
tenant_id=default_tenant.id, is_superuser=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
business_logger.info(f"用户创建成功: {new_user.username} (ID: {new_user.id})")
|
||||
return new_user
|
||||
except Exception as e:
|
||||
business_logger.error(f"用户创建失败: {user.username} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(
|
||||
f"用户创建失败: {user.username} - {str(e)}",
|
||||
code=BizCode.DB_ERROR,
|
||||
context={"username": user.username, "email": user.email},
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
|
||||
business_logger.info(f"创建超级管理员: {user.username}, email: {user.email}")
|
||||
|
||||
# 检查当前用户是否为超级管理员
|
||||
from app.core.permissions import permission_service, Subject
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员才能创建超级管理员用户"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超级管理员尝试创建超级管理员用户: {user.username}")
|
||||
raise BusinessException(
|
||||
str(e),
|
||||
code=BizCode.FORBIDDEN,
|
||||
context={
|
||||
"current_user_id": str(current_user.id),
|
||||
"current_user_username": current_user.username,
|
||||
"target_username": user.username
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"created_by": str(current_user.id)
|
||||
}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
business_logger.warning(f"邮箱已注册: {user.email}")
|
||||
raise BusinessException(
|
||||
"邮箱已注册",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
"created_by": str(current_user.id)
|
||||
}
|
||||
)
|
||||
|
||||
# 创建超级管理员用户并加入当前用户的租户
|
||||
business_logger.debug(f"开始创建超级管理员: {user.username}")
|
||||
hashed_password = get_password_hash(user.password)
|
||||
|
||||
new_user = user_repository.create_user(
|
||||
db=db, user=user, hashed_password=hashed_password,
|
||||
tenant_id=current_user.tenant_id, is_superuser=True
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
business_logger.info(f"超级管理员创建成功: {new_user.username} (ID: {new_user.id}), 已加入租户: {current_user.tenant_id}")
|
||||
return new_user
|
||||
except Exception as e:
|
||||
business_logger.error(f"超级管理员创建失败: {user.username} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(
|
||||
f"超级管理员创建失败: {user.username} - {str(e)}",
|
||||
code=BizCode.DB_ERROR,
|
||||
context={
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"created_by": str(current_user.id),
|
||||
"tenant_id": str(current_user.tenant_id)
|
||||
},
|
||||
cause=e
|
||||
)
|
||||
|
||||
|
||||
def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: User) -> User:
|
||||
business_logger.info(f"停用用户: user_id={user_id_to_deactivate}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查找用户
|
||||
business_logger.debug(f"查找待停用用户: {user_id_to_deactivate}")
|
||||
db_user = user_repository.get_user_by_id(db, user_id=user_id_to_deactivate)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id_to_deactivate}")
|
||||
raise BusinessException(
|
||||
"用户不存在",
|
||||
code=BizCode.USER_NOT_FOUND,
|
||||
context={"user_id": str(user_id_to_deactivate)}
|
||||
)
|
||||
|
||||
# 权限检查 using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_user(db_user)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.DEACTIVATE,
|
||||
resource,
|
||||
error_message="没有权限停用该用户"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试停用用户 {user_id_to_deactivate}")
|
||||
raise BusinessException(
|
||||
str(e),
|
||||
code=BizCode.FORBIDDEN,
|
||||
context={
|
||||
"current_user_id": str(current_user.id),
|
||||
"current_user_username": current_user.username,
|
||||
"target_user_id": str(user_id_to_deactivate)
|
||||
}
|
||||
)
|
||||
# 检查用户类型,如果是超级管理员,判断一下不是唯一的一个
|
||||
if db_user.is_superuser:
|
||||
is_only_superuser = user_repository.check_superuser_only(db)
|
||||
if is_only_superuser:
|
||||
business_logger.warning(f"停用超级管理员用户: {db_user.username} (ID: {user_id_to_deactivate})")
|
||||
raise BusinessException(
|
||||
"不能停用唯一的超级管理员用户",
|
||||
code=BizCode.FORBIDDEN,
|
||||
context={
|
||||
"user_id": str(user_id_to_deactivate),
|
||||
"username": db_user.username
|
||||
}
|
||||
)
|
||||
|
||||
# 停用用户
|
||||
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
|
||||
db_user.is_active = False
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
business_logger.info(f"用户停用成功: {db_user.username} (ID: {user_id_to_deactivate})")
|
||||
return db_user
|
||||
except Exception as e:
|
||||
business_logger.error(f"用户停用失败: user_id={user_id_to_deactivate} - {str(e)}")
|
||||
db.rollback()
|
||||
if isinstance(e, BusinessException):
|
||||
raise e
|
||||
raise BusinessException(f"{str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: User) -> User:
|
||||
business_logger.info(f"激活用户: user_id={user_id_to_activate}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查找用户
|
||||
business_logger.debug(f"查找待激活用户: {user_id_to_activate}")
|
||||
db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id_to_activate}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 权限检查 using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_user(db_user)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.ACTIVATE,
|
||||
resource,
|
||||
error_message="没有权限激活该用户"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试激活用户 {user_id_to_activate}")
|
||||
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 激活用户
|
||||
business_logger.debug(f"执行用户激活: {db_user.username} (ID: {user_id_to_activate})")
|
||||
db_user.is_active = True
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
business_logger.info(f"用户激活成功: {db_user.username} (ID: {user_id_to_activate})")
|
||||
return db_user
|
||||
except Exception as e:
|
||||
business_logger.error(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def get_user(db: Session, user_id: uuid.UUID, current_user: User) -> User:
|
||||
business_logger.info(f"获取用户信息: user_id={user_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 查找用户
|
||||
business_logger.debug(f"查找用户: {user_id}")
|
||||
db_user = user_repository.get_user_by_id(db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 权限检查 using permission service
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
resource = Resource.from_user(db_user)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.READ,
|
||||
resource,
|
||||
error_message="没有权限获取该用户信息"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试获取用户 {user_id} 信息")
|
||||
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 返回用户信息
|
||||
business_logger.debug(f"返回用户信息: {db_user.username} (ID: {user_id})")
|
||||
return db_user
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户信息失败: user_id={user_id} - {str(e)}")
|
||||
raise BusinessException(f"获取用户信息失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def get_tenant_superusers(db: Session, current_user: User, include_inactive: bool = True) -> list[User]:
|
||||
"""获取当前租户下的超管账号列表"""
|
||||
business_logger.info(f"获取租户超管列表: tenant_id={current_user.tenant_id}, 请求者: {current_user.username}, include_inactive={include_inactive}")
|
||||
|
||||
try:
|
||||
# 检查当前用户是否有权限查看(只有超管才能查看超管列表)
|
||||
from app.core.permissions import permission_service, Subject
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员才能查看超管列表"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超级管理员尝试查看超管列表: {current_user.username}")
|
||||
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
|
||||
|
||||
# 检查用户是否有租户
|
||||
if not current_user.tenant_id:
|
||||
business_logger.warning(f"用户没有租户信息: {current_user.username}")
|
||||
raise BusinessException("用户没有租户信息", code=BizCode.TENANT_NOT_FOUND)
|
||||
|
||||
# 获取租户下的超管列表
|
||||
business_logger.debug(f"查询租户超管: tenant_id={current_user.tenant_id}, include_inactive={include_inactive}")
|
||||
is_active_filter = None if include_inactive else True
|
||||
superusers = user_repository.get_superusers_by_tenant(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
is_active=is_active_filter
|
||||
)
|
||||
|
||||
business_logger.info(f"租户超管查询成功: tenant_id={current_user.tenant_id}, count={len(superusers)}")
|
||||
return superusers
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}")
|
||||
raise BusinessException(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
|
||||
"""更新用户的最后登录时间"""
|
||||
business_logger.info(f"更新用户最后登录时间: user_id={user_id}")
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 更新最后登录时间
|
||||
db_user.last_login_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
business_logger.info(f"用户最后登录时间更新成功: {db_user.username} (ID: {user_id})")
|
||||
return db_user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"更新用户最后登录时间失败: user_id={user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
|
||||
"""普通用户修改自己的密码"""
|
||||
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
|
||||
|
||||
# 检查权限:只能修改自己的密码
|
||||
if current_user.id != user_id:
|
||||
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only change your own password"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, db_user.hashed_password):
|
||||
business_logger.warning(f"用户旧密码验证失败: {user_id}")
|
||||
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 更新密码
|
||||
db_user.hashed_password = get_password_hash(new_password)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
# 使所有旧 tokens 失效
|
||||
await SessionService.invalidate_all_user_tokens(str(user_id))
|
||||
|
||||
business_logger.info(f"用户密码修改成功: {db_user.username} (ID: {user_id})")
|
||||
return db_user
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
|
||||
"""
|
||||
超级管理员修改指定用户的密码
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
target_user_id: 目标用户ID
|
||||
new_password: 新密码,如果为None则自动生成随机密码
|
||||
current_user: 当前用户(超级管理员)
|
||||
|
||||
Returns:
|
||||
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
|
||||
"""
|
||||
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
|
||||
|
||||
# 检查权限:只有超级管理员可以修改他人密码
|
||||
from app.core.permissions import permission_service, Subject
|
||||
|
||||
subject = Subject.from_user(current_user)
|
||||
try:
|
||||
permission_service.check_superuser(
|
||||
subject,
|
||||
error_message="只有超级管理员可以修改他人密码"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
|
||||
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
|
||||
|
||||
try:
|
||||
# 获取目标用户
|
||||
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
|
||||
if not target_user:
|
||||
business_logger.warning(f"目标用户不存在: {target_user_id}")
|
||||
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查租户权限:超管只能修改同租户用户的密码
|
||||
if current_user.tenant_id != target_user.tenant_id:
|
||||
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
|
||||
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
|
||||
|
||||
# 如果没有提供新密码,则生成随机密码
|
||||
actual_password = new_password if new_password else generate_random_password()
|
||||
|
||||
# 更新密码
|
||||
target_user.hashed_password = get_password_hash(actual_password)
|
||||
db.commit()
|
||||
db.refresh(target_user)
|
||||
|
||||
# 使所有旧 tokens 失效
|
||||
await SessionService.invalidate_all_user_tokens(str(target_user_id))
|
||||
|
||||
password_type = "指定密码" if new_password else "随机生成密码"
|
||||
business_logger.info(f"管理员修改用户密码成功: admin={current_user.username}, target={target_user.username} (ID: {target_user_id}), 类型={password_type}")
|
||||
return target_user, actual_password
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
|
||||
def generate_random_password(length: int = 12) -> str:
|
||||
"""
|
||||
生成随机密码
|
||||
|
||||
Args:
|
||||
length: 密码长度,默认12位
|
||||
|
||||
Returns:
|
||||
str: 生成的随机密码
|
||||
"""
|
||||
# 确保密码包含大小写字母、数字和特殊字符
|
||||
lowercase = string.ascii_lowercase
|
||||
uppercase = string.ascii_uppercase
|
||||
digits = string.digits
|
||||
special_chars = "!@#$%^&*"
|
||||
|
||||
# 确保至少包含每种字符类型
|
||||
password = [
|
||||
secrets.choice(lowercase),
|
||||
secrets.choice(uppercase),
|
||||
secrets.choice(digits),
|
||||
secrets.choice(special_chars)
|
||||
]
|
||||
|
||||
# 填充剩余长度
|
||||
all_chars = lowercase + uppercase + digits + special_chars
|
||||
for _ in range(length - 4):
|
||||
password.append(secrets.choice(all_chars))
|
||||
|
||||
# 打乱顺序
|
||||
secrets.SystemRandom().shuffle(password)
|
||||
|
||||
return ''.join(password)
|
||||
776
api/app/services/workspace_service.py
Normal file
776
api/app/services/workspace_service.py
Normal file
@@ -0,0 +1,776 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import secrets
|
||||
import hashlib
|
||||
import datetime
|
||||
from fastapi import HTTPException, status
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
|
||||
from app.schemas.workspace_schema import (
|
||||
WorkspaceCreate,
|
||||
WorkspaceUpdate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceInviteResponse,
|
||||
InviteValidateResponse,
|
||||
InviteAcceptRequest,
|
||||
WorkspaceMemberUpdate
|
||||
)
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.config import settings
|
||||
from app.services import user_service
|
||||
from os import getenv
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
import os #
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
def switch_workspace(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""切换工作空间"""
|
||||
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
|
||||
|
||||
# 检查用户是否为成员或超级管理员
|
||||
_check_workspace_member_permission(db, workspace_id, user)
|
||||
|
||||
# 更新当前用户的工作空间上下文
|
||||
try:
|
||||
user.current_workspace_id = workspace_id
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功切换工作空间为 {workspace_id}")
|
||||
return
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"切换工作空间失败 - 工作空间: {workspace_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def delete_workspace_member(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
member_id: uuid.UUID,
|
||||
user: User,
|
||||
):
|
||||
"""删除工作空间成员"""
|
||||
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
workspace_member.user.current_workspace_id = None
|
||||
db.commit()
|
||||
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
|
||||
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
"""获取当前用户参与的所有工作空间"""
|
||||
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
|
||||
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||
return workspaces
|
||||
|
||||
|
||||
def _create_workspace_only(
|
||||
db: Session, workspace: WorkspaceCreate, owner: User
|
||||
) -> Workspace:
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
|
||||
|
||||
try:
|
||||
# Create the workspace without adding any members
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||
db_workspace = workspace_repository.create_workspace(
|
||||
db=db, workspace=workspace, tenant_id=owner.tenant_id
|
||||
)
|
||||
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {owner.username}")
|
||||
return db_workspace
|
||||
except Exception as e:
|
||||
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
def create_workspace(
|
||||
db: Session, workspace: WorkspaceCreate, user: User
|
||||
) -> Workspace:
|
||||
business_logger.info(
|
||||
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
|
||||
f"storage_type: {workspace.storage_type}"
|
||||
)
|
||||
llm=workspace.llm
|
||||
embedding=workspace.embedding
|
||||
rerank=workspace.rerank
|
||||
try:
|
||||
# Create the workspace without adding any members
|
||||
business_logger.debug(f"创建工作空间: {workspace.name}")
|
||||
db_workspace = workspace_repository.create_workspace(
|
||||
db=db, workspace=workspace, tenant_id=user.tenant_id
|
||||
)
|
||||
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
|
||||
db.commit()
|
||||
db.refresh(db_workspace)
|
||||
|
||||
# 如果 storage_type 是 "rag",自动创建知识库
|
||||
if workspace.storage_type == "rag":
|
||||
business_logger.info(
|
||||
f"检测到 storage_type 为 'rag',开始为工作空间 "
|
||||
f"{db_workspace.id} 创建知识库"
|
||||
)
|
||||
try:
|
||||
import os
|
||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||
from app.repositories import knowledge_repository
|
||||
|
||||
# 创建知识库数据
|
||||
knowledge_data = KnowledgeCreate(
|
||||
workspace_id=db_workspace.id,
|
||||
created_by=user.id,
|
||||
parent_id=db_workspace.id,
|
||||
name="USER_RAG_MERORY",
|
||||
description=f"工作空间 {workspace.name} 的默认知识库",
|
||||
avatar='',
|
||||
type=KnowledgeType.General,
|
||||
permission_id=PermissionType.Private,
|
||||
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
|
||||
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
|
||||
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
parser_config={
|
||||
"layout_recognize": "DeepDOC",
|
||||
"chunk_token_num": 256,
|
||||
"delimiter": "\n",
|
||||
"auto_keywords": 0,
|
||||
"auto_questions": 0,
|
||||
"html4excel": False
|
||||
}
|
||||
)
|
||||
|
||||
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
|
||||
db_knowledge = knowledge_repository.create_knowledge(
|
||||
db=db,
|
||||
knowledge=knowledge_data
|
||||
)
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"为工作空间 {db_workspace.id} 自动创建知识库成功: "
|
||||
f"{db_knowledge.name} (ID: {db_knowledge.id})"
|
||||
)
|
||||
except Exception as kb_error:
|
||||
business_logger.error(
|
||||
f"为工作空间 {db_workspace.id} 创建知识库失败: {str(kb_error)}"
|
||||
)
|
||||
db.rollback()
|
||||
raise BusinessException(
|
||||
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
|
||||
BizCode.INTERNAL_ERROR
|
||||
)
|
||||
|
||||
return db_workspace
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def update_workspace(
|
||||
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
|
||||
) -> Workspace:
|
||||
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
|
||||
try:
|
||||
# 更新工作空间
|
||||
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
|
||||
update_data = workspace_in.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_workspace, field, value)
|
||||
|
||||
db.add(db_workspace)
|
||||
db.commit()
|
||||
db.refresh(db_workspace)
|
||||
business_logger.info(f"工作空间更新成功: {db_workspace.name} (ID: {workspace_id})")
|
||||
return db_workspace
|
||||
except Exception as e:
|
||||
business_logger.error(f"工作空间更新失败: workspace_id={workspace_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_members(
|
||||
db: Session, workspace_id: uuid.UUID, user: User
|
||||
) -> List[WorkspaceMember]:
|
||||
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
|
||||
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
# 查找工作空间
|
||||
business_logger.debug(f"查找工作空间: {workspace_id}")
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
business_logger.warning(f"工作空间不存在: {workspace_id}")
|
||||
raise BusinessException(
|
||||
message="Workspace not found",
|
||||
code=BizCode.WORKSPACE_NOT_FOUND
|
||||
)
|
||||
|
||||
# 权限检查:工作空间成员或超级管理员可以查看成员列表
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
workspace_memberships = {workspace_id} if member else set()
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(workspace)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.READ,
|
||||
resource,
|
||||
error_message=f"用户 {user.username} 没有查看工作空间 {workspace_id} 成员列表的权限"
|
||||
)
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(
|
||||
f"权限不足: 用户 {user.username} 尝试获取工作空间 {workspace_id} 成员列表"
|
||||
)
|
||||
raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED)
|
||||
|
||||
# 查询成员并预加载 user/workspace 关系
|
||||
members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||
business_logger.info(f"工作空间成员数量: {len(members)} - workspace_id={workspace_id}")
|
||||
return members
|
||||
|
||||
|
||||
|
||||
# ==================== 邀请相关服务方法 ====================
|
||||
|
||||
def _generate_invite_token() -> tuple[str, str]:
|
||||
"""生成邀请令牌和其哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (原始令牌, 令牌哈希)
|
||||
"""
|
||||
# 生成32字节的随机令牌
|
||||
token = secrets.token_urlsafe(32)
|
||||
# 生成令牌的SHA256哈希
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
return token, token_hash
|
||||
|
||||
|
||||
def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None:
|
||||
"""检查用户是否为工作空间成员或超级管理员(使用统一权限服务)"""
|
||||
# 获取工作空间信息
|
||||
db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not db_workspace:
|
||||
raise BusinessException(
|
||||
message="Workspace not found",
|
||||
code=BizCode.WORKSPACE_NOT_FOUND
|
||||
)
|
||||
|
||||
# 使用统一权限服务检查访问权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 任何成员都有访问权限
|
||||
workspace_memberships = {workspace_id} if member else set()
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(db_workspace)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.READ,
|
||||
resource,
|
||||
error_message=f"用户 {user.username} 不是工作空间 {workspace_id} 的成员"
|
||||
)
|
||||
business_logger.debug(f"用户 {user.username} 是工作空间 {workspace_id} 的成员或超级管理员")
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"权限不足: 用户 {user.username} 尝试访问工作空间 {workspace_id}")
|
||||
raise BusinessException(str(e), BizCode.WORKSPACE_NO_ACCESS)
|
||||
return db_workspace
|
||||
|
||||
|
||||
def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None:
|
||||
"""检查用户是否有工作空间管理员权限(使用统一权限服务)"""
|
||||
# 获取工作空间信息
|
||||
db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not db_workspace:
|
||||
raise BusinessException(
|
||||
message="Workspace not found",
|
||||
code=BizCode.WORKSPACE_NOT_FOUND
|
||||
)
|
||||
|
||||
# 使用统一权限服务检查管理权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 只有 manager 才有管理权限
|
||||
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(db_workspace)
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
Action.MANAGE,
|
||||
resource,
|
||||
error_message=f"用户 {user.username} 没有管理工作空间 {workspace_id} 的权限"
|
||||
)
|
||||
business_logger.debug(f"用户 {user.username} 有权限管理工作空间 {workspace_id}")
|
||||
except PermissionDeniedException as e:
|
||||
business_logger.warning(f"权限不足: 用户 {user.username} 尝试管理工作空间 {workspace_id}")
|
||||
raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED)
|
||||
return db_workspace
|
||||
|
||||
|
||||
def create_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
user: User
|
||||
) -> WorkspaceInviteResponse:
|
||||
"""创建工作空间邀请"""
|
||||
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
# 检查被邀请用户是否已经在工作空间中
|
||||
from app.repositories import user_repository
|
||||
invited_user = user_repository.get_user_by_email(db, invite_data.email)
|
||||
|
||||
if invited_user:
|
||||
# 用户存在,检查是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=invited_user.id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if existing_member:
|
||||
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
|
||||
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
|
||||
|
||||
# 检查是否已有待处理的邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
|
||||
email=invite_data.email,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
invite_token = None
|
||||
if existing_invite:
|
||||
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
|
||||
# 生成新的邀请链接(重新生成令牌)
|
||||
token, token_hash = _generate_invite_token()
|
||||
existing_invite.token_hash = token_hash
|
||||
existing_invite.updated_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
db.refresh(existing_invite)
|
||||
invite_token = token
|
||||
else:
|
||||
# 生成邀请令牌
|
||||
token, token_hash = _generate_invite_token()
|
||||
# 创建邀请
|
||||
db_invite = invite_repo.create_invite(
|
||||
workspace_id=workspace_id,
|
||||
invite_data=invite_data,
|
||||
token_hash=token_hash,
|
||||
created_by_user_id=user.id
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(db_invite)
|
||||
invite_token = token
|
||||
|
||||
invite_obj = existing_invite or db_invite
|
||||
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
|
||||
|
||||
# 构造响应
|
||||
response = WorkspaceInviteResponse.model_validate(invite_obj)
|
||||
response.invite_token = invite_token
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_invites(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
status: Optional[InviteStatus] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[WorkspaceInviteResponse]:
|
||||
"""获取工作空间邀请列表"""
|
||||
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
|
||||
|
||||
# 检查工作空间是否存在
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
# 获取邀请列表
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invites = invite_repo.get_workspace_invites(
|
||||
workspace_id=workspace_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
offset=offset
|
||||
)
|
||||
|
||||
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
|
||||
|
||||
|
||||
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
"""验证邀请令牌"""
|
||||
business_logger.info(f"验证邀请令牌")
|
||||
|
||||
# 生成令牌哈希
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
# 查找邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
# 检查邀请状态和过期时间
|
||||
now = datetime.datetime.now()
|
||||
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
|
||||
is_valid = not is_expired
|
||||
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
|
||||
|
||||
return InviteValidateResponse(
|
||||
workspace_name=workspace.name,
|
||||
workspace_id=invite.workspace_id,
|
||||
email=invite.email,
|
||||
role=WorkspaceRole(invite.role),
|
||||
is_expired=is_expired,
|
||||
is_valid=is_valid
|
||||
)
|
||||
|
||||
|
||||
def accept_workspace_invite(
|
||||
db: Session,
|
||||
accept_request: InviteAcceptRequest,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""接受工作空间邀请"""
|
||||
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
# 生成令牌哈希
|
||||
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
|
||||
|
||||
# 查找邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
# 检查邀请状态
|
||||
if invite.status != InviteStatus.pending:
|
||||
business_logger.warning(f"邀请已被处理: status={invite.status}")
|
||||
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
|
||||
|
||||
# 检查过期时间
|
||||
now = datetime.datetime.now()
|
||||
if invite.expires_at < now:
|
||||
business_logger.warning(f"邀请已过期")
|
||||
# 标记为过期
|
||||
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
|
||||
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
|
||||
|
||||
# 检查邮箱是否匹配
|
||||
if invite.email != user.email:
|
||||
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
|
||||
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
|
||||
|
||||
# 如果启用单工作空间模式,检查用户是否已有工作空间
|
||||
if settings.ENABLE_SINGLE_WORKSPACE:
|
||||
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
|
||||
if user_workspaces:
|
||||
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
|
||||
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
|
||||
|
||||
# 检查用户是否已经是工作空间成员
|
||||
existing_member = workspace_repository.get_member_in_workspace(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
workspace_id=invite.workspace_id
|
||||
)
|
||||
|
||||
if existing_member:
|
||||
business_logger.info(f"用户已是工作空间成员,更新邀请状态")
|
||||
invite_repo.update_invite_status(
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
accepted_at=now
|
||||
)
|
||||
db.commit()
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
return {
|
||||
"message": "You are already a member of this workspace",
|
||||
"workspace": workspace
|
||||
}
|
||||
|
||||
# 将角色映射到工作空间角色(现在直接使用相同的角色)
|
||||
workspace_role = invite.role
|
||||
|
||||
# 添加用户到工作空间
|
||||
workspace_repository.add_member_to_workspace(
|
||||
db=db,
|
||||
user_id=user.id,
|
||||
workspace_id=invite.workspace_id,
|
||||
role=workspace_role
|
||||
)
|
||||
|
||||
# 标记邀请为已接受
|
||||
invite_repo.update_invite_status(
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
accepted_at=now
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 获取工作空间信息
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
|
||||
|
||||
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
|
||||
|
||||
return {
|
||||
"message": "Successfully joined the workspace",
|
||||
"workspace": workspace,
|
||||
"role": workspace_role
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def revoke_workspace_invite(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
invite_id: uuid.UUID,
|
||||
user: User
|
||||
) -> dict:
|
||||
"""撤销工作空间邀请"""
|
||||
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
|
||||
|
||||
try:
|
||||
# 检查权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
# 撤销邀请
|
||||
invite_repo = WorkspaceInviteRepository(db)
|
||||
invite = invite_repo.revoke_invite(invite_id)
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
|
||||
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
if invite.workspace_id != workspace_id:
|
||||
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
|
||||
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
|
||||
|
||||
db.commit()
|
||||
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
|
||||
return {"message": "邀请撤销成功"}
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def update_workspace_member_roles(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
user: User,
|
||||
) -> List[WorkspaceMember]:
|
||||
"""更新工作空间成员角色"""
|
||||
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
|
||||
|
||||
# 检查管理员权限
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
# 获取所有当前成员
|
||||
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||
member_map = {m.id: m for m in all_members}
|
||||
|
||||
# 验证和业务规则检查
|
||||
update_ids = set()
|
||||
for upd in updates:
|
||||
# 检查成员是否存在
|
||||
if upd.id not in member_map:
|
||||
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
member = member_map[upd.id]
|
||||
|
||||
# 检查成员是否属于该工作空间
|
||||
if member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
|
||||
# 不能修改自己的角色
|
||||
if member.user_id == user.id:
|
||||
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
|
||||
|
||||
update_ids.add(upd.id)
|
||||
|
||||
# 检查是否至少保留一个 manager
|
||||
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
|
||||
managers_after_update = [
|
||||
m for m in all_members
|
||||
if m.id not in update_ids and m.role == WorkspaceRole.manager
|
||||
]
|
||||
|
||||
# 添加更新后会成为 manager 的成员
|
||||
for upd in updates:
|
||||
if upd.role == WorkspaceRole.manager:
|
||||
managers_after_update.append(member_map[upd.id])
|
||||
|
||||
if len(managers_after_update) == 0:
|
||||
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
|
||||
|
||||
# 执行更新
|
||||
try:
|
||||
for upd in updates:
|
||||
workspace_repository.update_member_role_by_id(
|
||||
db=db,
|
||||
id=upd.id,
|
||||
role=upd.role,
|
||||
)
|
||||
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
|
||||
|
||||
db.commit()
|
||||
|
||||
# 重新获取更新后的成员列表
|
||||
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
|
||||
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
|
||||
|
||||
return updated_members
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise BusinessException(f"更新成员角色失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def get_workspace_storage_type(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
) -> Optional[str]:
|
||||
"""获取工作空间的存储类型
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
user: 当前用户
|
||||
|
||||
Returns:
|
||||
storage_type: 存储类型字符串,如果未设置则返回 None
|
||||
"""
|
||||
business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的存储类型")
|
||||
|
||||
# 检查用户是否有权限访问该工作空间
|
||||
_check_workspace_member_permission(db, workspace_id, user)
|
||||
|
||||
# 查询工作空间
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
raise BusinessException(
|
||||
code=BizCode.WORKSPACE_NOT_FOUND,
|
||||
message="工作空间不存在"
|
||||
)
|
||||
|
||||
business_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {workspace.storage_type}")
|
||||
return workspace.storage_type
|
||||
|
||||
|
||||
def get_workspace_models_configs(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
user: User,
|
||||
) -> Optional[dict]:
|
||||
"""获取工作空间的模型配置(llm, embedding, rerank)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
user: 当前用户
|
||||
|
||||
Returns:
|
||||
dict: 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
|
||||
"""
|
||||
business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的模型配置")
|
||||
|
||||
# 检查用户是否有权限访问该工作空间
|
||||
_check_workspace_member_permission(db, workspace_id, user)
|
||||
|
||||
# 查询工作空间模型配置
|
||||
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
|
||||
|
||||
if configs is None:
|
||||
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
raise BusinessException(
|
||||
code=BizCode.WORKSPACE_NOT_FOUND,
|
||||
message="工作空间不存在"
|
||||
)
|
||||
|
||||
business_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return configs
|
||||
Reference in New Issue
Block a user