feat: Add base project structure with API and web components

This commit is contained in:
Ke Sun
2025-12-02 20:28:01 +08:00
parent f3de6d6cc9
commit c1adc62ec6
817 changed files with 111226 additions and 106 deletions

View File

View File

@@ -0,0 +1,116 @@
"""
Agent 配置格式转换器
用于将 Pydantic 模型转换为数据库存储格式
"""
from typing import Dict, Any, Optional
from app.schemas.app_schema import (
KnowledgeRetrievalConfig,
MemoryConfig,
VariableDefinition,
ToolConfig,
AgentConfigCreate,
AgentConfigUpdate,
)
class AgentConfigConverter:
"""Agent 配置格式转换器"""
@staticmethod
def to_storage_format(config: AgentConfigCreate | AgentConfigUpdate) -> Dict[str, Any]:
"""
将配置对象转换为数据库存储格式
Args:
config: AgentConfigCreate 或 AgentConfigUpdate 对象
Returns:
包含数据库字段的字典
"""
result = {}
# 1. 模型参数配置
if hasattr(config, 'model_parameters') and config.model_parameters:
result["model_parameters"] = config.model_parameters.model_dump()
# 2. 知识库检索配置
if config.knowledge_retrieval:
result["knowledge_retrieval"] = config.knowledge_retrieval.model_dump()
# 3. 记忆配置
if hasattr(config, 'memory') and config.memory:
result["memory"] = config.memory.model_dump()
# 4. 变量配置
if hasattr(config, 'variables') and config.variables:
result["variables"] = [var.model_dump() for var in config.variables]
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = {
name: tool.model_dump()
for name, tool in config.tools.items()
}
return result
@staticmethod
def from_storage_format(
model_parameters: Optional[Dict[str, Any]],
knowledge_retrieval: Optional[Dict[str, Any]],
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
Args:
model_parameters: 模型参数配置
knowledge_retrieval: 知识库检索配置
memory: 记忆配置
variables: 变量配置
tools: 工具配置
Returns:
包含 Pydantic 对象的字典
"""
result = {
"model_parameters": None,
"knowledge_retrieval": None,
"memory": MemoryConfig(enabled=True),
"variables": [],
"tools": {},
}
# 1. 解析模型参数配置
if model_parameters:
from app.schemas.app_schema import ModelParameters
result["model_parameters"] = ModelParameters(**model_parameters)
# 2. 解析知识库检索配置
if knowledge_retrieval:
result["knowledge_retrieval"] = KnowledgeRetrievalConfig(**knowledge_retrieval)
else:
# 提供默认的知识库配置(空列表)
result["knowledge_retrieval"] = KnowledgeRetrievalConfig(
knowledge_bases=[],
merge_strategy="weighted"
)
# 3. 解析记忆配置
if memory:
result["memory"] = MemoryConfig(**memory)
# 4. 解析变量配置
if variables:
result["variables"] = [VariableDefinition(**var) for var in variables]
# 5. 解析工具配置
if tools:
result["tools"] = {
name: ToolConfig(**tool_data)
for name, tool_data in tools.items()
}
return result

View File

@@ -0,0 +1,38 @@
"""
Agent 配置辅助函数
用于增强 AgentConfig 对象,添加解析后的字段
"""
from app.models import AgentConfig
from app.services.agent_config_converter import AgentConfigConverter
def enrich_agent_config(agent_cfg: AgentConfig) -> AgentConfig:
"""
增强 AgentConfig 对象,添加解析后的配置字段
Args:
agent_cfg: AgentConfig ORM 对象
Returns:
增强后的 AgentConfig 对象(添加了解析字段)
"""
if not agent_cfg:
return agent_cfg
# 解析数据库存储格式
parsed = AgentConfigConverter.from_storage_format(
model_parameters=agent_cfg.model_parameters,
knowledge_retrieval=agent_cfg.knowledge_retrieval,
memory=agent_cfg.memory,
variables=agent_cfg.variables,
tools=agent_cfg.tools,
)
# 将解析后的字段添加到对象上(用于序列化)
agent_cfg.model_parameters = parsed["model_parameters"]
agent_cfg.knowledge_retrieval = parsed["knowledge_retrieval"]
agent_cfg.memory = parsed["memory"]
agent_cfg.variables = parsed["variables"]
agent_cfg.tools = parsed["tools"]
return agent_cfg

View File

@@ -0,0 +1,191 @@
"""Agent 注册表服务"""
import uuid
from typing import Optional, List, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import select, or_, and_
from app.models import AgentConfig, App
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AgentRegistry:
"""Agent 注册表 - 管理所有可用的 Agent"""
def __init__(self, db: Session):
self.db = db
self._cache: Dict[str, Dict[str, Any]] = {}
def register_agent(self, agent: AgentConfig) -> None:
"""注册 Agent 到系统
Args:
agent: Agent 配置对象
"""
agent_info = self._to_agent_info(agent)
self._cache[str(agent.id)] = agent_info
logger.info(
f"Agent 注册成功",
extra={
"agent_id": str(agent.id),
"name": agent.app.name,
"domain": agent.agent_domain
}
)
def discover_agents(
self,
query: Optional[str] = None,
domain: Optional[str] = None,
capabilities: Optional[List[str]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> List[Dict[str, Any]]:
"""发现可用的 Agent
Args:
query: 搜索关键词
domain: 专业领域
capabilities: 所需能力列表
workspace_id: 工作空间ID权限过滤
Returns:
匹配的 Agent 列表
"""
# 构建查询
stmt = select(AgentConfig).join(App).where(
AgentConfig.is_active == True,
App.is_active == True
)
# 工作空间过滤(同工作空间或公开)
if workspace_id:
stmt = stmt.where(
or_(
App.workspace_id == workspace_id,
App.visibility == "public"
)
)
# 领域过滤
if domain:
stmt = stmt.where(AgentConfig.agent_domain == domain)
# 能力过滤
if capabilities:
# PostgreSQL JSON 数组包含查询
for cap in capabilities:
stmt = stmt.where(
AgentConfig.capabilities.contains([cap])
)
# 关键词搜索
if query:
stmt = stmt.where(
or_(
App.name.ilike(f"%{query}%"),
App.description.ilike(f"%{query}%")
)
)
agents = self.db.scalars(stmt).all()
logger.debug(
f"Agent 发现",
extra={
"query": query,
"domain": domain,
"capabilities": capabilities,
"found_count": len(agents)
}
)
return [self._to_agent_info(agent) for agent in agents]
def get_agent(self, agent_id: uuid.UUID) -> Optional[Dict[str, Any]]:
"""获取 Agent 信息
Args:
agent_id: Agent ID
Returns:
Agent 信息字典,如果不存在返回 None
"""
agent_id_str = str(agent_id)
# 先查缓存
if agent_id_str in self._cache:
return self._cache[agent_id_str]
# 查数据库
agent = self.db.get(AgentConfig, agent_id)
if agent and agent.is_active:
agent_info = self._to_agent_info(agent)
self._cache[agent_id_str] = agent_info
return agent_info
return None
def _to_agent_info(self, agent: AgentConfig) -> Dict[str, Any]:
"""转换为 Agent 信息字典
Args:
agent: Agent 配置对象
Returns:
Agent 信息字典
"""
return {
"id": str(agent.id),
"name": agent.app.name,
"description": agent.app.description,
"domain": agent.agent_domain,
"role": agent.agent_role,
"capabilities": agent.capabilities or [],
"tools": list(agent.tools.keys()) if agent.tools else [],
"knowledge_bases": self._extract_kb_ids(agent),
"system_prompt": self._truncate_prompt(agent.system_prompt),
"status": "active" if agent.is_active else "inactive",
"workspace_id": str(agent.app.workspace_id),
"visibility": agent.app.visibility
}
def _extract_kb_ids(self, agent: AgentConfig) -> List[str]:
"""提取知识库 ID 列表
Args:
agent: Agent 配置对象
Returns:
知识库 ID 列表
"""
if not agent.knowledge_retrieval:
return []
kb_config = agent.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
return [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
def _truncate_prompt(self, prompt: Optional[str], max_length: int = 200) -> Optional[str]:
"""截断提示词
Args:
prompt: 提示词
max_length: 最大长度
Returns:
截断后的提示词
"""
if not prompt:
return None
if len(prompt) <= max_length:
return prompt
return prompt[:max_length] + "..."
def clear_cache(self) -> None:
"""清空缓存"""
self._cache.clear()
logger.debug("Agent 注册表缓存已清空")

View File

@@ -0,0 +1,130 @@
from typing import Any, List
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import InMemorySaver
from pydantic import BaseModel
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model
from langchain.tools import tool
from langchain_core.messages import RemoveMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from app.services.api_resquests_server import send_message, model, retrieval
class config(BaseModel):
template_str:str
params:dict
model_configs: List[dict] = []
history_memory:bool
knowledge_base:bool
class RemoryInput(BaseModel):
question: str
end_user_id: str
search_switch:str
class ChatRequest(BaseModel):
end_user_id: str
message: str
search_switch:str
kb_ids: List[str] = []
similarity_threshold:float
vector_similarity_weight:float
top_k:int
hybrid:bool
token:str
class RetrievalInput(BaseModel):
message: str
kb_ids: List[str] = []
similarity_threshold: float
vector_similarity_weight: float
top_k: int
hybrid: bool
token: str
async def tool_Retrieval(req):
tool_result = retrieval_search.invoke({
"message":req.message, "kb_ids":req.kb_ids,
"similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight,
"top_k":req.top_k, "hybrid":req.hybrid, "token":req.token
})
return tool_result
async def tool_memory(req):
tool_result = remory_sk.invoke({
"question": req.message,
"end_user_id": req.end_user_id,
"search_switch": req.search_switch
})
return tool_result
@before_model
# ========== 消息剪枝中间件 ==========
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""保留前1条 + 最近3~4条消息"""
messages = state["messages"]
if len(messages) <= 10:
return None
first_msg = messages[0]
recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:]
new_messages = [first_msg] + recent_messages
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages
]
}
##-----------历史记忆------------
@ tool(args_schema=RemoryInput)
def remory_sk(question: str, end_user_id: str, search_switch: str):
"""
条件调用工具:
- 仅当 question 是疑问句时调用 send_message
- 根据 end_user_id 和 search_switch 参数决定是否执行检索
Args:
question: 用户的提问内容
end_user_id: 用户唯一标识符
search_switch: 搜索开关,控制是否执行检索
Returns:
检索结果或空字符串
"""
# 移除关于 configurable 的描述,避免混淆
if not end_user_id or end_user_id == '123':
print("警告: 无效的 user_id 参数")
return ''
if search_switch in ['on', 'off'] or not search_switch:
print("警告: 无效的 search_switch 参数")
return ''
return send_message(end_user_id, question, '[]', search_switch)
#-------------检索------------
@ tool(args_schema=RetrievalInput)
def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token):
'''检索'''
search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token)
return search
async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str):
"""根据模型名动态创建代理"""
model_name, api_key, api_base=await model(model_id,token)
llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key)
memory = InMemorySaver()
return create_agent(
llm,
tools=[remory_sk,retrieval_search],
middleware=[trim_messages],
checkpointer=memory,
system_prompt=PROMPT
)

View File

@@ -0,0 +1,331 @@
"""Agent 发现和调用工具"""
import uuid
import time
import datetime
from typing import Optional, Dict, Any, List
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.models import AgentConfig, ModelConfig, AgentInvocation
from app.services.agent_registry import AgentRegistry
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.repositories import workspace_repository, knowledge_repository
logger = get_business_logger()
# ==================== Agent 发现工具 ====================
class AgentDiscoveryInput(BaseModel):
"""Agent 发现工具输入参数"""
query: Optional[str] = Field(None, description="搜索关键词,如:'客服''技术支持'")
domain: Optional[str] = Field(None, description="专业领域,如:'customer_service''technical_support'")
capabilities: Optional[List[str]] = Field(None, description="所需能力列表,如:['退货处理', '订单查询']")
def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID):
"""创建 Agent 发现工具
Args:
registry: Agent 注册表
workspace_id: 当前工作空间 ID
Returns:
Agent 发现工具
"""
@tool(args_schema=AgentDiscoveryInput)
def discover_agents(
query: Optional[str] = None,
domain: Optional[str] = None,
capabilities: Optional[List[str]] = None
) -> str:
"""发现系统中可用的 Agent。当需要找到能够处理特定任务的 Agent 时使用此工具。
Args:
query: 搜索关键词(如:"客服""技术支持"
domain: 专业领域(如:"customer_service""technical_support"
capabilities: 所需能力(如:["退货处理", "订单查询"]
Returns:
可用 Agent 的列表和描述
"""
try:
agents = registry.discover_agents(
query=query,
domain=domain,
capabilities=capabilities,
workspace_id=workspace_id
)
if not agents:
return "未找到匹配的 Agent"
# 格式化输出
result = f"找到 {len(agents)} 个可用的 Agent\n\n"
for i, agent in enumerate(agents, 1):
result += f"{i}. {agent['name']}\n"
result += f" ID: {agent['id']}\n"
if agent['description']:
result += f" 描述: {agent['description']}\n"
if agent['domain']:
result += f" 领域: {agent['domain']}\n"
if agent['capabilities']:
result += f" 能力: {', '.join(agent['capabilities'])}\n"
if agent['tools']:
result += f" 工具: {', '.join(agent['tools'])}\n"
result += "\n"
logger.info(
f"Agent 发现成功",
extra={
"query": query,
"domain": domain,
"found_count": len(agents)
}
)
return result
except Exception as e:
logger.error(f"Agent 发现失败", extra={"error": str(e)})
return f"发现 Agent 失败: {str(e)}"
return discover_agents
# ==================== Agent 调用工具 ====================
class AgentInvocationInput(BaseModel):
"""Agent 调用工具输入参数"""
agent_id: str = Field(..., description="要调用的 Agent ID通过 discover_agents 工具获取)")
message: str = Field(..., description="发送给 Agent 的消息或任务描述")
context: Optional[Dict[str, Any]] = Field(None, description="可选的上下文信息(如:用户信息、历史记录等)")
def create_agent_invocation_tool(
db: Session,
registry: AgentRegistry,
workspace_id: uuid.UUID,
current_agent_id: uuid.UUID,
conversation_id: Optional[uuid.UUID] = None,
parent_invocation_id: Optional[uuid.UUID] = None,
invocation_chain: Optional[List[uuid.UUID]] = None
):
"""创建 Agent 调用工具
Args:
db: 数据库会话
registry: Agent 注册表
workspace_id: 当前工作空间 ID
current_agent_id: 当前 Agent ID
conversation_id: 会话 ID
parent_invocation_id: 父调用 ID
invocation_chain: 调用链(用于检测循环调用)
Returns:
Agent 调用工具
"""
# 1. 获取工作空间的 storage_type
storage_type = 'neo4j' # 默认值
user_rag_memory_id = None
try:
workspace = workspace_repository.get_workspace_by_id(db, workspace_id)
if workspace and workspace.storage_type:
storage_type = workspace.storage_type
logger.debug(
f"获取工作空间存储类型成功",
extra={
"workspace_id": str(workspace_id),
"storage_type": storage_type
}
)
except Exception as e:
logger.warning(
f"获取工作空间存储类型失败,使用默认值 neo4j",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
# 2. 如果 storage_type 是 rag获取知识库 ID
if storage_type == 'rag':
try:
knowledge = knowledge_repository.get_knowledge_by_name(
db=db,
name="USER_RAG_MEMORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
logger.debug(
f"获取 RAG 知识库成功",
extra={
"workspace_id": str(workspace_id),
"knowledge_id": user_rag_memory_id
}
)
else:
logger.warning(
f"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id)}
)
storage_type = 'neo4j'
except Exception as e:
logger.warning(
f"获取 RAG 知识库失败,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
storage_type = 'neo4j'
if invocation_chain is None:
invocation_chain = []
@tool(args_schema=AgentInvocationInput)
async def invoke_agent(
agent_id: str,
message: str,
context: Optional[Dict[str, Any]] = None
) -> str:
"""调用另一个 Agent 来处理任务。当当前 Agent 无法处理某个任务,需要其他专业 Agent 帮助时使用。
Args:
agent_id: 要调用的 Agent ID通过 discover_agents 工具获取)
message: 发送给 Agent 的消息或任务描述
context: 可选的上下文信息(如:用户信息、历史记录等)
Returns:
被调用 Agent 的响应结果
"""
try:
# 1. 验证 Agent 存在
agent_uuid = uuid.UUID(agent_id)
agent_info = registry.get_agent(agent_uuid)
if not agent_info:
return f"Agent {agent_id} 不存在"
# 2. 验证权限(同工作空间或公开)
if agent_info["workspace_id"] != str(workspace_id) and agent_info["visibility"] != "public":
return f"无权访问 Agent {agent_info['name']}"
# 3. 防止自己调用自己
if agent_id == str(current_agent_id):
return "不能调用自己"
# 4. 防止循环调用
if agent_uuid in invocation_chain:
return f"检测到循环调用:{agent_info['name']} 已在调用链中"
# 5. 检查调用深度
max_depth = 5
if len(invocation_chain) >= max_depth:
return f"调用深度超过限制(最大 {max_depth} 层)"
# 6. 获取 Agent 配置
agent_config = db.get(AgentConfig, agent_uuid)
if not agent_config:
return f"Agent 配置不存在"
# 7. 获取模型配置
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
if not model_config:
return f"Agent 模型配置不存在"
# 8. 创建调用记录
invocation = AgentInvocation(
caller_agent_id=current_agent_id,
callee_agent_id=agent_uuid,
conversation_id=conversation_id,
parent_invocation_id=parent_invocation_id,
input_message=message,
context=context,
status="running",
started_at=datetime.datetime.now()
)
db.add(invocation)
db.commit()
db.refresh(invocation)
logger.info(
f"Agent 调用开始",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"depth": len(invocation_chain)
}
)
start_time = time.time()
try:
# 9. 调用 Agent
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
workspace_id=workspace_id,
variables=context or {},
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
elapsed_time = time.time() - start_time
# 10. 更新调用记录
invocation.status = "completed"
invocation.output_message = result["message"]
invocation.completed_at = datetime.datetime.now()
invocation.elapsed_time = elapsed_time
invocation.token_usage = result.get("usage", {})
db.commit()
logger.info(
f"Agent 调用成功",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"elapsed_time": elapsed_time
}
)
return result["message"]
except Exception as e:
# 更新调用记录为失败
invocation.status = "failed"
invocation.error_message = str(e)
invocation.completed_at = datetime.datetime.now()
invocation.elapsed_time = time.time() - start_time
db.commit()
logger.error(
f"Agent 调用失败",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"error": str(e)
}
)
raise
except Exception as e:
logger.error(
f"Agent 调用异常",
extra={
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,
"error": str(e)
}
)
return f"调用 Agent 失败: {str(e)}"
return invoke_agent

View File

