[MODIFY] Code optimization
This commit is contained in:
@@ -27,7 +27,7 @@ class AgentRegistry:
|
||||
self._cache[str(agent.id)] = agent_info
|
||||
|
||||
logger.info(
|
||||
f"Agent 注册成功",
|
||||
"Agent 注册成功",
|
||||
extra={
|
||||
"agent_id": str(agent.id),
|
||||
"name": agent.app.name,
|
||||
@@ -92,7 +92,7 @@ class AgentRegistry:
|
||||
agents = self.db.scalars(stmt).all()
|
||||
|
||||
logger.debug(
|
||||
f"Agent 发现",
|
||||
"Agent 发现",
|
||||
extra={
|
||||
"query": query,
|
||||
"domain": domain,
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.agents import create_agent, AgentState
|
||||
from langchain.agents.middleware import before_model
|
||||
from langchain.tools import tool
|
||||
from langchain_core.messages import RemoveMessage
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.services.api_resquests_server import send_message, model, retrieval
|
||||
|
||||
|
||||
class config(BaseModel):
|
||||
template_str:str
|
||||
params:dict
|
||||
model_configs: List[dict] = []
|
||||
history_memory:bool
|
||||
knowledge_base:bool
|
||||
|
||||
class RemoryInput(BaseModel):
|
||||
question: str
|
||||
end_user_id: str
|
||||
search_switch:str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
end_user_id: str
|
||||
message: str
|
||||
search_switch:str
|
||||
kb_ids: List[str] = []
|
||||
similarity_threshold:float
|
||||
vector_similarity_weight:float
|
||||
top_k:int
|
||||
hybrid:bool
|
||||
token:str
|
||||
|
||||
class RetrievalInput(BaseModel):
|
||||
message: str
|
||||
kb_ids: List[str] = []
|
||||
similarity_threshold: float
|
||||
vector_similarity_weight: float
|
||||
top_k: int
|
||||
hybrid: bool
|
||||
token: str
|
||||
|
||||
async def tool_Retrieval(req):
|
||||
tool_result = retrieval_search.invoke({
|
||||
"message":req.message, "kb_ids":req.kb_ids,
|
||||
"similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight,
|
||||
"top_k":req.top_k, "hybrid":req.hybrid, "token":req.token
|
||||
})
|
||||
return tool_result
|
||||
async def tool_memory(req):
|
||||
tool_result = remory_sk.invoke({
|
||||
"question": req.message,
|
||||
"end_user_id": req.end_user_id,
|
||||
"search_switch": req.search_switch
|
||||
})
|
||||
return tool_result
|
||||
|
||||
|
||||
@before_model
|
||||
# ========== 消息剪枝中间件 ==========
|
||||
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
"""保留前1条 + 最近3~4条消息"""
|
||||
messages = state["messages"]
|
||||
if len(messages) <= 10:
|
||||
return None
|
||||
first_msg = messages[0]
|
||||
recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:]
|
||||
new_messages = [first_msg] + recent_messages
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
RemoveMessage(id=REMOVE_ALL_MESSAGES),
|
||||
*new_messages
|
||||
]
|
||||
}
|
||||
|
||||
##-----------历史记忆------------
|
||||
@ tool(args_schema=RemoryInput)
|
||||
def remory_sk(question: str, end_user_id: str, search_switch: str):
|
||||
"""
|
||||
条件调用工具:
|
||||
- 仅当 question 是疑问句时调用 send_message
|
||||
- 根据 end_user_id 和 search_switch 参数决定是否执行检索
|
||||
|
||||
Args:
|
||||
question: 用户的提问内容
|
||||
end_user_id: 用户唯一标识符
|
||||
search_switch: 搜索开关,控制是否执行检索
|
||||
|
||||
Returns:
|
||||
检索结果或空字符串
|
||||
"""
|
||||
# 移除关于 configurable 的描述,避免混淆
|
||||
if not end_user_id or end_user_id == '123':
|
||||
print("警告: 无效的 user_id 参数")
|
||||
return ''
|
||||
|
||||
if search_switch in ['on', 'off'] or not search_switch:
|
||||
print("警告: 无效的 search_switch 参数")
|
||||
return ''
|
||||
return send_message(end_user_id, question, '[]', search_switch)
|
||||
|
||||
#-------------检索------------
|
||||
|
||||
|
||||
@ tool(args_schema=RetrievalInput)
|
||||
def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token):
|
||||
'''检索'''
|
||||
search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token)
|
||||
return search
|
||||
async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str):
|
||||
"""根据模型名动态创建代理"""
|
||||
model_name, api_key, api_base=await model(model_id,token)
|
||||
llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key)
|
||||
memory = InMemorySaver()
|
||||
return create_agent(
|
||||
llm,
|
||||
tools=[remory_sk,retrieval_search],
|
||||
middleware=[trim_messages],
|
||||
checkpointer=memory,
|
||||
system_prompt=PROMPT
|
||||
)
|
||||
@@ -80,7 +80,7 @@ def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID
|
||||
result += "\n"
|
||||
|
||||
logger.info(
|
||||
f"Agent 发现成功",
|
||||
"Agent 发现成功",
|
||||
extra={
|
||||
"query": query,
|
||||
"domain": domain,
|
||||
@@ -91,7 +91,7 @@ def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent 发现失败", extra={"error": str(e)})
|
||||
logger.error("Agent 发现失败", extra={"error": str(e)})
|
||||
return f"发现 Agent 失败: {str(e)}"
|
||||
|
||||
return discover_agents
|
||||
@@ -138,7 +138,7 @@ def create_agent_invocation_tool(
|
||||
if workspace and workspace.storage_type:
|
||||
storage_type = workspace.storage_type
|
||||
logger.debug(
|
||||
f"获取工作空间存储类型成功",
|
||||
"获取工作空间存储类型成功",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"storage_type": storage_type
|
||||
@@ -146,7 +146,7 @@ def create_agent_invocation_tool(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取工作空间存储类型失败,使用默认值 neo4j",
|
||||
"获取工作空间存储类型失败,使用默认值 neo4j",
|
||||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||||
)
|
||||
|
||||
@@ -161,7 +161,7 @@ def create_agent_invocation_tool(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
logger.debug(
|
||||
f"获取 RAG 知识库成功",
|
||||
"获取 RAG 知识库成功",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"knowledge_id": user_rag_memory_id
|
||||
@@ -169,13 +169,13 @@ def create_agent_invocation_tool(
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
|
||||
"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"获取 RAG 知识库失败,将使用 neo4j 存储",
|
||||
"获取 RAG 知识库失败,将使用 neo4j 存储",
|
||||
extra={"workspace_id": str(workspace_id), "error": str(e)}
|
||||
)
|
||||
storage_type = 'neo4j'
|
||||
@@ -226,12 +226,12 @@ def create_agent_invocation_tool(
|
||||
# 6. 获取 Agent 配置
|
||||
agent_config = db.get(AgentConfig, agent_uuid)
|
||||
if not agent_config:
|
||||
return f"Agent 配置不存在"
|
||||
return "Agent 配置不存在"
|
||||
|
||||
# 7. 获取模型配置
|
||||
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
if not model_config:
|
||||
return f"Agent 模型配置不存在"
|
||||
return "Agent 模型配置不存在"
|
||||
|
||||
# 8. 创建调用记录
|
||||
invocation = AgentInvocation(
|
||||
@@ -249,7 +249,7 @@ def create_agent_invocation_tool(
|
||||
db.refresh(invocation)
|
||||
|
||||
logger.info(
|
||||
f"Agent 调用开始",
|
||||
"Agent 调用开始",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
@@ -286,7 +286,7 @@ def create_agent_invocation_tool(
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Agent 调用成功",
|
||||
"Agent 调用成功",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
@@ -306,7 +306,7 @@ def create_agent_invocation_tool(
|
||||
db.commit()
|
||||
|
||||
logger.error(
|
||||
f"Agent 调用失败",
|
||||
"Agent 调用失败",
|
||||
extra={
|
||||
"invocation_id": str(invocation.id),
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
@@ -319,7 +319,7 @@ def create_agent_invocation_tool(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Agent 调用异常",
|
||||
"Agent 调用异常",
|
||||
extra={
|
||||
"caller_agent_id": str(current_agent_id),
|
||||
"callee_agent_id": agent_id,
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
"""API Key Service"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional, Tuple, List
|
||||
import time
|
||||
import uuid
|
||||
import datetime
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.models.api_key_model import ApiKey, ApiKeyType
|
||||
from app.repositories.api_key_repository import ApiKeyRepository
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.models.api_key_model import ApiKey
|
||||
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.core.api_key_utils import generate_api_key
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
)
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
@@ -19,81 +25,108 @@ logger = get_business_logger()
|
||||
|
||||
class ApiKeyService:
|
||||
"""API Key 业务逻辑服务"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def create_api_key(
|
||||
db: Session,
|
||||
*,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyCreate
|
||||
db: Session,
|
||||
*,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyCreate
|
||||
) -> Tuple[ApiKey, str]:
|
||||
"""创建 API Key
|
||||
|
||||
"""
|
||||
创建 API Key
|
||||
Returns:
|
||||
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
|
||||
"""
|
||||
# 生成 API Key
|
||||
api_key, key_hash, key_prefix = generate_api_key(data.type)
|
||||
|
||||
# 创建数据
|
||||
api_key_data = {
|
||||
"id": uuid.uuid4(),
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash,
|
||||
"type": data.type,
|
||||
"scopes": data.scopes,
|
||||
"workspace_id": workspace_id,
|
||||
"resource_id": data.resource_id,
|
||||
"resource_type": data.resource_type,
|
||||
"rate_limit": data.rate_limit,
|
||||
"quota_limit": data.quota_limit,
|
||||
"expires_at": data.expires_at,
|
||||
"created_by": user_id,
|
||||
"created_at": datetime.datetime.now(),
|
||||
"updated_at": datetime.datetime.now(),
|
||||
}
|
||||
|
||||
api_key_obj = ApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
db.refresh(api_key_obj)
|
||||
|
||||
logger.info(f"API Key 创建成功", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"name": data.name,
|
||||
"type": data.type
|
||||
})
|
||||
|
||||
return api_key_obj, api_key
|
||||
|
||||
try:
|
||||
# 验证资源绑定
|
||||
if data.resource_type or data.resource_id:
|
||||
is_valid, error_msg = validate_resource_binding(
|
||||
data.resource_type, str(data.resource_id) if data.resource_id else None
|
||||
)
|
||||
if not is_valid:
|
||||
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
|
||||
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
# 生成 API Key
|
||||
api_key, key_hash, key_prefix = generate_api_key(data.type)
|
||||
|
||||
# 创建数据
|
||||
api_key_data = {
|
||||
"id": uuid.uuid4(),
|
||||
"name": data.name,
|
||||
"description": data.description,
|
||||
"key_prefix": key_prefix,
|
||||
"key_hash": key_hash,
|
||||
"type": data.type,
|
||||
"scopes": data.scopes,
|
||||
"workspace_id": workspace_id,
|
||||
"resource_id": data.resource_id,
|
||||
"resource_type": data.resource_type,
|
||||
"rate_limit": data.rate_limit or 10,
|
||||
"daily_request_limit": data.daily_request_limit or 10000,
|
||||
"quota_limit": data.quota_limit,
|
||||
"expires_at": data.expires_at,
|
||||
"created_by": user_id,
|
||||
}
|
||||
|
||||
api_key_obj = ApiKeyRepository.create(db, api_key_data)
|
||||
db.commit()
|
||||
|
||||
logger.info("API Key 创建成功", extra={
|
||||
"api_key_id": str(api_key_obj.id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"api_key_name": data.name,
|
||||
"type": data.type
|
||||
})
|
||||
|
||||
return api_key_obj, api_key
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"API Key 创建失败: {e}", extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"api_key_name": getattr(data, 'name', 'unknown'),
|
||||
"error": str(e)
|
||||
})
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> ApiKey:
|
||||
"""获取 API Key"""
|
||||
api_key = ApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException("API Key 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
|
||||
|
||||
if api_key.workspace_id != workspace_id:
|
||||
raise BusinessException("无权访问此 API Key", BizCode.FORBIDDEN)
|
||||
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_api_keys(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
query: api_key_schema.ApiKeyQuery
|
||||
) -> PageData:
|
||||
"""列出 API Keys"""
|
||||
items, total = ApiKeyRepository.list_by_workspace(db, workspace_id, query)
|
||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||
|
||||
|
||||
return PageData(
|
||||
page=PageMeta(
|
||||
page=query.page,
|
||||
@@ -103,52 +136,69 @@ class ApiKeyService:
|
||||
),
|
||||
items=[api_key_schema.ApiKey.model_validate(item) for item in items]
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def update_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate
|
||||
) -> ApiKey:
|
||||
"""更新 API Key"""
|
||||
"""更新 API Key配置"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
|
||||
# 检查名称重复
|
||||
if data.name and data.name != api_key.name:
|
||||
existing = db.scalar(
|
||||
select(ApiKey).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.name == data.name,
|
||||
ApiKey.is_active,
|
||||
ApiKey.id != api_key_id
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"API Key 更新成功", extra={"api_key_id": str(api_key_id)})
|
||||
|
||||
logger.info("API Key 更新成功", extra={"api_key_id": str(api_key_id)})
|
||||
return api_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def delete_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> bool:
|
||||
"""删除 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
|
||||
ApiKeyRepository.delete(db, api_key_id)
|
||||
db.commit()
|
||||
|
||||
logger.info(f"API Key 删除成功", extra={"api_key_id": str(api_key_id)})
|
||||
|
||||
logger.info("API Key 删除成功", extra={"api_key_id": str(api_key_id)})
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def regenerate_api_key(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Tuple[ApiKey, str]:
|
||||
"""重新生成 API Key"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
|
||||
# 检查 API Key 是否激活
|
||||
if not api_key.is_active:
|
||||
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
|
||||
|
||||
# 生成新的 API Key
|
||||
new_api_key, key_hash, key_prefix = generate_api_key(ApiKeyType(api_key.type))
|
||||
|
||||
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
|
||||
|
||||
# 更新
|
||||
ApiKeyRepository.update(db, api_key_id, {
|
||||
"key_hash": key_hash,
|
||||
@@ -156,18 +206,201 @@ class ApiKeyService:
|
||||
})
|
||||
db.commit()
|
||||
db.refresh(api_key)
|
||||
|
||||
logger.info(f"API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
||||
|
||||
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
|
||||
return api_key, new_api_key
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_stats(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> api_key_schema.ApiKeyStats:
|
||||
"""获取使用统计"""
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
|
||||
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
|
||||
return api_key_schema.ApiKeyStats(**stats_data)
|
||||
|
||||
@staticmethod
|
||||
def get_logs(
|
||||
db: Session,
|
||||
api_key_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
filters: dict,
|
||||
page: int,
|
||||
pagesize: int
|
||||
) -> PageData:
|
||||
"""获取 API Key 使用日志"""
|
||||
# 验证 API Key 权限
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
items, total = ApiKeyLogRepository.list_by_api_key(
|
||||
db, api_key_id, filters, page, pagesize
|
||||
)
|
||||
|
||||
# 计算分页信息
|
||||
pages = math.ceil(total / pagesize) if total > 0 else 0
|
||||
|
||||
return PageData(
|
||||
page=PageMeta(
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=page < pages
|
||||
),
|
||||
items=[api_key_schema.ApiKeyLog.model_validate(item) for item in items]
|
||||
)
|
||||
|
||||
|
||||
class RateLimiterService:
|
||||
def __init__(self):
|
||||
self.redis = aio_redis
|
||||
|
||||
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
||||
"""
|
||||
检查QPS限制
|
||||
Returns:
|
||||
(is_allowed, rate_limit_info)
|
||||
"""
|
||||
key = f"rate_limit:qps:{api_key_id}"
|
||||
async with self.redis.pipeline() as pipe:
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, 1) # 1 秒过期
|
||||
results = await pipe.execute()
|
||||
|
||||
current = results[0]
|
||||
remaining = max(0, limit - current)
|
||||
reset_time = int(time.time()) + 1
|
||||
|
||||
return current <= limit, {
|
||||
"limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset": reset_time
|
||||
}
|
||||
|
||||
async def check_daily_requests(
|
||||
self,
|
||||
api_key_id: uuid.UUID,
|
||||
limit: int
|
||||
) -> Tuple[bool, dict]:
|
||||
"""检查日调用量限制"""
|
||||
today = datetime.now().strftime("%Y%m%d")
|
||||
key = f"rate_limit:daily:{api_key_id}:{today}"
|
||||
|
||||
now = datetime.now()
|
||||
tomorrow_0 = (now + timedelta(days=1)).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
expire_seconds = int((tomorrow_0 - now).total_seconds())
|
||||
|
||||
async with self.redis.pipeline() as pipe:
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, expire_seconds, nx=True)
|
||||
results = await pipe.execute()
|
||||
|
||||
current = results[0]
|
||||
remaining = max(0, limit - current)
|
||||
reset_time = int(tomorrow_0.timestamp())
|
||||
|
||||
return current <= limit, {
|
||||
"limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset": reset_time
|
||||
}
|
||||
|
||||
async def check_all_limits(
|
||||
self,
|
||||
api_key: ApiKey
|
||||
) -> Tuple[bool, str, dict]:
|
||||
"""
|
||||
检查所有限制
|
||||
Returns:
|
||||
(is_allowed, error_message, rate_limit_headers)
|
||||
"""
|
||||
# Check QPS
|
||||
qps_ok, qps_info = await self.check_qps(
|
||||
api_key.id,
|
||||
api_key.rate_limit
|
||||
)
|
||||
if not qps_ok:
|
||||
return False, "QPS limit exceeded", {
|
||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||
"X-RateLimit-Reset": str(qps_info["reset"])
|
||||
}
|
||||
|
||||
# Check daily requests
|
||||
daily_ok, daily_info = await self.check_daily_requests(
|
||||
api_key.id,
|
||||
api_key.daily_request_limit
|
||||
)
|
||||
if not daily_ok:
|
||||
return False, "Daily request limit exceeded", {
|
||||
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
|
||||
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
|
||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
||||
}
|
||||
|
||||
# All checks passed
|
||||
headers = {
|
||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
|
||||
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
|
||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
||||
}
|
||||
return True, "", headers
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def validate_api_key(
|
||||
db: Session,
|
||||
api_key: str
|
||||
) -> Optional[ApiKey]:
|
||||
"""
|
||||
验证API Key 有效性
|
||||
|
||||
检查:
|
||||
1. Key hash 是否存在
|
||||
2. is_active 是否为true
|
||||
3. expires_at 是否未过期
|
||||
4. quota 是否未超限
|
||||
"""
|
||||
key_hash = hash_api_key(api_key)
|
||||
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
|
||||
|
||||
if not api_key_obj:
|
||||
return None
|
||||
|
||||
if not api_key_obj.is_active:
|
||||
return None
|
||||
|
||||
if api_key_obj.expires_at and datetime.now() > api_key_obj.expires_at:
|
||||
return None
|
||||
|
||||
if api_key_obj.quota_limit and api_key_obj.quota_used >= api_key_obj.quota_limit:
|
||||
return None
|
||||
|
||||
return api_key_obj
|
||||
|
||||
@staticmethod
|
||||
def check_scope(api_key: ApiKey, required_scope: str) -> bool:
|
||||
"""检查权限范围"""
|
||||
return required_scope in api_key.scopes
|
||||
|
||||
@staticmethod
|
||||
def check_resource(
|
||||
api_key: ApiKey,
|
||||
resource_type: str,
|
||||
resource_id: uuid.UUID
|
||||
) -> bool:
|
||||
"""检查资源绑定"""
|
||||
if not api_key.resource_id:
|
||||
return True
|
||||
|
||||
return (
|
||||
api_key.resource_type == resource_type and
|
||||
api_key.resource_id == resource_id
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class AppService:
|
||||
"""
|
||||
if workspace_id is not None and app.workspace_id != workspace_id:
|
||||
logger.warning(
|
||||
f"工作空间访问被拒",
|
||||
"工作空间访问被拒",
|
||||
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS)
|
||||
@@ -103,7 +103,7 @@ class AppService:
|
||||
"""
|
||||
if not self._check_app_accessible(app, workspace_id):
|
||||
logger.warning(
|
||||
f"应用访问被拒",
|
||||
"应用访问被拒",
|
||||
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
||||
)
|
||||
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
|
||||
@@ -122,7 +122,7 @@ class AppService:
|
||||
"""
|
||||
app = self.db.get(App, app_id)
|
||||
if not app:
|
||||
logger.warning(f"应用不存在", extra={"app_id": str(app_id)})
|
||||
logger.warning("应用不存在", extra={"app_id": str(app_id)})
|
||||
raise ResourceNotFoundException("应用", str(app_id))
|
||||
return app
|
||||
|
||||
@@ -257,7 +257,7 @@ class AppService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多智能体配置检查通过",
|
||||
"多智能体配置检查通过",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"master_agent_id": str(multi_agent_config.master_agent_id),
|
||||
@@ -295,7 +295,7 @@ class AppService:
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(agent_cfg)
|
||||
logger.debug(f"Agent 配置已创建", extra={"app_id": str(app_id)})
|
||||
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
|
||||
|
||||
def _create_multi_agent_config(
|
||||
self,
|
||||
@@ -380,7 +380,7 @@ class AppService:
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(multi_agent_cfg)
|
||||
logger.debug(f"多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode})
|
||||
logger.debug("多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode})
|
||||
|
||||
def _get_next_version(self, app_id: uuid.UUID) -> int:
|
||||
"""获取下一个版本号
|
||||
@@ -474,7 +474,7 @@ class AppService:
|
||||
BusinessException: 当创建失败时
|
||||
"""
|
||||
logger.info(
|
||||
f"创建应用",
|
||||
"创建应用",
|
||||
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
@@ -511,12 +511,12 @@ class AppService:
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
|
||||
logger.info(f"应用创建成功", extra={"app_id": str(app.id), "app_name": app.name})
|
||||
logger.info("应用创建成功", extra={"app_id": str(app.id), "app_name": app.name})
|
||||
return app
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"应用创建失败", extra={"app_name": data.name, "error": str(e)})
|
||||
logger.error("应用创建失败", extra={"app_name": data.name, "error": str(e)})
|
||||
raise BusinessException(f"应用创建失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||||
|
||||
def update_app(
|
||||
@@ -540,7 +540,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用不在指定工作空间时
|
||||
"""
|
||||
logger.info(f"更新应用", extra={"app_id": str(app_id)})
|
||||
logger.info("更新应用", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
@@ -556,9 +556,9 @@ class AppService:
|
||||
app.updated_at = datetime.datetime.now()
|
||||
self.db.commit()
|
||||
self.db.refresh(app)
|
||||
logger.info(f"应用更新成功", extra={"app_id": str(app_id)})
|
||||
logger.info("应用更新成功", extra={"app_id": str(app_id)})
|
||||
else:
|
||||
logger.debug(f"应用无变更", extra={"app_id": str(app_id)})
|
||||
logger.debug("应用无变更", extra={"app_id": str(app_id)})
|
||||
|
||||
return app
|
||||
|
||||
@@ -578,17 +578,17 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用不在指定工作空间时
|
||||
"""
|
||||
logger.info(f"删除应用", extra={"app_id": str(app_id)})
|
||||
logger.info("删除应用", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_workspace_access(app, workspace_id)
|
||||
|
||||
# 删除应用(级联删除相关数据)
|
||||
self.db.delete(app)
|
||||
# 逻辑删除应用
|
||||
app.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"应用删除成功",
|
||||
"应用删除成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"app_name": app.name,
|
||||
@@ -619,7 +619,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当源应用不存在时
|
||||
BusinessException: 当复制失败时
|
||||
"""
|
||||
logger.info(f"复制应用", extra={"source_app_id": str(app_id)})
|
||||
logger.info("复制应用", extra={"source_app_id": str(app_id)})
|
||||
|
||||
try:
|
||||
# 获取源应用
|
||||
@@ -682,7 +682,7 @@ class AppService:
|
||||
self.db.refresh(new_app)
|
||||
|
||||
logger.info(
|
||||
f"应用复制成功",
|
||||
"应用复制成功",
|
||||
extra={
|
||||
"source_app_id": str(app_id),
|
||||
"new_app_id": str(new_app.id),
|
||||
@@ -695,7 +695,7 @@ class AppService:
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(
|
||||
f"应用复制失败",
|
||||
"应用复制失败",
|
||||
extra={"source_app_id": str(app_id), "error": str(e)}
|
||||
)
|
||||
raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||||
@@ -734,7 +734,7 @@ class AppService:
|
||||
from app.models import AppShare
|
||||
|
||||
logger.debug(
|
||||
f"查询应用列表",
|
||||
"查询应用列表",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"include_shared": include_shared,
|
||||
@@ -745,6 +745,7 @@ class AppService:
|
||||
|
||||
# 构建查询条件
|
||||
filters = []
|
||||
filters.append(App.is_active == True)
|
||||
if type:
|
||||
filters.append(App.type == type)
|
||||
if visibility:
|
||||
@@ -791,7 +792,7 @@ class AppService:
|
||||
items = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.debug(
|
||||
f"应用列表查询完成",
|
||||
"应用列表查询完成",
|
||||
extra={"total": total, "returned": len(items), "include_shared": include_shared}
|
||||
)
|
||||
return items, int(total)
|
||||
@@ -819,7 +820,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或不在指定工作空间时
|
||||
"""
|
||||
logger.info(f"更新 Agent 配置", extra={"app_id": str(app_id)})
|
||||
logger.info("更新 Agent 配置", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
@@ -841,7 +842,7 @@ class AppService:
|
||||
updated_at=now,
|
||||
)
|
||||
self.db.add(agent_cfg)
|
||||
logger.debug(f"创建新的 Agent 配置", extra={"app_id": str(app_id)})
|
||||
logger.debug("创建新的 Agent 配置", extra={"app_id": str(app_id)})
|
||||
|
||||
# 转换为存储格式
|
||||
storage_data = AgentConfigConverter.to_storage_format(data)
|
||||
@@ -867,7 +868,7 @@ class AppService:
|
||||
self.db.commit()
|
||||
self.db.refresh(agent_cfg)
|
||||
|
||||
logger.info(f"Agent 配置更新成功", extra={"app_id": str(app_id)})
|
||||
logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)})
|
||||
return agent_cfg
|
||||
|
||||
def get_agent_config(
|
||||
@@ -891,7 +892,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或不可访问时
|
||||
"""
|
||||
logger.debug(f"获取 Agent 配置", extra={"app_id": str(app_id)})
|
||||
logger.debug("获取 Agent 配置", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
@@ -908,7 +909,7 @@ class AppService:
|
||||
return config
|
||||
|
||||
# 返回默认配置模板(不保存到数据库)
|
||||
logger.debug(f"配置不存在,返回默认模板", extra={"app_id": str(app_id)})
|
||||
logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)})
|
||||
return self._create_default_agent_config(app_id)
|
||||
|
||||
def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig:
|
||||
@@ -981,7 +982,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用缺少配置或不在指定工作空间时
|
||||
"""
|
||||
logger.info(f"发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)})
|
||||
logger.info("发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
# 检查应用归属
|
||||
@@ -1039,7 +1040,7 @@ class AppService:
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"多智能体应用发布配置准备完成",
|
||||
"多智能体应用发布配置准备完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"master_agent_id": str(multi_agent_cfg.master_agent_id),
|
||||
@@ -1083,7 +1084,7 @@ class AppService:
|
||||
self.db.refresh(release)
|
||||
|
||||
logger.info(
|
||||
f"应用发布成功",
|
||||
"应用发布成功",
|
||||
extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)}
|
||||
)
|
||||
return release
|
||||
@@ -1107,7 +1108,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用不可访问时
|
||||
"""
|
||||
logger.debug(f"获取当前发布版本", extra={"app_id": str(app_id)})
|
||||
logger.debug("获取当前发布版本", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
# 只读操作,允许访问共享应用
|
||||
@@ -1137,7 +1138,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用不可访问时
|
||||
"""
|
||||
logger.debug(f"列出发布版本", extra={"app_id": str(app_id)})
|
||||
logger.debug("列出发布版本", extra={"app_id": str(app_id)})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
# 只读操作,允许访问共享应用
|
||||
@@ -1171,7 +1172,7 @@ class AppService:
|
||||
ResourceNotFoundException: 当应用或版本不存在时
|
||||
BusinessException: 当应用不在指定工作空间时
|
||||
"""
|
||||
logger.info(f"回滚应用", extra={"app_id": str(app_id), "version": version})
|
||||
logger.info("回滚应用", extra={"app_id": str(app_id), "version": version})
|
||||
|
||||
app = self._get_app_or_404(app_id)
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
@@ -1184,7 +1185,7 @@ class AppService:
|
||||
|
||||
if not release:
|
||||
logger.warning(
|
||||
f"发布版本不存在",
|
||||
"发布版本不存在",
|
||||
extra={"app_id": str(app_id), "version": version}
|
||||
)
|
||||
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
|
||||
@@ -1196,7 +1197,7 @@ class AppService:
|
||||
self.db.refresh(release)
|
||||
|
||||
logger.info(
|
||||
f"应用回滚成功",
|
||||
"应用回滚成功",
|
||||
extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)}
|
||||
)
|
||||
return release
|
||||
@@ -1229,7 +1230,7 @@ class AppService:
|
||||
from app.models import AppShare, Workspace
|
||||
|
||||
logger.info(
|
||||
f"分享应用",
|
||||
"分享应用",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"target_workspaces": [str(wid) for wid in target_workspace_ids],
|
||||
@@ -1268,7 +1269,7 @@ class AppService:
|
||||
|
||||
if existing_share:
|
||||
logger.debug(
|
||||
f"应用已分享到该工作空间,跳过",
|
||||
"应用已分享到该工作空间,跳过",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)}
|
||||
)
|
||||
shares.append(existing_share)
|
||||
@@ -1288,14 +1289,14 @@ class AppService:
|
||||
shares.append(share)
|
||||
|
||||
logger.debug(
|
||||
f"创建分享记录",
|
||||
"创建分享记录",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)}
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"应用分享成功",
|
||||
"应用分享成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"shared_count": len(shares),
|
||||
@@ -1326,7 +1327,7 @@ class AppService:
|
||||
from app.models import AppShare
|
||||
|
||||
logger.info(
|
||||
f"取消应用分享",
|
||||
"取消应用分享",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"target_workspace_id": str(target_workspace_id)
|
||||
@@ -1346,7 +1347,7 @@ class AppService:
|
||||
|
||||
if not share:
|
||||
logger.warning(
|
||||
f"分享记录不存在",
|
||||
"分享记录不存在",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
|
||||
)
|
||||
raise ResourceNotFoundException(
|
||||
@@ -1359,7 +1360,7 @@ class AppService:
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"应用分享已取消",
|
||||
"应用分享已取消",
|
||||
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
|
||||
)
|
||||
|
||||
@@ -1384,7 +1385,7 @@ class AppService:
|
||||
"""
|
||||
from app.models import AppShare
|
||||
|
||||
logger.debug(f"列出应用分享记录", extra={"app_id": str(app_id)})
|
||||
logger.debug("列出应用分享记录", extra={"app_id": str(app_id)})
|
||||
|
||||
# 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
@@ -1398,7 +1399,7 @@ class AppService:
|
||||
shares = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.debug(
|
||||
f"应用分享记录查询完成",
|
||||
"应用分享记录查询完成",
|
||||
extra={"app_id": str(app_id), "count": len(shares)}
|
||||
)
|
||||
|
||||
@@ -1435,7 +1436,7 @@ class AppService:
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(f"试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
@@ -1464,7 +1465,7 @@ class AppService:
|
||||
|
||||
# 4. 调用试运行服务
|
||||
logger.debug(
|
||||
f"准备调用试运行服务",
|
||||
"准备调用试运行服务",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model": model_config.name,
|
||||
@@ -1485,7 +1486,7 @@ class AppService:
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"试运行服务返回结果",
|
||||
"试运行服务返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
|
||||
@@ -1495,7 +1496,7 @@ class AppService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"试运行完成",
|
||||
"试运行完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
@@ -1534,7 +1535,7 @@ class AppService:
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(f"流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
@@ -1609,7 +1610,7 @@ class AppService:
|
||||
from app.models import ModelConfig
|
||||
|
||||
logger.info(
|
||||
f"多模型对比试运行",
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models),
|
||||
@@ -1666,7 +1667,7 @@ class AppService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比完成",
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
@@ -1708,7 +1709,7 @@ class AppService:
|
||||
from app.models import ModelConfig
|
||||
|
||||
logger.info(
|
||||
f"多模型对比流式试运行",
|
||||
"多模型对比流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models)
|
||||
@@ -1765,7 +1766,7 @@ class AppService:
|
||||
yield event
|
||||
|
||||
logger.info(
|
||||
f"多模型对比流式完成",
|
||||
"多模型对比流式完成",
|
||||
extra={"app_id": str(app_id)}
|
||||
)
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ def register_user_with_invite(
|
||||
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit)
|
||||
invite_accept = InviteAcceptRequest(token=invite_token)
|
||||
workspace_service.accept_workspace_invite(db, invite_accept, user)
|
||||
logger.info(f"用户接受邀请成功")
|
||||
logger.info("用户接受邀请成功")
|
||||
|
||||
# 重新查询用户对象以确保获取最新状态
|
||||
from app.repositories import user_repository
|
||||
@@ -200,7 +200,7 @@ def bind_workspace_with_invite(
|
||||
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit)
|
||||
invite_accept = InviteAcceptRequest(token=invite_token)
|
||||
workspace_service.accept_workspace_invite(db, invite_accept, user)
|
||||
logger.info(f"用户接受邀请成功")
|
||||
logger.info("用户接受邀请成功")
|
||||
|
||||
# 重新查询用户对象以确保获取最新状态
|
||||
from app.repositories import user_repository
|
||||
|
||||
@@ -42,7 +42,7 @@ class ConversationService:
|
||||
self.db.refresh(conversation)
|
||||
|
||||
logger.info(
|
||||
f"创建会话成功",
|
||||
"创建会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"app_id": str(app_id),
|
||||
@@ -201,7 +201,7 @@ class ConversationService:
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"保存会话消息成功",
|
||||
"保存会话消息成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"user_message_length": len(user_message),
|
||||
@@ -221,7 +221,7 @@ class ConversationService:
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"删除会话成功",
|
||||
"删除会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
|
||||
@@ -74,7 +74,7 @@ class ConversationStateManager:
|
||||
state["same_agent_turns"] = 0
|
||||
|
||||
logger.info(
|
||||
f"Agent 切换",
|
||||
"Agent 切换",
|
||||
extra={
|
||||
"conversation_id": conversation_id,
|
||||
"from": state["current_agent_id"],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -88,7 +88,7 @@ class LLMRouter:
|
||||
路由结果
|
||||
"""
|
||||
logger.info(
|
||||
f"开始 LLM 智能路由",
|
||||
"开始 LLM 智能路由",
|
||||
extra={
|
||||
"message_length": len(message),
|
||||
"conversation_id": conversation_id,
|
||||
@@ -177,7 +177,7 @@ class LLMRouter:
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"路由完成",
|
||||
"路由完成",
|
||||
extra={
|
||||
"agent_id": agent_id,
|
||||
"strategy": strategy,
|
||||
@@ -393,7 +393,7 @@ class LLMRouter:
|
||||
|
||||
# 打印供应商信息
|
||||
logger.info(
|
||||
f"LLM 路由使用模型",
|
||||
"LLM 路由使用模型",
|
||||
extra={
|
||||
"provider": api_key_config.provider,
|
||||
"model_name": api_key_config.model_name,
|
||||
@@ -680,6 +680,6 @@ class LLMRouter:
|
||||
return self.routing_rules[0].get("target_agent_id")
|
||||
|
||||
if self.sub_agents:
|
||||
return list(self.sub_agents.keys())[0]
|
||||
return next(iter(self.sub_agents.keys()))
|
||||
|
||||
return "default-agent"
|
||||
|
||||
593
api/app/services/master_agent_router.py
Normal file
593
api/app/services/master_agent_router.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""Master Agent 路由器 - 让 Master Agent 真正成为决策中心"""
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MasterAgentRouter:
|
||||
"""Master Agent 路由器
|
||||
|
||||
让 Master Agent 作为"大脑",负责:
|
||||
1. 分析用户意图
|
||||
2. 选择最合适的 Sub Agent
|
||||
3. 决定是否需要多 Agent 协作
|
||||
4. 管理会话上下文
|
||||
|
||||
优势:
|
||||
- 更智能的决策(基于完整上下文)
|
||||
- 减少 LLM 调用次数
|
||||
- 架构更清晰(Master Agent 真正起作用)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
master_agent_config: AgentConfig,
|
||||
master_model_config: ModelConfig,
|
||||
sub_agents: Dict[str, Any],
|
||||
state_manager: ConversationStateManager,
|
||||
enable_rule_fast_path: bool = True
|
||||
):
|
||||
"""初始化 Master Agent 路由器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
master_agent_config: Master Agent 配置
|
||||
master_model_config: Master Agent 使用的模型配置
|
||||
sub_agents: 子 Agent 配置字典
|
||||
state_manager: 会话状态管理器
|
||||
enable_rule_fast_path: 是否启用规则快速路径(性能优化)
|
||||
"""
|
||||
self.db = db
|
||||
self.master_agent_config = master_agent_config
|
||||
self.master_model_config = master_model_config
|
||||
self.sub_agents = sub_agents
|
||||
self.state_manager = state_manager
|
||||
self.enable_rule_fast_path = enable_rule_fast_path
|
||||
|
||||
logger.info(
|
||||
"Master Agent 路由器初始化",
|
||||
extra={
|
||||
"master_agent": master_agent_config.name,
|
||||
"sub_agent_count": len(sub_agents),
|
||||
"enable_rule_fast_path": enable_rule_fast_path
|
||||
}
|
||||
)
|
||||
|
||||
async def route(
|
||||
self,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""智能路由决策
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
conversation_id: 会话 ID
|
||||
variables: 变量参数
|
||||
|
||||
Returns:
|
||||
路由决策结果
|
||||
"""
|
||||
logger.info(
|
||||
"开始 Master Agent 路由",
|
||||
extra={
|
||||
"message_length": len(message),
|
||||
"conversation_id": conversation_id
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 获取会话状态
|
||||
state = None
|
||||
if conversation_id:
|
||||
state = self.state_manager.get_state(conversation_id)
|
||||
|
||||
# 2. 尝试规则快速路径(可选的性能优化)
|
||||
if self.enable_rule_fast_path:
|
||||
rule_result = self._try_rule_fast_path(message, state)
|
||||
if rule_result:
|
||||
logger.info(
|
||||
"规则快速路径命中",
|
||||
extra={
|
||||
"agent_id": rule_result["selected_agent_id"],
|
||||
"confidence": rule_result["confidence"]
|
||||
}
|
||||
)
|
||||
|
||||
# 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
conversation_id,
|
||||
rule_result["selected_agent_id"],
|
||||
message,
|
||||
rule_result.get("topic"),
|
||||
rule_result["confidence"]
|
||||
)
|
||||
|
||||
return rule_result
|
||||
|
||||
# 3. 调用 Master Agent 做决策
|
||||
decision = await self._master_agent_decide(message, state, variables)
|
||||
|
||||
# 4. 更新会话状态
|
||||
if conversation_id:
|
||||
self.state_manager.update_state(
|
||||
conversation_id,
|
||||
decision["selected_agent_id"],
|
||||
message,
|
||||
decision.get("topic"),
|
||||
decision["confidence"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Master Agent 路由完成",
|
||||
extra={
|
||||
"agent_id": decision["selected_agent_id"],
|
||||
"strategy": decision["strategy"],
|
||||
"confidence": decision["confidence"]
|
||||
}
|
||||
)
|
||||
|
||||
return decision
|
||||
|
||||
def _try_rule_fast_path(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""尝试规则快速路径(性能优化)
|
||||
|
||||
对于明确的关键词匹配,直接返回结果,不调用 Master Agent
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
|
||||
Returns:
|
||||
如果命中规则返回决策结果,否则返回 None
|
||||
"""
|
||||
# 定义高置信度关键词规则
|
||||
high_confidence_rules = [
|
||||
{
|
||||
"keywords": ["数学", "方程", "计算", "求解"],
|
||||
"agent_role": "数学",
|
||||
"confidence_threshold": 0.9
|
||||
},
|
||||
{
|
||||
"keywords": ["物理", "力学", "电路", "光学"],
|
||||
"agent_role": "物理",
|
||||
"confidence_threshold": 0.9
|
||||
},
|
||||
{
|
||||
"keywords": ["订单", "发货", "物流", "快递"],
|
||||
"agent_role": "订单",
|
||||
"confidence_threshold": 0.9
|
||||
},
|
||||
{
|
||||
"keywords": ["退款", "退货", "售后"],
|
||||
"agent_role": "退款",
|
||||
"confidence_threshold": 0.9
|
||||
}
|
||||
]
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
for rule in high_confidence_rules:
|
||||
matched_keywords = [kw for kw in rule["keywords"] if kw in message_lower]
|
||||
|
||||
if matched_keywords:
|
||||
confidence = len(matched_keywords) / len(rule["keywords"])
|
||||
|
||||
if confidence >= rule["confidence_threshold"]:
|
||||
# 查找对应的 agent
|
||||
for agent_id, agent_data in self.sub_agents.items():
|
||||
agent_info = agent_data.get("info", {})
|
||||
if agent_info.get("role") == rule["agent_role"]:
|
||||
return {
|
||||
"selected_agent_id": agent_id,
|
||||
"confidence": confidence,
|
||||
"strategy": "rule_fast_path",
|
||||
"reasoning": f"关键词匹配: {', '.join(matched_keywords)}",
|
||||
"topic": rule["agent_role"],
|
||||
"need_collaboration": False,
|
||||
"routing_method": "rule"
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
async def _master_agent_decide(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]],
|
||||
variables: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""让 Master Agent 做路由决策
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
variables: 变量参数
|
||||
|
||||
Returns:
|
||||
决策结果
|
||||
"""
|
||||
# 1. 构建决策 prompt
|
||||
prompt = self._build_decision_prompt(message, state, variables)
|
||||
|
||||
# 2. 调用 Master Agent 的 LLM
|
||||
try:
|
||||
response = await self._call_master_agent_llm(prompt)
|
||||
|
||||
# 3. 解析决策
|
||||
decision = self._parse_decision(response)
|
||||
|
||||
# 4. 验证决策
|
||||
decision = self._validate_decision(decision)
|
||||
|
||||
return decision
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent 决策失败: {str(e)}")
|
||||
# 降级到默认 agent
|
||||
return self._get_fallback_decision(message)
|
||||
|
||||
def _build_decision_prompt(
|
||||
self,
|
||||
message: str,
|
||||
state: Optional[Dict[str, Any]],
|
||||
variables: Optional[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建 Master Agent 的决策 prompt
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
state: 会话状态
|
||||
variables: 变量参数
|
||||
|
||||
Returns:
|
||||
prompt 字符串
|
||||
"""
|
||||
# 1. 构建 Sub Agent 描述(简化版,提升性能)
|
||||
agent_descriptions = []
|
||||
for agent_id, agent_data in self.sub_agents.items():
|
||||
agent_info = agent_data.get("info", {})
|
||||
|
||||
name = agent_info.get("name", "未命名")
|
||||
role = agent_info.get("role", "")
|
||||
capabilities = agent_info.get("capabilities", [])
|
||||
|
||||
# 简化格式:一行描述
|
||||
desc = f"- {agent_id}: {name}"
|
||||
if role:
|
||||
desc += f" ({role})"
|
||||
if capabilities:
|
||||
desc += f" - {', '.join(capabilities[:3])}" # 只取前3个能力
|
||||
|
||||
agent_descriptions.append(desc)
|
||||
|
||||
agents_text = "\n".join(agent_descriptions)
|
||||
|
||||
# 2. 构建会话上下文
|
||||
context_text = ""
|
||||
if state:
|
||||
current_agent = state.get("current_agent_id")
|
||||
last_topic = state.get("last_topic")
|
||||
same_turns = state.get("same_agent_turns", 0)
|
||||
|
||||
if current_agent:
|
||||
context_text = f"""
|
||||
当前会话上下文:
|
||||
- 当前使用的 Agent: {current_agent}
|
||||
- 上一个主题: {last_topic}
|
||||
- 连续使用轮数: {same_turns}
|
||||
"""
|
||||
|
||||
# 获取第一个可用的 agent_id 作为示例
|
||||
example_agent_id = next(iter(self.sub_agents.keys())) if self.sub_agents else "agent_id"
|
||||
|
||||
# 3. 构建完整 prompt(简化版,提升性能)
|
||||
prompt = f"""路由任务:分析问题并选择合适的 Agent。
|
||||
|
||||
可用 Agent:
|
||||
{agents_text}
|
||||
{context_text}
|
||||
问题:"{message}"
|
||||
|
||||
返回 JSON 格式决策:
|
||||
|
||||
**情况1:单一问题(最常见)**
|
||||
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": false, "reasoning": "选择理由"}}
|
||||
|
||||
**情况2:需要拆分成多个独立子问题**
|
||||
当用户问题包含多个完全独立的子问题时使用(如"写诗+做数学题")。
|
||||
必须提供 sub_questions 数组,每个子问题必须指定 agent_id:
|
||||
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": true, "need_decomposition": true,
|
||||
"sub_questions": [
|
||||
{{"question": "具体的子问题1", "agent_id": "{example_agent_id}", "order": 1, "depends_on": []}},
|
||||
{{"question": "具体的子问题2", "agent_id": "{example_agent_id}", "order": 2, "depends_on": []}}
|
||||
],
|
||||
"collaboration_strategy": "decomposition", "reasoning": "问题包含X个独立子问题"}}
|
||||
|
||||
**情况3:需要多个Agent协作分析同一问题**
|
||||
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": true,
|
||||
"collaboration_agents": [{{"agent_id": "{example_agent_id}", "role": "primary", "task": "主要任务", "order": 1}}],
|
||||
"collaboration_strategy": "sequential", "reasoning": "需要多角度分析"}}
|
||||
|
||||
重要规则:
|
||||
1. selected_agent_id 必须从上面的可用 Agent 列表中选择
|
||||
2. 如果选择情况2(拆分),sub_questions 数组不能为空,必须包含具体的子问题
|
||||
3. 每个子问题的 agent_id 也必须从可用列表中选择
|
||||
4. depends_on 表示依赖关系(如 [1] 表示依赖第1个子问题的结果)
|
||||
5. 大多数情况应该选择情况1(单一Agent),只有明确需要时才拆分或协作
|
||||
6. 只做路由决策,不要回答问题内容
|
||||
|
||||
请返回 JSON:"""
|
||||
|
||||
return prompt
|
||||
|
||||
async def _call_master_agent_llm(self, prompt: str) -> str:
|
||||
"""调用 Master Agent 的 LLM
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
|
||||
Returns:
|
||||
LLM 响应
|
||||
"""
|
||||
try:
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.models import ModelApiKey, ModelType
|
||||
|
||||
# 获取 API Key 配置
|
||||
api_key_config = self.db.query(ModelApiKey).filter(
|
||||
ModelApiKey.model_config_id == self.master_model_config.id,
|
||||
ModelApiKey.is_active == True
|
||||
).first()
|
||||
|
||||
if not api_key_config:
|
||||
raise Exception("Master Agent 模型没有可用的 API Key")
|
||||
|
||||
logger.info(
|
||||
"调用 Master Agent LLM",
|
||||
extra={
|
||||
"provider": api_key_config.provider,
|
||||
"model_name": api_key_config.model_name
|
||||
}
|
||||
)
|
||||
|
||||
# 创建 RedBearModelConfig
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=api_key_config.model_name,
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
temperature=0.3, # 决策任务使用较低温度
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 创建 LLM 实例
|
||||
llm = RedBearLLM(model_config, type=ModelType.CHAT)
|
||||
|
||||
# 调用模型
|
||||
response = await llm.ainvoke(prompt)
|
||||
|
||||
# 提取响应内容
|
||||
if hasattr(response, 'content'):
|
||||
return response.content
|
||||
else:
|
||||
return str(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent LLM 调用失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def _parse_decision(self, response: str) -> Dict[str, Any]:
|
||||
"""解析 Master Agent 的决策
|
||||
|
||||
Args:
|
||||
response: LLM 响应
|
||||
|
||||
Returns:
|
||||
决策字典
|
||||
"""
|
||||
try:
|
||||
# 提取 JSON
|
||||
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
decision = json.loads(json_match.group())
|
||||
|
||||
# 添加默认值
|
||||
decision.setdefault("confidence", 0.8)
|
||||
decision.setdefault("strategy", "master_agent")
|
||||
decision.setdefault("routing_method", "master_agent")
|
||||
decision.setdefault("need_collaboration", False)
|
||||
decision.setdefault("collaboration_agents", [])
|
||||
|
||||
return decision
|
||||
else:
|
||||
raise ValueError("无法从响应中提取 JSON")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Master Agent 决策失败: {str(e)}")
|
||||
logger.debug(f"原始响应: {response}")
|
||||
raise
|
||||
|
||||
def _validate_decision(self, decision: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证决策的有效性
|
||||
|
||||
Args:
|
||||
decision: 决策字典
|
||||
|
||||
Returns:
|
||||
验证后的决策
|
||||
"""
|
||||
# 验证 agent_id
|
||||
selected_agent_id = decision.get("selected_agent_id")
|
||||
if selected_agent_id not in self.sub_agents:
|
||||
logger.warning(f"Master Agent 返回的 agent_id 无效: {selected_agent_id}")
|
||||
# 使用默认 agent
|
||||
decision["selected_agent_id"] = self._get_default_agent_id()
|
||||
decision["confidence"] = 0.5
|
||||
decision["reasoning"] = "原始选择无效,使用默认 Agent"
|
||||
|
||||
# 验证 confidence
|
||||
confidence = decision.get("confidence", 0.8)
|
||||
if not isinstance(confidence, (int, float)) or confidence < 0 or confidence > 1:
|
||||
decision["confidence"] = 0.8
|
||||
|
||||
# 验证协作 agents
|
||||
if decision.get("need_collaboration"):
|
||||
# 检查是否是问题拆分模式
|
||||
if decision.get("need_decomposition") or decision.get("sub_questions"):
|
||||
# 问题拆分模式
|
||||
sub_questions = decision.get("sub_questions", [])
|
||||
|
||||
# 验证每个子问题
|
||||
valid_sub_questions = []
|
||||
for sub_q in sub_questions:
|
||||
if isinstance(sub_q, dict):
|
||||
agent_id = sub_q.get("agent_id")
|
||||
question = sub_q.get("question")
|
||||
|
||||
if agent_id in self.sub_agents and question:
|
||||
# 确保有必要的字段
|
||||
sub_q.setdefault("order", len(valid_sub_questions) + 1)
|
||||
sub_q.setdefault("depends_on", [])
|
||||
valid_sub_questions.append(sub_q)
|
||||
else:
|
||||
# 记录验证失败的原因
|
||||
logger.warning(
|
||||
"子问题验证失败",
|
||||
extra={
|
||||
"agent_id": agent_id,
|
||||
"has_question": bool(question),
|
||||
"agent_exists": agent_id in self.sub_agents if agent_id else False,
|
||||
"available_agents": list(self.sub_agents.keys())
|
||||
}
|
||||
)
|
||||
|
||||
decision["sub_questions"] = valid_sub_questions
|
||||
|
||||
# 如果所有子问题都验证失败,降级处理
|
||||
if not valid_sub_questions and sub_questions:
|
||||
logger.warning(
|
||||
"所有子问题验证失败,降级到单 Agent 模式",
|
||||
extra={
|
||||
"original_sub_question_count": len(sub_questions),
|
||||
"available_agents": list(self.sub_agents.keys())
|
||||
}
|
||||
)
|
||||
# 降级:取消协作标记,使用默认 Agent
|
||||
decision["need_collaboration"] = False
|
||||
decision["need_decomposition"] = False
|
||||
decision["collaboration_strategy"] = None
|
||||
# 选择第一个可用的 Agent
|
||||
if self.sub_agents:
|
||||
first_agent_id = next(iter(self.sub_agents.keys()))
|
||||
decision["selected_agent_id"] = first_agent_id
|
||||
logger.info(f"降级使用默认 Agent: {first_agent_id}")
|
||||
|
||||
# 设置协作策略为 decomposition
|
||||
decision["collaboration_strategy"] = "decomposition"
|
||||
|
||||
logger.info(
|
||||
"问题拆分决策验证完成",
|
||||
extra={
|
||||
"sub_question_count": len(valid_sub_questions),
|
||||
"strategy": "decomposition"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 普通协作模式
|
||||
collaboration_agents = decision.get("collaboration_agents", [])
|
||||
|
||||
# 如果是简单列表格式,转换为详细格式
|
||||
if collaboration_agents and isinstance(collaboration_agents[0], str):
|
||||
collaboration_agents = [
|
||||
{
|
||||
"agent_id": agent_id,
|
||||
"role": "primary" if i == 0 else "secondary",
|
||||
"task": "协作处理",
|
||||
"order": i + 1
|
||||
}
|
||||
for i, agent_id in enumerate(collaboration_agents)
|
||||
]
|
||||
|
||||
# 验证每个协作 agent
|
||||
valid_agents = []
|
||||
for agent_info in collaboration_agents:
|
||||
if isinstance(agent_info, dict):
|
||||
agent_id = agent_info.get("agent_id")
|
||||
if agent_id in self.sub_agents:
|
||||
# 确保有必要的字段
|
||||
agent_info.setdefault("role", "secondary")
|
||||
agent_info.setdefault("task", "协作处理")
|
||||
agent_info.setdefault("order", len(valid_agents) + 1)
|
||||
valid_agents.append(agent_info)
|
||||
elif isinstance(agent_info, str) and agent_info in self.sub_agents:
|
||||
valid_agents.append({
|
||||
"agent_id": agent_info,
|
||||
"role": "secondary",
|
||||
"task": "协作处理",
|
||||
"order": len(valid_agents) + 1
|
||||
})
|
||||
|
||||
decision["collaboration_agents"] = valid_agents
|
||||
|
||||
# 设置默认协作策略
|
||||
if not decision.get("collaboration_strategy"):
|
||||
decision["collaboration_strategy"] = "sequential"
|
||||
|
||||
logger.info(
|
||||
"协作决策验证完成",
|
||||
extra={
|
||||
"collaboration_agent_count": len(valid_agents),
|
||||
"strategy": decision.get("collaboration_strategy")
|
||||
}
|
||||
)
|
||||
|
||||
return decision
|
||||
|
||||
def _get_fallback_decision(self, message: str) -> Dict[str, Any]:
|
||||
"""获取降级决策(当 Master Agent 失败时)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
降级决策
|
||||
"""
|
||||
default_agent_id = self._get_default_agent_id()
|
||||
|
||||
return {
|
||||
"selected_agent_id": default_agent_id,
|
||||
"confidence": 0.5,
|
||||
"strategy": "fallback",
|
||||
"reasoning": "Master Agent 决策失败,使用默认 Agent",
|
||||
"topic": "未知",
|
||||
"need_collaboration": False,
|
||||
"collaboration_agents": [],
|
||||
"routing_method": "fallback"
|
||||
}
|
||||
|
||||
def _get_default_agent_id(self) -> str:
|
||||
"""获取默认 Agent ID
|
||||
|
||||
Returns:
|
||||
默认 Agent ID
|
||||
"""
|
||||
if self.sub_agents:
|
||||
# 返回第一个 agent
|
||||
return next(iter(self.sub_agents.keys()))
|
||||
|
||||
return "default-agent"
|
||||
@@ -29,8 +29,7 @@ from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.db import get_db
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# TODO 后续更新
|
||||
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
@@ -697,7 +696,7 @@ class MemoryAgentService:
|
||||
logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}")
|
||||
else:
|
||||
# 没有提供 workspace_id,所有知识库类型返回 0
|
||||
logger.info(f"未提供 workspace_id,知识库类型统计全部为 0")
|
||||
logger.info("未提供 workspace_id,知识库类型统计全部为 0")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"知识库类型统计失败: {e}")
|
||||
@@ -720,7 +719,7 @@ class MemoryAgentService:
|
||||
end_users = []
|
||||
for app_id in app_ids:
|
||||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||||
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list])
|
||||
end_users.extend(h for h in end_user_orm_list)
|
||||
|
||||
# 统计所有宿主的 Chunk 总数
|
||||
total_chunks = 0
|
||||
@@ -742,7 +741,7 @@ class MemoryAgentService:
|
||||
else:
|
||||
# 没有 workspace_id 时,返回 0
|
||||
result["memory"] = 0
|
||||
logger.info(f"未提供 workspace_id,memory 统计为 0")
|
||||
logger.info("未提供 workspace_id,memory 统计为 0")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -24,6 +24,30 @@ from app.core.logging_config import get_business_logger
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_current_workspace_type(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> Optional[str]:
|
||||
"""获取当前工作空间类型"""
|
||||
business_logger.info(f"获取工作空间类型: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.repositories.workspace_repository import get_workspace_by_id
|
||||
|
||||
workspace = get_workspace_by_id(db, workspace_id)
|
||||
if not workspace:
|
||||
business_logger.warning(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
return None
|
||||
|
||||
business_logger.info(f"成功获取工作空间类型: {workspace.storage_type}")
|
||||
return workspace.storage_type
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间类型失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
@@ -169,7 +193,7 @@ def get_workspace_memory_list(
|
||||
business_logger.warning(f"获取宿主列表失败: {str(e)}")
|
||||
result["hosts"] = []
|
||||
|
||||
business_logger.info(f"成功获取工作空间记忆列表")
|
||||
business_logger.info("成功获取工作空间记忆列表")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
@@ -587,7 +611,7 @@ async def get_chunk_insight(
|
||||
"insight": insight
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取chunk洞察")
|
||||
business_logger.info("成功获取chunk洞察")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -168,7 +168,7 @@ async def get_document_chunks(
|
||||
|
||||
# 执行分页查询
|
||||
try:
|
||||
api_logger.debug(f"开始执行文档块查询")
|
||||
api_logger.debug("开始执行文档块查询")
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(
|
||||
document_id=str(document_id),
|
||||
@@ -516,67 +516,18 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
await parse_document_by_id(document, db=db, current_user=current_user)
|
||||
# 重新查询刚创建的文档ID
|
||||
new_document_id = find_document_id_by_kb_and_filename(
|
||||
db=db,
|
||||
kb_id=user_rag_memory_id,
|
||||
file_name=f"{group_id}.txt"
|
||||
)
|
||||
|
||||
if new_document_id:
|
||||
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
|
||||
else:
|
||||
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}")
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
db.close()
|
||||
# 在异步环境中调用示例
|
||||
|
||||
|
||||
async def example_usage():
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 创建 CustomTextFileCreate 对象
|
||||
title = '2f6ff1eb-50c7-4765-8e89-e4566be19122'
|
||||
create_data = CustomTextFileCreate(
|
||||
title=title,
|
||||
content="莫扎特在巴黎经历母亲去世后返回萨尔茨堡,他随后创作的交响曲主题是否与格鲁克在维也纳推动的“改革歌剧”理念存在共通之处?贝多芬早年曾师从海顿,而海顿又受雇于埃斯特哈齐家族——这种师承体系如何影响了当时欧洲宫廷音乐的传承结构?斯卡拉歌剧院选择萨列里的歌剧作为开幕演出,是否与当时米兰政治环境和奥地利宫廷影响有关?"
|
||||
)
|
||||
|
||||
# 创建用户对象
|
||||
current_user = SimpleUser("6243c125-9420-402c-bbb5-d1977811ac96")
|
||||
|
||||
# 上传文件
|
||||
result = await memory_konwledges_up(
|
||||
kb_id="c71df60a-36a6-4759-a2ce-101e3087b401",
|
||||
parent_id="c71df60a-36a6-4759-a2ce-101e3087b401",
|
||||
create_data=create_data,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
print(result)
|
||||
#找到document_id
|
||||
|
||||
# 使用刚创建的文档ID进行解析
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", file_name=f"{title}.txt")
|
||||
print('====',document)
|
||||
res___=await parse_document_by_id(document, db=db, current_user=current_user)
|
||||
print(res___)
|
||||
|
||||
# result='e8cf9ace-d1a9-4af2-b0c4-3fc94f4f8042'
|
||||
# document_id='d22e8173-50d0-4e10-a7de-aa638ef893bc'
|
||||
#
|
||||
# '''更新块'''
|
||||
#
|
||||
# new_content = "这是新的 chunk 内容,用来覆盖原来的内容"
|
||||
# # 构造 ChunkUpdate 对象
|
||||
# update_data = ChunkCreate(content=new_content)
|
||||
# updated_chunk = await create_document_chunk(
|
||||
# kb_id= result,
|
||||
# document_id=document_id,
|
||||
# create_data= update_data,
|
||||
# db=db,
|
||||
# current_user=current_user
|
||||
# )
|
||||
# print(updated_chunk)
|
||||
return '','',''
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# asyncio.run(example_usage())
|
||||
asyncio.run(write_rag('1111','22222',"c71df60a-36a6-4759-a2ce-101e3087b401"))
|
||||
db.close()
|
||||
@@ -7,9 +7,12 @@ Handles business logic for memory storage operations.
|
||||
from typing import Dict, List, Optional, Any
|
||||
import os
|
||||
import json
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.core.logging_config import get_logger
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
@@ -23,11 +26,10 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# TODO 后续更新
|
||||
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
# from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
# from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
# from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -52,7 +54,7 @@ class MemoryStorageService:
|
||||
Returns:
|
||||
Storage information dictionary
|
||||
"""
|
||||
logger.info(f"Getting storage info ")
|
||||
logger.info("Getting storage info ")
|
||||
|
||||
# Empty wrapper - implement your logic here
|
||||
result = {
|
||||
@@ -65,30 +67,28 @@ class MemoryStorageService:
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"""Service layer for config params CRUD.
|
||||
|
||||
The DB connection is optional; when absent, methods return a failure
|
||||
response containing an SQL preview to aid integration.
|
||||
使用 SQLAlchemy ORM 进行数据库操作。
|
||||
"""
|
||||
|
||||
def __init__(self, db_conn: Optional[Any] = None) -> None:
|
||||
self.db_conn = db_conn
|
||||
|
||||
# --- Driver compatibility helpers ---
|
||||
@staticmethod
|
||||
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
|
||||
mod = type(conn).__module__
|
||||
return ("psycopg2" in mod) or ("psycopg" in mod)
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
"""初始化服务
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
for item in data_list:
|
||||
for field in ['created_at', 'updated_at']:
|
||||
if field in item and item[field] is not None:
|
||||
value = item[field]
|
||||
dt = None
|
||||
|
||||
|
||||
# 如果是 datetime 对象,直接使用
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
@@ -98,24 +98,21 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
pass # 保持原值
|
||||
|
||||
|
||||
# 转换为 YYYYMMDDHHmmss 格式
|
||||
if dt:
|
||||
item[field] = dt.strftime('%Y%m%d%H%M%S')
|
||||
|
||||
|
||||
return data_list
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||
configs = self._get_workspace_configs(params.workspace_id)
|
||||
if configs is None:
|
||||
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
|
||||
|
||||
|
||||
# 只在未指定时填充(允许手动覆盖)
|
||||
if not params.llm_id:
|
||||
params.llm_id = configs.get('llm')
|
||||
@@ -123,19 +120,16 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
params.embedding_id = configs.get('embedding')
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
query, qparams = DataConfigRepository.build_insert(params)
|
||||
cur = self.db_conn.cursor()
|
||||
# PostgreSQL 使用 psycopg2 的命名参数格式
|
||||
cur.execute(query, qparams)
|
||||
self.db_conn.commit()
|
||||
return {"affected": getattr(cur, "rowcount", None)}
|
||||
|
||||
|
||||
config = DataConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
|
||||
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
|
||||
"""获取工作空间模型配置(内部方法,便于测试)"""
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||||
|
||||
|
||||
db_session = SessionLocal()
|
||||
try:
|
||||
return get_workspace_models_configs(db_session, workspace_id)
|
||||
@@ -143,121 +137,91 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
db_session.close()
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称)
|
||||
query, qparams = DataConfigRepository.build_delete(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
# 如果没有任何行被删除,抛出异常
|
||||
if not affected:
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
success = DataConfigRepository.delete(self.db, key.config_id)
|
||||
if not success:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Update ---
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
query, qparams = DataConfigRepository.build_update(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
|
||||
return {"affected": 1}
|
||||
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
query, qparams = DataConfigRepository.build_update_extracted(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update_extracted(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
return {"affected": 1}
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
|
||||
query, qparams = DataConfigRepository.build_update_forget(update)
|
||||
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
affected = getattr(cur, "rowcount", None)
|
||||
self.db_conn.commit()
|
||||
if not affected:
|
||||
config = DataConfigRepository.update_forget(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": affected}
|
||||
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_extracted(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
return result
|
||||
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
|
||||
query, qparams = DataConfigRepository.build_select_forget(key)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
|
||||
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
# Map row to dict (DB-API cursor description available for many drivers)
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
raw = {columns[i]: row[i] for i in range(len(columns))}
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
data_list = self._convert_timestamps_to_format([raw])
|
||||
return data_list[0] if data_list else raw
|
||||
return result
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
query, qparams = DataConfigRepository.build_select_all(workspace_id)
|
||||
if self.db_conn is None:
|
||||
raise ConnectionError("数据库连接未配置")
|
||||
configs = DataConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
cur = self.db_conn.cursor()
|
||||
cur.execute(query, qparams)
|
||||
rows = cur.fetchall()
|
||||
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
|
||||
if not rows:
|
||||
return []
|
||||
# Map rows to list of dicts
|
||||
columns = [desc[0] for desc in cur.description]
|
||||
data_list = [dict(zip(columns, row)) for row in rows]
|
||||
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
for item in data_list:
|
||||
if 'workspace_id' in item and item['workspace_id'] is not None:
|
||||
item['workspace_id'] = str(item['workspace_id'])
|
||||
# 将 ORM 对象转换为字典列表
|
||||
data_list = []
|
||||
for config in configs:
|
||||
config_dict = {
|
||||
"config_id": config.config_id,
|
||||
"config_name": config.config_name,
|
||||
"config_desc": config.config_desc,
|
||||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||||
"group_id": config.group_id,
|
||||
"user_id": config.user_id,
|
||||
"apply_id": config.apply_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
"llm": config.llm,
|
||||
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
|
||||
"enable_llm_disambiguation": config.enable_llm_disambiguation,
|
||||
"deep_retrieval": config.deep_retrieval,
|
||||
"t_type_strict": config.t_type_strict,
|
||||
"t_name_strict": config.t_name_strict,
|
||||
"t_overall": config.t_overall,
|
||||
"state": config.state,
|
||||
"chunker_strategy": config.chunker_strategy,
|
||||
"pruning_enabled": config.pruning_enabled,
|
||||
"pruning_scene": config.pruning_scene,
|
||||
"pruning_threshold": config.pruning_threshold,
|
||||
"enable_self_reflexion": config.enable_self_reflexion,
|
||||
"iteration_period": config.iteration_period,
|
||||
"reflexion_range": config.reflexion_range,
|
||||
"baseline": config.baseline,
|
||||
"statement_granularity": config.statement_granularity,
|
||||
"include_dialogue_context": config.include_dialogue_context,
|
||||
"max_context": config.max_context,
|
||||
"lambda_time": config.lambda_time,
|
||||
"lambda_mem": config.lambda_mem,
|
||||
"offset": config.offset,
|
||||
"created_at": config.created_at,
|
||||
"updated_at": config.updated_at,
|
||||
}
|
||||
data_list.append(config_dict)
|
||||
|
||||
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
|
||||
return self._convert_timestamps_to_format(data_list)
|
||||
|
||||
|
||||
@@ -296,7 +260,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
# 应用内存覆写并刷新常量(在导入主管线前)
|
||||
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
@@ -308,7 +272,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
|
||||
logger.info("[PILOT_RUN] pipeline_main completed")
|
||||
|
||||
|
||||
# 调用自我反思
|
||||
# data = [
|
||||
# {
|
||||
@@ -346,10 +310,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
result_path = settings.get_memory_output_path("extracted_result.json")
|
||||
if not os.path.isfile(result_path):
|
||||
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
|
||||
|
||||
|
||||
with open(result_path, "r", encoding="utf-8") as rf:
|
||||
extracted_result = json.load(rf)
|
||||
|
||||
|
||||
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
|
||||
return {
|
||||
"config_id": cid,
|
||||
@@ -405,7 +369,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
if not result or len(result) < 4:
|
||||
data = {
|
||||
@@ -418,7 +382,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
data = {
|
||||
"total": result[-1]["Count"],
|
||||
"counts": {
|
||||
@@ -504,14 +468,27 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取热门记忆标签,按数量排序并返回前N个
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取更多标签供LLM筛选(获取limit*4个标签)
|
||||
raw_limit = limit * 4
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
|
||||
from app.services.memory_dashboard_service import get_workspace_end_users
|
||||
end_users = get_workspace_end_users(db, workspace_id, current_user)
|
||||
|
||||
tags = []
|
||||
for end_user in end_users:
|
||||
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
|
||||
if tag:
|
||||
# 将每个用户的标签列表展平到总列表中
|
||||
tags.extend(tag)
|
||||
|
||||
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
|
||||
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
|
||||
|
||||
|
||||
@@ -53,13 +53,13 @@ class ModelParameterMerger:
|
||||
|
||||
# 应用模型配置参数
|
||||
if model_config_params:
|
||||
for key in default_params.keys():
|
||||
for key in default_params:
|
||||
if key in model_config_params:
|
||||
merged[key] = model_config_params[key]
|
||||
|
||||
# 应用 Agent 配置参数(优先级最高)
|
||||
if agent_config_params:
|
||||
for key in default_params.keys():
|
||||
for key in default_params:
|
||||
if key in agent_config_params and agent_config_params[key] is not None:
|
||||
merged[key] = agent_config_params[key]
|
||||
|
||||
@@ -67,7 +67,7 @@ class ModelParameterMerger:
|
||||
merged = {k: v for k, v in merged.items() if v is not None}
|
||||
|
||||
logger.debug(
|
||||
f"参数合并完成",
|
||||
"参数合并完成",
|
||||
extra={
|
||||
"model_params": model_config_params,
|
||||
"agent_params": agent_config_params,
|
||||
|
||||
@@ -24,17 +24,17 @@ class ModelConfigService:
|
||||
"""模型配置服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_id(db: Session, model_id: uuid.UUID) -> ModelConfig:
|
||||
def get_model_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据ID获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_id(db, model_id)
|
||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_model_list(db: Session, query: ModelConfigQuery) -> PageData:
|
||||
def get_model_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> PageData:
|
||||
"""获取模型配置列表"""
|
||||
models, total = ModelConfigRepository.get_list(db, query)
|
||||
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
|
||||
pages = math.ceil(total / query.pagesize) if total > 0 else 0
|
||||
|
||||
return PageData(
|
||||
@@ -48,17 +48,17 @@ class ModelConfigService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name)
|
||||
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, limit)
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
@@ -220,10 +220,10 @@ class ModelConfigService:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def create_model(db: Session, model_data: ModelConfigCreate) -> ModelConfig:
|
||||
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建模型配置"""
|
||||
# 检查名称是否已存在
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name):
|
||||
# 检查名称是否已存在(同租户内)
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证配置
|
||||
@@ -247,6 +247,8 @@ class ModelConfigService:
|
||||
# 事务处理
|
||||
api_key_data = model_data.api_keys
|
||||
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
|
||||
# 添加租户ID
|
||||
model_config_data["tenant_id"] = tenant_id
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush() # 获取生成的 ID
|
||||
@@ -263,28 +265,28 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id)
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data)
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def delete_model(db: Session, model_id: uuid.UUID) -> bool:
|
||||
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
|
||||
"""删除模型配置"""
|
||||
if not ModelConfigRepository.get_by_id(db, model_id):
|
||||
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
success = ModelConfigRepository.delete(db, model_id)
|
||||
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
|
||||
db.commit()
|
||||
return success
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -116,7 +116,7 @@ class MultiAgentService:
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"创建多 Agent 配置成功",
|
||||
"创建多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id),
|
||||
@@ -320,7 +320,7 @@ class MultiAgentService:
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"创建多 Agent 配置成功",
|
||||
"创建多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id),
|
||||
@@ -363,12 +363,12 @@ class MultiAgentService:
|
||||
# if data.routing_rules is not None:
|
||||
# config.routing_rules = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
|
||||
|
||||
if data.execution_config is None:
|
||||
execution_config_data = {}
|
||||
elif isinstance(data.execution_config, dict):
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config)
|
||||
else:
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
|
||||
if data.execution_config is not None:
|
||||
if isinstance(data.execution_config, dict):
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config)
|
||||
else:
|
||||
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
|
||||
config.execution_config = execution_config_data
|
||||
|
||||
if data.aggregation_strategy is not None:
|
||||
config.aggregation_strategy = data.aggregation_strategy
|
||||
@@ -380,7 +380,7 @@ class MultiAgentService:
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"更新多 Agent 配置成功",
|
||||
"更新多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id)
|
||||
@@ -399,11 +399,12 @@ class MultiAgentService:
|
||||
if not config:
|
||||
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
|
||||
|
||||
self.db.delete(config)
|
||||
# 逻辑删除多 Agent 配置
|
||||
config.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"删除多 Agent 配置成功",
|
||||
"删除多 Agent 配置成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"app_id": str(app_id)
|
||||
@@ -542,7 +543,7 @@ class MultiAgentService:
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"添加子 Agent 成功",
|
||||
"添加子 Agent 成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"agent_id": str(agent_id),
|
||||
@@ -586,7 +587,7 @@ class MultiAgentService:
|
||||
self.db.refresh(config)
|
||||
|
||||
logger.info(
|
||||
f"移除子 Agent 成功",
|
||||
"移除子 Agent 成功",
|
||||
extra={
|
||||
"config_id": str(config.id),
|
||||
"agent_id": str(agent_id)
|
||||
|
||||
@@ -92,7 +92,7 @@ class ReleaseShareService:
|
||||
share = self.repo.create(share)
|
||||
|
||||
logger.info(
|
||||
f"创建分享配置",
|
||||
"创建分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(release.id),
|
||||
@@ -130,7 +130,7 @@ class ReleaseShareService:
|
||||
share = self.repo.update(share)
|
||||
|
||||
logger.info(
|
||||
f"更新分享配置",
|
||||
"更新分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(share.release_id)
|
||||
@@ -214,7 +214,7 @@ class ReleaseShareService:
|
||||
self.repo.delete(share)
|
||||
|
||||
logger.info(
|
||||
f"删除分享配置",
|
||||
"删除分享配置",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"release_id": str(release_id)
|
||||
@@ -249,7 +249,7 @@ class ReleaseShareService:
|
||||
share = self.repo.update(share)
|
||||
|
||||
logger.info(
|
||||
f"重新生成分享 token",
|
||||
"重新生成分享 token",
|
||||
extra={
|
||||
"share_id": str(share.id),
|
||||
"old_token": old_token,
|
||||
|
||||
@@ -4,7 +4,7 @@ import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
@@ -16,6 +16,8 @@ from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -88,7 +90,7 @@ class SharedChatService:
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
f"会话不存在,将创建新会话",
|
||||
"会话不存在,将创建新会话",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
@@ -102,7 +104,7 @@ class SharedChatService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"为分享链接创建新会话",
|
||||
"为分享链接创建新会话",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"share_token": share_token,
|
||||
@@ -121,17 +123,24 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
actual_config_id = None
|
||||
config_id=actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -199,10 +208,11 @@ class SharedChatService:
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
|
||||
memory_flag=False
|
||||
if memory==True:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag=True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
@@ -234,6 +244,7 @@ class SharedChatService:
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
@@ -254,7 +265,11 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -280,6 +295,7 @@ class SharedChatService:
|
||||
# )
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -301,7 +317,9 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
@@ -312,6 +330,9 @@ class SharedChatService:
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -381,9 +402,11 @@ class SharedChatService:
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag=False
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
@@ -440,7 +463,11 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
@@ -464,13 +491,14 @@ class SharedChatService:
|
||||
"usage": {}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
f"流式聊天完成",
|
||||
"流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -539,13 +567,19 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""多 Agent 聊天(非流式)"""
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -609,6 +643,8 @@ class SharedChatService:
|
||||
"sub_results": result.get("sub_results")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
@@ -630,11 +666,16 @@ class SharedChatService:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -741,13 +782,14 @@ class SharedChatService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多 Agent 流式聊天完成",
|
||||
"多 Agent 流式聊天完成",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"elapsed_time": elapsed_time,
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
|
||||
@@ -75,7 +75,7 @@ class SmartRouter:
|
||||
}
|
||||
"""
|
||||
logger.info(
|
||||
f"开始智能路由",
|
||||
"开始智能路由",
|
||||
extra={
|
||||
"message_length": len(message),
|
||||
"conversation_id": conversation_id,
|
||||
@@ -170,7 +170,7 @@ class SmartRouter:
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"路由完成",
|
||||
"路由完成",
|
||||
extra={
|
||||
"agent_id": agent_id,
|
||||
"strategy": strategy,
|
||||
@@ -421,6 +421,6 @@ class SmartRouter:
|
||||
|
||||
# 否则使用第一个子 Agent
|
||||
if self.sub_agents:
|
||||
return list(self.sub_agents.keys())[0]
|
||||
return next(iter(self.sub_agents.keys()))
|
||||
|
||||
return "default-agent"
|
||||
|
||||
@@ -64,7 +64,7 @@ def create_initial_superuser(db: Session):
|
||||
raise BusinessException(
|
||||
f"初始超级用户创建失败: {str(e)}",
|
||||
code=BizCode.DB_ERROR,
|
||||
context={"username": username, "email": email},
|
||||
context={"username": user_in.username, "email": user_in.email},
|
||||
cause=e
|
||||
)
|
||||
|
||||
@@ -423,7 +423,7 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
|
||||
business_logger.info(f"用户最后登录时间更新成功: {db_user.username} (ID: {user_id})")
|
||||
return db_user
|
||||
|
||||
except HTTPException:
|
||||
except (BusinessException, PermissionDeniedException):
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"更新用户最后登录时间失败: user_id={user_id} - {str(e)}")
|
||||
@@ -438,19 +438,14 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
|
||||
# 检查权限:只能修改自己的密码
|
||||
if current_user.id != user_id:
|
||||
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only change your own password"
|
||||
)
|
||||
raise PermissionDeniedException("You can only change your own password")
|
||||
|
||||
try:
|
||||
# 获取用户
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
business_logger.warning(f"用户不存在: {user_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
|
||||
)
|
||||
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证旧密码
|
||||
if not verify_password(old_password, db_user.hashed_password):
|
||||
|
||||
@@ -148,7 +148,7 @@ def create_workspace(
|
||||
description=f"工作空间 {workspace.name} 的默认知识库",
|
||||
avatar='',
|
||||
type=KnowledgeType.General,
|
||||
permission_id=PermissionType.Private,
|
||||
permission_id=PermissionType.Memory,
|
||||
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
|
||||
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
|
||||
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
|
||||
@@ -459,7 +459,7 @@ def get_workspace_invites(
|
||||
|
||||
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
"""验证邀请令牌"""
|
||||
business_logger.info(f"验证邀请令牌")
|
||||
business_logger.info("验证邀请令牌")
|
||||
|
||||
# 生成令牌哈希
|
||||
token_hash = hashlib.sha256(token.encode()).hexdigest()
|
||||
@@ -469,7 +469,7 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请令牌无效")
|
||||
business_logger.warning("邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
# 检查邀请状态和过期时间
|
||||
@@ -511,7 +511,7 @@ def accept_workspace_invite(
|
||||
invite = invite_repo.get_invite_by_token_hash(token_hash)
|
||||
|
||||
if not invite:
|
||||
business_logger.warning(f"邀请令牌无效")
|
||||
business_logger.warning("邀请令牌无效")
|
||||
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
|
||||
|
||||
# 检查邀请状态
|
||||
@@ -522,7 +522,7 @@ def accept_workspace_invite(
|
||||
# 检查过期时间
|
||||
now = datetime.datetime.now()
|
||||
if invite.expires_at < now:
|
||||
business_logger.warning(f"邀请已过期")
|
||||
business_logger.warning("邀请已过期")
|
||||
# 标记为过期
|
||||
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
|
||||
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
|
||||
@@ -547,7 +547,7 @@ def accept_workspace_invite(
|
||||
)
|
||||
|
||||
if existing_member:
|
||||
business_logger.info(f"用户已是工作空间成员,更新邀请状态")
|
||||
business_logger.info("用户已是工作空间成员,更新邀请状态")
|
||||
invite_repo.update_invite_status(
|
||||
invite.id,
|
||||
InviteStatus.accepted,
|
||||
@@ -739,6 +739,34 @@ def get_workspace_storage_type(
|
||||
return workspace.storage_type
|
||||
|
||||
|
||||
def get_workspace_storage_type_without_auth(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> Optional[str]:
|
||||
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
storage_type: 存储类型字符串,如果未设置则返回 None
|
||||
"""
|
||||
business_logger.info(f"获取工作空间 {workspace_id} 的存储类型(无权限验证)")
|
||||
|
||||
# 查询工作空间
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
|
||||
raise BusinessException(
|
||||
code=BizCode.WORKSPACE_NOT_FOUND,
|
||||
message="工作空间不存在"
|
||||
)
|
||||
|
||||
business_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {workspace.storage_type}")
|
||||
return workspace.storage_type
|
||||
|
||||
|
||||
def get_workspace_models_configs(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
|
||||
Reference in New Issue
Block a user