@@ -0,0 +1,173 @@
"""API Key Service"""
from sqlalchemy.orm import Session
from typing import Optional, Tuple, List
import uuid
import datetime
import math
from app.models.api_key_model import ApiKey, ApiKeyType
from app.repositories.api_key_repository import ApiKeyRepository
from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta
from app.core.api_key_utils import generate_api_key
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ApiKeyService:
"""API Key 业务逻辑服务"""
@staticmethod
def create_api_key(
db: Session,
*,
workspace_id: uuid.UUID,
user_id: uuid.UUID,
data: api_key_schema.ApiKeyCreate
) -> Tuple[ApiKey, str]:
"""创建 API Key
Returns:
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
"""
# 生成 API Key
api_key, key_hash, key_prefix = generate_api_key(data.type)
# 创建数据
api_key_data = {
"id": uuid.uuid4(),
"name": data.name,
"description": data.description,
"key_prefix": key_prefix,
"key_hash": key_hash,
"type": data.type,
"scopes": data.scopes,
"workspace_id": workspace_id,
"resource_id": data.resource_id,
"resource_type": data.resource_type,
"rate_limit": data.rate_limit,
"quota_limit": data.quota_limit,
"expires_at": data.expires_at,
"created_by": user_id,
"created_at": datetime.datetime.now(),
"updated_at": datetime.datetime.now(),
}
api_key_obj = ApiKeyRepository.create(db, api_key_data)
db.commit()
db.refresh(api_key_obj)
logger.info(f"API Key 创建成功", extra={
"api_key_id": str(api_key_obj.id),
"name": data.name,
"type": data.type
})
return api_key_obj, api_key
@staticmethod
def get_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> ApiKey:
"""获取 API Key"""
api_key = ApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException("API Key 不存在", BizCode.NOT_FOUND)
if api_key.workspace_id != workspace_id:
raise BusinessException("无权访问此 API Key", BizCode.FORBIDDEN)
return api_key
@staticmethod
def list_api_keys(
db: Session,
workspace_id: uuid.UUID,
query: api_key_schema.ApiKeyQuery
) -> PageData:
"""列出 API Keys"""
items, total = ApiKeyRepository.list_by_workspace(db, workspace_id, query)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=query.page,
pagesize=query.pagesize,
total=total,
hasnext=query.page < pages
),
items=[api_key_schema.ApiKey.model_validate(item) for item in items]
)
@staticmethod
def update_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID,
data: api_key_schema.ApiKeyUpdate
) -> ApiKey:
"""更新 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)
db.commit()
db.refresh(api_key)
logger.info(f"API Key 更新成功", extra={"api_key_id": str(api_key_id)})
return api_key
@staticmethod
def delete_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> bool:
"""删除 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
ApiKeyRepository.delete(db, api_key_id)
db.commit()
logger.info(f"API Key 删除成功", extra={"api_key_id": str(api_key_id)})
return True
@staticmethod
def regenerate_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Tuple[ApiKey, str]:
"""重新生成 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
# 生成新的 API Key
new_api_key, key_hash, key_prefix = generate_api_key(ApiKeyType(api_key.type))
# 更新
ApiKeyRepository.update(db, api_key_id, {
"key_hash": key_hash,
"key_prefix": key_prefix
})
db.commit()
db.refresh(api_key)
logger.info(f"API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
return api_key, new_api_key
@staticmethod
def get_stats(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> api_key_schema.ApiKeyStats:
"""获取使用统计"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
return api_key_schema.ApiKeyStats(**stats_data)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
from sqlalchemy.orm import Session
from typing import Optional, Tuple, Union
import jwt
import time
from app.models.user_model import User
from app.repositories import user_repository
from app.core.security import verify_password
from app.core.config import settings
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
# Token 配置
TOKEN_SECRET_KEY = settings.SECRET_KEY
TOKEN_ALGORITHM = "HS256"
def authenticate_user(db: Session, email: str, password: str) -> Optional[User]:
"""
Authenticates a user.
:param db: The database session.
:param email: The email.
:param password: The password.
:return: The user object if authentication is successful, otherwise None.
"""
user = user_repository.get_user_by_email(db, email=email)
if not user:
return None # User not found
if not user.is_active:
return None # User is inactive
if not verify_password(password, user.hashed_password):
return None # Incorrect password
return user # Authentication successful
def authenticate_user_with_status(db: Session, email: str, password: str) -> Tuple[bool, Optional[User], str]:
"""
认证用户并返回详细状态(用于需要区分不同失败原因的场景)
:param db: 数据库会话
:param email: 用户邮箱
:param password: 用户密码
:return: (认证成功, 用户对象, 状态消息)
状态消息: "success", "user_not_found", "user_inactive", "password_incorrect"
"""
from app.core.logging_config import get_auth_logger
logger = get_auth_logger()
# 查找用户
user = user_repository.get_user_by_email(db, email=email)
if not user:
logger.warning(f"用户不存在: {email}")
return (False, None, "user_not_found")
# 检查用户状态
if not user.is_active:
logger.warning(f"用户未激活: {email}")
return (False, user, "user_inactive")
# 验证密码
if not verify_password(password, user.hashed_password):
logger.warning(f"密码错误: {email}")
return (False, user, "password_incorrect")
logger.info(f"用户认证成功: {email}")
return (True, user, "success")
def authenticate_user_or_raise(db: Session, email: str, password: str) -> User:
"""
认证用户,失败时抛出异常(推荐使用)
:param db: 数据库会话
:param email: 用户邮箱
:param password: 用户密码
:return: 用户对象
:raises BusinessException: 认证失败时抛出
"""
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_auth_logger
logger = get_auth_logger()
# 查找用户
user = user_repository.get_user_by_email(db, email=email)
if not user:
logger.warning(f"用户不存在: {email}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 检查用户状态
if not user.is_active:
logger.warning(f"用户未激活: {email}")
raise BusinessException("用户未激活", code=BizCode.USER_NOT_FOUND)
# 验证密码
if not verify_password(password, user.hashed_password):
logger.warning(f"密码错误: {email}")
raise BusinessException("密码错误", code=BizCode.PASSWORD_ERROR)
logger.info(f"用户认证成功: {email}")
return user
def get_user_by_username(db: Session, username: str) -> Optional[User]:
"""
Get a user by username.
:param db: The database session.
:param username: The username.
:return: The user object if found, otherwise None.
"""
return user_repository.get_user_by_username(db, username=username)
def get_user_by_id(db: Session, user_id: str) -> Optional[User]:
"""
Get a user by user_id.
:param db: The database session.
:param user_id: The user id (UUID string).
:return: The user object if found, otherwise None.
"""
return user_repository.get_user_by_id(db, user_id=user_id)
def register_user_with_invite(
db: Session,
email: str,
password: str,
invite_token: str,
workspace_id: str
) -> User:
"""
使用邀请码注册新用户并加入工作空间
:param db: 数据库会话
:param email: 用户邮箱
:param password: 用户密码
:param invite_token: 邀请令牌
:param workspace_id: 工作空间ID
:return: 创建的用户对象
"""
from app.schemas.user_schema import UserCreate
from app.schemas.workspace_schema import InviteAcceptRequest
from app.services import user_service, workspace_service
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.info(f"使用邀请码注册用户: {email}")
try:
# 创建用户
user_create = UserCreate(
email=email,
password=password,
username=email.split('@')[0]
)
user = user_service.create_user(db=db, user=user_create)
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit
invite_accept = InviteAcceptRequest(token=invite_token)
workspace_service.accept_workspace_invite(db, invite_accept, user)
logger.info(f"用户接受邀请成功")
# 重新查询用户对象以确保获取最新状态
from app.repositories import user_repository
user = user_repository.get_user_by_id(db, str(user.id))
# 设置当前工作空间
user.current_workspace_id = workspace_id
db.commit()
db.refresh(user)
logger.info(f"用户注册并加入工作空间成功: {user.email}, workspace_id: {user.current_workspace_id}")
return user
except Exception as e:
db.rollback()
logger.error(f"注册用户失败: {email} - {str(e)}")
raise
def bind_workspace_with_invite(
db: Session,
user: User,
invite_token: str,
workspace_id: str
) -> User:
from app.schemas.user_schema import UserCreate
from app.schemas.workspace_schema import InviteAcceptRequest
from app.services import user_service, workspace_service
from app.core.logging_config import get_business_logger
logger = get_business_logger()
try:
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit
invite_accept = InviteAcceptRequest(token=invite_token)
workspace_service.accept_workspace_invite(db, invite_accept, user)
logger.info(f"用户接受邀请成功")
# 重新查询用户对象以确保获取最新状态
from app.repositories import user_repository
user = user_repository.get_user_by_id(db, str(user.id))
# 设置当前工作空间
user.current_workspace_id = workspace_id
db.commit()
db.refresh(user)
return user
except Exception as e:
db.rollback()
logger.error(f"绑定工作空间失败: user={user.email} - {str(e)}")
raise
def create_access_token(user_id: str, share_token: str) -> str:
"""创建访问 token
Token 不设置过期时间,只要 share_token 有效token 就有效
Args:
user_id: 用户 ID
share_token: 分享 token
Returns:
JWT token
"""
payload = {
"user_id": user_id,
"share_token": share_token,
"iat": int(time.time()) # 签发时间
}
token = jwt.encode(payload, TOKEN_SECRET_KEY, algorithm=TOKEN_ALGORITHM)
return token
def decode_access_token(token: str) -> dict:
"""解码访问 token
Args:
token: JWT token
Returns:
包含 user_id 和 share_token 的字典
Raises:
BusinessException: token 无效
"""
try:
payload = jwt.decode(token, TOKEN_SECRET_KEY, algorithms=[TOKEN_ALGORITHM])
return {
"user_id": payload["user_id"],
"share_token": payload["share_token"]
}
except jwt.InvalidTokenError:
raise BusinessException("无效的访问 token", BizCode.INVALID_TOKEN)

View File

@@ -0,0 +1,229 @@
"""会话服务"""
import uuid
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import select, desc
from app.models import Conversation, Message
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ConversationService:
"""会话服务"""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
title: Optional[str] = None,
is_draft: bool = False,
config_snapshot: Optional[dict] = None
) -> Conversation:
"""创建会话"""
conversation = Conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
title=title or "新会话",
is_draft=is_draft,
config_snapshot=config_snapshot
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info(
f"创建会话成功",
extra={
"conversation_id": str(conversation.id),
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"is_draft": is_draft
}
)
return conversation
def get_conversation(
self,
conversation_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
) -> Conversation:
"""获取会话"""
stmt = select(Conversation).where(Conversation.id == conversation_id)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
conversation = self.db.scalars(stmt).first()
if not conversation:
raise ResourceNotFoundException("会话", str(conversation_id))
return conversation
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
is_draft: Optional[bool] = None,
page: int = 1,
pagesize: int = 20
) -> Tuple[List[Conversation], int]:
"""列出会话"""
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active == True
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# 总数
count_stmt = stmt.with_only_columns(Conversation.id)
total = len(self.db.execute(count_stmt).all())
# 分页
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
return conversations, total
def add_message(
self,
conversation_id: uuid.UUID,
role: str,
content: str,
meta_data: Optional[dict] = None
) -> Message:
"""添加消息"""
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
meta_data=meta_data
)
self.db.add(message)
# 更新会话的消息计数和更新时间
conversation = self.get_conversation(conversation_id)
conversation.message_count += 1
# 如果是第一条用户消息,可以用它作为标题
if conversation.message_count == 1 and role == "user":
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
self.db.commit()
self.db.refresh(message)
return message
def get_messages(
self,
conversation_id: uuid.UUID,
limit: Optional[int] = None
) -> List[Message]:
"""获取会话消息"""
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
if limit:
stmt = stmt.limit(limit)
messages = list(self.db.scalars(stmt).all())
return messages
def get_conversation_history(
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None
) -> List[dict]:
"""获取会话历史消息
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
Returns:
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
"""
messages = self.get_messages(conversation_id, limit=max_history)
# 转换为字典格式
history = [
{
"role": msg.role,
"content": msg.content
}
for msg in messages
]
return history
def save_conversation_messages(
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str
):
"""保存会话消息(用户消息和助手回复)"""
# 添加用户消息
self.add_message(
conversation_id=conversation_id,
role="user",
content=user_message
)
# 添加助手消息
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
)
logger.debug(
f"保存会话消息成功",
extra={
"conversation_id": str(conversation_id),
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
def delete_conversation(
self,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID
):
"""删除会话(软删除)"""
conversation = self.get_conversation(conversation_id, workspace_id)
conversation.is_active = False
self.db.commit()
logger.info(
f"删除会话成功",
extra={
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)
}
)

View File

@@ -0,0 +1,261 @@
"""会话状态管理器 - 解决多轮对话路由错乱"""
import json
from typing import Optional, Dict, Any, List
from datetime import datetime
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ConversationStateManager:
"""会话状态管理器
用于管理多轮对话中的会话状态,包括:
- 当前使用的 Agent
- 路由历史
- 主题追踪
- Agent 切换统计
"""
def __init__(self, storage_backend: Optional[Any] = None):
"""初始化状态管理器
Args:
storage_backend: 存储后端Redis/内存等)
"""
self.storage = storage_backend or InMemoryStorage()
self.ttl = 3600 # 1小时过期
def get_state(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话状态
Args:
conversation_id: 会话 ID
Returns:
会话状态字典
"""
state = self.storage.get(f"conv_state:{conversation_id}")
if not state:
logger.info(f"创建新会话状态: {conversation_id}")
return self._create_new_state(conversation_id)
return state
def update_state(
self,
conversation_id: str,
agent_id: str,
message: str,
topic: Optional[str] = None,
confidence: float = 1.0
) -> Dict[str, Any]:
"""更新会话状态
Args:
conversation_id: 会话 ID
agent_id: 当前 Agent ID
message: 用户消息
topic: 消息主题
confidence: 路由置信度
Returns:
更新后的状态
"""
state = self.get_state(conversation_id)
# 检测 Agent 切换
agent_changed = False
if state["current_agent_id"] and state["current_agent_id"] != agent_id:
agent_changed = True
state["switch_count"] += 1
state["previous_agent_id"] = state["current_agent_id"]
state["same_agent_turns"] = 0
logger.info(
f"Agent 切换",
extra={
"conversation_id": conversation_id,
"from": state["current_agent_id"],
"to": agent_id,
"switch_count": state["switch_count"]
}
)
else:
state["same_agent_turns"] += 1
# 更新当前 Agent
state["current_agent_id"] = agent_id
state["last_message"] = message
state["last_topic"] = topic
state["updated_at"] = datetime.now().isoformat()
# 添加到历史
history_item = {
"message": message[:100], # 截断长消息
"agent_id": agent_id,
"topic": topic,
"confidence": confidence,
"agent_changed": agent_changed,
"timestamp": datetime.now().isoformat()
}
state["routing_history"].append(history_item)
# 保持最近 10 条历史
if len(state["routing_history"]) > 10:
state["routing_history"] = state["routing_history"][-10:]
# 保存状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
def clear_state(self, conversation_id: str) -> None:
"""清除会话状态
Args:
conversation_id: 会话 ID
"""
self.storage.delete(f"conv_state:{conversation_id}")
logger.info(f"清除会话状态: {conversation_id}")
def get_routing_history(
self,
conversation_id: str,
limit: int = 10
) -> List[Dict[str, Any]]:
"""获取路由历史
Args:
conversation_id: 会话 ID
limit: 返回数量限制
Returns:
路由历史列表
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
return history[-limit:] if history else []
def get_statistics(self, conversation_id: str) -> Dict[str, Any]:
"""获取会话统计信息
Args:
conversation_id: 会话 ID
Returns:
统计信息
"""
state = self.get_state(conversation_id)
history = state.get("routing_history", [])
# 统计各 Agent 使用次数
agent_usage = {}
for item in history:
agent_id = item["agent_id"]
agent_usage[agent_id] = agent_usage.get(agent_id, 0) + 1
# 统计主题分布
topic_distribution = {}
for item in history:
topic = item.get("topic", "未知")
topic_distribution[topic] = topic_distribution.get(topic, 0) + 1
return {
"conversation_id": conversation_id,
"total_turns": len(history),
"switch_count": state.get("switch_count", 0),
"current_agent_id": state.get("current_agent_id"),
"same_agent_turns": state.get("same_agent_turns", 0),
"agent_usage": agent_usage,
"topic_distribution": topic_distribution,
"created_at": state.get("created_at"),
"updated_at": state.get("updated_at")
}
def _create_new_state(self, conversation_id: str) -> Dict[str, Any]:
"""创建新的会话状态
Args:
conversation_id: 会话 ID
Returns:
新的状态字典
"""
state = {
"conversation_id": conversation_id,
"current_agent_id": None,
"previous_agent_id": None,
"routing_history": [],
"last_message": None,
"last_topic": None,
"switch_count": 0,
"same_agent_turns": 0,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat()
}
# 保存初始状态
self.storage.set(
f"conv_state:{conversation_id}",
state,
ttl=self.ttl
)
return state
class InMemoryStorage:
"""内存存储后端(用于开发和测试)"""
def __init__(self):
self._storage: Dict[str, Dict[str, Any]] = {}
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
return self._storage.get(key)
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self._storage[key] = value
def delete(self, key: str) -> None:
"""删除数据"""
if key in self._storage:
del self._storage[key]
def clear(self) -> None:
"""清空所有数据"""
self._storage.clear()
class RedisStorage:
"""Redis 存储后端(用于生产环境)"""
def __init__(self, redis_client):
"""初始化 Redis 存储
Args:
redis_client: Redis 客户端实例
"""
self.redis = redis_client
def get(self, key: str) -> Optional[Dict[str, Any]]:
"""获取数据"""
data = self.redis.get(key)
if data:
return json.loads(data)
return None
def set(self, key: str, value: Dict[str, Any], ttl: int = 3600) -> None:
"""设置数据"""
self.redis.setex(key, ttl, json.dumps(value))
def delete(self, key: str) -> None:
"""删除数据"""
self.redis.delete(key)

View File

@@ -0,0 +1,85 @@
import uuid
from sqlalchemy.orm import Session
from app.models.user_model import User
from app.models.document_model import Document
from app.schemas.document_schema import DocumentCreate, DocumentUpdate
from app.repositories import document_repository
from app.core.logging_config import get_business_logger
# Obtain a dedicated logger for business logic
business_logger = get_business_logger()
def get_documents_paginated(
db: Session,
current_user: User,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
business_logger.debug(f"Query document in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
try:
total, items = document_repository.get_documents_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc
)
business_logger.info(f"The document paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
return total, items
except Exception as e:
business_logger.error(f"Querying document pagination failed: username={current_user.username} - {str(e)}")
raise
def create_document(
db: Session, document: DocumentCreate, current_user: User
) -> Document:
business_logger.info(f"Create a document: {document.file_name}, creator: {current_user.username}")
try:
document.created_by = current_user.id
db_document = document_repository.create_document(
db=db, document=document
)
business_logger.info(f"The document has been successfully created: {document.file_name} (ID: {db_document.id}), creator: {current_user.username}")
return db_document
except Exception as e:
business_logger.error(f"Failed to create a document: {document.file_name} - {str(e)}")
raise
def get_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> Document | None:
business_logger.debug(f"Query document based on ID: document_id={document_id}, username: {current_user.username}")
try:
document = document_repository.get_document_by_id(db=db, document_id=document_id)
if document:
business_logger.info(f"document query successful: {document.file_name} (ID: {document_id})")
else:
business_logger.warning(f"document does not exist: document_id={document_id}")
return document
except Exception as e:
business_logger.error(f"Failed to query the document based on the ID: document_id={document_id} - {str(e)}")
raise
def reset_documents_progress_by_kb_id(db: Session, kb_id: uuid.UUID, current_user: User) -> int:
business_logger.debug(f"Reset the processing progress of all documents under the specified knowledge base: kb_id=={kb_id}, username: {current_user.username}")
return document_repository.reset_documents_progress_by_kb_id(db=db, kb_id=kb_id)
def delete_document_by_id(db: Session, document_id: uuid.UUID, current_user: User) -> None:
business_logger.info(f"Delete document: document_id={document_id}, operator: {current_user.username}")
try:
document_repository.delete_document_by_id(db=db, document_id=document_id)
business_logger.info(f"document deleted successfully: document_id={document_id}, operator: {current_user.username}")
except Exception as e:
business_logger.error(f"Failed to delete document: document_id={document_id} - {str(e)}")
raise

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,87 @@
import uuid
from sqlalchemy.orm import Session
from app.models.user_model import User
from app.models.file_model import File
from app.schemas.file_schema import FileCreate, FileUpdate
from app.repositories import file_repository
from app.core.logging_config import get_business_logger
# Obtain a dedicated logger for business logic
business_logger = get_business_logger()
def get_files_paginated(
db: Session,
current_user: User,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
business_logger.debug(f"Query file in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
try:
total, items = file_repository.get_files_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc
)
business_logger.info(f"The file paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
return total, items
except Exception as e:
business_logger.error(f"Querying file pagination failed: username={current_user.username} - {str(e)}")
raise
def create_file(
db: Session, file: FileCreate, current_user: User
) -> File:
business_logger.info(f"Create a file: {file.file_name}, creator: {current_user.username}")
try:
file.created_by = current_user.id
if file.parent_id is None:
file.parent_id = file.kb_id
db_file = file_repository.create_file(
db=db, file=file
)
business_logger.info(f"The file has been successfully created: {file.file_name} (ID: {db_file.id}), creator: {current_user.username}")
return db_file
except Exception as e:
business_logger.error(f"Failed to create a file: {file.file_name} - {str(e)}")
raise
def get_file_by_id(db: Session, file_id: uuid.UUID) -> File | None:
business_logger.debug(f"Query file based on ID: file_id={file_id}")
try:
file = file_repository.get_file_by_id(db=db, file_id=file_id)
if file:
business_logger.info(f"file query successful: {file.file_name} (ID: {file_id})")
else:
business_logger.warning(f"file does not exist: file_id={file_id}")
return file
except Exception as e:
business_logger.error(f"Failed to query the file based on the ID: file_id={file_id} - {str(e)}")
raise
def get_files_by_parent_id(db: Session, parent_id: uuid.UUID | None, current_user: User) -> list | None:
business_logger.debug(f"Query file based on folder ID: parent_id={parent_id}, username: {current_user.username}")
return file_repository.get_files_by_parent_id(db=db, parent_id=parent_id)
def delete_file_by_id(db: Session, file_id: uuid.UUID, current_user: User) -> None:
business_logger.info(f"Delete file: file_id={file_id}, operator: {current_user.username}")
try:
file_repository.delete_file_by_id(db=db, file_id=file_id)
business_logger.info(f"file_id deleted successfully: file_id={file_id}, operator: {current_user.username}")
except Exception as e:
business_logger.error(f"Failed to delete file: file_id={file_id} - {str(e)}")
raise

View File

@@ -0,0 +1,126 @@
import uuid
from sqlalchemy.orm import Session
from app.models.user_model import User
from app.models.knowledge_model import Knowledge
from app.schemas.knowledge_schema import KnowledgeCreate, KnowledgeUpdate
from app.repositories import knowledge_repository
from app.core.logging_config import get_business_logger
# Obtain a dedicated logger for business logic
business_logger = get_business_logger()
def get_knowledges_paginated(
db: Session,
current_user: User,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
business_logger.debug(f"Query knowledge base in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
try:
total, items = knowledge_repository.get_knowledges_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc
)
business_logger.info(f"The knowledge base paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
return total, items
except Exception as e:
business_logger.error(f"Querying knowledge base pagination failed: username={current_user.username} - {str(e)}")
raise
def get_chunded_knowledgeids(
db: Session,
current_user: User,
filters: list
) -> list:
business_logger.debug(f"Query the list of vectorized knowledge base IDs: username={current_user.username}")
try:
items = knowledge_repository.get_chunded_knowledgeids(
db=db,
filters=filters
)
business_logger.info(f"Querying the vectorized knowledge base id list succeeded: username={current_user.username} count={len(items)}")
return items
except Exception as e:
business_logger.error(f"Querying the vectorized knowledge base id list failed: username={current_user.username} - {str(e)}")
raise
def create_knowledge(
db: Session, knowledge: KnowledgeCreate, current_user: User
) -> Knowledge:
business_logger.info(f"Create a knowledge base: {knowledge.name}, creator: {current_user.username}")
try:
knowledge.created_by = current_user.id
if knowledge.workspace_id is None:
knowledge.workspace_id = current_user.current_workspace_id
if knowledge.parent_id is None:
knowledge.parent_id = knowledge.workspace_id
business_logger.debug(f"Start creating the knowledge base: {knowledge.name}")
db_knowledge = knowledge_repository.create_knowledge(
db=db, knowledge=knowledge
)
business_logger.info(f"The knowledge base has been successfully created: {knowledge.name} (ID: {db_knowledge.id}), creator: {current_user.username}")
return db_knowledge
except Exception as e:
business_logger.error(f"Failed to create a knowledge base: {knowledge.name} - {str(e)}")
raise
def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> Knowledge | None:
business_logger.debug(f"Query knowledge base based on ID: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id)
if knowledge:
business_logger.info(f"knowledge base query successful: {knowledge.name} (ID: {knowledge_id})")
else:
business_logger.warning(f"knowledge base does not exist: knowledge_id={knowledge_id}")
return knowledge
except Exception as e:
business_logger.error(f"Failed to query the knowledge base based on the ID: knowledge_id={knowledge_id} - {str(e)}")
raise
def get_knowledge_by_name(db: Session, name: str, current_user: User) -> Knowledge | None:
business_logger.debug(f"Query knowledge base based on name: name={name}, username: {current_user.username}")
try:
knowledge = knowledge_repository.get_knowledge_by_name(db=db, name=name, workspace_id=current_user.current_workspace_id)
if knowledge:
business_logger.info(f"knowledge base query successful: {name} (ID: {knowledge.id})")
else:
business_logger.warning(f"knowledge base does not exist: name={name}")
return knowledge
except Exception as e:
business_logger.error(f"Failed to query the knowledge base based on the name: name={name} - {str(e)}")
raise
def delete_knowledge_by_id(db: Session, knowledge_id: uuid.UUID, current_user: User) -> None:
business_logger.info(f"Delete knowledge base: knowledge_id={knowledge_id}, operator: {current_user.username}")
try:
# First, query the knowledge base information for logging purposes
knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=knowledge_id)
if knowledge:
business_logger.debug(f"Execute knowledge base deletion: {knowledge.name} (ID: {knowledge_id})")
else:
business_logger.warning(f"The knowledge base to be deleted does not exist: knowledge_id={knowledge_id}")
knowledge_repository.delete_knowledge_by_id(db=db, knowledge_id=knowledge_id)
business_logger.info(f"knowledge base record deleted successfully: knowledge_id={knowledge_id}, operator: {current_user.username}")
except Exception as e:
business_logger.error(f"Failed to delete knowledge base: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -0,0 +1,108 @@
import uuid
from sqlalchemy.orm import Session
from app.models.user_model import User
from app.models.knowledgeshare_model import KnowledgeShare
from app.schemas.knowledgeshare_schema import KnowledgeShareCreate
from app.repositories import knowledgeshare_repository
from app.core.logging_config import get_business_logger
# Obtain a dedicated logger for business logic
business_logger = get_business_logger()
def get_knowledgeshares_paginated(
db: Session,
current_user: User,
filters: list,
page: int,
pagesize: int,
orderby: str = None,
desc: bool = False
) -> tuple[int, list]:
business_logger.debug(f"Query knowledge base sharing in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
try:
total, items = knowledgeshare_repository.get_knowledgeshares_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc
)
business_logger.info(f"The knowledge base sharing paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
return total, items
except Exception as e:
business_logger.error(f"Querying knowledge base sharing pagination failed: username={current_user.username} - {str(e)}")
raise
def get_source_kb_ids_by_target_kb_id(
db: Session,
current_user: User,
filters: list
) -> list:
business_logger.debug(f"Query the original knowledge base id list by sharing the knowledge base: username={current_user.username}")
try:
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
business_logger.info(f"Successfully queried the original knowledge base ID list by sharing the knowledge base: username={current_user.username} count={len(items)}")
return items
except Exception as e:
business_logger.error(f"Failed to query the original knowledge base ID list through knowledge base sharing: username={current_user.username} - {str(e)}")
raise
def create_knowledgeshare(
db: Session, knowledgeshare: KnowledgeShareCreate, current_user: User
) -> KnowledgeShare:
business_logger.info(f"Create a knowledge base sharing: creator: {current_user.username}")
try:
knowledgeshare.source_workspace_id = current_user.current_workspace_id
knowledgeshare.shared_by = current_user.id
business_logger.debug("Start creating a knowledge base sharing")
db_knowledgeshare = knowledgeshare_repository.create_knowledgeshare(
db=db, knowledgeshare=knowledgeshare
)
business_logger.info(f"knowledge base sharing created successfully: (ID: {db_knowledgeshare.id}), creator: {current_user.username}")
return db_knowledgeshare
except Exception as e:
business_logger.error(f"Failed to create a knowledge base sharing - {str(e)}")
raise
def get_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> KnowledgeShare | None:
business_logger.debug(f"Query knowledge base sharing based on ID: knowledgeshare_id={knowledgeshare_id}, username: {current_user.username}")
try:
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
if knowledgeshare:
business_logger.info(f"knowledge base sharing query successful: (ID: {knowledgeshare_id})")
else:
business_logger.warning(f"knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
return knowledgeshare
except Exception as e:
business_logger.error(f"Failed to query the knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise
def delete_knowledgeshare_by_id(db: Session, knowledgeshare_id: uuid.UUID, current_user: User) -> None:
business_logger.info(f"Delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}")
try:
# First, query the knowledge base sharing information for logging purposes
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
if knowledgeshare:
business_logger.debug(f"Execute knowledge base sharing deletion: (ID: {knowledgeshare_id})")
else:
business_logger.warning(f"The knowledge base sharing does not exist: knowledgeshare_id={knowledgeshare_id}")
knowledgeshare_repository.delete_knowledgeshare_by_id(db=db, knowledgeshare_id=knowledgeshare_id)
business_logger.info(f"knowledge base sharing deleted successfully: knowledgeshare_id={knowledgeshare_id}, operator: {current_user.username}")
except Exception as e:
business_logger.error(f"Failed to delete knowledge base sharing: knowledgeshare_id={knowledgeshare_id} - {str(e)}")
raise

View File

@@ -0,0 +1,51 @@
import requests
import json
from dotenv import load_dotenv
import os
# 加载.env文件
load_dotenv()
# 读取web_search环境变量
web_search_value = os.getenv('web_search')
def Search(query):
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
api_key = web_search_value
payload = json.dumps({
"messages": [
{
"role": "user",
"content": query
}
], #搜索输入
"edition":"standard", #搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", #使用的搜索引擎版本
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
"search_filter": {
"range": {
"page_time": {
"gte": "now-1w/d", #时间查询参数,大于或等于
"lt": "now/d", #时间查询参数,小于
"gt": "", #时间查询参数,大于
"lte": "" #时间查询参数,小于或等于
}
}
},
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
"search_recency_filter":"week", #根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content":True #是否输出网页完整原文
}, ensure_ascii=False)
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
content=[]
for i in response['references']:
title=i['title']
snippet=i['snippet']
content.append(title+';'+snippet)
content=''.join(content)
return content

View File

@@ -0,0 +1,340 @@
"""LLM 客户端适配器 - 支持多种 LLM 提供商"""
import os
import json
from typing import Optional, Dict, Any
from abc import ABC, abstractmethod
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class BaseLLMClient(ABC):
"""LLM 客户端基类"""
@abstractmethod
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求
Args:
prompt: 提示词
**kwargs: 其他参数
Returns:
LLM 响应文本
"""
pass
class OpenAIClient(BaseLLMClient):
"""OpenAI 客户端"""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "gpt-3.5-turbo",
base_url: Optional[str] = None
):
"""初始化 OpenAI 客户端
Args:
api_key: API 密钥
model: 模型名称
base_url: API 基础 URL可选用于兼容其他服务
"""
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.model = model
self.base_url = base_url
if not self.api_key:
raise ValueError("OpenAI API key 未配置")
try:
from openai import AsyncOpenAI
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.base_url
)
except ImportError:
raise ImportError("请安装 openai 库: pip install openai")
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求
Args:
prompt: 提示词
**kwargs: 其他参数temperature, max_tokens 等)
Returns:
LLM 响应文本
"""
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=kwargs.get("temperature", 0.3),
max_tokens=kwargs.get("max_tokens", 500)
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"OpenAI API 调用失败: {str(e)}")
raise
class AzureOpenAIClient(BaseLLMClient):
"""Azure OpenAI 客户端"""
def __init__(
self,
api_key: Optional[str] = None,
endpoint: Optional[str] = None,
deployment_name: Optional[str] = None,
api_version: str = "2024-02-15-preview"
):
"""初始化 Azure OpenAI 客户端
Args:
api_key: API 密钥
endpoint: Azure 端点
deployment_name: 部署名称
api_version: API 版本
"""
self.api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
self.endpoint = endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
self.deployment_name = deployment_name or os.getenv("AZURE_OPENAI_DEPLOYMENT")
self.api_version = api_version
if not all([self.api_key, self.endpoint, self.deployment_name]):
raise ValueError("Azure OpenAI 配置不完整")
try:
from openai import AsyncAzureOpenAI
self.client = AsyncAzureOpenAI(
api_key=self.api_key,
azure_endpoint=self.endpoint,
api_version=self.api_version
)
except ImportError:
raise ImportError("请安装 openai 库: pip install openai")
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求"""
try:
response = await self.client.chat.completions.create(
model=self.deployment_name,
messages=[{"role": "user", "content": prompt}],
temperature=kwargs.get("temperature", 0.3),
max_tokens=kwargs.get("max_tokens", 500)
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Azure OpenAI API 调用失败: {str(e)}")
raise
class AnthropicClient(BaseLLMClient):
"""Anthropic Claude 客户端"""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "claude-3-sonnet-20240229"
):
"""初始化 Anthropic 客户端
Args:
api_key: API 密钥
model: 模型名称
"""
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.model = model
if not self.api_key:
raise ValueError("Anthropic API key 未配置")
try:
from anthropic import AsyncAnthropic
self.client = AsyncAnthropic(api_key=self.api_key)
except ImportError:
raise ImportError("请安装 anthropic 库: pip install anthropic")
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求"""
try:
response = await self.client.messages.create(
model=self.model,
max_tokens=kwargs.get("max_tokens", 500),
temperature=kwargs.get("temperature", 0.3),
messages=[{"role": "user", "content": prompt}]
)
return response.content[0].text
except Exception as e:
logger.error(f"Anthropic API 调用失败: {str(e)}")
raise
class LocalLLMClient(BaseLLMClient):
"""本地 LLM 客户端(通过 HTTP API"""
def __init__(
self,
base_url: str = "http://localhost:8000",
model: str = "local-model"
):
"""初始化本地 LLM 客户端
Args:
base_url: API 基础 URL
model: 模型名称
"""
self.base_url = base_url
self.model = model
try:
import httpx
self.client = httpx.AsyncClient(timeout=30.0)
except ImportError:
raise ImportError("请安装 httpx 库: pip install httpx")
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求"""
try:
response = await self.client.post(
f"{self.base_url}/v1/chat/completions",
json={
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
"temperature": kwargs.get("temperature", 0.3),
"max_tokens": kwargs.get("max_tokens", 500)
}
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
logger.error(f"本地 LLM API 调用失败: {str(e)}")
raise
class MockLLMClient(BaseLLMClient):
"""模拟 LLM 客户端(用于测试)"""
def __init__(self):
"""初始化模拟客户端"""
self.call_count = 0
async def chat(self, prompt: str, **kwargs) -> str:
"""发送聊天请求(返回模拟结果)"""
self.call_count += 1
logger.info(f"模拟 LLM 调用 (第 {self.call_count} 次)")
# 简单的规则匹配
prompt_lower = prompt.lower()
if "数学" in prompt_lower or "方程" in prompt_lower or "计算" in prompt_lower:
return json.dumps({
"agent_id": "math-agent",
"confidence": 0.9,
"reason": "消息包含数学相关内容"
}, ensure_ascii=False)
elif "化学" in prompt_lower or "反应" in prompt_lower or "元素" in prompt_lower:
return json.dumps({
"agent_id": "chemistry-agent",
"confidence": 0.85,
"reason": "消息包含化学相关内容"
}, ensure_ascii=False)
elif "物理" in prompt_lower or "" in prompt_lower or "速度" in prompt_lower:
return json.dumps({
"agent_id": "physics-agent",
"confidence": 0.88,
"reason": "消息包含物理相关内容"
}, ensure_ascii=False)
elif "语文" in prompt_lower or "古诗" in prompt_lower or "作文" in prompt_lower:
return json.dumps({
"agent_id": "chinese-agent",
"confidence": 0.87,
"reason": "消息包含语文相关内容"
}, ensure_ascii=False)
elif "英语" in prompt_lower or "单词" in prompt_lower or "语法" in prompt_lower:
return json.dumps({
"agent_id": "english-agent",
"confidence": 0.86,
"reason": "消息包含英语相关内容"
}, ensure_ascii=False)
else:
return json.dumps({
"agent_id": "math-agent",
"confidence": 0.5,
"reason": "无法明确判断,使用默认 Agent"
}, ensure_ascii=False)
class LLMClientFactory:
"""LLM 客户端工厂"""
@staticmethod
def create(
provider: str = "mock",
**kwargs
) -> BaseLLMClient:
"""创建 LLM 客户端
Args:
provider: 提供商名称 (openai, azure, anthropic, local, mock)
**kwargs: 客户端配置参数
Returns:
LLM 客户端实例
"""
provider = provider.lower()
if provider == "openai":
return OpenAIClient(**kwargs)
elif provider == "azure":
return AzureOpenAIClient(**kwargs)
elif provider == "anthropic":
return AnthropicClient(**kwargs)
elif provider == "local":
return LocalLLMClient(**kwargs)
elif provider == "mock":
return MockLLMClient()
else:
raise ValueError(f"不支持的 LLM 提供商: {provider}")
@staticmethod
def create_from_env() -> BaseLLMClient:
"""从环境变量创建 LLM 客户端
环境变量:
- LLM_PROVIDER: 提供商名称
- OPENAI_API_KEY: OpenAI API 密钥
- AZURE_OPENAI_API_KEY: Azure OpenAI API 密钥
- ANTHROPIC_API_KEY: Anthropic API 密钥
Returns:
LLM 客户端实例
"""
provider = os.getenv("LLM_PROVIDER", "mock")
logger.info(f"从环境变量创建 LLM 客户端: {provider}")
return LLMClientFactory.create(provider)

View File

@@ -0,0 +1,685 @@
"""基于 LLM 的智能路由器 - 混合策略"""
import json
import re
import uuid
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session
from app.services.conversation_state_manager import ConversationStateManager
from app.models import ModelConfig, AgentConfig
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class LLMRouter:
"""基于 LLM 的智能路由器
混合策略:
1. 先用关键词快速筛选(置信度 > 0.8 直接返回)
2. 对于模糊情况(置信度 0.3-0.8),调用 LLM 辅助
3. 对于完全不匹配(置信度 < 0.3),调用 LLM
4. 缓存 LLM 结果,减少重复调用
"""
# 主题切换信号
SWITCH_SIGNALS = [
"换个话题", "另外", "还有", "对了",
"那这个呢", "再问一个", "顺便问下",
"我想问", "帮我", "请问", "换一个"
]
# 延续信号
CONTINUATION_SIGNALS = [
"继续", "还是", "", "同样", "类似",
"这个", "那个", "", "", "", ""
]
def __init__(
self,
db: Session,
state_manager: ConversationStateManager,
routing_rules: List[Dict[str, Any]],
sub_agents: Dict[str, Any],
routing_model_config: Optional[ModelConfig] = None,
use_llm: bool = True
):
"""初始化 LLM 路由器
Args:
db: 数据库会话
state_manager: 会话状态管理器
routing_rules: 路由规则列表
sub_agents: 子 Agent 配置字典
routing_model_config: 用于路由的模型配置(可选)
use_llm: 是否启用 LLM默认 True
"""
self.db = db
self.state_manager = state_manager
self.routing_rules = routing_rules
self.sub_agents = sub_agents
self.routing_model_config = routing_model_config
self.use_llm = use_llm and routing_model_config is not None
# 配置参数
self.min_confidence_for_switch = 0.7
self.max_same_agent_turns = 10
self.keyword_high_confidence_threshold = 0.8 # 关键词高置信度阈值
self.keyword_low_confidence_threshold = 0.3 # 关键词低置信度阈值
# 缓存配置
self.cache_enabled = True
self.cache_size = 1000
async def route(
self,
message: str,
conversation_id: Optional[str] = None,
force_new: bool = False
) -> Dict[str, Any]:
"""智能路由(混合策略)
Args:
message: 用户消息
conversation_id: 会话 ID
force_new: 是否强制重新路由
Returns:
路由结果
"""
logger.info(
f"开始 LLM 智能路由",
extra={
"message_length": len(message),
"conversation_id": conversation_id,
"use_llm": self.use_llm
}
)
# 1. 获取会话状态
state = None
if conversation_id and not force_new:
state = self.state_manager.get_state(conversation_id)
# 2. 检测主题切换
topic_changed = self._detect_topic_change(message, state)
# 3. 提取当前主题
topic = await self._extract_topic_with_llm(message) if self.use_llm else self._extract_topic(message)
# 4. 选择路由策略
if force_new:
agent_id, confidence, method = await self._route_with_hybrid(message)
strategy = "force_new"
reason = "用户强制重新路由"
elif not state or not state.get("current_agent_id"):
agent_id, confidence, method = await self._route_with_hybrid(message)
strategy = "new_conversation"
reason = "新会话,首次路由"
elif topic_changed:
agent_id, confidence, method = await self._route_with_hybrid(message)
strategy = "topic_changed"
reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}"
elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns:
agent_id, confidence, method = await self._route_with_hybrid(message)
strategy = "max_turns_reached"
reason = f"同一 Agent 已使用 {state['same_agent_turns']}"
else:
current_agent_id = state["current_agent_id"]
should_continue, continue_confidence = self._should_continue_current_agent(
message,
current_agent_id
)
if should_continue:
agent_id = current_agent_id
confidence = continue_confidence
method = "keyword"
strategy = "continue_current"
reason = "消息在当前 Agent 能力范围内"
else:
new_agent_id, new_confidence, method = await self._route_with_hybrid(message)
if new_confidence > continue_confidence + self.min_confidence_for_switch:
agent_id = new_agent_id
confidence = new_confidence
strategy = "switch_agent"
reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}"
else:
agent_id = current_agent_id
confidence = continue_confidence
method = "keyword"
strategy = "keep_current"
reason = "置信度差距不足以切换 Agent"
# 5. 更新会话状态
if conversation_id:
self.state_manager.update_state(
conversation_id,
agent_id,
message,
topic,
confidence
)
result = {
"agent_id": agent_id,
"confidence": confidence,
"strategy": strategy,
"topic": topic,
"topic_changed": topic_changed,
"reason": reason,
"routing_method": method # "keyword", "llm", "hybrid"
}
logger.info(
f"路由完成",
extra={
"agent_id": agent_id,
"strategy": strategy,
"confidence": confidence,
"method": method
}
)
return result
async def _route_with_hybrid(self, message: str) -> Tuple[str, float, str]:
"""混合路由策略
Args:
message: 用户消息
Returns:
(agent_id, confidence, method)
"""
# 1. 先用关键词匹配
keyword_agent_id, keyword_confidence = self._route_with_keywords(message)
# 2. 判断是否需要 LLM
if not self.use_llm or not self.routing_model_config:
# 不使用 LLM直接返回关键词结果
return keyword_agent_id, keyword_confidence, "keyword"
if keyword_confidence >= self.keyword_high_confidence_threshold:
# 关键词置信度很高,直接返回
logger.info(f"关键词置信度高 ({keyword_confidence:.2f}),跳过 LLM")
return keyword_agent_id, keyword_confidence, "keyword"
# 3. 使用 LLM 辅助决策
logger.info(f"关键词置信度较低 ({keyword_confidence:.2f}),调用 LLM")
llm_agent_id, llm_confidence = await self._route_with_llm(message)
# 4. 综合决策
if llm_confidence > keyword_confidence:
# LLM 置信度更高
final_confidence = llm_confidence * 0.7 + keyword_confidence * 0.3
return llm_agent_id, final_confidence, "llm"
else:
# 关键词置信度更高或相当
final_confidence = keyword_confidence * 0.7 + llm_confidence * 0.3
return keyword_agent_id, final_confidence, "hybrid"
def _route_with_keywords(self, message: str) -> Tuple[str, float]:
"""基于关键词的路由
Args:
message: 用户消息
Returns:
(agent_id, confidence)
"""
best_agent_id = None
best_score = 0.0
for rule in self.routing_rules:
score = self._calculate_rule_score(message, rule)
if score > best_score:
best_score = score
best_agent_id = rule.get("target_agent_id")
if not best_agent_id or best_score < 0.3:
best_agent_id = self._get_default_agent_id()
best_score = 0.5
return best_agent_id, best_score
async def _route_with_llm(self, message: str) -> Tuple[str, float]:
"""基于 LLM 的路由
Args:
message: 用户消息
Returns:
(agent_id, confidence)
"""
# 检查缓存
if self.cache_enabled:
cached_result = self._get_cached_llm_result(message)
if cached_result:
logger.info("使用缓存的 LLM 路由结果")
return cached_result
# 构建 prompt
prompt = self._build_routing_prompt(message)
try:
# 调用 LLM
response = await self._call_llm(prompt)
# 解析结果
agent_id, confidence = self._parse_llm_response(response)
# 缓存结果
if self.cache_enabled:
self._cache_llm_result(message, agent_id, confidence)
return agent_id, confidence
except Exception as e:
logger.error(f"LLM 路由失败: {str(e)}")
# 降级到关键词路由
return self._route_with_keywords(message)
def _build_routing_prompt(self, message: str) -> str:
"""构建 LLM 路由 prompt
Args:
message: 用户消息
Returns:
prompt 字符串
"""
# 构建 Agent 描述
agent_descriptions = []
for agent_id, agent_data in self.sub_agents.items():
# 获取 Agent 信息
agent_info = agent_data.get("info", {})
agent_config = agent_data.get("config")
# 查找该 Agent 的路由规则
rules = [r for r in self.routing_rules if r.get("target_agent_id") == agent_id]
# 构建描述
name = agent_info.get("name", "未命名 Agent")
role = agent_info.get("role", "")
capabilities = agent_info.get("capabilities", [])
desc_parts = [f"- agent_id: {agent_id}", f" 名称: {name}"]
if role:
desc_parts.append(f" 角色: {role}")
# 从路由规则获取关键词
if rules:
rule = rules[0]
keywords = rule.get("keywords", [])
if keywords:
desc_parts.append(f" 关键词: {', '.join(keywords[:5])}")
# 从 Agent 信息获取能力
if capabilities:
desc_parts.append(f" 擅长: {', '.join(capabilities[:5])}")
agent_descriptions.append("\n".join(desc_parts))
agents_text = "\n\n".join(agent_descriptions)
# 如果没有 Agent 描述,添加警告
if not agents_text:
agents_text = "(警告:没有可用的 Agent 信息)"
# 提取所有可用的 agent_id
available_agent_ids = list(self.sub_agents.keys())
agent_ids_text = ", ".join(available_agent_ids)
prompt = f"""你是一个智能路由助手,需要根据用户的消息,选择最合适的 Agent 来处理。
可用的 Agent
{agents_text}
用户消息:"{message}"
**重要**:你必须从以下 agent_id 中选择一个:{agent_ids_text}
请分析这条消息,选择最合适的 Agent。
要求:
1. 仔细理解消息的意图和主题
2. 从上面列出的 agent_id 中选择最匹配的一个
3. 给出置信度0-1 之间的小数)
4. agent_id 必须是上面列出的其中一个,不能自己编造
请以 JSON 格式返回:
{{
"agent_id": "从上面列表中选择的 agent_id",
"confidence": 0.95,
"reason": "选择理由"
}}
"""
return prompt
async def _call_llm(self, prompt: str) -> str:
"""调用 LLM API使用系统的 RedBearLLM
Args:
prompt: 提示词
Returns:
LLM 响应
"""
if not self.routing_model_config:
raise Exception("路由模型配置未设置")
try:
# 使用系统的 RedBearLLM 来调用模型
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelApiKey, ModelType
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == self.routing_model_config.id,
ModelApiKey.is_active == True
).first()
if not api_key_config:
raise Exception("路由模型没有可用的 API Key")
# 打印供应商信息
logger.info(
f"LLM 路由使用模型",
extra={
"provider": api_key_config.provider,
"model_name": api_key_config.model_name,
"api_base": api_key_config.api_base,
"model_config_id": str(self.routing_model_config.id)
}
)
# 创建 RedBearModelConfig
model_config = RedBearModelConfig(
model_name=api_key_config.model_name,
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
temperature=0.3,
max_tokens=500
)
logger.debug(f"创建 LLM 实例 - Provider: {api_key_config.provider}, Model: {api_key_config.model_name}")
# 创建 LLM 实例
llm = RedBearLLM(model_config, type=ModelType.CHAT)
# 调用模型
response = await llm.ainvoke(prompt)
# 提取响应内容
if hasattr(response, 'content'):
return response.content
else:
return str(response)
except Exception as e:
logger.error(f"LLM 路由调用失败: {str(e)}")
# 降级到关键词路由
raise
def _parse_llm_response(self, response: str) -> Tuple[str, float]:
"""解析 LLM 响应
Args:
response: LLM 响应文本
Returns:
(agent_id, confidence)
"""
try:
# 提取 JSON
json_match = re.search(r'\{[^}]+\}', response)
if json_match:
result = json.loads(json_match.group())
agent_id = result.get("agent_id")
confidence = float(result.get("confidence", 0.5))
# 验证 agent_id 是否有效
if agent_id not in self.sub_agents:
logger.warning(f"LLM 返回的 agent_id 无效: {agent_id}")
agent_id = self._get_default_agent_id()
confidence = 0.5
return agent_id, confidence
else:
raise ValueError("无法从响应中提取 JSON")
except Exception as e:
logger.error(f"解析 LLM 响应失败: {str(e)}")
return self._get_default_agent_id(), 0.5
def _get_cached_llm_result(self, message: str) -> Optional[Tuple[str, float]]:
"""获取缓存的 LLM 结果
Args:
message: 用户消息
Returns:
缓存的结果或 None
"""
# TODO: 实现真正的缓存机制(使用 Redis 或内存字典)
return None
def _cache_llm_result(self, message: str, agent_id: str, confidence: float):
"""缓存 LLM 结果
Args:
message: 用户消息
agent_id: Agent ID
confidence: 置信度
"""
# lru_cache 会自动处理缓存
pass
async def _extract_topic_with_llm(self, message: str) -> str:
"""使用 LLM 提取主题
Args:
message: 用户消息
Returns:
主题名称
"""
if not self.routing_model_config:
return self._extract_topic(message)
prompt = f"""请分析以下消息的主题,从这些选项中选择一个:
数学、物理、化学、语文、英语、历史、作业、学习规划、订单、退款、账户、支付、其他
消息:"{message}"
只返回主题名称,不要其他内容。
"""
try:
response = await self._call_llm(prompt)
topic = response.strip()
# 验证主题
valid_topics = [
"数学", "物理", "化学", "语文", "英语", "历史",
"作业", "学习规划", "订单", "退款", "账户", "支付", "其他"
]
if topic in valid_topics:
return topic
else:
return self._extract_topic(message)
except Exception as e:
logger.error(f"LLM 提取主题失败: {str(e)}")
return self._extract_topic(message)
# 以下方法与 SmartRouter 相同
def _detect_topic_change(
self,
message: str,
state: Optional[Dict[str, Any]]
) -> bool:
"""检测主题是否切换"""
if not state or not state.get("last_topic"):
return False
for signal in self.SWITCH_SIGNALS:
if signal in message:
logger.info(f"检测到主题切换信号: {signal}")
return True
current_topic = self._extract_topic(message)
last_topic = state.get("last_topic")
if current_topic != last_topic and current_topic != "其他":
logger.info(f"主题变化: {last_topic} -> {current_topic}")
return True
return False
def _should_continue_current_agent(
self,
message: str,
current_agent_id: str
) -> Tuple[bool, float]:
"""判断是否应该继续使用当前 Agent"""
has_continuation_signal = any(
signal in message
for signal in self.CONTINUATION_SIGNALS
)
current_score = self._calculate_agent_score(message, current_agent_id)
if has_continuation_signal and current_score > 0.3:
return True, min(current_score + 0.2, 1.0)
if current_score > 0.6:
return True, current_score
return False, current_score
def _calculate_rule_score(
self,
message: str,
rule: Dict[str, Any]
) -> float:
"""计算规则匹配分数"""
score = 0.0
message_lower = message.lower()
keywords = rule.get("keywords", [])
if keywords:
matched_keywords = sum(
1 for keyword in keywords
if keyword.lower() in message_lower
)
keyword_score = matched_keywords / len(keywords)
score += keyword_score * 0.6
patterns = rule.get("patterns", [])
if patterns:
matched_patterns = sum(
1 for pattern in patterns
if re.search(pattern, message, re.IGNORECASE)
)
pattern_score = matched_patterns / len(patterns)
score += pattern_score * 0.3
exclude_keywords = rule.get("exclude_keywords", [])
if exclude_keywords:
has_exclude = any(
keyword.lower() in message_lower
for keyword in exclude_keywords
)
if has_exclude:
score *= 0.5
min_keyword_count = rule.get("min_keyword_count", 0)
if keywords and min_keyword_count > 0:
matched_count = sum(
1 for keyword in keywords
if keyword.lower() in message_lower
)
if matched_count < min_keyword_count:
score *= 0.7
return min(score, 1.0)
def _calculate_agent_score(
self,
message: str,
agent_id: str
) -> float:
"""计算 Agent 对消息的匹配分数"""
agent_rules = [
rule for rule in self.routing_rules
if rule.get("target_agent_id") == agent_id
]
if not agent_rules:
return 0.0
max_score = max(
self._calculate_rule_score(message, rule)
for rule in agent_rules
)
return max_score
def _extract_topic(self, message: str) -> str:
"""提取消息主题(关键词方式)"""
topic_keywords = {
"数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"],
"物理": ["物理", "", "速度", "加速度", "能量", "功率", "电路"],
"化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"],
"语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"],
"英语": ["英语", "单词", "语法", "翻译", "时态", "句型"],
"历史": ["历史", "朝代", "事件", "人物", "战争", "革命"],
"作业": ["作业", "批改", "检查", "评分", "反馈"],
"学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"],
"订单": ["订单", "发货", "物流", "配送", "快递"],
"退款": ["退款", "退货", "售后", "换货", "维修"],
"账户": ["账户", "密码", "登录", "注册", "绑定"],
"支付": ["支付", "付款", "充值", "余额", "优惠券"]
}
message_lower = message.lower()
topic_scores = {}
for topic, keywords in topic_keywords.items():
matched = sum(
1 for keyword in keywords
if keyword in message_lower
)
if matched > 0:
topic_scores[topic] = matched
if topic_scores:
best_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
return best_topic
return "其他"
def _get_default_agent_id(self) -> str:
"""获取默认 Agent ID"""
if self.routing_rules:
return self.routing_rules[0].get("target_agent_id")
if self.sub_agents:
return list(self.sub_agents.keys())[0]
return "default-agent"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,595 @@
from sqlalchemy.orm import Session
from typing import List
import uuid
from fastapi import HTTPException
from app.models.user_model import User
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.memory_increment_model import MemoryIncrement
from app.repositories import (
app_repository,
end_user_repository,
memory_increment_repository,
knowledge_repository
)
from app.schemas.end_user_schema import EndUser as EndUserSchema
from app.schemas.memory_increment_schema import MemoryIncrement as MemoryIncrementSchema
from app.schemas.app_schema import App as AppSchema
from app.core.logging_config import get_business_logger
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
def get_workspace_end_users(
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> List[EndUser]:
"""获取工作空间的所有宿主"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 查询应用ORM并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
apps = [AppSchema.model_validate(h) for h in apps_orm]
app_ids = [app.id for app in apps]
end_users = []
for app_id in app_ids:
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list])
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间宿主列表失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_memory_increment(
db: Session,
workspace_id: uuid.UUID,
limit: int,
current_user: User
) -> List[MemoryIncrementSchema]:
"""获取工作空间的记忆增量"""
business_logger.info(f"获取工作空间记忆增量: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 查询记忆增量
memory_increment_orm_list = memory_increment_repository.get_memory_increments_by_workspace_id(db, workspace_id, limit)
memory_increment = [MemoryIncrementSchema.model_validate(m) for m in memory_increment_orm_list]
business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
return memory_increment
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间记忆增量失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_api_increment(
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> int:
"""获取工作空间的API调用增量"""
business_logger.info(f"获取工作空间API调用增量: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 查询API调用增量
api_increment = 856
business_logger.info(f"成功获取 {api_increment} API调用增量")
return api_increment
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间API调用增量失败: workspace_id={workspace_id} - {str(e)}")
raise
def write_workspace_total_memory(
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> int:
"""写入工作空间的记忆总量"""
business_logger.info(f"写入工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 模拟记忆总量
total_num = 1024
# 写入记忆总量
memory_increment_repository.write_memory_increment(db, workspace_id, total_num)
business_logger.info(f"成功写入记忆总量 {total_num}")
return total_num
except HTTPException:
raise
except Exception as e:
business_logger.error(f"写入工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_memory_list(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
limit: int = 7
) -> dict:
"""
获取工作空间的记忆列表(整合接口)
整合以下三个接口的数据:
1. total_memory - 工作空间记忆总量
2. memory_increment - 工作空间记忆增量
3. hosts - 工作空间宿主列表
"""
business_logger.info(f"获取工作空间记忆列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
result = {}
try:
# 1. 获取记忆总量
try:
total_memory = write_workspace_total_memory(db, workspace_id, current_user)
result["total_memory"] = total_memory
business_logger.info(f"成功获取记忆总量: {total_memory}")
except Exception as e:
business_logger.warning(f"获取记忆总量失败: {str(e)}")
result["total_memory"] = 0.0
# 2. 获取记忆增量
try:
memory_increment = get_workspace_memory_increment(db, workspace_id, limit, current_user)
result["memory_increment"] = memory_increment
business_logger.info(f"成功获取 {len(memory_increment)} 条记忆增量记录")
except Exception as e:
business_logger.warning(f"获取记忆增量失败: {str(e)}")
result["memory_increment"] = []
# 3. 获取宿主列表
try:
hosts = get_workspace_end_users(db, workspace_id, current_user)
result["hosts"] = hosts
business_logger.info(f"成功获取 {len(hosts)} 个宿主记录")
except Exception as e:
business_logger.warning(f"获取宿主列表失败: {str(e)}")
result["hosts"] = []
business_logger.info(f"成功获取工作空间记忆列表")
return result
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间记忆列表失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_total_end_users(
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> dict:
"""
获取用户列表的总用户数
"""
business_logger.info(f"获取用户列表的总用户数: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 复用原有的 get_workspace_end_users 逻辑
end_users = get_workspace_end_users(db, workspace_id, current_user)
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return {
"total_num": len(end_users),
"online_num": len(end_users)
}
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取用户列表失败: workspace_id={workspace_id} - {str(e)}")
raise
async def get_workspace_total_memory_count(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
end_user_id: str = None
) -> dict:
"""
获取工作空间的记忆总量通过聚合所有host的记忆数
逻辑:
1. 从 memory_list 获取所有 host_id
2. 对每个 host_id 调用 search_all 获取 total
3. 将所有 total 求和返回
"""
business_logger.info(f"获取工作空间记忆总量: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
# 1. 获取所有 hosts
hosts = get_workspace_end_users(db, workspace_id, current_user)
business_logger.info(f"获取到 {len(hosts)} 个宿主")
if not hosts:
business_logger.warning("未找到任何宿主返回0")
return {
"total_memory_count": 0,
"host_count": 0,
"details": []
}
# 2. 对每个 host_id 调用 search_all 获取 total
from app.services import memory_storage_service
total_count = 0
details = []
# 如果提供了 end_user_id只查询该用户
if end_user_id:
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
return {
"total_memory_count": search_result.get("total", 0),
"host_count": 1,
"details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}]
}
for host in hosts:
try:
end_user_id_str = str(host.id)
search_result = await memory_storage_service.search_all(
end_user_id=end_user_id_str
)
host_total = search_result.get("total", 0)
total_count += host_total
details.append({
"end_user_id": end_user_id_str,
"count": host_total
})
business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}")
except Exception as e:
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
# 失败的 host 记为 0
details.append({
"end_user_id": str(host.id),
"count": 0
})
result = {
"total_memory_count": total_count,
"host_count": len(hosts),
"details": details
}
business_logger.info(f"成功获取工作空间记忆总量: {total_count} (来自 {len(hosts)} 个宿主)")
return result
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间记忆总量失败: workspace_id={workspace_id} - {str(e)}")
raise
# ======== RAG 相关服务 ========
def get_rag_total_doc(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有doc_num的总和
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总文档数: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_doc = knowledge_repository.get_total_doc_num_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总文档数: {total_doc}")
return total_doc
except Exception as e:
business_logger.error(f"获取RAG总文档数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_total_chunk(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有chunk_num的总和
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总chunk数: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_chunk = knowledge_repository.get_total_chunk_num_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总chunk数: {total_chunk}")
return total_chunk
except Exception as e:
business_logger.error(f"获取RAG总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_total_kb(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
return total_kb
except Exception as e:
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_current_user_total_chunk(
end_user_id: str,
db: Session,
current_user: User
) -> int:
"""
计算documents表中file_name=='end_user_id'+'.txt'的所有记录chunk_num的总和
"""
business_logger.info(f"获取用户总chunk数: end_user_id={end_user_id}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from sqlalchemy import func
# 构造文件名
file_name = f"{end_user_id}.txt"
# 查询并求和
total_chunk = db.query(func.sum(Document.chunk_num)).filter(
Document.file_name == file_name
).scalar() or 0
business_logger.info(f"成功获取用户总chunk数: {total_chunk} (file_name={file_name})")
return int(total_chunk)
except Exception as e:
business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}")
raise
def get_rag_content(
end_user_id: str,
limit: int,
db: Session,
current_user: User
) -> dict:
"""
先在documents表中查询file_name=='end_user_id'+'.txt'的id和kb_id,
然后调用/chunks/{kb_id}/{document_id}/chunks接口的相关代码获取所有内容
接着对获取的内容进行提取只要page_content的内容
最后返回数据
"""
business_logger.info(f"获取RAG内容: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
# 1. 构造文件名
file_name = f"{end_user_id}.txt"
# 2. 查询documents表获取id和kb_id
documents = db.query(Document).filter(
Document.file_name == file_name
).all()
if not documents:
business_logger.warning(f"未找到文件: {file_name}")
return {
"total": 0,
"contents": []
}
business_logger.info(f"找到 {len(documents)} 个文档记录")
# 3. 获取所有chunks的page_content
all_contents = []
total_chunks = 0
for document in documents:
try:
# 获取知识库信息
kb = knowledge_repository.get_knowledge_by_id(db, document.kb_id)
if not kb:
business_logger.warning(f"知识库不存在: kb_id={document.kb_id}")
continue
# 初始化向量服务
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=kb)
# 获取该文档的所有chunks分页获取
page = 1
pagesize = 100 # 每页100条
while True:
total, items = vector_service.search_by_segment(
document_id=str(document.id),
query=None,
pagesize=pagesize,
page=page,
asc=True
)
if not items:
break
# 提取page_content
for item in items:
all_contents.append(item.page_content)
total_chunks += 1
# # 如果达到limit限制直接返回
# if limit > 0 and total_chunks >= limit:
# business_logger.info(f"已达到limit限制: {limit}")
# return {
# "total": total_chunks,
# "contents": all_contents[:limit]
# }
# 检查是否还有下一页
if page * pagesize >= total:
break
page += 1
business_logger.info(f"文档 {document.id} 获取了 {len(items)} 个chunks")
except Exception as e:
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
continue
# 4. 返回结果
result = {
"total": total_chunks,
"contents": all_contents[:limit] if limit > 0 else all_contents
}
business_logger.info(f"成功获取RAG内容: total={total_chunks}, 返回={len(result['contents'])}")
return result
except Exception as e:
business_logger.error(f"获取RAG内容失败: end_user_id={end_user_id} - {str(e)}")
raise
async def get_chunk_summary_and_tags(
end_user_id: str,
limit: int,
max_tags: int,
db: Session,
current_user: User
) -> dict:
"""
获取chunk的总结、标签和人物形象
Args:
end_user_id: 宿主ID
limit: 返回的chunk数量限制
max_tags: 最大标签数量
db: 数据库会话
current_user: 当前用户
Returns:
包含summary、tags和personas的字典
"""
business_logger.info(f"获取chunk摘要、标签和人物形象: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
try:
# 1. 获取chunk内容
rag_content = get_rag_content(end_user_id, limit, db, current_user)
chunks = rag_content.get("contents", [])
if not chunks:
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
return {
"summary": "暂无内容",
"tags": [],
"personas": []
}
# 2. 导入RAG工具函数
from app.core.rag_utils import generate_chunk_summary, extract_chunk_tags, extract_chunk_persona
# 3. 并发生成摘要、提取标签和人物形象
import asyncio
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
# 4. 格式化标签数据
tags = [{"tag": tag, "frequency": freq} for tag, freq in tags_with_freq]
result = {
"summary": summary,
"tags": tags,
"personas": personas
}
business_logger.info(f"成功获取chunk摘要、{len(tags)} 个标签和 {len(personas)} 个人物形象")
return result
except Exception as e:
business_logger.error(f"获取chunk摘要、标签和人物形象失败: end_user_id={end_user_id} - {str(e)}")
raise
async def get_chunk_insight(
end_user_id: str,
limit: int,
db: Session,
current_user: User
) -> dict:
"""
获取chunk的洞察分析
Args:
end_user_id: 宿主ID
limit: 返回的chunk数量限制
db: 数据库会话
current_user: 当前用户
Returns:
包含insight的字典
"""
business_logger.info(f"获取chunk洞察: end_user_id={end_user_id}, limit={limit}, 操作者: {current_user.username}")
try:
# 1. 获取chunk内容
rag_content = get_rag_content(end_user_id, limit, db, current_user)
chunks = rag_content.get("contents", [])
if not chunks:
business_logger.warning(f"未找到chunk内容: end_user_id={end_user_id}")
return {
"insight": "暂无足够数据生成洞察报告"
}
# 2. 导入RAG工具函数
from app.core.rag_utils import generate_chunk_insight
# 3. 生成洞察
insight = await generate_chunk_insight(chunks, max_chunks=limit)
result = {
"insight": insight
}
business_logger.info(f"成功获取chunk洞察")
return result
except Exception as e:
business_logger.error(f"获取chunk洞察失败: end_user_id={end_user_id} - {str(e)}")
raise

View File

@@ -0,0 +1,582 @@
# 修改 memory_konwledges_server.py 文件
import asyncio
import os
import re
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User
from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试
api_logger = get_api_logger()
class ChunkCreate(BaseModel):
content: str
class SimpleUser:
def __init__(self, user_id: str):
# 确保ID是UUID类型
self.id = user_id
self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
"""
解析指定文档
Args:
document_id: 文档ID
db: 数据库会话
current_user: 当前用户
Returns:
dict: 包含任务ID的结果字典
Raises:
HTTPException: 当文档、文件或知识库不存在时抛出异常
"""
try:
# 1. 检查文档是否存在
api_logger.debug(f"检查文档是否存在: {document_id}")
db_document = document_service.get_document_by_id(db, document_id=document_id, current_user=current_user)
if not db_document:
api_logger.warning(f"文档不存在或无访问权限: document_id={document_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文档不存在或无访问权限"
)
# 2. 检查文件是否存在
api_logger.debug(f"检查文件是否存在: {db_document.file_id}")
db_file = file_service.get_file_by_id(db, file_id=db_document.file_id)
if not db_file:
api_logger.warning(f"文件不存在或无访问权限: file_id={db_document.file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件不存在或无访问权限"
)
# 3. 构建文件路径:/files/{kb_id}/{parent_id}/{file.id}{file.file_ext}
file_path = os.path.join(
settings.FILE_PATH,
str(db_file.kb_id),
str(db_file.parent_id),
f"{db_file.id}{db_file.file_ext}"
)
# 4. 检查文件是否存在于磁盘上
if not os.path.exists(file_path):
api_logger.warning(f"文件未找到(可能已被删除): file_path={file_path}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文件未找到(可能已被删除)"
)
# 5. 获取知识库信息
api_logger.info(f"获取知识库详情: knowledge_id={db_document.kb_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=db_document.kb_id,
current_user=current_user)
if not db_knowledge:
api_logger.warning(f"知识库不存在或访问被拒绝: knowledge_id={db_document.kb_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="知识库不存在或访问被拒绝"
)
# 6. 发送解析任务到Celery后台队列
task = celery_app.send_task("app.core.rag.tasks.parse_document", args=[file_path, document_id])
result = {
"task_id": task.id
}
api_logger.info(f"文档解析任务已接受: document_id={document_id}, task_id={task.id}")
return result
except Exception as e:
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise
'''获取块ID'''
async def get_document_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
keywords: Optional[str] = None,
db: Session = None,
current_user: User = None
):
"""
分页查询文档块列表
Args:
kb_id: 知识库ID
document_id: 文档ID
page: 页码默认为1
pagesize: 每页大小默认为20
keywords: 用于匹配块内容的关键字
db: 数据库会话
current_user: 当前用户
Returns:
dict: 包含分页数据的响应结果
Raises:
HTTPException: 当知识库不存在或查询失败时抛出异常
"""
api_logger.info(
f"分页查询文档块列表: kb_id={kb_id}, document_id={document_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 参数验证
if page < 1 or pagesize < 1:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="分页参数必须大于0"
)
# 获取知识库信息
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="知识库不存在或访问被拒绝"
)
# 执行分页查询
try:
api_logger.debug(f"开始执行文档块查询")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.search_by_segment(
document_id=str(document_id),
query=keywords,
pagesize=pagesize,
page=page,
asc=True
)
api_logger.info(f"文档块查询成功: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"文档块查询失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询失败: {str(e)}"
)
# 构造响应结果
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename(
db: Session,
kb_id: str,
file_name: str
) -> str | None:
"""
通过 kb_id 和 file_name 在 documents 表中查找对应的 ID
Args:
db: 数据库会话
kb_id: 知识库ID
file_name: 文件名
Returns:
str | None: 找到的 document ID未找到返回 None
"""
try:
# 查询 documents 表
document = db.query(Document).filter(
Document.kb_id == kb_id,
Document.file_name == file_name
).first()
if document:
print(f"找到文档: ID={document.id}, kb_id={kb_id}, file_name={file_name}")
return str(document.id)
else:
return None
except Exception as e:
return None
'''获取知识库ID'''
def find_documents_by_kb_id(
db: Session,
kb_id: str,
limit: int = 10
) -> list[dict]:
"""
通过 kb_id 查找所有相关文档
Args:
db: 数据库会话
kb_id: 知识库ID
limit: 返回结果数量限制
Returns:
list[dict]: 文档列表,包含 id, file_name, created_at 等信息
"""
try:
documents = db.query(Document).filter(
Document.kb_id == kb_id
).limit(limit).all()
result = []
for doc in documents:
result.append({
"id": str(doc.id),
"file_name": doc.file_name,
"file_ext": doc.file_ext,
"file_size": doc.file_size,
"created_at": doc.created_at.isoformat() if doc.created_at else None,
"status": getattr(doc, 'status', None)
})
return result
except Exception as e:
return []
''''上传文件'''
async def memory_konwledges_up(
kb_id: str,
parent_id: str,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: SimpleUser = None, # 修改为SimpleUser
):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The content is empty."
)
# If the file size exceeds 50MB (50 * 1024 * 1024 bytes)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The content size exceeds the {settings.MAX_FILE_SIZE}byte limit"
)
upload_file = file_schema.FileCreate(
kb_id=kb_id,
created_by=current_user.id, # 现在是UUID类型
parent_id=parent_id,
file_name=f"{create_data.title}.txt",
file_ext=".txt",
file_size=file_size,
)
db_file = file_service.create_file(db=db, file=upload_file, current_user=current_user)
# Construct a save path/files/{kb_id}/{parent_id}/{file.id}{file_extension}
# 使用 settings.FILE_PATH 确保与 parse_document_by_id 一致
save_dir = os.path.join(settings.FILE_PATH, str(kb_id), str(parent_id))
# 确保目录存在
Path(save_dir).mkdir(parents=True, exist_ok=True)
save_path = os.path.join(save_dir, f"{db_file.id}.txt")
# Save file
with open(save_path, "wb") as f:
f.write(content_bytes)
# Verify whether the file has been saved successfully
if not os.path.exists(save_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="File save failed"
)
# Create a document
create_document_data = document_schema.DocumentCreate(
kb_id=kb_id,
created_by=current_user.id,
file_id=db_file.id,
file_name=db_file.file_name,
file_ext=db_file.file_ext,
file_size=db_file.file_size,
file_meta={},
parser_id="naive",
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 128,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": "false"
}
)
db_document = document_service.create_document(db=db, document=create_document_data, current_user=current_user)
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk(
kb_id: uuid.UUID,
document_id: uuid.UUID,
create_data: ChunkCreate,
db: Session,
current_user: User
):
"""
创建文档块
Args:
kb_id: 知识库ID
document_id: 文档ID
create_data: 创建数据
db: 数据库会话
current_user: 当前用户
Returns:
dict: 包含创建的文档块信息的成功响应
Raises:
HTTPException: 当知识库或文档不存在时抛出相应异常
"""
api_logger.info(
f"创建文档块请求: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
# 1. 获取知识库信息
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
if not db_knowledge:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="知识库不存在或访问被拒绝"
)
# 2. 获取文档信息
db_document = db.query(Document).filter(Document.id == document_id).first()
if not db_document:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="文档不存在或您无访问权限"
)
# 3. 初始化向量服务
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
# 4. 获取排序ID处理索引不存在的情况
sort_id = 0
try:
total, items = vector_service.search_by_segment(document_id=str(document_id), pagesize=1, page=1, asc=False)
if items:
sort_id = items[0].metadata["sort_id"]
except Exception as e:
# 如果索引不存在,从 0 开始
error_msg = str(e)
if "index_not_found_exception" in error_msg or "no such index" in error_msg:
api_logger.warning(f"索引不存在,将从 sort_id=0 开始: {error_msg}")
sort_id = 0
else:
# 其他错误则抛出
api_logger.error(f"查询文档块失败: {error_msg}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询文档块失败: {error_msg}"
)
sort_id = sort_id + 1
# 5. 创建文档块
doc_id = uuid.uuid4().hex
metadata = {
"doc_id": doc_id,
"file_id": str(db_document.file_id),
"file_name": db_document.file_name,
"file_created_at": int(db_document.created_at.timestamp() * 1000),
"document_id": str(document_id),
"knowledge_id": str(kb_id),
"sort_id": sort_id,
"status": 1,
}
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
# 6. 存储向量化的文档块(这会自动创建索引如果不存在)
try:
vector_service.add_chunks([chunk])
except Exception as e:
api_logger.error(f"添加文档块到向量库失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"添加文档块到向量库失败: {str(e)}"
)
# 7. 更新 chunk_num
db_document.chunk_num += 1
db.commit()
return success(data=chunk, msg="文档块创建成功")
async def write_rag(group_id, message, user_rag_memory_id):
"""
将消息写入 RAG 知识库
Args:
group_id: 组ID用作文件标题
message: 消息内容
user_rag_memory_id: 知识库ID必须是有效的UUID
Returns:
写入结果
Raises:
HTTPException: 当参数无效或操作失败时
"""
# 验证 user_rag_memory_id 是否为有效的 UUID
if not user_rag_memory_id:
api_logger.error("user_rag_memory_id 为空,无法执行 RAG 写入操作")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="知识库ID不能为空"
)
try:
# 尝试将字符串转换为 UUID 以验证格式
kb_uuid = uuid.UUID(user_rag_memory_id)
except (ValueError, AttributeError) as e:
api_logger.error(f"user_rag_memory_id 不是有效的UUID: {user_rag_memory_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"知识库ID格式无效: {user_rag_memory_id}"
)
db_gen = get_db()
db = next(db_gen)
try:
create_data = CustomTextFileCreate(title=group_id, content=message)
current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt")
print('======',document)
api_logger.info(f"查找文档结果: document_id={document}")
if document is not None:
# 文档已存在,直接添加新块
api_logger.info(f"文档已存在,添加新块: document_id={document}")
create_chunks = ChunkCreate(content=message)
result = await create_document_chunk(
kb_id=kb_uuid,
document_id=uuid.UUID(document),
create_data=create_chunks,
db=db,
current_user=current_user
)
return result
else:
# 文档不存在,创建新文档
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}")
result = await memory_konwledges_up(
kb_id=user_rag_memory_id,
parent_id=user_rag_memory_id,
create_data=create_data,
db=db,
current_user=current_user
)
await parse_document_by_id(document, db=db, current_user=current_user)
return result
finally:
# 确保数据库会话被关闭
db.close()
# 在异步环境中调用示例
async def example_usage():
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 创建 CustomTextFileCreate 对象
title = '2f6ff1eb-50c7-4765-8e89-e4566be19122'
create_data = CustomTextFileCreate(
title=title,
content="莫扎特在巴黎经历母亲去世后返回萨尔茨堡,他随后创作的交响曲主题是否与格鲁克在维也纳推动的“改革歌剧”理念存在共通之处?贝多芬早年曾师从海顿,而海顿又受雇于埃斯特哈齐家族——这种师承体系如何影响了当时欧洲宫廷音乐的传承结构?斯卡拉歌剧院选择萨列里的歌剧作为开幕演出,是否与当时米兰政治环境和奥地利宫廷影响有关?"
)
# 创建用户对象
current_user = SimpleUser("6243c125-9420-402c-bbb5-d1977811ac96")
# 上传文件
result = await memory_konwledges_up(
kb_id="c71df60a-36a6-4759-a2ce-101e3087b401",
parent_id="c71df60a-36a6-4759-a2ce-101e3087b401",
create_data=create_data,
db=db,
current_user=current_user
)
print(result)
#找到document_id
# 使用刚创建的文档ID进行解析
document = find_document_id_by_kb_and_filename(db=db, kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", file_name=f"{title}.txt")
print('====',document)
res___=await parse_document_by_id(document, db=db, current_user=current_user)
print(res___)
# result='e8cf9ace-d1a9-4af2-b0c4-3fc94f4f8042'
# document_id='d22e8173-50d0-4e10-a7de-aa638ef893bc'
#
# '''更新块'''
#
# new_content = "这是新的 chunk 内容,用来覆盖原来的内容"
# # 构造 ChunkUpdate 对象
# update_data = ChunkCreate(content=new_content)
# updated_chunk = await create_document_chunk(
# kb_id= result,
# document_id=document_id,
# create_data= update_data,
# db=db,
# current_user=current_user
# )
# print(updated_chunk)
return '','',''
if __name__ == "__main__":
# asyncio.run(example_usage())
asyncio.run(write_rag('1111','22222',"c71df60a-36a6-4759-a2ce-101e3087b401"))

View File

@@ -0,0 +1,568 @@
"""
Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any
import os
import json
from dotenv import load_dotenv
from app.core.logging_config import get_logger
from app.schemas.memory_storage_schema import (
ConfigFilter,
ConfigPilotRun,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.data_config_repository import DataConfigRepository
logger = get_logger(__name__)
# Load environment variables for Neo4j connector
load_dotenv()
_neo4j_connector = Neo4jConnector()
class MemoryStorageService:
"""Service for memory storage operations"""
def __init__(self):
logger.info("MemoryStorageService initialized")
async def get_storage_info(self) -> dict:
"""
Example wrapper method - retrieves storage information
Args:
Returns:
Storage information dictionary
"""
logger.info(f"Getting storage info ")
# Empty wrapper - implement your logic here
result = {
"status": "active",
"message": "This is an example wrapper"
}
return result
class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD.
The DB connection is optional; when absent, methods return a failure
response containing an SQL preview to aid integration.
"""
def __init__(self, db_conn: Optional[Any] = None) -> None:
self.db_conn = db_conn
# --- Driver compatibility helpers ---
@staticmethod
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
mod = type(conn).__module__
return ("psycopg2" in mod) or ("psycopg" in mod)
@staticmethod
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
from datetime import datetime
for item in data_list:
for field in ['created_at', 'updated_at']:
if field in item and item[field] is not None:
value = item[field]
dt = None
# 如果是 datetime 对象,直接使用
if isinstance(value, datetime):
dt = value
# 如果是字符串,先解析
elif isinstance(value, str):
try:
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
except Exception:
pass # 保持原值
# 转换为 YYYYMMDDHHmmss 格式
if dt:
item[field] = dt.strftime('%Y%m%d%H%M%S')
return data_list
# --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
# 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id)
if configs is None:
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
# 只在未指定时填充(允许手动覆盖)
if not params.llm_id:
params.llm_id = configs.get('llm')
if not params.embedding_id:
params.embedding_id = configs.get('embedding')
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
query, qparams = DataConfigRepository.build_insert(params)
cur = self.db_conn.cursor()
# PostgreSQL 使用 psycopg2 的命名参数格式
cur.execute(query, qparams)
self.db_conn.commit()
return {"affected": getattr(cur, "rowcount", None)}
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
"""获取工作空间模型配置(内部方法,便于测试)"""
from app.db import SessionLocal
from app.repositories.workspace_repository import get_workspace_models_configs
db_session = SessionLocal()
try:
return get_workspace_models_configs(db_session, workspace_id)
finally:
db_session.close()
# --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称)
query, qparams = DataConfigRepository.build_delete(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
# 如果没有任何行被删除,抛出异常
if not affected:
raise ValueError("未找到配置")
return {"affected": affected}
# --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
query, qparams = DataConfigRepository.build_update(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
raise ValueError("未找到配置")
return {"affected": affected}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
query, qparams = DataConfigRepository.build_update_extracted(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
raise ValueError("未找到配置")
return {"affected": affected}
# --- Forget config params ---
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
query, qparams = DataConfigRepository.build_update_forget(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
raise ValueError("未找到配置")
return {"affected": affected}
# --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_extracted(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_forget(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
query, qparams = DataConfigRepository.build_select_all(workspace_id)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
rows = cur.fetchall()
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
if not rows:
return []
# Map rows to list of dicts
columns = [desc[0] for desc in cur.description]
data_list = [dict(zip(columns, row)) for row in rows]
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
for item in data_list:
if 'workspace_id' in item and item['workspace_id'] is not None:
item['workspace_id'] = str(item['workspace_id'])
return self._convert_timestamps_to_format(data_list)
async def pilot_run(self, payload: ConfigPilotRun) -> Dict[str, Any]:
"""
选择策略与内存覆写与同步版保持一致:优先 payload.config_id其次 dbrun.json两者皆无时报错。
支持 dialogue_text 参数用于试运行模式。
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
logger.info(f"[PILOT_RUN] Received dialogue_text length: {len(dialogue_text)}, preview: {dialogue_text[:100]}")
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量(在导入主管线前)
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# 导入并 await 主管线(使用当前 ASGI 事件循环)
from app.core.memory.main import main as pipeline_main
from app.core.memory.utils.self_reflexion_utils import reflexion
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
logger.info("[PILOT_RUN] pipeline_main completed")
# 调用自我反思
# data = [
# {
# "data": {
# "id": "1",
# "statement": "张明现在在谷歌工作。",
# "group_id": "1",
# "chunk_id": "10",
# "created_at": "2023-01-01",
# "expired_at": "2023-01-02",
# "valid_at": "2023-01-01",
# "invalid_at": "2023-01-02",
# "entity_ids": []
# },
# "conflict": True,
# "conflict_memory": {
# "id": "1",
# "statement": "张明现在在清华大学当讲师。",
# "group_id": "1",
# "chunk_id": "1",
# "created_at": "2019-12-01T19:15:05.213210",
# "expired_at": None,
# "valid_at": None,
# "invalid_at": None,
# "entity_ids": []
# }
# }
# ]
from app.core.memory.utils.config.get_example_data import get_example_data
data = get_example_data()
reflexion_result = await reflexion(data)
# 读取输出,使用全局配置路径
from app.core.config import settings
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
return {
"config_id": cid,
"time_log": os.path.join(project_root, "time.log"),
"extracted_result": extracted_result,
}
# -------------------- Neo4j Search & Analytics (fused from data_search_service.py) --------------------
# Ensure env for connector (e.g., NEO4J_PASSWORD)
load_dotenv()
_neo4j_connector = Neo4jConnector()
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DIALOGUE,
group_id=end_user_id,
)
data = {"search_for": "dialogue", "num": result[0]["num"]}
return data
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_CHUNK,
group_id=end_user_id,
)
data = {"search_for": "chunk", "num": result[0]["num"]}
return data
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_STATEMENT,
group_id=end_user_id,
)
data = {"search_for": "statement", "num": result[0]["num"]}
return data
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY,
group_id=end_user_id,
)
data = {"search_for": "entity", "num": result[0]["num"]}
return data
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
)
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
"total": 0,
"counts": {
"dialogue": 0,
"chunk": 0,
"statement": 0,
"entity": 0,
},
}
return data
data = {
"total": result[-1]["Count"],
"counts": {
"dialogue": result[0]["Count"],
"chunk": result[1]["Count"],
"statement": result[2]["Count"],
"entity": result[3]["Count"],
},
}
return data
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""统一知识库类型分布接口。
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
)
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
"total": 0,
"distribution": [
{"type": "dialogue", "count": 0},
{"type": "chunk", "count": 0},
{"type": "statement", "count": 0},
{"type": "entity", "count": 0},
]
}
return data
total = result[-1]["Count"]
distribution = [
{"type": "dialogue", "count": result[0]["Count"]},
{"type": "chunk", "count": result[1]["Count"]},
{"type": "statement", "count": result[2]["Count"]},
{"type": "entity", "count": result[3]["Count"]},
]
data = {"total": total, "distribution": distribution}
return data
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_DETIALS,
group_id=end_user_id,
)
return result
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_EDGES,
group_id=end_user_id,
)
return result
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""搜索所有实体之间的关系网络group 维度)。"""
result = await _neo4j_connector.execute_query(
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
group_id=end_user_id,
)
# 对source_node 和 target_node 的 fact_summary进行截取只截取前三条的内容需要提取前三条“来源”
for item in result:
source_fact = item["sourceNode"]["fact_summary"]
target_fact = item["targetNode"]["fact_summary"]
# 截取前三条“来源”
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
# 与现有返回风格保持一致,携带搜索类型、数量与详情
data = {
"search_for": "entity_graph",
"num": len(result),
"detials": result,
}
return data
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
"""
获取热门记忆标签按数量排序并返回前N个
"""
# 获取更多标签供LLM筛选获取limit*4个标签
raw_limit = limit * 4
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
# 只返回前limit个
top_tags = sorted_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags]
async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]:
insight = MemoryInsight(end_user_id)
report = await insight.generate_insight_report()
await insight.close()
data = {"report": report}
return data
async def analytics_recent_activity_stats() -> Dict[str, Any]:
stats, _msg = get_recent_activity_stats()
total = (
stats.get("chunk_count", 0)
+ stats.get("statements_count", 0)
+ stats.get("triplet_entities_count", 0)
+ stats.get("triplet_relations_count", 0)
+ stats.get("temporal_count", 0)
)
# 精简:仅提供“最新一次活动多久前”
latest_relative = None
try:
info = stats.get("log_path", "")
idx = info.rfind("最新:")
if idx != -1:
latest_path = info[idx + 3 :].strip()
if latest_path and os.path.exists(latest_path):
import time
diff = max(0.0, time.time() - os.path.getmtime(latest_path))
m = int(diff // 60)
if m < 1:
latest_relative = "刚刚"
elif m < 60:
latest_relative = f"{m}分钟前"
else:
h = int(m // 60)
latest_relative = f"{h}小时前" if h < 24 else f"{int(h // 24)}天前"
except Exception:
pass
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
return data
async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]:
summary = await generate_user_summary(end_user_id)
data = {"summary": summary}
return data

View File

@@ -0,0 +1,160 @@
"""
模型参数合并器
用于合并 ModelConfig 和 AgentConfig 中的模型参数,
AgentConfig 中的参数优先级更高,可以覆盖 ModelConfig 的默认参数。
"""
from typing import Dict, Any, Optional
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ModelParameterMerger:
"""模型参数合并器"""
@staticmethod
def merge_parameters(
model_config_params: Optional[Dict[str, Any]],
agent_config_params: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""
合并模型配置参数和 Agent 配置参数
优先级agent_config_params > model_config_params > 默认值
Args:
model_config_params: ModelConfig.config 中的参数
agent_config_params: AgentConfig.model_parameters 中的参数
Returns:
合并后的参数字典
Example:
>>> model_params = {"temperature": 0.5, "max_tokens": 1000}
>>> agent_params = {"temperature": 0.8}
>>> merged = ModelParameterMerger.merge_parameters(model_params, agent_params)
>>> merged
{"temperature": 0.8, "max_tokens": 1000}
"""
# 默认参数
default_params = {
"temperature": 0.7,
"max_tokens": 2000,
"top_p": 1.0,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"n": 1,
"stop": None
}
# 合并参数:默认值 -> 模型配置 -> Agent 配置
merged = default_params.copy()
# 应用模型配置参数
if model_config_params:
for key in default_params.keys():
if key in model_config_params:
merged[key] = model_config_params[key]
# 应用 Agent 配置参数(优先级最高)
if agent_config_params:
for key in default_params.keys():
if key in agent_config_params and agent_config_params[key] is not None:
merged[key] = agent_config_params[key]
# 移除 None 值
merged = {k: v for k, v in merged.items() if v is not None}
logger.debug(
f"参数合并完成",
extra={
"model_params": model_config_params,
"agent_params": agent_config_params,
"merged": merged
}
)
return merged
@staticmethod
def get_effective_parameters(
model_config: Optional[Any],
agent_config: Optional[Any]
) -> Dict[str, Any]:
"""
获取有效的模型参数(从 ORM 对象中提取并合并)
Args:
model_config: ModelConfig ORM 对象
agent_config: AgentConfig ORM 对象
Returns:
合并后的参数字典
"""
# 提取模型配置参数
model_params = None
if model_config and hasattr(model_config, 'config'):
model_params = model_config.config
# 提取 Agent 配置参数
agent_params = None
if agent_config and hasattr(agent_config, 'model_parameters'):
agent_params = agent_config.model_parameters
return ModelParameterMerger.merge_parameters(model_params, agent_params)
@staticmethod
def format_for_llm_call(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""
格式化参数用于 LLM API 调用
不同的 LLM 提供商可能需要不同的参数格式,
这个方法可以根据需要进行转换。
Args:
parameters: 合并后的参数字典
Returns:
格式化后的参数字典
"""
# 基本格式化(可以根据不同提供商扩展)
formatted = parameters.copy()
# 确保参数在有效范围内
if "temperature" in formatted:
formatted["temperature"] = max(0.0, min(2.0, formatted["temperature"]))
if "max_tokens" in formatted:
formatted["max_tokens"] = max(1, min(32000, formatted["max_tokens"]))
if "top_p" in formatted:
formatted["top_p"] = max(0.0, min(1.0, formatted["top_p"]))
if "frequency_penalty" in formatted:
formatted["frequency_penalty"] = max(-2.0, min(2.0, formatted["frequency_penalty"]))
if "presence_penalty" in formatted:
formatted["presence_penalty"] = max(-2.0, min(2.0, formatted["presence_penalty"]))
if "n" in formatted:
formatted["n"] = max(1, min(10, formatted["n"]))
return formatted
def merge_model_parameters(
model_config_params: Optional[Dict[str, Any]],
agent_config_params: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""
合并模型参数的便捷函数
Args:
model_config_params: ModelConfig.config 中的参数
agent_config_params: AgentConfig.model_parameters 中的参数
Returns:
合并后的参数字典
"""
return ModelParameterMerger.merge_parameters(model_config_params, agent_config_params)

View File

@@ -0,0 +1,409 @@
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import math
import time
import asyncio
from app.models.models_model import ModelConfig, ModelApiKey, ModelType
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository
from app.schemas import model_schema
from app.schemas.model_schema import (
ModelConfigCreate, ModelConfigUpdate, ModelApiKeyCreate, ModelApiKeyUpdate,
ModelConfigQuery, ModelStats
)
from app.core.logging_config import get_business_logger
from app.schemas.response_schema import PageData, PageMeta
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = get_business_logger()
class ModelConfigService:
"""模型配置服务"""
@staticmethod
def get_model_by_id(db: Session, model_id: uuid.UUID) -> ModelConfig:
"""根据ID获取模型配置"""
model = ModelConfigRepository.get_by_id(db, model_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def get_model_list(db: Session, query: ModelConfigQuery) -> PageData:
"""获取模型配置列表"""
models, total = ModelConfigRepository.get_list(db, query)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=query.page,
pagesize=query.pagesize,
total=total,
hasnext=query.page < pages
),
items=[model_schema.ModelConfig.model_validate(model) for model in models]
)
@staticmethod
def get_model_by_name(db: Session, name: str) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def search_models_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
"""按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, limit)
@staticmethod
async def validate_model_config(
db: Session,
*,
model_name: str,
provider: str,
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello"
) -> Dict[str, Any]:
"""验证模型配置是否有效
Args:
db: 数据库会话
model_name: 模型名称
provider: 提供商
api_key: API密钥
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
Returns:
Dict: 验证结果
"""
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
import traceback
try:
start_time = time.time()
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower()
if model_type_lower in ["llm", "chat"]:
# LLM/Chat 模型验证 - 统一使用字符串输入
llm = RedBearLLM(model_config, type=ModelType.LLM if model_type_lower == "llm" else ModelType.CHAT)
response = await llm.ainvoke(test_message)
elapsed_time = time.time() - start_time
content = response.content if hasattr(response, 'content') else str(response)
usage = None
if hasattr(response, 'usage_metadata'):
usage = {
"input_tokens": getattr(response.usage_metadata, 'input_tokens', 0),
"output_tokens": getattr(response.usage_metadata, 'output_tokens', 0),
"total_tokens": getattr(response.usage_metadata, 'total_tokens', 0)
}
return {
"valid": True,
"message": f"{model_type.upper()} 模型配置验证成功",
"response": content,
"elapsed_time": elapsed_time,
"usage": usage,
"error": None
}
elif model_type_lower == "embedding":
# Embedding 模型验证(在线程中运行同步方法)
embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"]
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Embedding 模型配置验证成功",
"response": f"成功生成 {len(vectors)} 个向量,维度: {len(vectors[0]) if vectors else 0}",
"elapsed_time": elapsed_time,
"usage": {
"input_tokens": len(test_message),
"vector_count": len(vectors),
"vector_dimension": len(vectors[0]) if vectors else 0
},
"error": None
}
elif model_type_lower == "rerank":
# Rerank 模型验证(在线程中运行同步方法)
rerank = RedBearRerank(model_config)
query = test_message
documents = ["这是第一个文档", "这是第二个文档", "这是第三个文档"]
results = await asyncio.to_thread(rerank.rerank, query=query, documents=documents, top_n=3)
elapsed_time = time.time() - start_time
return {
"valid": True,
"message": "Rerank 模型配置验证成功",
"response": f"成功对 {len(documents)} 个文档进行重排序,返回 top {len(results) if results else 0} 结果",
"elapsed_time": elapsed_time,
"usage": {
"query_length": len(query),
"document_count": len(documents),
"result_count": len(results) if results else 0
},
"error": None
}
else:
return {
"valid": False,
"message": "不支持的模型类型",
"response": None,
"elapsed_time": None,
"usage": None,
"error": f"不支持的模型类型: {model_type}"
}
except Exception as e:
# 提取详细的错误信息
error_message = str(e)
error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商)
error_message = "区域限制: 该模型在当前区域或国家/地区不可用,请检查提供商的服务区域限制"
elif "ValidationException" in error_type or "ValidationException" in error_message:
# 其他验证错误
if "access denied" in error_message.lower():
error_message = "访问被拒绝: 请检查 API 凭证和权限配置"
else:
error_message = f"验证失败: {error_message}"
elif "AuthenticationError" in error_type or "authentication" in error_message.lower():
error_message = "认证失败: API Key 无效或已过期"
elif "RateLimitError" in error_type or "rate limit" in error_message.lower():
error_message = "请求频率限制: 已超过 API 调用限制"
elif "InvalidRequestError" in error_type or "invalid request" in error_message.lower():
error_message = f"无效请求: {error_message}"
elif "model_copy" in error_message:
error_message = "模型消息格式错误: 请确保使用正确的模型类型LLM/Chat"
# 记录详细错误日志
logger.error(f"模型验证失败 - 类型: {error_type}, 模型: {model_name}, 提供商: {provider}")
logger.error(f"错误详情: {error_message}")
logger.debug(f"完整堆栈: {traceback.format_exc()}")
return {
"valid": False,
"message": f"{model_type.upper()} 模型配置验证失败",
"response": None,
"elapsed_time": None,
"usage": None,
"error": error_message,
"error_type": error_type
}
@staticmethod
async def create_model(db: Session, model_data: ModelConfigCreate) -> ModelConfig:
"""创建模型配置"""
# 检查名称是否已存在
if ModelConfigRepository.get_by_name(db, model_data.name):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证配置
if not model_data.skip_validation and model_data.api_keys:
api_key_data = model_data.api_keys
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type, # 传递模型类型
test_message="Hello"
)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
# 事务处理
api_key_data = model_data.api_keys
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
model = ModelConfigRepository.create(db, model_config_data)
db.flush() # 获取生成的 ID
if api_key_data:
api_key_create_schema = ModelApiKeyCreate(
model_config_id=model.id,
**api_key_data.dict()
)
ModelApiKeyRepository.create(db, api_key_create_schema)
db.commit()
db.refresh(model)
return model
@staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> ModelConfig:
"""更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data)
db.commit()
db.refresh(model)
return model
@staticmethod
def delete_model(db: Session, model_id: uuid.UUID) -> bool:
"""删除模型配置"""
if not ModelConfigRepository.get_by_id(db, model_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
success = ModelConfigRepository.delete(db, model_id)
db.commit()
return success
@staticmethod
def get_model_stats(db: Session) -> ModelStats:
"""获取模型统计信息"""
stats_data = ModelConfigRepository.get_stats(db)
return ModelStats(
total_models=stats_data["total_models"],
active_models=stats_data["active_models"],
llm_count=stats_data["llm_count"],
embedding_count=stats_data["embedding_count"],
rerank_count=stats_data["rerank_count"],
provider_stats=stats_data["provider_stats"]
)
class ModelApiKeyService:
"""模型API Key服务"""
@staticmethod
def get_api_key_by_id(db: Session, api_key_id: uuid.UUID) -> ModelApiKey:
"""根据ID获取API Key"""
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
return api_key
@staticmethod
def get_api_keys_by_model(db: Session, model_config_id: uuid.UUID, is_active: bool = True) -> List[ModelApiKey]:
"""根据模型配置ID获取API Key列表"""
if not ModelConfigRepository.get_by_id(db, model_config_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active)
@staticmethod
async def create_api_key(db: Session, api_key_data: ModelApiKeyCreate) -> ModelApiKey:
"""创建API Key"""
model_config = ModelConfigRepository.get_by_id(db, api_key_data.model_config_id)
if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type, # 传递模型类型
test_message="Hello"
)
print(validation_result)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.create(db, api_key_data)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
async def update_api_key(db: Session, api_key_id: uuid.UUID, api_key_data: ModelApiKeyUpdate) -> ModelApiKey:
"""更新API Key"""
existing_api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not existing_api_key:
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
# 获取关联的模型配置以获取模型类型
model_config = ModelConfigRepository.get_by_id(db, existing_api_key.model_config_id)
if not model_config:
raise BusinessException("关联的模型配置不存在", BizCode.MODEL_NOT_FOUND)
validation_result = await ModelConfigService.validate_model_config(
db=db,
model_name=api_key_data.model_name,
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type, # 传递模型类型
test_message="Hello"
)
print(validation_result)
if not validation_result["valid"]:
raise BusinessException(
f"模型配置验证失败: {validation_result['error']}",
BizCode.INVALID_PARAMETER
)
api_key = ModelApiKeyRepository.update(db, api_key_id, api_key_data)
db.commit()
db.refresh(api_key)
return api_key
@staticmethod
def delete_api_key(db: Session, api_key_id: uuid.UUID) -> bool:
"""删除API Key"""
if not ModelApiKeyRepository.get_by_id(db, api_key_id):
raise BusinessException("API Key不存在", BizCode.NOT_FOUND)
success = ModelApiKeyRepository.delete(db, api_key_id)
db.commit()
return success
@staticmethod
def get_available_api_key(db: Session, model_config_id: uuid.UUID) -> Optional[ModelApiKey]:
"""获取可用的API Key按优先级和负载均衡"""
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config_id, is_active=True)
if not api_keys:
return None
return min(api_keys, key=lambda x: int(x.usage_count or "0"))
@staticmethod
def record_api_key_usage(db: Session, api_key_id: uuid.UUID) -> bool:
"""记录API Key使用"""
success = ModelApiKeyRepository.update_usage(db, api_key_id)
if success:
db.commit()
return success

View File

@@ -0,0 +1,191 @@
"""
多智能体配置格式转换器
用于将 Pydantic 模型转换为数据库存储格式
"""
from typing import Dict, Any, Optional, List
import uuid
from app.schemas.multi_agent_schema import (
SubAgentConfig,
RoutingRule,
ExecutionConfig,
MultiAgentConfigCreate,
MultiAgentConfigUpdate,
)
class MultiAgentConfigConverter:
"""多智能体配置格式转换器"""
@staticmethod
def to_storage_format(config: MultiAgentConfigCreate | MultiAgentConfigUpdate) -> Dict[str, Any]:
"""
将配置对象转换为数据库存储格式
Args:
config: MultiAgentConfigCreate 或 MultiAgentConfigUpdate 对象
Returns:
包含数据库字段的字典
"""
result = {}
# 1. 子 Agent 配置
if hasattr(config, 'sub_agents') and config.sub_agents:
result["sub_agents"] = [
MultiAgentConfigConverter._convert_uuid_to_str(agent.model_dump())
for agent in config.sub_agents
]
# 2. 路由规则配置
if hasattr(config, 'routing_rules') and config.routing_rules:
result["routing_rules"] = [
MultiAgentConfigConverter._convert_uuid_to_str(rule.model_dump())
for rule in config.routing_rules
]
# 3. 执行配置
if hasattr(config, 'execution_config') and config.execution_config:
result["execution_config"] = MultiAgentConfigConverter._convert_uuid_to_str(
config.execution_config.model_dump()
)
return result
@staticmethod
def from_storage_format(
sub_agents: Optional[List[Dict[str, Any]]],
routing_rules: Optional[List[Dict[str, Any]]],
execution_config: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
Args:
sub_agents: 子 Agent 配置列表
routing_rules: 路由规则配置列表
execution_config: 执行配置
Returns:
包含 Pydantic 对象的字典
"""
result = {
"sub_agents": [],
"routing_rules": [],
"execution_config": None,
}
# 1. 解析子 Agent 配置
if sub_agents:
result["sub_agents"] = [
SubAgentConfig(**agent_data)
for agent_data in sub_agents
]
# 2. 解析路由规则配置
if routing_rules:
result["routing_rules"] = [
RoutingRule(**rule_data)
for rule_data in routing_rules
]
else:
# 提供默认的空路由规则
result["routing_rules"] = []
# 3. 解析执行配置
if execution_config:
result["execution_config"] = ExecutionConfig(**execution_config)
else:
# 提供默认的执行配置
result["execution_config"] = ExecutionConfig(
max_iterations=10,
timeout=300,
enable_parallel=False,
error_handling="stop"
)
return result
@staticmethod
def _convert_uuid_to_str(obj: Any) -> Any:
"""
递归转换对象中的所有 UUID 为字符串
Args:
obj: 要转换的对象dict, list, UUID 等)
Returns:
转换后的对象
"""
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return {k: MultiAgentConfigConverter._convert_uuid_to_str(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [MultiAgentConfigConverter._convert_uuid_to_str(item) for item in obj]
else:
return obj
@staticmethod
def enrich_with_published_configs(
sub_agents: List[Dict[str, Any]],
get_published_config_func
) -> List[Dict[str, Any]]:
"""
为子 Agent 配置添加发布的 config_id
Args:
sub_agents: 子 Agent 配置列表
get_published_config_func: 获取发布配置的函数
Returns:
增强后的子 Agent 配置列表
"""
enriched_agents = []
for agent in sub_agents:
agent_copy = agent.copy()
# 获取该 Agent 当前发布的配置
if 'agent_id' in agent:
try:
agent_id = uuid.UUID(agent['agent_id']) if isinstance(agent['agent_id'], str) else agent['agent_id']
published_config = get_published_config_func(agent_id)
if published_config:
agent_copy['published_config_id'] = str(published_config.get('id')) if isinstance(published_config, dict) else None
except Exception as e:
# 如果获取失败,记录但不中断
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.warning(f"获取 Agent {agent.get('agent_id')} 的发布配置失败: {e}")
enriched_agents.append(agent_copy)
return enriched_agents
@staticmethod
def create_default_template(app_id: uuid.UUID) -> Dict[str, Any]:
"""
创建默认的多智能体配置模板
Args:
app_id: 应用 ID
Returns:
默认配置模板
"""
return {
"app_id": str(app_id),
"master_agent_id": None,
"orchestration_mode": "sequential",
"sub_agents": [],
"routing_rules": [],
"execution_config": {
"max_iterations": 10,
"timeout": 300,
"enable_parallel": False,
"error_handling": "stop"
},
"aggregation_strategy": "last",
"is_active": False
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,630 @@
"""多 Agent 配置管理服务"""
import uuid
from typing import Optional, List, Tuple, Any
from sqlalchemy.orm import Session
from sqlalchemy import select, desc
from app.models import MultiAgentConfig, App, AgentConfig
from app.schemas.multi_agent_schema import (
MultiAgentConfigCreate,
MultiAgentConfigUpdate,
MultiAgentRunRequest
)
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.models import AppRelease
logger = get_business_logger()
def convert_uuids_to_str(obj: Any) -> Any:
"""递归转换对象中的所有 UUID 为字符串
Args:
obj: 要转换的对象dict, list, UUID 等)
Returns:
转换后的对象
"""
if isinstance(obj, uuid.UUID):
return str(obj)
elif isinstance(obj, dict):
return {k: convert_uuids_to_str(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_uuids_to_str(item) for item in obj]
else:
return obj
class MultiAgentService:
"""多 Agent 配置管理服务"""
def __init__(self, db: Session):
self.db = db
def create_config(
self,
app_id: uuid.UUID,
data: MultiAgentConfigCreate,
created_by: uuid.UUID
) -> MultiAgentConfig:
"""创建多 Agent 配置
Args:
app_id: 应用 ID
data: 配置数据
created_by: 创建者 ID
Returns:
多 Agent 配置
"""
# 1. 验证应用存在
app = self.db.get(App, app_id)
if not app:
raise ResourceNotFoundException("应用", str(app_id))
# 2. 检查是否已有有效配置
existing = self.db.scalars(
select(MultiAgentConfig)
.where(
MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True
)
.order_by(MultiAgentConfig.updated_at.desc())
).first()
if existing:
raise BusinessException("应用已有多 Agent 配置", BizCode.DUPLICATE_RESOURCE)
# 3. 验证主 Agent 存在
master_agent = self.db.get(AgentConfig, data.master_agent_id)
if not master_agent:
raise ResourceNotFoundException("主 Agent", str(data.master_agent_id))
# 4. 验证子 Agent 存在
for sub_agent in data.sub_agents:
agent = self.db.get(AgentConfig, sub_agent.agent_id)
if not agent:
raise ResourceNotFoundException("子 Agent", str(sub_agent.agent_id))
# 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化)
sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents]
routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
# 处理 execution_config可能是 None、字典或 Pydantic 模型)
if data.execution_config is None:
execution_config_data = {}
elif isinstance(data.execution_config, dict):
execution_config_data = convert_uuids_to_str(data.execution_config)
else:
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
config = MultiAgentConfig(
app_id=app_id,
master_agent_id=data.master_agent_id,
master_agent_name=data.master_agent_name,
orchestration_mode=data.orchestration_mode,
sub_agents=sub_agents_data,
routing_rules=routing_rules_data,
execution_config=execution_config_data,
aggregation_strategy=data.aggregation_strategy
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
logger.info(
f"创建多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id),
"mode": data.orchestration_mode,
"sub_agent_count": len(data.sub_agents)
}
)
return config
def get_config(self, app_id: uuid.UUID) -> Optional[MultiAgentConfig]:
"""获取多 Agent 配置
Args:
app_id: 应用 ID
Returns:
多 Agent 配置,如果不存在返回 None
"""
return self.db.scalars(
select(MultiAgentConfig)
.where(
MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True
)
.order_by(MultiAgentConfig.updated_at.desc())
).first()
def get_multi_agent_configs(self, app_id: uuid.UUID) -> Optional[dict]:
"""通过 app_id 获取最新有效的多智能体配置,并将 agent_id 转换为 app_id
Args:
app_id: 应用 ID
Returns:
转换后的配置字典,如果不存在返回 None
"""
config = self.get_config(app_id)
if not config:
return None
# 转换 master_agent_id (release_id) 为 app_id
master_release = self.db.get(AppRelease, config.master_agent_id)
master_app_id = master_release.app_id if master_release else config.master_agent_id
# 转换 sub_agents 中的 agent_id (release_id) 为 app_id
converted_sub_agents = []
for sub_agent in config.sub_agents:
sub_agent_copy = sub_agent.copy()
release_id = sub_agent.get("agent_id")
if release_id:
try:
release_id_uuid = uuid.UUID(release_id) if isinstance(release_id, str) else release_id
sub_release = self.db.get(AppRelease, release_id_uuid)
if sub_release:
sub_agent_copy["agent_id"] = str(sub_release.app_id)
except Exception as e:
logger.warning(f"转换 sub_agent agent_id 失败: {release_id}, 错误: {str(e)}")
converted_sub_agents.append(sub_agent_copy)
# 构建返回的配置字典
return {
"id": config.id,
"app_id": config.app_id,
"master_agent_id": master_app_id,
"master_agent_name": config.master_agent_name,
"orchestration_mode": config.orchestration_mode,
"sub_agents": converted_sub_agents,
"routing_rules": config.routing_rules,
"execution_config": config.execution_config,
"aggregation_strategy": config.aggregation_strategy,
"is_active": config.is_active,
"created_at": config.created_at,
"updated_at": config.updated_at
}
def get_published_config_by_agent_id(self, agent_id: uuid.UUID) -> Optional[dict]:
"""通过 agent_id 获取当前发布版本的完整配置
Args:
agent_id: Agent 配置 ID
Returns:
当前发布版本的配置字典,如果没有发布版本则返回 None
"""
from app.models import AppRelease
# 查询 Agent 配置
agent_config = self.db.get(AgentConfig, agent_id)
if not agent_config:
logger.warning(f"Agent 配置不存在: {agent_id}")
return None
# 获取关联的应用
app = self.db.get(App, agent_config.app_id)
if not app or not app.current_release_id:
logger.warning(f"应用未发布或不存在: app_id={agent_config.app_id}")
return None
# 获取当前发布版本
release = self.db.get(AppRelease, app.current_release_id)
if not release:
logger.warning(f"发布版本不存在: release_id={app.current_release_id}")
return None
# 从发布版本的 config 中获取完整配置
# config 是一个 JSON 对象,包含了发布时的配置快照
config_data = release.config
if config_data and isinstance(config_data, dict):
return config_data
return None
def get_published_by_agent_id(self, agent_id: uuid.UUID) -> Optional[AppRelease]:
"""通过 agent_id 获取当前发布版本的完整配置
Args:
agent_id: Agent 配置 ID
Returns:
当前发布版本的配置字典,如果没有发布版本则返回 None
"""
# 获取关联的应用
app = self.db.get(App, agent_id)
if not app or not app.current_release_id:
logger.warning(f"应用未发布或不存在: app_id={agent_id}")
return None
# 获取当前发布版本
release = self.db.get(AppRelease, app.current_release_id)
if not release:
logger.warning(f"发布版本不存在: release_id={app.current_release_id}")
return None
return release
def update_config(
self,
app_id: uuid.UUID,
data: MultiAgentConfigUpdate
) -> MultiAgentConfig:
"""更新多 Agent 配置
Args:
app_id: 应用 ID
data: 更新数据
Returns:
更新后的配置
"""
config = self.get_config(app_id)
if not config:
# 1. 验证应用存在
app = self.db.get(App, app_id)
if not app:
raise ResourceNotFoundException("应用", str(app_id))
# 2. 验证主 Agent 存在并获取发布版本 ID
master_app_release = self.get_published_by_agent_id(data.master_agent_id)
if not master_app_release:
raise ResourceNotFoundException("主 Agent 未发布或不存在", str(data.master_agent_id))
# 使用发布版本 ID
data.master_agent_id = master_app_release.id
# 3. 验证子 Agent 存在并获取发布版本 ID
for sub_agent in data.sub_agents:
agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id)
if not agent_app_release:
raise ResourceNotFoundException("子 Agent 未发布或不存在", str(sub_agent.agent_id))
# 使用发布版本 ID
sub_agent.agent_id = agent_app_release.id
# 5. 创建配置(转换 UUID 为字符串以支持 JSON 序列化)
sub_agents_data = [convert_uuids_to_str(sub_agent.model_dump()) for sub_agent in data.sub_agents]
# routing_rules_data = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
# 处理 execution_config可能是 None、字典或 Pydantic 模型)
if data.execution_config is None:
execution_config_data = {}
elif isinstance(data.execution_config, dict):
execution_config_data = convert_uuids_to_str(data.execution_config)
else:
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
config = MultiAgentConfig(
app_id=app_id,
master_agent_id=data.master_agent_id,
master_agent_name=data.master_agent_name,
orchestration_mode=data.orchestration_mode,
sub_agents=sub_agents_data,
# routing_rules=routing_rules_data,
execution_config=execution_config_data,
aggregation_strategy=data.aggregation_strategy
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
logger.info(
f"创建多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id),
"mode": data.orchestration_mode,
"sub_agent_count": len(data.sub_agents)
}
)
return config
# raise ResourceNotFoundException("多 Agent 配置", str(app_id))
# 更新字段
if data.master_agent_id is not None:
# 验证主 Agent 存在
# 3. 验证主 Agent 存在并获取发布配置
master_app_release = self.get_published_by_agent_id(data.master_agent_id)
if not master_app_release:
raise ResourceNotFoundException("主 Agent 未发布或", str(data.master_agent_id))
config.master_agent_id = master_app_release.id
if data.master_agent_name is not None:
config.master_agent_name = data.master_agent_name
if data.orchestration_mode is not None:
config.orchestration_mode = data.orchestration_mode
if data.sub_agents is not None:
# 验证子 Agent 存在,并获取其发布的 config_id
updated_sub_agents = []
for sub_agent in data.sub_agents:
agent_app_release = self.get_published_by_agent_id(sub_agent.agent_id)
if not agent_app_release:
raise ResourceNotFoundException("子 Agent 未发布或", str(sub_agent.agent_id))
sub_agent.agent_id = agent_app_release.id
sub_agent_dict = convert_uuids_to_str(sub_agent.model_dump())
updated_sub_agents.append(sub_agent_dict)
config.sub_agents = updated_sub_agents
# if data.routing_rules is not None:
# config.routing_rules = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
if data.execution_config is None:
execution_config_data = {}
elif isinstance(data.execution_config, dict):
execution_config_data = convert_uuids_to_str(data.execution_config)
else:
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
if data.aggregation_strategy is not None:
config.aggregation_strategy = data.aggregation_strategy
if data.is_active is not None:
config.is_active = data.is_active
self.db.commit()
self.db.refresh(config)
logger.info(
f"更新多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id)
}
)
return config
def delete_config(self, app_id: uuid.UUID) -> None:
"""删除多 Agent 配置
Args:
app_id: 应用 ID
"""
config = self.get_config(app_id)
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
self.db.delete(config)
self.db.commit()
logger.info(
f"删除多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id)
}
)
async def run(
self,
app_id: uuid.UUID,
request: MultiAgentRunRequest
) -> dict:
"""运行多 Agent 任务
Args:
app_id: 应用 ID
request: 运行请求
Returns:
执行结果
"""
# 1. 获取配置
config = self.get_config(app_id)
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
if not config.is_active:
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 执行任务
result = await orchestrator.execute(
message=request.message,
conversation_id=request.conversation_id,
user_id=request.user_id,
variables=request.variables,
use_llm_routing=getattr(request, 'use_llm_routing', True), # 默认启用 LLM 路由
web_search=getattr(request, 'web_search', False), # 网络搜索参数
memory=getattr(request, 'memory', True) # 记忆功能参数
)
return result
async def run_stream(
self,
app_id: uuid.UUID,
request: MultiAgentRunRequest,
storage_type :str,
user_rag_memory_id :str
):
"""运行多 Agent 任务(流式返回)
Args:
app_id: 应用 ID
request: 运行请求
Yields:
SSE 格式的事件流
"""
# 1. 获取配置
config = self.get_config(app_id)
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
if not config.is_active:
raise BusinessException("多 Agent 配置已禁用", BizCode.RESOURCE_DISABLED)
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=request.message,
conversation_id=request.conversation_id,
user_id=request.user_id,
variables=request.variables,
use_llm_routing=getattr(request, 'use_llm_routing', True),
web_search=getattr(request, 'web_search', False), # 网络搜索参数
memory=getattr(request, 'memory', True) , # 记忆功能参数
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
def add_sub_agent(
self,
app_id: uuid.UUID,
agent_id: uuid.UUID,
name: str,
role: Optional[str] = None,
priority: int = 1,
capabilities: Optional[List[str]] = None
) -> MultiAgentConfig:
"""添加子 Agent
Args:
app_id: 应用 ID
agent_id: Agent ID
name: Agent 名称
role: 角色描述
priority: 优先级
capabilities: 能力列表
Returns:
更新后的配置
"""
config = self.get_config(app_id)
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
# 验证 Agent 存在
agent = self.db.get(AgentConfig, agent_id)
if not agent:
raise ResourceNotFoundException("Agent", str(agent_id))
# 检查是否已存在
for sub_agent in config.sub_agents:
if sub_agent["agent_id"] == str(agent_id):
raise BusinessException("Agent 已存在于配置中", BizCode.DUPLICATE_RESOURCE)
# 添加子 Agent
new_sub_agent = {
"agent_id": str(agent_id),
"name": name,
"role": role,
"priority": priority,
"capabilities": capabilities or []
}
config.sub_agents.append(new_sub_agent)
# 标记为已修改
self.db.add(config)
self.db.commit()
self.db.refresh(config)
logger.info(
f"添加子 Agent 成功",
extra={
"config_id": str(config.id),
"agent_id": str(agent_id),
"agent_name": name
}
)
return config
def remove_sub_agent(
self,
app_id: uuid.UUID,
agent_id: uuid.UUID
) -> MultiAgentConfig:
"""移除子 Agent
Args:
app_id: 应用 ID
agent_id: Agent ID
Returns:
更新后的配置
"""
config = self.get_config(app_id)
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
# 查找并移除
original_count = len(config.sub_agents)
config.sub_agents = [
sub_agent for sub_agent in config.sub_agents
if sub_agent["agent_id"] != str(agent_id)
]
if len(config.sub_agents) == original_count:
raise ResourceNotFoundException("子 Agent", str(agent_id))
# 标记为已修改
self.db.add(config)
self.db.commit()
self.db.refresh(config)
logger.info(
f"移除子 Agent 成功",
extra={
"config_id": str(config.id),
"agent_id": str(agent_id)
}
)
return config
def list_configs(
self,
workspace_id: uuid.UUID,
page: int = 1,
pagesize: int = 20
) -> Tuple[List[MultiAgentConfig], int]:
"""列出多 Agent 配置
Args:
workspace_id: 工作空间 ID
page: 页码
pagesize: 每页数量
Returns:
配置列表和总数
"""
# 构建查询
stmt = (
select(MultiAgentConfig)
.join(App)
.where(App.workspace_id == workspace_id)
.order_by(desc(MultiAgentConfig.created_at))
)
# 总数
count_stmt = stmt.with_only_columns(MultiAgentConfig.id)
total = len(self.db.execute(count_stmt).all())
# 分页
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
configs = list(self.db.scalars(stmt).all())
return configs, total

View File

@@ -0,0 +1,444 @@
import uuid
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from sqlalchemy import select
from app.models import ReleaseShare, AppRelease, App, AgentConfig
from app.repositories.release_share_repository import ReleaseShareRepository
from app.core.share_utils import (
generate_share_token,
hash_password,
verify_password,
build_share_url,
generate_embed_code
)
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.schemas import release_share_schema
logger = get_business_logger()
class ReleaseShareService:
"""发布版本分享服务"""
def __init__(self, db: Session):
self.db = db
self.repo = ReleaseShareRepository(db)
def create_or_update_share(
self,
release_id: uuid.UUID,
user_id: uuid.UUID,
workspace_id: uuid.UUID,
data: release_share_schema.ReleaseShareCreate,
base_url: Optional[str] = None
) -> ReleaseShare:
"""创建或更新分享配置
Args:
release_id: 发布版本 ID
user_id: 用户 ID
workspace_id: 工作空间 ID
data: 分享配置数据
base_url: 基础 URL用于生成完整的分享链接
Returns:
分享配置
"""
# 验证发布版本存在且属于该工作空间
release = self._get_release_or_404(release_id)
self._validate_release_access(release, workspace_id)
# 检查是否已存在分享配置
existing_share = self.repo.get_by_release_id(release_id)
if existing_share:
# 更新现有配置
return self._update_share_internal(existing_share, data)
else:
# 创建新配置
return self._create_share_internal(release, user_id, data)
def _create_share_internal(
self,
release: AppRelease,
user_id: uuid.UUID,
data: release_share_schema.ReleaseShareCreate
) -> ReleaseShare:
"""内部方法:创建分享配置"""
# 生成唯一的 share_token
share_token = self._generate_unique_token()
# 处理密码
password_hash = None
if data.require_password and data.password:
password_hash = hash_password(data.password)
# 创建分享配置
share = ReleaseShare(
release_id=release.id,
app_id=release.app_id,
is_enabled=data.is_enabled,
share_token=share_token,
require_password=data.require_password,
password_hash=password_hash,
allow_embed=data.allow_embed,
embed_domains=data.embed_domains or [],
created_by=user_id
)
share = self.repo.create(share)
logger.info(
f"创建分享配置",
extra={
"share_id": str(share.id),
"release_id": str(release.id),
"app_id": str(release.app_id),
"share_token": share_token
}
)
return share
def _update_share_internal(
self,
share: ReleaseShare,
data: release_share_schema.ReleaseShareUpdate
) -> ReleaseShare:
"""内部方法:更新分享配置"""
if data.is_enabled is not None:
share.is_enabled = data.is_enabled
if data.require_password is not None:
share.require_password = data.require_password
if data.password is not None:
if data.password:
share.password_hash = hash_password(data.password)
else:
share.password_hash = None
if data.allow_embed is not None:
share.allow_embed = data.allow_embed
if data.embed_domains is not None:
share.embed_domains = data.embed_domains or []
share = self.repo.update(share)
logger.info(
f"更新分享配置",
extra={
"share_id": str(share.id),
"release_id": str(share.release_id)
}
)
return share
def update_share(
self,
release_id: uuid.UUID,
workspace_id: uuid.UUID,
data: release_share_schema.ReleaseShareUpdate
) -> ReleaseShare:
"""更新分享配置
Args:
release_id: 发布版本 ID
workspace_id: 工作空间 ID
data: 更新数据
Returns:
更新后的分享配置
"""
# 验证发布版本
release = self._get_release_or_404(release_id)
self._validate_release_access(release, workspace_id)
# 获取分享配置
share = self.repo.get_by_release_id(release_id)
if not share:
raise ResourceNotFoundException("分享配置", str(release_id))
return self._update_share_internal(share, data)
def get_share(
self,
release_id: uuid.UUID,
workspace_id: uuid.UUID,
base_url: Optional[str] = None
) -> Optional[release_share_schema.ReleaseShare]:
"""获取分享配置
Args:
release_id: 发布版本 ID
workspace_id: 工作空间 ID
base_url: 基础 URL
Returns:
分享配置 Schema
"""
# 验证发布版本
release = self._get_release_or_404(release_id)
self._validate_release_access(release, workspace_id)
share = self.repo.get_by_release_id(release_id)
if not share:
return None
return self._convert_to_schema(share, base_url)
def delete_share(
self,
release_id: uuid.UUID,
workspace_id: uuid.UUID
) -> None:
"""删除(禁用)分享配置
Args:
release_id: 发布版本 ID
workspace_id: 工作空间 ID
"""
# 验证发布版本
release = self._get_release_or_404(release_id)
self._validate_release_access(release, workspace_id)
share = self.repo.get_by_release_id(release_id)
if not share:
raise ResourceNotFoundException("分享配置", str(release_id))
self.repo.delete(share)
logger.info(
f"删除分享配置",
extra={
"share_id": str(share.id),
"release_id": str(release_id)
}
)
def regenerate_token(
self,
release_id: uuid.UUID,
workspace_id: uuid.UUID
) -> ReleaseShare:
"""重新生成分享 token旧链接失效
Args:
release_id: 发布版本 ID
workspace_id: 工作空间 ID
Returns:
更新后的分享配置
"""
# 验证发布版本
release = self._get_release_or_404(release_id)
self._validate_release_access(release, workspace_id)
share = self.repo.get_by_release_id(release_id)
if not share:
raise ResourceNotFoundException("分享配置", str(release_id))
# 生成新 token
old_token = share.share_token
share.share_token = self._generate_unique_token()
share = self.repo.update(share)
logger.info(
f"重新生成分享 token",
extra={
"share_id": str(share.id),
"old_token": old_token,
"new_token": share.share_token
}
)
return share
def get_shared_release_info(
self,
share_token: str,
password: Optional[str] = None
) -> release_share_schema.SharedReleaseInfo:
"""获取公开分享的发布版本信息
Args:
share_token: 分享 token
password: 访问密码(如果需要)
Returns:
分享的发布版本信息
"""
# 获取分享配置
share = self.repo.get_by_share_token(share_token)
if not share:
raise ResourceNotFoundException("分享链接", share_token)
# 检查是否启用
if not share.is_enabled:
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
# 验证密码
is_password_verified = False
if share.require_password:
if not password:
# 需要密码但未提供,返回基本信息
release = self.db.get(AppRelease, share.release_id)
return release_share_schema.SharedReleaseInfo(
app_name=release.name,
app_description=release.description,
app_icon=release.icon,
app_type=release.type,
version=release.version,
release_notes=release.release_notes,
published_at=int(release.published_at.timestamp() * 1000),
config={},
require_password=True,
is_password_verified=False,
allow_embed=share.allow_embed
)
# 验证密码
if not share.password_hash or not verify_password(password, share.password_hash):
raise BusinessException("密码错误", BizCode.INVALID_PASSWORD)
is_password_verified = True
# 获取发布版本详细信息
release = self.db.get(AppRelease, share.release_id)
if not release:
raise ResourceNotFoundException("发布版本", str(share.release_id))
# 异步更新访问统计(不阻塞响应)
try:
self.repo.increment_view_count(share.id)
except Exception as e:
logger.warning(f"更新访问统计失败: {str(e)}")
# 返回完整信息
return release_share_schema.SharedReleaseInfo(
app_name=release.name,
app_description=release.description,
app_icon=release.icon,
app_type=release.type,
version=release.version,
release_notes=release.release_notes,
published_at=int(release.published_at.timestamp() * 1000),
config=release.config or {},
require_password=share.require_password,
is_password_verified=is_password_verified,
allow_embed=share.allow_embed
)
def verify_password(
self,
share_token: str,
password: str
) -> bool:
"""验证分享密码
Args:
share_token: 分享 token
password: 密码
Returns:
是否验证成功
"""
share = self.repo.get_by_share_token(share_token)
if not share:
raise ResourceNotFoundException("分享链接", share_token)
if not share.is_enabled:
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
if not share.require_password:
return True
if not share.password_hash:
return False
return verify_password(password, share.password_hash)
def get_embed_code(
self,
share_token: str,
width: str = "100%",
height: str = "600px",
base_url: Optional[str] = None
) -> release_share_schema.EmbedCode:
"""获取嵌入代码
Args:
share_token: 分享 token
width: 宽度
height: 高度
base_url: 基础 URL
Returns:
嵌入代码
"""
share = self.repo.get_by_share_token(share_token)
if not share:
raise ResourceNotFoundException("分享链接", share_token)
if not share.is_enabled:
raise BusinessException("该分享链接已禁用", BizCode.SHARE_DISABLED)
if not share.allow_embed:
raise BusinessException("该分享不允许嵌入", BizCode.EMBED_NOT_ALLOWED)
embed_data = generate_embed_code(share_token, width, height, base_url)
return release_share_schema.EmbedCode(**embed_data)
def _generate_unique_token(self, max_attempts: int = 10) -> str:
"""生成唯一的分享 token"""
for _ in range(max_attempts):
token = generate_share_token()
if not self.repo.token_exists(token):
return token
raise BusinessException("生成唯一 token 失败,请重试", BizCode.INTERNAL_ERROR)
def _get_release_or_404(self, release_id: uuid.UUID) -> AppRelease:
"""获取发布版本或抛出 404"""
release = self.db.get(AppRelease, release_id)
if not release:
raise ResourceNotFoundException("发布版本", str(release_id))
return release
def _validate_release_access(self, release: AppRelease, workspace_id: uuid.UUID) -> None:
"""验证发布版本访问权限"""
app = self.db.get(App, release.app_id)
if not app:
raise ResourceNotFoundException("应用", str(release.app_id))
if app.workspace_id != workspace_id:
raise BusinessException("无权访问该发布版本", BizCode.PERMISSION_DENIED)
def _convert_to_schema(
self,
share: ReleaseShare,
base_url: Optional[str] = None
) -> release_share_schema.ReleaseShare:
"""转换为 Schema"""
share_url = build_share_url(share.share_token, base_url)
return release_share_schema.ReleaseShare(
id=share.id,
release_id=share.release_id,
app_id=share.app_id,
is_enabled=share.is_enabled,
share_token=share.share_token,
share_url=share_url,
require_password=share.require_password,
allow_embed=share.allow_embed,
embed_domains=share.embed_domains or [],
view_count=share.view_count,
last_accessed_at=share.last_accessed_at,
created_at=share.created_at,
updated_at=share.updated_at
)

View File

@@ -0,0 +1,160 @@
from typing import Optional
import json
from datetime import datetime, timedelta, timezone
from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete
from app.core.config import settings
class SessionService:
"""用户会话管理服务"""
@staticmethod
def _get_user_session_key(username: str) -> str:
"""获取用户会话的Redis键"""
return f"user_session:{username}"
@staticmethod
def _get_token_blacklist_key(token_id: str) -> str:
"""获取token黑名单的Redis键"""
return f"token_blacklist:{token_id}"
@staticmethod
async def set_user_active_session(username: str, token_id: str, expires_at: datetime) -> None:
"""设置用户的活跃会话
Args:
username: 用户名
token_id: token的唯一标识
expires_at: token过期时间
"""
if not settings.ENABLE_SINGLE_SESSION:
return
session_key = SessionService._get_user_session_key(username)
session_data = {
"token_id": token_id,
"created_at": datetime.now(timezone.utc).isoformat(),
"expires_at": expires_at.isoformat()
}
# 计算过期时间(秒)
expire_seconds = int((expires_at - datetime.now(timezone.utc)).total_seconds())
if expire_seconds > 0:
await aio_redis_set(session_key, session_data, expire=expire_seconds)
@staticmethod
async def get_user_active_session(username: str) -> Optional[dict]:
"""获取用户的活跃会话
Args:
username: 用户名
Returns:
会话数据字典或None
"""
if not settings.ENABLE_SINGLE_SESSION:
return None
session_key = SessionService._get_user_session_key(username)
session_data = await aio_redis_get(session_key)
if session_data:
try:
return json.loads(session_data) if isinstance(session_data, str) else session_data
except json.JSONDecodeError:
return None
return None
@staticmethod
async def invalidate_old_session(username: str, new_token_id: str) -> None:
"""使旧会话失效
Args:
username: 用户名
new_token_id: 新token的ID
"""
if not settings.ENABLE_SINGLE_SESSION:
return
# 获取当前活跃会话
current_session = await SessionService.get_user_active_session(username)
if current_session and current_session.get("token_id") != new_token_id:
# 将旧token加入黑名单
old_token_id = current_session.get("token_id")
if old_token_id:
await SessionService.blacklist_token(old_token_id)
@staticmethod
async def blacklist_token(token_id: str, expire_seconds: int = None) -> None:
"""将token加入黑名单
Args:
token_id: token的唯一标识
expire_seconds: 黑名单过期时间默认为refresh token的过期时间
"""
if expire_seconds is None:
# 默认使用refresh token的过期时间
expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
blacklist_key = SessionService._get_token_blacklist_key(token_id)
await aio_redis_set(blacklist_key, "blacklisted", expire=expire_seconds)
@staticmethod
async def is_token_blacklisted(token_id: str) -> bool:
"""检查token是否在黑名单中
Args:
token_id: token的唯一标识
Returns:
True如果token在黑名单中否则False
"""
if not settings.ENABLE_SINGLE_SESSION:
return False
blacklist_key = SessionService._get_token_blacklist_key(token_id)
result = await aio_redis_get(blacklist_key)
return result is not None
@staticmethod
async def clear_user_session(username: str) -> None:
"""清除用户会话
Args:
username: 用户名
"""
session_key = SessionService._get_user_session_key(username)
await aio_redis_delete(session_key)
@staticmethod
async def invalidate_all_user_tokens(user_id: str) -> None:
"""使用户的所有 tokens 失效(用于密码重置等场景)
通过在 Redis 中设置一个用户级别的失效标记来实现。
所有在此时间点之前签发的 tokens 都将被视为无效。
Args:
user_id: 用户ID
"""
invalidation_key = f"user_token_invalidation:{user_id}"
current_time = datetime.now(timezone.utc).isoformat()
# 设置失效时间戳,过期时间为 refresh token 的最大有效期
expire_seconds = settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 60 * 60
await aio_redis_set(invalidation_key, current_time, expire=expire_seconds)
@staticmethod
async def get_user_token_invalidation_time(user_id: str) -> Optional[str]:
"""获取用户 token 失效时间
Args:
user_id: 用户ID
Returns:
失效时间的 ISO 格式字符串,如果没有失效记录则返回 None
"""
invalidation_key = f"user_token_invalidation:{user_id}"
result = await aio_redis_get(invalidation_key)
return result if result else None

View File

@@ -0,0 +1,759 @@
"""基于分享链接的聊天服务"""
import uuid
import time
import asyncio
from typing import Optional, Dict, Any, AsyncGenerator
from sqlalchemy.orm import Session
from app.models import ReleaseShare, AppRelease, Conversation
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
from app.repositories import knowledge_repository
import json
logger = get_business_logger()
class SharedChatService:
"""基于分享链接的聊天服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
self.share_service = ReleaseShareService(db)
def _get_release_by_share_token(
self,
share_token: str,
password: Optional[str] = None
) -> tuple[ReleaseShare, AppRelease]:
"""通过 share_token 获取发布版本"""
# 获取分享配置
share = self.share_service.repo.get_by_share_token(share_token)
if not share:
raise ResourceNotFoundException("分享链接", share_token)
# 验证分享是否启用
if not share.is_enabled:
raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED)
# 验证密码
if share.require_password:
if not password:
raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED)
if not self.share_service.verify_password(share_token, password):
raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD)
# 获取发布版本
release = self.db.get(AppRelease, share.release_id)
if not release:
raise ResourceNotFoundException("发布版本", str(share.release_id))
# 更新访问统计
try:
self.share_service.repo.increment_view_count(share.id)
except Exception as e:
logger.warning(f"更新访问统计失败: {str(e)}")
return share, release
def create_or_get_conversation(
self,
share_token: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
password: Optional[str] = None
) -> Conversation:
"""创建或获取会话"""
share, release = self._get_release_by_share_token(share_token, password)
# 如果提供了 conversation_id尝试获取现有会话
if conversation_id:
try:
conversation = self.conversation_service.get_conversation(
conversation_id=conversation_id,
workspace_id=release.app.workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
except ResourceNotFoundException:
logger.warning(
f"会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)}
)
# 创建新会话(使用发布版本的配置)
conversation = self.conversation_service.create_conversation(
app_id=release.app_id,
workspace_id=release.app.workspace_id,
user_id=user_id,
is_draft=False, # 分享链接使用发布版本
config_snapshot=release.config
)
logger.info(
f"为分享链接创建新会话",
extra={
"conversation_id": str(conversation.id),
"share_token": share_token,
"release_id": str(release.id)
}
)
return conversation
async def chat(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
) -> Dict[str, Any]:
"""聊天(非流式)"""
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time()
if variables is None:
variables = {}
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
# 获取 Agent 配置
config = release.config or {}
# 获取模型配置ID
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key
stmt = (
select(ModelApiKey)
.where(
ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True
)
.order_by(ModelApiKey.priority.desc())
.limit(1)
)
api_key_obj = self.db.scalars(stmt).first()
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
conversation_id=conversation_id,
user_id=user_id,
password=password
)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
tools.append(kb_tool)
# 添加长期记忆工具
if memory==True:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools=config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled",False)
if web_search==True:
if web_search_enable==True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 加载历史消息
history = []
memory_config={"enabled":True,'max_history':10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation.id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 调用 Agent
result = await agent.chat(
message=message,
history=history,
context=None,
end_user_id=user_id
)
# 保存消息
self.conversation_service.save_conversation_messages(
conversation_id=conversation.id,
user_message=message,
assistant_message=result["content"]
)
# self.conversation_service.add_message(
# conversation_id=conversation.id,
# role="user",
# content=message
# )
# self.conversation_service.add_message(
# conversation_id=conversation.id,
# role="assistant",
# content=result["content"],
# meta_data={
# "model": api_key_obj.model_name,
# "usage": result.get("usage", {})
# }
# )
elapsed_time = time.time() - start_time
return {
"conversation_id": conversation.id,
"message": result["content"],
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time
}
async def chat_stream(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json
start_time = time.time()
if variables is None:
variables = {}
memory_config = {"enabled": memory, "memory_content": "17", "max_history": 10}
try:
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
# 获取 Agent 配置
config = release.config or {}
agent_config_data = config.get("agent_config", {})
# 获取模型配置ID
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key
stmt = (
select(ModelApiKey)
.where(
ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True
)
.order_by(ModelApiKey.priority.desc())
.limit(1)
)
api_key_obj = self.db.scalars(stmt).first()
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
conversation_id=conversation_id,
user_id=user_id,
password=password
)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
tools.append(kb_tool)
# 添加长期记忆工具
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation.id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent
full_content = ""
async for chunk in agent.chat_stream(
message=message,
history=history,
context=None,
end_user_id=user_id
):
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
content=full_content,
meta_data={
"model": api_key_obj.model_name,
"usage": {}
}
)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
f"流式聊天完成",
extra={
"conversation_id": str(conversation.id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断")
raise
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_conversation_messages(
self,
share_token: str,
conversation_id: uuid.UUID,
password: Optional[str] = None
) -> Conversation:
"""获取会话消息"""
share, release = self._get_release_by_share_token(share_token, password)
# 获取会话
conversation = self.conversation_service.get_conversation(
conversation_id=conversation_id,
workspace_id=release.app.workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
def list_conversations(
self,
share_token: str,
user_id: Optional[str] = None,
password: Optional[str] = None,
page: int = 1,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""列出会话"""
share, release = self._get_release_by_share_token(share_token, password)
conversations, total = self.conversation_service.list_conversations(
app_id=release.app_id,
workspace_id=release.app.workspace_id,
user_id=user_id,
is_draft=False, # 只显示发布版本的会话
page=page,
pagesize=pagesize
)
return conversations, total
async def multi_agent_chat(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
) -> Dict[str, Any]:
"""多 Agent 聊天(非流式)"""
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
start_time = time.time()
if variables is None:
variables = {}
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
conversation_id=conversation_id,
user_id=user_id,
password=password
)
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True
).first()
if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=message,
conversation_id=conversation.id,
user_id=user_id,
variables=variables,
use_llm_routing=True,
web_search=web_search,
memory=memory
)
# 使用多 Agent 服务执行
multi_agent_service = MultiAgentService(self.db)
result = await multi_agent_service.run(
app_id=release.app_id,
request=multi_agent_request
)
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
content=result.get("message", ""),
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"sub_results": result.get("sub_results")
}
)
return {
"conversation_id": conversation.id,
"message": result.get("message", ""),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
},
"elapsed_time": elapsed_time
}
async def multi_agent_chat_stream(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)"""
start_time = time.time()
if variables is None:
variables = {}
try:
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
conversation_id=conversation_id,
user_id=user_id,
password=password
)
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True
).first()
if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 获取 storage_type 和 user_rag_memory_id
workspace_id = release.app.workspace_id
storage_type = 'neo4j' # 默认值
user_rag_memory_id = ''
try:
# 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享)
from app.models import Workspace
workspace = self.db.get(Workspace, workspace_id)
if workspace and workspace.storage_type:
storage_type = workspace.storage_type
# 获取 USER_RAG_MERORY 知识库 ID
knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db,
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
except Exception as e:
logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=message,
conversation_id=conversation.id,
user_id=user_id,
variables=variables,
use_llm_routing=True,
web_search=web_search,
memory=memory
)
# 使用多 Agent 服务流式执行
multi_agent_service = MultiAgentService(self.db)
full_content = ""
async for event in multi_agent_service.run_stream(
app_id=release.app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
# 直接转发事件
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
content=full_content,
meta_data={
"elapsed_time": elapsed_time
}
)
logger.info(
f"多 Agent 流式聊天完成",
extra={
"conversation_id": str(conversation.id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")
raise
except Exception as e:
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"

View File

@@ -0,0 +1,426 @@
"""智能路由器 - 解决多轮对话路由错乱"""
import re
from typing import Dict, Any, List, Optional, Tuple
from app.services.conversation_state_manager import ConversationStateManager
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class SmartRouter:
"""智能路由器
核心功能:
1. 检测主题切换
2. 判断是否应该继续使用当前 Agent
3. 智能选择最合适的 Agent
4. 支持强制重新路由
"""
# 主题切换信号
SWITCH_SIGNALS = [
"换个话题", "另外", "还有", "对了",
"那这个呢", "再问一个", "顺便问下",
"我想问", "帮我", "请问", "换一个"
]
# 延续信号
CONTINUATION_SIGNALS = [
"继续", "还是", "", "同样", "类似",
"这个", "那个", "", "", "", ""
]
def __init__(
self,
state_manager: ConversationStateManager,
routing_rules: List[Dict[str, Any]],
sub_agents: Dict[str, Any]
):
"""初始化智能路由器
Args:
state_manager: 会话状态管理器
routing_rules: 路由规则列表
sub_agents: 子 Agent 配置字典
"""
self.state_manager = state_manager
self.routing_rules = routing_rules
self.sub_agents = sub_agents
# 配置参数
self.min_confidence_for_switch = 0.7 # 切换 Agent 的最小置信度
self.max_same_agent_turns = 10 # 同一 Agent 最大连续轮数
async def route(
self,
message: str,
conversation_id: Optional[str] = None,
force_new: bool = False
) -> Dict[str, Any]:
"""智能路由
Args:
message: 用户消息
conversation_id: 会话 ID
force_new: 是否强制重新路由(忽略历史)
Returns:
路由结果 {
"agent_id": str,
"confidence": float,
"strategy": str,
"topic": str,
"topic_changed": bool,
"reason": str
}
"""
logger.info(
f"开始智能路由",
extra={
"message_length": len(message),
"conversation_id": conversation_id,
"force_new": force_new
}
)
# 1. 获取会话状态
state = None
if conversation_id and not force_new:
state = self.state_manager.get_state(conversation_id)
# 2. 检测主题切换
topic_changed = self._detect_topic_change(message, state)
# 3. 提取当前主题
topic = self._extract_topic(message)
# 4. 选择路由策略
if force_new:
# 强制重新路由
agent_id, confidence = self._route_from_scratch(message)
strategy = "force_new"
reason = "用户强制重新路由"
elif not state or not state.get("current_agent_id"):
# 新会话,从头路由
agent_id, confidence = self._route_from_scratch(message)
strategy = "new_conversation"
reason = "新会话,首次路由"
elif topic_changed:
# 主题切换,重新路由
agent_id, confidence = self._route_from_scratch(message)
strategy = "topic_changed"
reason = f"检测到主题切换: {state.get('last_topic')} -> {topic}"
elif state.get("same_agent_turns", 0) >= self.max_same_agent_turns:
# 同一 Agent 使用太久,强制重新评估
agent_id, confidence = self._route_from_scratch(message)
strategy = "max_turns_reached"
reason = f"同一 Agent 已使用 {state['same_agent_turns']}"
else:
# 检查是否应该继续使用当前 Agent
current_agent_id = state["current_agent_id"]
should_continue, continue_confidence = self._should_continue_current_agent(
message,
current_agent_id
)
if should_continue:
# 继续使用当前 Agent
agent_id = current_agent_id
confidence = continue_confidence
strategy = "continue_current"
reason = "消息在当前 Agent 能力范围内"
else:
# 重新路由
new_agent_id, new_confidence = self._route_from_scratch(message)
# 只有新 Agent 的置信度明显更高时才切换
if new_confidence > continue_confidence + self.min_confidence_for_switch:
agent_id = new_agent_id
confidence = new_confidence
strategy = "switch_agent"
reason = f"新 Agent 置信度更高: {new_confidence:.2f} vs {continue_confidence:.2f}"
else:
# 置信度差距不大,继续使用当前 Agent
agent_id = current_agent_id
confidence = continue_confidence
strategy = "keep_current"
reason = "置信度差距不足以切换 Agent"
# 5. 更新会话状态
if conversation_id:
self.state_manager.update_state(
conversation_id,
agent_id,
message,
topic,
confidence
)
result = {
"agent_id": agent_id,
"confidence": confidence,
"strategy": strategy,
"topic": topic,
"topic_changed": topic_changed,
"reason": reason
}
logger.info(
f"路由完成",
extra={
"agent_id": agent_id,
"strategy": strategy,
"confidence": confidence,
"topic": topic
}
)
return result
def _detect_topic_change(
self,
message: str,
state: Optional[Dict[str, Any]]
) -> bool:
"""检测主题是否切换
Args:
message: 用户消息
state: 会话状态
Returns:
是否切换主题
"""
if not state or not state.get("last_topic"):
return False
# 检查明确的切换信号
for signal in self.SWITCH_SIGNALS:
if signal in message:
logger.info(f"检测到主题切换信号: {signal}")
return True
# 比较主题
current_topic = self._extract_topic(message)
last_topic = state.get("last_topic")
if current_topic != last_topic and current_topic != "其他":
logger.info(f"主题变化: {last_topic} -> {current_topic}")
return True
return False
def _should_continue_current_agent(
self,
message: str,
current_agent_id: str
) -> Tuple[bool, float]:
"""判断是否应该继续使用当前 Agent
Args:
message: 用户消息
current_agent_id: 当前 Agent ID
Returns:
(是否继续, 置信度)
"""
# 检查延续信号
has_continuation_signal = any(
signal in message
for signal in self.CONTINUATION_SIGNALS
)
# 计算当前 Agent 对消息的匹配度
current_score = self._calculate_agent_score(message, current_agent_id)
# 如果有延续信号且匹配度不太低,继续使用
if has_continuation_signal and current_score > 0.3:
return True, min(current_score + 0.2, 1.0)
# 如果匹配度高,继续使用
if current_score > 0.6:
return True, current_score
return False, current_score
def _route_from_scratch(self, message: str) -> Tuple[str, float]:
"""从头开始路由(不考虑历史)
Args:
message: 用户消息
Returns:
(Agent ID, 置信度)
"""
best_agent_id = None
best_score = 0.0
# 遍历所有路由规则
for rule in self.routing_rules:
score = self._calculate_rule_score(message, rule)
if score > best_score:
best_score = score
best_agent_id = rule.get("target_agent_id")
# 如果没有匹配的规则,使用默认 Agent
if not best_agent_id or best_score < 0.3:
best_agent_id = self._get_default_agent_id()
best_score = 0.5
logger.warning(f"未找到匹配规则,使用默认 Agent: {best_agent_id}")
return best_agent_id, best_score
def _calculate_rule_score(
self,
message: str,
rule: Dict[str, Any]
) -> float:
"""计算规则匹配分数
Args:
message: 用户消息
rule: 路由规则
Returns:
匹配分数 (0-1)
"""
score = 0.0
message_lower = message.lower()
# 1. 关键词匹配 (权重 0.6)
keywords = rule.get("keywords", [])
if keywords:
matched_keywords = sum(
1 for keyword in keywords
if keyword.lower() in message_lower
)
keyword_score = matched_keywords / len(keywords)
score += keyword_score * 0.6
# 2. 正则匹配 (权重 0.3)
patterns = rule.get("patterns", [])
if patterns:
matched_patterns = sum(
1 for pattern in patterns
if re.search(pattern, message, re.IGNORECASE)
)
pattern_score = matched_patterns / len(patterns)
score += pattern_score * 0.3
# 3. 排除关键词 (负分)
exclude_keywords = rule.get("exclude_keywords", [])
if exclude_keywords:
has_exclude = any(
keyword.lower() in message_lower
for keyword in exclude_keywords
)
if has_exclude:
score *= 0.5 # 减半
# 4. 最小关键词数量要求
min_keyword_count = rule.get("min_keyword_count", 0)
if keywords and min_keyword_count > 0:
matched_count = sum(
1 for keyword in keywords
if keyword.lower() in message_lower
)
if matched_count < min_keyword_count:
score *= 0.7 # 惩罚
return min(score, 1.0)
def _calculate_agent_score(
self,
message: str,
agent_id: str
) -> float:
"""计算 Agent 对消息的匹配分数
Args:
message: 用户消息
agent_id: Agent ID
Returns:
匹配分数 (0-1)
"""
# 找到该 Agent 对应的所有规则
agent_rules = [
rule for rule in self.routing_rules
if rule.get("target_agent_id") == agent_id
]
if not agent_rules:
return 0.0
# 返回最高分数
max_score = max(
self._calculate_rule_score(message, rule)
for rule in agent_rules
)
return max_score
def _extract_topic(self, message: str) -> str:
"""提取消息主题
Args:
message: 用户消息
Returns:
主题名称
"""
# 主题关键词映射
topic_keywords = {
"数学": ["数学", "方程", "计算", "求解", "x", "y", "函数", "几何"],
"物理": ["物理", "", "速度", "加速度", "能量", "功率", "电路"],
"化学": ["化学", "方程式", "反应", "元素", "分子", "原子", "化合物"],
"语文": ["语文", "古诗", "作文", "阅读", "文言文", "诗词"],
"英语": ["英语", "单词", "语法", "翻译", "时态", "句型"],
"历史": ["历史", "朝代", "事件", "人物", "战争", "革命"],
"作业": ["作业", "批改", "检查", "评分", "反馈"],
"学习规划": ["计划", "规划", "方法", "技巧", "时间", "安排"],
"订单": ["订单", "发货", "物流", "配送", "快递"],
"退款": ["退款", "退货", "售后", "换货", "维修"],
"账户": ["账户", "密码", "登录", "注册", "绑定"],
"支付": ["支付", "付款", "充值", "余额", "优惠券"]
}
message_lower = message.lower()
# 统计每个主题的匹配度
topic_scores = {}
for topic, keywords in topic_keywords.items():
matched = sum(
1 for keyword in keywords
if keyword in message_lower
)
if matched > 0:
topic_scores[topic] = matched
# 返回匹配度最高的主题
if topic_scores:
best_topic = max(topic_scores.items(), key=lambda x: x[1])[0]
return best_topic
return "其他"
def _get_default_agent_id(self) -> str:
"""获取默认 Agent ID
Returns:
默认 Agent ID
"""
# 优先使用第一个路由规则的 Agent
if self.routing_rules:
return self.routing_rules[0].get("target_agent_id")
# 否则使用第一个子 Agent
if self.sub_agents:
return list(self.sub_agents.keys())[0]
return "default-agent"

View File

@@ -0,0 +1,52 @@
from app.celery_app import celery_app
def create_processing_task(item_data: dict) -> str:
"""
Sends a task to the Celery queue to process an item.
:param item_data: The dictionary representation of the item.
:return: The ID of the created task.
"""
task = celery_app.send_task("tasks.process_item", args=[item_data])
return task.id
def get_task_result(task_id: str) -> dict:
"""
Checks the status and result of a Celery task.
:param task_id: The ID of the task to check.
:return: A dictionary with the task's status and result (if ready).
"""
result = celery_app.AsyncResult(task_id)
if result.ready():
return {"status": result.status, "result": result.get()}
return {"status": result.status}
def get_task_memory_read_result(task_id: str) -> dict:
"""
Checks the status and result of a memory read task.
:param task_id: The ID of the task to check.
:return: A dictionary with the task's status and result (if ready).
"""
result = celery_app.AsyncResult(task_id)
if result.ready():
return {"status": result.status, "result": result.get()}
return {"status": result.status}
def get_task_memory_write_result(task_id: str) -> dict:
"""
Checks the status and result of a memory write task.
:param task_id: The ID of the task to check.
:return: A dictionary with the task's status and result (if ready).
"""
result = celery_app.AsyncResult(task_id)
if result.ready():
return {"status": result.status, "result": result.get()}
return {"status": result.status}

View File

@@ -0,0 +1,220 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.logging_config import get_business_logger
from app.repositories.tenant_repository import TenantRepository
from app.repositories.user_repository import UserRepository
from app.repositories.workspace_repository import WorkspaceRepository
from app.schemas.tenant_schema import (
TenantCreate, TenantUpdate, Tenant, TenantQuery, TenantList
)
from app.schemas.user_schema import User
from app.schemas.workspace_schema import WorkspaceCreate
from app.models.tenant_model import Tenants
from app.models.user_model import User as UserModel
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
class TenantService:
"""租户业务逻辑层"""
def __init__(self, db: Session):
self.db = db
self.tenant_repo = TenantRepository(db)
self.user_repo = UserRepository(db)
self.workspace_repo = WorkspaceRepository(db)
def create_tenant(self, tenant_data: TenantCreate) -> Tenants:
"""创建租户"""
# 检查租户名称是否已存在
existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name)
if existing_tenant:
raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME)
try:
tenant = self.tenant_repo.create_tenant(tenant_data)
business_logger.info(f"创建租户成功: {tenant.name} (ID: {tenant.id})")
return tenant
except Exception as e:
business_logger.error(f"创建租户失败: {str(e)}")
raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR)
def create_tenant_and_assign_user(self, tenant_data: TenantCreate, user_id: uuid.UUID) -> Tenants:
"""创建租户并分配用户"""
try:
# 创建租户
tenant = self.create_tenant(tenant_data)
# 将用户分配给租户
success = self.user_repo.assign_user_to_tenant(user_id, tenant.id)
if not success:
raise BusinessException("分配用户到租户失败", code=BizCode.STATE_CONFLICT)
business_logger.info(f"创建租户并分配用户成功: {tenant.name}")
return tenant
except Exception as e:
business_logger.error(f"创建租户和分配用户失败: {str(e)}")
self.db.rollback()
raise BusinessException(f"创建租户失败: {str(e)}", code=BizCode.DB_ERROR)
def get_tenant(self, tenant_id: uuid.UUID) -> Optional[Tenants]:
"""获取租户"""
return self.tenant_repo.get_tenant_by_id(tenant_id)
def get_tenant_by_name(self, name: str) -> Optional[Tenants]:
"""根据名称获取租户"""
return self.tenant_repo.get_tenant_by_name(name)
def get_tenants(self, query: TenantQuery) -> TenantList:
"""获取租户列表"""
skip = (query.page - 1) * query.size
tenants = self.tenant_repo.get_tenants(
skip=skip,
limit=query.size,
is_active=query.is_active,
search=query.search
)
total = self.tenant_repo.count_tenants(
is_active=query.is_active,
search=query.search
)
pages = (total + query.size - 1) // query.size
return TenantList(
items=[Tenant.model_validate(tenant) for tenant in tenants],
total=total,
page=query.page,
size=query.size,
pages=pages
)
def update_tenant(self, tenant_id: uuid.UUID, tenant_data: TenantUpdate) -> Optional[Tenants]:
"""更新租户"""
# 如果更新名称,检查是否重复
if tenant_data.name:
existing_tenant = self.tenant_repo.get_tenant_by_name(tenant_data.name)
if existing_tenant and existing_tenant.id != tenant_id:
raise BusinessException(f"租户名称 '{tenant_data.name}' 已存在", code=BizCode.DUPLICATE_NAME)
try:
tenant = self.tenant_repo.update_tenant(tenant_id, tenant_data)
if tenant:
business_logger.info(f"更新租户成功: {tenant.name} (ID: {tenant.id})")
return tenant
except Exception as e:
business_logger.error(f"更新租户失败: {str(e)}")
raise BusinessException(f"更新租户失败: {str(e)}", code=BizCode.DB_ERROR)
def delete_tenant(self, tenant_id: uuid.UUID) -> bool:
"""删除租户"""
try:
# 检查租户是否存在
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
return False
# 检查是否有关联的用户
users = self.tenant_repo.get_tenant_users(tenant_id)
if users:
raise BusinessException("无法删除租户,存在关联的用户", code=BizCode.STATE_CONFLICT)
# 检查是否有关联的工作空间
workspaces = self.workspace_repo.get_workspaces_by_tenant(tenant_id)
if workspaces:
raise BusinessException("无法删除租户,存在关联的工作空间", code=BizCode.STATE_CONFLICT)
success = self.tenant_repo.delete_tenant(tenant_id)
if success:
business_logger.info(f"删除租户成功: {tenant.name} (ID: {tenant.id})")
return success
except Exception as e:
business_logger.error(f"删除租户失败: {str(e)}")
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
# 租户用户管理
def get_tenant_users(
self,
tenant_id: uuid.UUID,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> List[UserModel]:
"""获取租户下的用户列表"""
return self.user_repo.get_users_by_tenant(
tenant_id=tenant_id,
skip=skip,
limit=limit,
is_active=is_active,
search=search
)
def count_tenant_users(
self,
tenant_id: uuid.UUID,
is_active: Optional[bool] = None,
search: Optional[str] = None
) -> int:
"""统计租户下的用户数量"""
return self.user_repo.count_users_by_tenant(
tenant_id=tenant_id,
is_active=is_active,
search=search
)
def assign_user_to_tenant(self, user_id: uuid.UUID, tenant_id: uuid.UUID) -> bool:
"""将用户分配给租户"""
# 检查租户是否存在
tenant = self.tenant_repo.get_tenant_by_id(tenant_id)
if not tenant:
raise BusinessException("租户不存在", code=BizCode.TENANT_NOT_FOUND)
try:
success = self.user_repo.assign_user_to_tenant(user_id, tenant_id)
if success:
business_logger.info(f"分配用户到租户成功: 用户ID {user_id}, 租户ID {tenant_id}")
return success
except Exception as e:
business_logger.error(f"分配用户到租户失败: {str(e)}")
raise BusinessException(f"分配用户到租户失败: {str(e)}", code=BizCode.DB_ERROR)
def get_user_tenant(self, user_id: uuid.UUID) -> Optional[Tenants]:
"""获取用户所属的租户"""
return self.tenant_repo.get_user_tenant(user_id)
def remove_user_from_tenant(self, user_id: uuid.UUID) -> bool:
"""将用户从租户中移除设置tenant_id为None"""
try:
user = self.user_repo.get_user_by_id(user_id)
if not user:
return False
success = self.user_repo.assign_user_to_tenant(user_id, None)
if success:
business_logger.info(f"移除用户租户关联成功: 用户ID {user_id}")
return success
except Exception as e:
business_logger.error(f"移除用户租户关联失败: {str(e)}")
raise BusinessException(f"移除用户租户关联失败: {str(e)}", code=BizCode.DB_ERROR)
def get_users_without_tenant(
self,
skip: int = 0,
limit: int = 100,
is_active: Optional[bool] = None
) -> List[UserModel]:
"""获取没有租户的用户列表"""
return self.user_repo.get_users_without_tenant(
skip=skip,
limit=limit,
is_active=is_active
)

View File

@@ -0,0 +1,617 @@
"""
Upload Service for Generic File Upload System
Handles file upload, storage, access, deletion, and metadata updates.
"""
import os
import uuid
import shutil
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from fastapi import UploadFile
from app.models.user_model import User
from app.models.generic_file_model import GenericFile
from app.repositories.generic_file_repository import GenericFileRepository
from app.core.upload_enums import UploadContext
from app.core.storage_strategy import StrategyFactory
from app.core.validators.file_validator import FileValidator
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.core.error_codes import BizCode
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.uow import IUnitOfWork
from app.core.compensation import CompensationHandler
# Get logger
logger = get_logger(__name__)
class FileNotFoundError(BusinessException):
"""Exception raised when file is not found."""
def __init__(self, file_id: uuid.UUID):
super().__init__(
f"文件 {file_id} 不存在",
code=BizCode.NOT_FOUND
)
class FileAccessDeniedError(BusinessException):
"""Exception raised when file access is denied."""
def __init__(self, file_id: uuid.UUID):
super().__init__(
f"无权访问文件 {file_id}",
code=BizCode.FORBIDDEN
)
class FileStorageError(BusinessException):
"""Exception raised when file storage fails."""
def __init__(self, reason: str):
super().__init__(
f"文件存储失败: {reason}",
code=BizCode.INTERNAL_ERROR
)
class FileReferencedError(BusinessException):
"""Exception raised when trying to delete a referenced file."""
def __init__(self, file_id: uuid.UUID, reference_count: int):
super().__init__(
f"文件 {file_id}{reference_count} 个资源引用,无法删除",
code=BizCode.BAD_REQUEST
)
class UploadResult:
"""Result of a file upload operation."""
def __init__(self, success: bool, file_id: Optional[uuid.UUID] = None,
file_name: str = "", error: Optional[str] = None):
self.success = success
self.file_id = file_id
self.file_name = file_name
self.error = error
class UploadService:
"""
Service for handling file uploads and management.
Coordinates validation, storage, and database operations.
Uses Unit of Work pattern for transaction management.
"""
def __init__(self, uow: IUnitOfWork = None):
self.validator = FileValidator()
self.uow = uow
def upload_file(
self,
file: UploadFile,
context: UploadContext,
metadata: Optional[Dict[str, Any]],
current_user: User,
db: Session = None
) -> GenericFile:
"""
Upload a single file using Unit of Work pattern with compensation transactions.
Args:
file: The uploaded file
context: Upload context (avatar, app_icon, etc.)
metadata: Additional metadata for the file
current_user: The user uploading the file
db: Optional database session (for backward compatibility)
Returns:
GenericFile: The created file record
Raises:
FileSizeExceededError: If file size exceeds limit
FileTypeNotAllowedError: If file type is not allowed
EmptyFileError: If file is empty
FileStorageError: If file storage fails
"""
logger.info(f"Starting file upload: filename={file.filename}, context={context}, user={current_user.id}")
if metadata is None:
metadata = {}
# Get storage strategy for this context
strategy = StrategyFactory.get_strategy(context)
upload_policy = strategy.get_upload_policy()
# Validate file against upload policy
logger.debug(f"Validating file: {file.filename}")
self.validator.validate_and_raise(file, upload_policy)
# Generate file ID
file_id = uuid.uuid4()
# Extract file information
filename = file.filename or "unknown"
file_extension = ""
if "." in filename:
file_extension = "." + filename.rsplit(".", 1)[1].lower()
# Get file size
file.file.seek(0, 2)
file_size = file.file.tell()
file.file.seek(0)
# Get storage path
storage_path = strategy.get_storage_path(
tenant_id=current_user.tenant_id,
file_id=file_id,
file_extension=file_extension,
metadata=metadata
)
logger.debug(f"Storage path: {storage_path}")
# Use Unit of Work pattern with compensation handler
compensation = CompensationHandler()
try:
# Use provided UoW or create a new one for backward compatibility
if self.uow:
uow = self.uow
should_manage_context = False
else:
# Backward compatibility: use provided db session
if db:
# Create a temporary UoW wrapper for the existing session
from app.core.uow import SqlAlchemyUnitOfWork
uow = SqlAlchemyUnitOfWork(lambda: db)
uow._session = db
uow.files = GenericFileRepository(db)
should_manage_context = False
else:
raise FileStorageError("Either uow or db session must be provided")
# 1. Save physical file
self._save_physical_file(file, storage_path)
# Register compensation: delete physical file if database operation fails
compensation.register(lambda: self._delete_physical_file(storage_path))
# 2. Generate access URL
access_url = None
if context in [UploadContext.AVATAR, UploadContext.APP_ICON]:
access_url = f"{settings.FILE_ACCESS_URL_PREFIX}/{file_id}"
# 3. Create file data
file_data = {
"id": file_id,
"tenant_id": current_user.tenant_id,
"created_by": current_user.id,
"file_name": filename,
"file_ext": file_extension,
"file_size": file_size,
"mime_type": file.content_type,
"context": context.value,
"storage_path": str(storage_path),
"file_metadata": metadata,
"status": "active",
"is_public": metadata.get("is_public", False),
"access_url": access_url,
"reference_count": 0,
}
# 4. Create database record
db_file = uow.files.create_file(file_data)
# 5. Commit transaction (only if we're managing the session)
if should_manage_context:
uow.commit()
elif db:
db.commit()
# Success - clear compensation operations
compensation.clear()
logger.info(f"File upload completed successfully: {filename} (ID: {file_id})")
return db_file
except Exception as e:
# Execute compensation operations
compensation.execute()
# Rollback if we're managing the session
if db:
db.rollback()
logger.error(f"File upload failed: {str(e)}")
raise FileStorageError(f"文件上传失败: {str(e)}")
def _save_physical_file(self, file: UploadFile, storage_path: Path):
"""
Save physical file to filesystem.
Args:
file: The uploaded file
storage_path: Path where file should be saved
Raises:
FileStorageError: If file save fails
"""
try:
# Create directory if it doesn't exist
storage_path.parent.mkdir(parents=True, exist_ok=True)
# Save file
with open(storage_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"File saved to filesystem: {storage_path}")
except Exception as e:
logger.error(f"Failed to save file to filesystem: {str(e)}")
raise FileStorageError(f"无法保存文件到磁盘: {str(e)}")
def _delete_physical_file(self, storage_path: Path):
"""
Delete physical file (compensation operation).
Args:
storage_path: Path of file to delete
"""
try:
if os.path.exists(storage_path):
os.remove(storage_path)
logger.info(f"补偿操作:删除文件 {storage_path}")
except Exception as e:
logger.error(f"删除文件失败: {e}")
def _restore_file_from_backup(self, backup_path: Path, original_path: Path):
"""
Restore file from backup (compensation operation).
Args:
backup_path: Path of backup file
original_path: Path where file should be restored
"""
try:
if backup_path.exists():
shutil.copy2(backup_path, original_path)
logger.info(f"补偿操作:从备份恢复文件 {original_path}")
# Clean up backup after restoration
os.remove(backup_path)
logger.debug(f"补偿操作:删除备份文件 {backup_path}")
except Exception as e:
logger.error(f"恢复文件失败: {e}")
def upload_files_batch(
self,
files: List[UploadFile],
context: UploadContext,
metadata: Optional[Dict[str, Any]],
current_user: User,
db: Session = None
) -> List[UploadResult]:
"""
Upload multiple files in batch.
Individual file failures do not affect other files.
Args:
files: List of uploaded files
context: Upload context (avatar, app_icon, etc.)
metadata: Additional metadata for the files
current_user: The user uploading the files
db: Optional database session (for backward compatibility)
Returns:
List[UploadResult]: List of upload results for each file
Raises:
BusinessException: If batch size exceeds limit
"""
logger.info(f"Starting batch upload: {len(files)} files, context={context}, user={current_user.id}")
# Validate batch size
MAX_BATCH_SIZE = 20
if len(files) > MAX_BATCH_SIZE:
raise BusinessException(
f"批量上传文件数量不能超过 {MAX_BATCH_SIZE}",
code=BizCode.BAD_REQUEST,
context={
"file_count": len(files),
"max_batch_size": MAX_BATCH_SIZE,
"user_id": str(current_user.id),
"tenant_id": str(current_user.tenant_id),
"context": context
}
)
results = []
for file in files:
try:
# Upload each file independently
db_file = self.upload_file(file, context, metadata, current_user, db)
results.append(UploadResult(
success=True,
file_id=db_file.id,
file_name=file.filename or "unknown",
error=None
))
logger.info(f"Batch upload success: {file.filename}")
except Exception as e:
# Log error but continue with other files
logger.error(f"Batch upload failed for {file.filename}: {str(e)}")
results.append(UploadResult(
success=False,
file_id=None,
file_name=file.filename or "unknown",
error=str(e)
))
logger.info(f"Batch upload completed: {sum(1 for r in results if r.success)}/{len(files)} successful")
return results
def get_file(
self,
file_id: uuid.UUID,
current_user: User,
db: Session = None
) -> GenericFile:
"""
Get a file by ID with permission validation.
Args:
file_id: UUID of the file
current_user: The user requesting the file
db: Optional database session (for backward compatibility)
Returns:
GenericFile: The file record
Raises:
FileNotFoundError: If file doesn't exist
FileAccessDeniedError: If user doesn't have permission
"""
logger.debug(f"Getting file: file_id={file_id}, user={current_user.id}")
# Use UoW or provided db session
if self.uow:
with self.uow:
file = self.uow.files.get_file_by_id(file_id)
elif db:
repository = GenericFileRepository(db)
file = repository.get_file_by_id(file_id)
else:
raise FileStorageError("Either uow or db session must be provided")
if not file:
logger.warning(f"File not found: {file_id}")
raise FileNotFoundError(file_id)
# Check permissions using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_file(file)
try:
permission_service.require_permission(
subject,
Action.READ,
resource,
error_message=f"无权访问文件 {file_id}"
)
except PermissionDeniedException:
logger.warning(f"Access denied: file_id={file_id}, user={current_user.id}")
raise FileAccessDeniedError(file_id)
logger.debug(f"File access granted: {file.file_name}")
return file
def delete_file(
self,
file_id: uuid.UUID,
current_user: User,
db: Session = None
) -> None:
"""
Delete a file (both physical file and database record) using UoW pattern with compensation.
This method uses compensation transactions to ensure data consistency:
1. Delete physical file first
2. Register compensation to restore file if DB deletion fails
3. Delete database record
4. Commit transaction
5. Clear compensation on success
Args:
file_id: UUID of the file to delete
current_user: The user requesting deletion
db: Optional database session (for backward compatibility)
Raises:
FileNotFoundError: If file doesn't exist
FileAccessDeniedError: If user doesn't have permission
FileReferencedError: If file is still referenced
FileStorageError: If deletion fails
"""
logger.info(f"Deleting file: file_id={file_id}, user={current_user.id}")
# Get file and check permissions
if self.uow:
with self.uow:
file = self.uow.files.get_file_by_id(file_id)
elif db:
repository = GenericFileRepository(db)
file = repository.get_file_by_id(file_id)
else:
raise FileStorageError("Either uow or db session must be provided")
if not file:
logger.warning(f"File not found for deletion: {file_id}")
raise FileNotFoundError(file_id)
# Check permissions using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_file(file)
try:
permission_service.require_permission(
subject,
Action.DELETE,
resource,
error_message=f"无权删除文件 {file_id}"
)
except PermissionDeniedException:
logger.warning(f"Delete access denied: file_id={file_id}, user={current_user.id}")
raise FileAccessDeniedError(file_id)
# Check reference count
if file.reference_count > 0:
logger.warning(f"Cannot delete referenced file: file_id={file_id}, references={file.reference_count}")
raise FileReferencedError(file_id, file.reference_count)
# Store storage path and file content for potential restoration
storage_path = Path(file.storage_path)
backup_path = None
# Use compensation handler for atomic deletion
compensation = CompensationHandler()
try:
# 1. Backup and delete physical file first
if storage_path.exists():
# Create backup in temp location
backup_path = storage_path.parent / f".backup_{file_id}{storage_path.suffix}"
shutil.copy2(storage_path, backup_path)
logger.debug(f"Created backup: {backup_path}")
# Delete original file
os.remove(storage_path)
logger.info(f"Physical file deleted: {storage_path}")
# Register compensation: restore file from backup if DB deletion fails
compensation.register(lambda: self._restore_file_from_backup(backup_path, storage_path))
else:
logger.warning(f"Physical file not found: {storage_path}")
# 2. Delete database record (soft delete)
if self.uow:
with self.uow:
self.uow.files.delete_file(file_id)
self.uow.commit()
elif db:
repository = GenericFileRepository(db)
repository.delete_file(file_id)
db.commit()
logger.info(f"File record deleted successfully: {file.file_name} (ID: {file_id})")
# 3. Success - clear compensations and remove backup
compensation.clear()
if backup_path and backup_path.exists():
os.remove(backup_path)
logger.debug(f"Removed backup: {backup_path}")
except Exception as e:
# Execute compensation to restore file
compensation.execute()
# Rollback database if using db session
if db:
db.rollback()
logger.error(f"Failed to delete file: {str(e)}")
raise FileStorageError(f"无法删除文件: {str(e)}")
def update_file_metadata(
self,
file_id: uuid.UUID,
update_data: Dict[str, Any],
current_user: User,
db: Session = None
) -> GenericFile:
"""
Update file metadata using UoW pattern.
Args:
file_id: UUID of the file to update
update_data: Dictionary containing fields to update
current_user: The user requesting the update
db: Optional database session (for backward compatibility)
Returns:
GenericFile: The updated file record
Raises:
FileNotFoundError: If file doesn't exist
FileAccessDeniedError: If user doesn't have permission
"""
logger.info(f"Updating file metadata: file_id={file_id}, user={current_user.id}")
# Get file and check permissions
if self.uow:
with self.uow:
file = self.uow.files.get_file_by_id(file_id)
elif db:
repository = GenericFileRepository(db)
file = repository.get_file_by_id(file_id)
else:
raise FileStorageError("Either uow or db session must be provided")
if not file:
logger.warning(f"File not found for update: {file_id}")
raise FileNotFoundError(file_id)
# Check permissions using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_file(file)
try:
permission_service.require_permission(
subject,
Action.UPDATE,
resource,
error_message=f"无权更新文件 {file_id}"
)
except PermissionDeniedException:
logger.warning(f"Update access denied: file_id={file_id}, user={current_user.id}")
raise FileAccessDeniedError(file_id)
# Filter allowed fields for update
# Users can only update: file_name, file_metadata, is_public
allowed_fields = ["file_name", "file_metadata", "is_public"]
filtered_update_data = {
key: value for key, value in update_data.items()
if key in allowed_fields
}
if not filtered_update_data:
logger.warning(f"No valid fields to update for file: {file_id}")
return file
# Update file metadata
try:
if self.uow:
with self.uow:
updated_file = self.uow.files.update_file(file_id, filtered_update_data)
self.uow.commit()
elif db:
repository = GenericFileRepository(db)
updated_file = repository.update_file(file_id, filtered_update_data)
db.commit()
logger.info(f"File metadata updated successfully: {file.file_name} (ID: {file_id})")
return updated_file
except Exception as e:
if db:
db.rollback()
logger.error(f"Failed to update file metadata: {str(e)}")
raise FileStorageError(f"无法更新文件元数据: {str(e)}")

View File

@@ -0,0 +1,570 @@
import datetime
import secrets
import string
from sqlalchemy.orm import Session
import uuid
from app.models.user_model import User
from app.repositories import user_repository
from app.schemas.user_schema import UserCreate
from app.schemas.tenant_schema import TenantCreate
from app.services.tenant_service import TenantService
from app.services.session_service import SessionService
from app.core.security import get_password_hash, verify_password
from app.core.config import settings
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.core.error_codes import BizCode
# from app.services import workspace_service
# from app.schemas.workspace_schema import WorkspaceCreate
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
def create_initial_superuser(db: Session):
business_logger.info("检查并创建初始超级用户")
superuser = user_repository.get_superuser(db)
if superuser:
business_logger.info("超级用户已存在,跳过创建")
return
user_in = UserCreate(
username=settings.FIRST_SUPERUSER_USERNAME,
email=settings.FIRST_SUPERUSER_EMAIL,
password=settings.FIRST_SUPERUSER_PASSWORD,
)
try:
business_logger.debug("开始创建初始租户")
# Create a default tenant for the superuser
default_tenant = TenantCreate(
name=f"{user_in.username}'s Tenant",
description=f"Default tenant for {user_in.username}",
)
# Create tenant service and create tenant with user assignment
tenant_service = TenantService(db)
tenant = tenant_service.create_tenant(default_tenant)
db.flush()
business_logger.debug("开始创建初始超级用户")
hashed_password = get_password_hash(user_in.password)
superuser = user_repository.create_user(
db=db, user=user_in, hashed_password=hashed_password, is_superuser=True,
tenant_id=tenant.id
)
db.commit()
db.refresh(superuser)
business_logger.info(f"初始超级用户创建成功: {superuser.username} (ID: {superuser.id})")
return superuser
except Exception as e:
business_logger.error(f"初始超级用户创建失败: {str(e)}")
db.rollback()
raise BusinessException(
f"初始超级用户创建失败: {str(e)}",
code=BizCode.DB_ERROR,
context={"username": username, "email": email},
cause=e
)
def create_user(db: Session, user: UserCreate) -> User:
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
try:
# 检查用户名是否已存在
business_logger.debug(f"检查用户名是否已存在: {user.username}")
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
if db_user_by_username:
business_logger.warning(f"用户名已存在: {user.username}")
raise BusinessException(
"用户名已存在",
code=BizCode.DUPLICATE_NAME,
context={"username": user.username, "email": user.email}
)
# 检查邮箱是否已注册
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:
business_logger.warning(f"邮箱已注册: {user.email}")
raise BusinessException(
"邮箱已注册",
code=BizCode.DUPLICATE_NAME,
context={"email": user.email, "username": user.username}
)
# 创建普通用户,需要有默认租户
business_logger.debug(f"开始创建用户: {user.username}")
hashed_password = get_password_hash(user.password)
# 获取默认租户(第一个活跃租户)
from app.repositories.tenant_repository import TenantRepository
tenant_repo = TenantRepository(db)
tenants = tenant_repo.get_tenants(skip=0, limit=1, is_active=True)
if not tenants:
business_logger.error("系统中没有可用的租户")
raise BusinessException(
"系统配置错误:没有可用的租户",
code=BizCode.TENANT_NOT_FOUND,
context={"username": user.username, "email": user.email}
)
default_tenant = tenants[0]
new_user = user_repository.create_user(
db=db, user=user, hashed_password=hashed_password,
tenant_id=default_tenant.id, is_superuser=False
)
db.commit()
db.refresh(new_user)
business_logger.info(f"用户创建成功: {new_user.username} (ID: {new_user.id})")
return new_user
except Exception as e:
business_logger.error(f"用户创建失败: {user.username} - {str(e)}")
db.rollback()
raise BusinessException(
f"用户创建失败: {user.username} - {str(e)}",
code=BizCode.DB_ERROR,
context={"username": user.username, "email": user.email},
cause=e
)
def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
business_logger.info(f"创建超级管理员: {user.username}, email: {user.email}")
# 检查当前用户是否为超级管理员
from app.core.permissions import permission_service, Subject
subject = Subject.from_user(current_user)
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员才能创建超级管理员用户"
)
except PermissionDeniedException as e:
business_logger.warning(f"非超级管理员尝试创建超级管理员用户: {user.username}")
raise BusinessException(
str(e),
code=BizCode.FORBIDDEN,
context={
"current_user_id": str(current_user.id),
"current_user_username": current_user.username,
"target_username": user.username
}
)
try:
# 检查用户名是否已存在
business_logger.debug(f"检查用户名是否已存在: {user.username}")
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
if db_user_by_username:
business_logger.warning(f"用户名已存在: {user.username}")
raise BusinessException(
"用户名已存在",
code=BizCode.DUPLICATE_NAME,
context={
"username": user.username,
"email": user.email,
"created_by": str(current_user.id)
}
)
# 检查邮箱是否已注册
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:
business_logger.warning(f"邮箱已注册: {user.email}")
raise BusinessException(
"邮箱已注册",
code=BizCode.DUPLICATE_NAME,
context={
"email": user.email,
"username": user.username,
"created_by": str(current_user.id)
}
)
# 创建超级管理员用户并加入当前用户的租户
business_logger.debug(f"开始创建超级管理员: {user.username}")
hashed_password = get_password_hash(user.password)
new_user = user_repository.create_user(
db=db, user=user, hashed_password=hashed_password,
tenant_id=current_user.tenant_id, is_superuser=True
)
db.commit()
db.refresh(new_user)
business_logger.info(f"超级管理员创建成功: {new_user.username} (ID: {new_user.id}), 已加入租户: {current_user.tenant_id}")
return new_user
except Exception as e:
business_logger.error(f"超级管理员创建失败: {user.username} - {str(e)}")
db.rollback()
raise BusinessException(
f"超级管理员创建失败: {user.username} - {str(e)}",
code=BizCode.DB_ERROR,
context={
"username": user.username,
"email": user.email,
"created_by": str(current_user.id),
"tenant_id": str(current_user.tenant_id)
},
cause=e
)
def deactivate_user(db: Session, user_id_to_deactivate: uuid.UUID, current_user: User) -> User:
business_logger.info(f"停用用户: user_id={user_id_to_deactivate}, 操作者: {current_user.username}")
try:
# 查找用户
business_logger.debug(f"查找待停用用户: {user_id_to_deactivate}")
db_user = user_repository.get_user_by_id(db, user_id=user_id_to_deactivate)
if not db_user:
business_logger.warning(f"用户不存在: {user_id_to_deactivate}")
raise BusinessException(
"用户不存在",
code=BizCode.USER_NOT_FOUND,
context={"user_id": str(user_id_to_deactivate)}
)
# 权限检查 using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_user(db_user)
try:
permission_service.require_permission(
subject,
Action.DEACTIVATE,
resource,
error_message="没有权限停用该用户"
)
except PermissionDeniedException as e:
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试停用用户 {user_id_to_deactivate}")
raise BusinessException(
str(e),
code=BizCode.FORBIDDEN,
context={
"current_user_id": str(current_user.id),
"current_user_username": current_user.username,
"target_user_id": str(user_id_to_deactivate)
}
)
# 检查用户类型,如果是超级管理员,判断一下不是唯一的一个
if db_user.is_superuser:
is_only_superuser = user_repository.check_superuser_only(db)
if is_only_superuser:
business_logger.warning(f"停用超级管理员用户: {db_user.username} (ID: {user_id_to_deactivate})")
raise BusinessException(
"不能停用唯一的超级管理员用户",
code=BizCode.FORBIDDEN,
context={
"user_id": str(user_id_to_deactivate),
"username": db_user.username
}
)
# 停用用户
business_logger.debug(f"执行用户停用: {db_user.username} (ID: {user_id_to_deactivate})")
db_user.is_active = False
db.add(db_user)
db.commit()
db.refresh(db_user)
business_logger.info(f"用户停用成功: {db_user.username} (ID: {user_id_to_deactivate})")
return db_user
except Exception as e:
business_logger.error(f"用户停用失败: user_id={user_id_to_deactivate} - {str(e)}")
db.rollback()
if isinstance(e, BusinessException):
raise e
raise BusinessException(f"{str(e)}", code=BizCode.DB_ERROR)
def activate_user(db: Session, user_id_to_activate: uuid.UUID, current_user: User) -> User:
business_logger.info(f"激活用户: user_id={user_id_to_activate}, 操作者: {current_user.username}")
try:
# 查找用户
business_logger.debug(f"查找待激活用户: {user_id_to_activate}")
db_user = user_repository.get_user_by_id(db, user_id=user_id_to_activate)
if not db_user:
business_logger.warning(f"用户不存在: {user_id_to_activate}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 权限检查 using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_user(db_user)
try:
permission_service.require_permission(
subject,
Action.ACTIVATE,
resource,
error_message="没有权限激活该用户"
)
except PermissionDeniedException as e:
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试激活用户 {user_id_to_activate}")
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
# 激活用户
business_logger.debug(f"执行用户激活: {db_user.username} (ID: {user_id_to_activate})")
db_user.is_active = True
db.add(db_user)
db.commit()
db.refresh(db_user)
business_logger.info(f"用户激活成功: {db_user.username} (ID: {user_id_to_activate})")
return db_user
except Exception as e:
business_logger.error(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}")
db.rollback()
raise BusinessException(f"用户激活失败: user_id={user_id_to_activate} - {str(e)}", code=BizCode.DB_ERROR)
def get_user(db: Session, user_id: uuid.UUID, current_user: User) -> User:
business_logger.info(f"获取用户信息: user_id={user_id}, 操作者: {current_user.username}")
try:
# 查找用户
business_logger.debug(f"查找用户: {user_id}")
db_user = user_repository.get_user_by_id(db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 权限检查 using permission service
from app.core.permissions import permission_service, Subject, Resource, Action
subject = Subject.from_user(current_user)
resource = Resource.from_user(db_user)
try:
permission_service.require_permission(
subject,
Action.READ,
resource,
error_message="没有权限获取该用户信息"
)
except PermissionDeniedException as e:
business_logger.warning(f"权限不足: 用户 {current_user.username} 尝试获取用户 {user_id} 信息")
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
# 返回用户信息
business_logger.debug(f"返回用户信息: {db_user.username} (ID: {user_id})")
return db_user
except Exception as e:
business_logger.error(f"获取用户信息失败: user_id={user_id} - {str(e)}")
raise BusinessException(f"获取用户信息失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
def get_tenant_superusers(db: Session, current_user: User, include_inactive: bool = True) -> list[User]:
"""获取当前租户下的超管账号列表"""
business_logger.info(f"获取租户超管列表: tenant_id={current_user.tenant_id}, 请求者: {current_user.username}, include_inactive={include_inactive}")
try:
# 检查当前用户是否有权限查看(只有超管才能查看超管列表)
from app.core.permissions import permission_service, Subject
subject = Subject.from_user(current_user)
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员才能查看超管列表"
)
except PermissionDeniedException as e:
business_logger.warning(f"非超级管理员尝试查看超管列表: {current_user.username}")
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
# 检查用户是否有租户
if not current_user.tenant_id:
business_logger.warning(f"用户没有租户信息: {current_user.username}")
raise BusinessException("用户没有租户信息", code=BizCode.TENANT_NOT_FOUND)
# 获取租户下的超管列表
business_logger.debug(f"查询租户超管: tenant_id={current_user.tenant_id}, include_inactive={include_inactive}")
is_active_filter = None if include_inactive else True
superusers = user_repository.get_superusers_by_tenant(
db=db,
tenant_id=current_user.tenant_id,
is_active=is_active_filter
)
business_logger.info(f"租户超管查询成功: tenant_id={current_user.tenant_id}, count={len(superusers)}")
return superusers
except Exception as e:
business_logger.error(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}")
raise BusinessException(f"获取租户超管列表失败: tenant_id={current_user.tenant_id} - {str(e)}", code=BizCode.DB_ERROR)
def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
"""更新用户的最后登录时间"""
business_logger.info(f"更新用户最后登录时间: user_id={user_id}")
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 更新最后登录时间
db_user.last_login_at = datetime.datetime.now()
db.commit()
db.refresh(db_user)
business_logger.info(f"用户最后登录时间更新成功: {db_user.username} (ID: {user_id})")
return db_user
except HTTPException:
raise
except Exception as e:
business_logger.error(f"更新用户最后登录时间失败: user_id={user_id} - {str(e)}")
db.rollback()
raise
async def change_password(db: Session, user_id: uuid.UUID, old_password: str, new_password: str, current_user: User) -> User:
"""普通用户修改自己的密码"""
business_logger.info(f"用户修改密码请求: user_id={user_id}, current_user={current_user.id}")
# 检查权限:只能修改自己的密码
if current_user.id != user_id:
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only change your own password"
)
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
# 验证旧密码
if not verify_password(old_password, db_user.hashed_password):
business_logger.warning(f"用户旧密码验证失败: {user_id}")
raise BusinessException("当前密码不正确", code=BizCode.VALIDATION_FAILED)
# 更新密码
db_user.hashed_password = get_password_hash(new_password)
db.commit()
db.refresh(db_user)
# 使所有旧 tokens 失效
await SessionService.invalidate_all_user_tokens(str(user_id))
business_logger.info(f"用户密码修改成功: {db_user.username} (ID: {user_id})")
return db_user
except Exception as e:
business_logger.error(f"修改用户密码失败: user_id={user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"修改用户密码失败: user_id={user_id} - {str(e)}", code=BizCode.DB_ERROR)
async def admin_change_password(db: Session, target_user_id: uuid.UUID, new_password: str = None, current_user: User = None) -> tuple[User, str]:
"""
超级管理员修改指定用户的密码
Args:
db: 数据库会话
target_user_id: 目标用户ID
new_password: 新密码如果为None则自动生成随机密码
current_user: 当前用户(超级管理员)
Returns:
tuple[User, str]: (更新后的用户对象, 实际使用的密码)
"""
business_logger.info(f"管理员修改用户密码请求: admin={current_user.id}, target_user={target_user_id}")
# 检查权限:只有超级管理员可以修改他人密码
from app.core.permissions import permission_service, Subject
subject = Subject.from_user(current_user)
try:
permission_service.check_superuser(
subject,
error_message="只有超级管理员可以修改他人密码"
)
except PermissionDeniedException as e:
business_logger.warning(f"非超管用户尝试修改他人密码: current_user={current_user.id}")
raise BusinessException(str(e), code=BizCode.FORBIDDEN)
try:
# 获取目标用户
target_user = user_repository.get_user_by_id(db=db, user_id=target_user_id)
if not target_user:
business_logger.warning(f"目标用户不存在: {target_user_id}")
raise BusinessException("目标用户不存在", code=BizCode.USER_NOT_FOUND)
# 检查租户权限:超管只能修改同租户用户的密码
if current_user.tenant_id != target_user.tenant_id:
business_logger.warning(f"跨租户密码修改尝试: admin_tenant={current_user.tenant_id}, target_tenant={target_user.tenant_id}")
raise BusinessException("不可跨租户修改用户密码", code=BizCode.FORBIDDEN)
# 如果没有提供新密码,则生成随机密码
actual_password = new_password if new_password else generate_random_password()
# 更新密码
target_user.hashed_password = get_password_hash(actual_password)
db.commit()
db.refresh(target_user)
# 使所有旧 tokens 失效
await SessionService.invalidate_all_user_tokens(str(target_user_id))
password_type = "指定密码" if new_password else "随机生成密码"
business_logger.info(f"管理员修改用户密码成功: admin={current_user.username}, target={target_user.username} (ID: {target_user_id}), 类型={password_type}")
return target_user, actual_password
except Exception as e:
business_logger.error(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}")
db.rollback()
raise BusinessException(f"管理员修改用户密码失败: admin={current_user.id}, target_user={target_user_id} - {str(e)}", code=BizCode.DB_ERROR)
def generate_random_password(length: int = 12) -> str:
"""
生成随机密码
Args:
length: 密码长度默认12位
Returns:
str: 生成的随机密码
"""
# 确保密码包含大小写字母、数字和特殊字符
lowercase = string.ascii_lowercase
uppercase = string.ascii_uppercase
digits = string.digits
special_chars = "!@#$%^&*"
# 确保至少包含每种字符类型
password = [
secrets.choice(lowercase),
secrets.choice(uppercase),
secrets.choice(digits),
secrets.choice(special_chars)
]
# 填充剩余长度
all_chars = lowercase + uppercase + digits + special_chars
for _ in range(length - 4):
password.append(secrets.choice(all_chars))
# 打乱顺序
secrets.SystemRandom().shuffle(password)
return ''.join(password)

View File

@@ -0,0 +1,776 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import secrets
import hashlib
import datetime
from fastapi import HTTPException, status
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.models.tenant_model import Tenants
from app.models.user_model import User
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace, WorkspaceRole, WorkspaceInvite, InviteStatus, WorkspaceMember
from app.schemas.workspace_schema import (
WorkspaceCreate,
WorkspaceUpdate,
WorkspaceInviteCreate,
WorkspaceInviteResponse,
InviteValidateResponse,
InviteAcceptRequest,
WorkspaceMemberUpdate
)
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.services import user_service
from os import getenv
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
import os #
from dotenv import load_dotenv
load_dotenv()
def switch_workspace(
db: Session,
workspace_id: uuid.UUID,
user: User,
):
"""切换工作空间"""
business_logger.debug(f"用户 {user.username} 请求切换工作空间为 {workspace_id}")
# 检查用户是否为成员或超级管理员
_check_workspace_member_permission(db, workspace_id, user)
# 更新当前用户的工作空间上下文
try:
user.current_workspace_id = workspace_id
db.commit()
business_logger.info(f"用户 {user.username} 成功切换工作空间为 {workspace_id}")
return
except Exception as e:
db.rollback()
business_logger.error(f"切换工作空间失败 - 工作空间: {workspace_id}, 错误: {str(e)}")
raise BusinessException(f"切换工作空间失败: {str(e)}", BizCode.INTERNAL_ERROR)
def delete_workspace_member(
db: Session,
workspace_id: uuid.UUID,
member_id: uuid.UUID,
user: User,
):
"""删除工作空间成员"""
business_logger.debug(f"用户 {user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
_check_workspace_admin_permission(db, workspace_id, user)
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
if not workspace_member:
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
if workspace_member.workspace_id != workspace_id:
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
try:
workspace_member.is_active = False
workspace_member.user.current_workspace_id = None
db.commit()
business_logger.info(f"用户 {user.username} 成功删除工作空间 {workspace_id} 的成员 {member_id}")
except Exception as e:
db.rollback()
business_logger.error(f"删除工作空间成员失败 - 工作空间: {workspace_id}, 成员: {member_id}, 错误: {str(e)}")
raise BusinessException(f"删除工作空间成员失败: {str(e)}", BizCode.INTERNAL_ERROR)
def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
"""获取当前用户参与的所有工作空间"""
business_logger.debug(f"获取用户工作空间列表: {user.username} (ID: {user.id})")
workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
return workspaces
def _create_workspace_only(
db: Session, workspace: WorkspaceCreate, owner: User
) -> Workspace:
business_logger.debug(f"创建工作空间: {workspace.name}, 创建者: {owner.username}")
try:
# Create the workspace without adding any members
business_logger.debug(f"创建工作空间: {workspace.name}")
db_workspace = workspace_repository.create_workspace(
db=db, workspace=workspace, tenant_id=owner.tenant_id
)
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {owner.username}")
return db_workspace
except Exception as e:
business_logger.error(f"创建工作空间失败: {workspace.name} - {str(e)}")
raise
def create_workspace(
db: Session, workspace: WorkspaceCreate, user: User
) -> Workspace:
business_logger.info(
f"创建工作空间: {workspace.name}, 创建者: {user.username}, "
f"storage_type: {workspace.storage_type}"
)
llm=workspace.llm
embedding=workspace.embedding
rerank=workspace.rerank
try:
# Create the workspace without adding any members
business_logger.debug(f"创建工作空间: {workspace.name}")
db_workspace = workspace_repository.create_workspace(
db=db, workspace=workspace, tenant_id=user.tenant_id
)
business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}")
db.commit()
db.refresh(db_workspace)
# 如果 storage_type 是 "rag",自动创建知识库
if workspace.storage_type == "rag":
business_logger.info(
f"检测到 storage_type 为 'rag',开始为工作空间 "
f"{db_workspace.id} 创建知识库"
)
try:
import os
from app.schemas.knowledge_schema import KnowledgeCreate
from app.models.knowledge_model import KnowledgeType, PermissionType
from app.repositories import knowledge_repository
# 创建知识库数据
knowledge_data = KnowledgeCreate(
workspace_id=db_workspace.id,
created_by=user.id,
parent_id=db_workspace.id,
name="USER_RAG_MERORY",
description=f"工作空间 {workspace.name} 的默认知识库",
avatar='',
type=KnowledgeType.General,
permission_id=PermissionType.Private,
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
image2text_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
parser_config={
"layout_recognize": "DeepDOC",
"chunk_token_num": 256,
"delimiter": "\n",
"auto_keywords": 0,
"auto_questions": 0,
"html4excel": False
}
)
# 直接使用 repository 创建知识库,避免 service 层的额外逻辑
db_knowledge = knowledge_repository.create_knowledge(
db=db,
knowledge=knowledge_data
)
db.commit()
business_logger.info(
f"为工作空间 {db_workspace.id} 自动创建知识库成功: "
f"{db_knowledge.name} (ID: {db_knowledge.id})"
)
except Exception as kb_error:
business_logger.error(
f"为工作空间 {db_workspace.id} 创建知识库失败: {str(kb_error)}"
)
db.rollback()
raise BusinessException(
f"工作空间创建成功,但知识库创建失败: {str(kb_error)}",
BizCode.INTERNAL_ERROR
)
return db_workspace
except Exception as e:
business_logger.error(f"工作空间创建失败: {workspace.name} - {str(e)}")
db.rollback()
raise
def update_workspace(
db: Session, workspace_id: uuid.UUID, workspace_in: WorkspaceUpdate, user: User
) -> Workspace:
business_logger.info(f"更新工作空间: workspace_id={workspace_id}, 操作者: {user.username}")
db_workspace = _check_workspace_admin_permission(db,workspace_id,user)
try:
# 更新工作空间
business_logger.debug(f"执行工作空间更新: {db_workspace.name} (ID: {workspace_id})")
update_data = workspace_in.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_workspace, field, value)
db.add(db_workspace)
db.commit()
db.refresh(db_workspace)
business_logger.info(f"工作空间更新成功: {db_workspace.name} (ID: {workspace_id})")
return db_workspace
except Exception as e:
business_logger.error(f"工作空间更新失败: workspace_id={workspace_id} - {str(e)}")
db.rollback()
raise
def get_workspace_members(
db: Session, workspace_id: uuid.UUID, user: User
) -> List[WorkspaceMember]:
"""获取某工作空间的成员列表(关系序列化由模型关系支持)"""
business_logger.info(f"获取工作空间成员: workspace_id={workspace_id}, 操作者: {user.username}")
# 查找工作空间
business_logger.debug(f"查找工作空间: {workspace_id}")
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not workspace:
business_logger.warning(f"工作空间不存在: {workspace_id}")
raise BusinessException(
message="Workspace not found",
code=BizCode.WORKSPACE_NOT_FOUND
)
# 权限检查:工作空间成员或超级管理员可以查看成员列表
from app.core.permissions import permission_service, Subject, Resource, Action
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
workspace_memberships = {workspace_id} if member else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(workspace)
try:
permission_service.require_permission(
subject,
Action.READ,
resource,
error_message=f"用户 {user.username} 没有查看工作空间 {workspace_id} 成员列表的权限"
)
except PermissionDeniedException as e:
business_logger.warning(
f"权限不足: 用户 {user.username} 尝试获取工作空间 {workspace_id} 成员列表"
)
raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED)
# 查询成员并预加载 user/workspace 关系
members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
business_logger.info(f"工作空间成员数量: {len(members)} - workspace_id={workspace_id}")
return members
# ==================== 邀请相关服务方法 ====================
def _generate_invite_token() -> tuple[str, str]:
"""生成邀请令牌和其哈希值
Returns:
tuple: (原始令牌, 令牌哈希)
"""
# 生成32字节的随机令牌
token = secrets.token_urlsafe(32)
# 生成令牌的SHA256哈希
token_hash = hashlib.sha256(token.encode()).hexdigest()
return token, token_hash
def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None:
"""检查用户是否为工作空间成员或超级管理员(使用统一权限服务)"""
# 获取工作空间信息
db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not db_workspace:
raise BusinessException(
message="Workspace not found",
code=BizCode.WORKSPACE_NOT_FOUND
)
# 使用统一权限服务检查访问权限
from app.core.permissions import permission_service, Subject, Resource, Action
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
# 任何成员都有访问权限
workspace_memberships = {workspace_id} if member else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(db_workspace)
try:
permission_service.require_permission(
subject,
Action.READ,
resource,
error_message=f"用户 {user.username} 不是工作空间 {workspace_id} 的成员"
)
business_logger.debug(f"用户 {user.username} 是工作空间 {workspace_id} 的成员或超级管理员")
except PermissionDeniedException as e:
business_logger.warning(f"权限不足: 用户 {user.username} 尝试访问工作空间 {workspace_id}")
raise BusinessException(str(e), BizCode.WORKSPACE_NO_ACCESS)
return db_workspace
def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user: User) -> Workspace | None:
"""检查用户是否有工作空间管理员权限(使用统一权限服务)"""
# 获取工作空间信息
db_workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not db_workspace:
raise BusinessException(
message="Workspace not found",
code=BizCode.WORKSPACE_NOT_FOUND
)
# 使用统一权限服务检查管理权限
from app.core.permissions import permission_service, Subject, Resource, Action
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
# 只有 manager 才有管理权限
workspace_memberships = {workspace_id} if (member and member.role == WorkspaceRole.manager) else set()
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
resource = Resource.from_workspace(db_workspace)
try:
permission_service.require_permission(
subject,
Action.MANAGE,
resource,
error_message=f"用户 {user.username} 没有管理工作空间 {workspace_id} 的权限"
)
business_logger.debug(f"用户 {user.username} 有权限管理工作空间 {workspace_id}")
except PermissionDeniedException as e:
business_logger.warning(f"权限不足: 用户 {user.username} 尝试管理工作空间 {workspace_id}")
raise BusinessException(str(e), BizCode.WORKSPACE_ACCESS_DENIED)
return db_workspace
def create_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_data: WorkspaceInviteCreate,
user: User
) -> WorkspaceInviteResponse:
"""创建工作空间邀请"""
business_logger.info(f"创建工作空间邀请: workspace_id={workspace_id}, email={invite_data.email}, 创建者: {user.username}")
try:
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
if settings.ENABLE_SINGLE_WORKSPACE:
# 检查被邀请用户是否已经在工作空间中
from app.repositories import user_repository
invited_user = user_repository.get_user_by_email(db, invite_data.email)
if invited_user:
# 用户存在,检查是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
db=db,
user_id=invited_user.id,
workspace_id=workspace_id
)
if existing_member:
business_logger.warning(f"用户 {invite_data.email} 已经是工作空间成员")
raise BusinessException("该用户已经是工作空间成员", BizCode.RESOURCE_ALREADY_EXISTS)
# 检查是否已有待处理的邀请
invite_repo = WorkspaceInviteRepository(db)
existing_invite = invite_repo.get_pending_invite_by_email_and_workspace(
email=invite_data.email,
workspace_id=workspace_id
)
invite_token = None
if existing_invite:
business_logger.info(f"邮箱 {invite_data.email} 在工作空间 {workspace_id} 已有待处理邀请,返回现有邀请")
# 生成新的邀请链接(重新生成令牌)
token, token_hash = _generate_invite_token()
existing_invite.token_hash = token_hash
existing_invite.updated_at = datetime.datetime.now()
db.commit()
db.refresh(existing_invite)
invite_token = token
else:
# 生成邀请令牌
token, token_hash = _generate_invite_token()
# 创建邀请
db_invite = invite_repo.create_invite(
workspace_id=workspace_id,
invite_data=invite_data,
token_hash=token_hash,
created_by_user_id=user.id
)
db.commit()
db.refresh(db_invite)
invite_token = token
invite_obj = existing_invite or db_invite
business_logger.info(f"工作空间邀请创建成功: invite_id={invite_obj.id}, email={invite_data.email}")
# 构造响应
response = WorkspaceInviteResponse.model_validate(invite_obj)
response.invite_token = invite_token
return response
except Exception as e:
db.rollback()
business_logger.error(f"创建工作空间邀请失败: workspace_id={workspace_id}, email={invite_data.email} - {str(e)}")
raise
def get_workspace_invites(
db: Session,
workspace_id: uuid.UUID,
user: User,
status: Optional[InviteStatus] = None,
limit: int = 50,
offset: int = 0
) -> List[WorkspaceInviteResponse]:
"""获取工作空间邀请列表"""
business_logger.info(f"获取工作空间邀请列表: workspace_id={workspace_id}, 操作者: {user.username}")
# 检查工作空间是否存在
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not workspace:
raise BusinessException("工作空间不存在", BizCode.WORKSPACE_NOT_FOUND)
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
# 获取邀请列表
invite_repo = WorkspaceInviteRepository(db)
invites = invite_repo.get_workspace_invites(
workspace_id=workspace_id,
status=status,
limit=limit,
offset=offset
)
return [WorkspaceInviteResponse.model_validate(invite) for invite in invites]
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
"""验证邀请令牌"""
business_logger.info(f"验证邀请令牌")
# 生成令牌哈希
token_hash = hashlib.sha256(token.encode()).hexdigest()
# 查找邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning(f"邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态和过期时间
now = datetime.datetime.now()
is_expired = invite.expires_at < now or invite.status != InviteStatus.pending
is_valid = not is_expired
# 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"邀请令牌验证完成: valid={is_valid}, expired={is_expired}")
return InviteValidateResponse(
workspace_name=workspace.name,
workspace_id=invite.workspace_id,
email=invite.email,
role=WorkspaceRole(invite.role),
is_expired=is_expired,
is_valid=is_valid
)
def accept_workspace_invite(
db: Session,
accept_request: InviteAcceptRequest,
user: User
) -> dict:
"""接受工作空间邀请"""
business_logger.info(f"接受工作空间邀请: 用户 {user.username}")
try:
from app.core.config import settings
# 生成令牌哈希
token_hash = hashlib.sha256(accept_request.token.encode()).hexdigest()
# 查找邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning(f"邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态
if invite.status != InviteStatus.pending:
business_logger.warning(f"邀请已被处理: status={invite.status}")
raise BusinessException(f"邀请已被{invite.status}", BizCode.WORKSPACE_INVITE_INVALID)
# 检查过期时间
now = datetime.datetime.now()
if invite.expires_at < now:
business_logger.warning(f"邀请已过期")
# 标记为过期
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
# 检查邮箱是否匹配
if invite.email != user.email:
business_logger.warning(f"邮箱不匹配: invite_email={invite.email}, user_email={user.email}")
raise BusinessException("邮箱与邀请邮箱不匹配", BizCode.FORBIDDEN)
# 如果启用单工作空间模式,检查用户是否已有工作空间
if settings.ENABLE_SINGLE_WORKSPACE:
user_workspaces = workspace_repository.get_workspaces_by_user(db=db, user_id=user.id)
if user_workspaces:
business_logger.warning(f"单工作空间模式下用户已有工作空间: user={user.username}")
raise BusinessException("用户只能加入一个工作空间", BizCode.FORBIDDEN)
# 检查用户是否已经是工作空间成员
existing_member = workspace_repository.get_member_in_workspace(
db=db,
user_id=user.id,
workspace_id=invite.workspace_id
)
if existing_member:
business_logger.info(f"用户已是工作空间成员,更新邀请状态")
invite_repo.update_invite_status(
invite.id,
InviteStatus.accepted,
accepted_at=now
)
db.commit()
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
return {
"message": "You are already a member of this workspace",
"workspace": workspace
}
# 将角色映射到工作空间角色(现在直接使用相同的角色)
workspace_role = invite.role
# 添加用户到工作空间
workspace_repository.add_member_to_workspace(
db=db,
user_id=user.id,
workspace_id=invite.workspace_id,
role=workspace_role
)
# 标记邀请为已接受
invite_repo.update_invite_status(
invite.id,
InviteStatus.accepted,
accepted_at=now
)
db.commit()
# 获取工作空间信息
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=invite.workspace_id)
business_logger.info(f"用户成功加入工作空间: user={user.username}, workspace={workspace.name}, role={workspace_role}")
return {
"message": "Successfully joined the workspace",
"workspace": workspace,
"role": workspace_role
}
except Exception as e:
db.rollback()
business_logger.error(f"接受工作空间邀请失败: user={user.username} - {str(e)}")
raise
def revoke_workspace_invite(
db: Session,
workspace_id: uuid.UUID,
invite_id: uuid.UUID,
user: User
) -> dict:
"""撤销工作空间邀请"""
business_logger.info(f"撤销工作空间邀请: workspace_id={workspace_id}, invite_id={invite_id}, 操作者: {user.username}")
try:
# 检查权限
_check_workspace_admin_permission(db, workspace_id, user)
# 撤销邀请
invite_repo = WorkspaceInviteRepository(db)
invite = invite_repo.revoke_invite(invite_id)
if not invite:
business_logger.warning(f"邀请不存在: invite_id={invite_id}")
raise BusinessException("邀请不存在", BizCode.WORKSPACE_INVITE_NOT_FOUND)
if invite.workspace_id != workspace_id:
business_logger.warning(f"邀请不属于指定工作空间: invite_id={invite_id}, workspace_id={workspace_id}")
raise BusinessException("邀请不属于指定工作空间", BizCode.BAD_REQUEST)
db.commit()
business_logger.info(f"工作空间邀请撤销成功: invite_id={invite_id}")
return {"message": "邀请撤销成功"}
except Exception as e:
db.rollback()
business_logger.error(f"撤销工作空间邀请失败: invite_id={invite_id} - {str(e)}")
raise
def update_workspace_member_roles(
db: Session,
workspace_id: uuid.UUID,
updates: List[WorkspaceMemberUpdate],
user: User,
) -> List[WorkspaceMember]:
"""更新工作空间成员角色"""
business_logger.info(f"更新工作空间成员角色: workspace_id={workspace_id}, 操作者: {user.username}, 更新数量: {len(updates)}")
# 检查管理员权限
_check_workspace_admin_permission(db, workspace_id, user)
# 获取所有当前成员
all_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
member_map = {m.id: m for m in all_members}
# 验证和业务规则检查
update_ids = set()
for upd in updates:
# 检查成员是否存在
if upd.id not in member_map:
raise BusinessException(f"成员 {upd.id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
member = member_map[upd.id]
# 检查成员是否属于该工作空间
if member.workspace_id != workspace_id:
raise BusinessException(f"成员 {upd.id} 不属于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
# 不能修改自己的角色
if member.user_id == user.id:
raise BusinessException("不能修改自己的角色", BizCode.BAD_REQUEST)
update_ids.add(upd.id)
# 检查是否至少保留一个 manager
current_managers = [m for m in all_members if m.role == WorkspaceRole.manager]
managers_after_update = [
m for m in all_members
if m.id not in update_ids and m.role == WorkspaceRole.manager
]
# 添加更新后会成为 manager 的成员
for upd in updates:
if upd.role == WorkspaceRole.manager:
managers_after_update.append(member_map[upd.id])
if len(managers_after_update) == 0:
raise BusinessException("工作空间至少需要一个管理员", BizCode.BAD_REQUEST)
# 执行更新
try:
for upd in updates:
workspace_repository.update_member_role_by_id(
db=db,
id=upd.id,
role=upd.role,
)
business_logger.debug(f"更新成员 {upd.id} 角色为 {upd.role}")
db.commit()
# 重新获取更新后的成员列表
updated_members = workspace_repository.get_members_by_workspace(db=db, workspace_id=workspace_id)
business_logger.info(f"成员角色更新完成: workspace_id={workspace_id}, 更新数量={len(updates)}")
return updated_members
except Exception as e:
db.rollback()
business_logger.error(f"更新工作空间成员角色失败: workspace_id={workspace_id} - {str(e)}")
raise BusinessException(f"更新成员角色失败: {str(e)}", BizCode.INTERNAL_ERROR)
def get_workspace_storage_type(
db: Session,
workspace_id: uuid.UUID,
user: User,
) -> Optional[str]:
"""获取工作空间的存储类型
Args:
db: 数据库会话
workspace_id: 工作空间ID
user: 当前用户
Returns:
storage_type: 存储类型字符串,如果未设置则返回 None
"""
business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的存储类型")
# 检查用户是否有权限访问该工作空间
_check_workspace_member_permission(db, workspace_id, user)
# 查询工作空间
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not workspace:
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
raise BusinessException(
code=BizCode.WORKSPACE_NOT_FOUND,
message="工作空间不存在"
)
business_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {workspace.storage_type}")
return workspace.storage_type
def get_workspace_models_configs(
db: Session,
workspace_id: uuid.UUID,
user: User,
) -> Optional[dict]:
"""获取工作空间的模型配置llm, embedding, rerank
Args:
db: 数据库会话
workspace_id: 工作空间ID
user: 当前用户
Returns:
dict: 包含 llm, embedding, rerank 的字典,如果工作空间不存在则返回 None
"""
business_logger.info(f"用户 {user.username} 请求获取工作空间 {workspace_id} 的模型配置")
# 检查用户是否有权限访问该工作空间
_check_workspace_member_permission(db, workspace_id, user)
# 查询工作空间模型配置
configs = workspace_repository.get_workspace_models_configs(db=db, workspace_id=workspace_id)
if configs is None:
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
raise BusinessException(
code=BizCode.WORKSPACE_NOT_FOUND,
message="工作空间不存在"
)
business_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return configs