[MODIFY] Code optimization
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
"""API Key 管理接口 - 基于 JWT 认证"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.user_model import User
|
||||
@@ -10,142 +14,344 @@ from app.core.response_utils import success
|
||||
from app.schemas import api_key_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.api_key_service import ApiKeyService
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.exceptions import (
|
||||
BusinessException,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/apikeys", tags=["API Keys"])
|
||||
logger = get_business_logger()
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def create_api_key(
|
||||
data: api_key_schema.ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
data: api_key_schema.ApiKeyCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建 API Key
|
||||
|
||||
"""
|
||||
创建 API Key
|
||||
|
||||
- 支持三种类型:app/rag/memory
|
||||
- 创建后返回明文 API Key(仅此一次)
|
||||
- 支持设置权限范围、速率限制、配额等
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 创建 API Key
|
||||
api_key_obj, api_key = ApiKeyService.create_api_key(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 创建成功")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "create_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"创建API Key失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def list_api_keys(
|
||||
type: api_key_schema.ApiKeyType = Query(None),
|
||||
is_active: bool = Query(None),
|
||||
resource_id: uuid.UUID = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(10, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
type: api_key_schema.ApiKeyType = Query(None, description="按类型过滤"),
|
||||
is_active: bool = Query(True, description="按状态过滤"),
|
||||
resource_id: uuid.UUID = Query(None, description="按资源过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""列出 API Keys"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
query = api_key_schema.ApiKeyQuery(
|
||||
type=type,
|
||||
is_active=is_active,
|
||||
resource_id=resource_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||
return success(data=result)
|
||||
"""
|
||||
列出 API Keys
|
||||
|
||||
- 支持多维度过滤
|
||||
- 支持分页
|
||||
- 自动按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
query = api_key_schema.ApiKeyQuery(
|
||||
type=type,
|
||||
is_active=is_active,
|
||||
resource_id=resource_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
result = ApiKeyService.list_api_keys(db, workspace_id, query)
|
||||
|
||||
logger.info("API Keys 查询成功", extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total_count": result.get("total", 0) if isinstance(result, dict) else 0
|
||||
})
|
||||
|
||||
return success(data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "list_api_keys"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"API Keys 查询失败:{str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 详情"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("获取API Key详情成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key"
|
||||
})
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key))
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"获取API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def update_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
data: api_key_schema.ApiKeyUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
"""更新 API Key配置"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_key = ApiKeyService.update_api_key(db, api_key_id, workspace_id, data)
|
||||
|
||||
logger.info("API Key 更新配置成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=api_key_schema.ApiKey.model_validate(api_key), msg="API Key 更新成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "update_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"更新API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{api_key_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def delete_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除 API Key"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
return success(msg="API Key 删除成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
ApiKeyService.delete_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("API Key 删除成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(msg="API Key 删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "delete_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"删除API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{api_key_id}/regenerate", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def regenerate_api_key(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""重新生成 API Key
|
||||
"""
|
||||
重新生成 API Key
|
||||
|
||||
- 生成新的 API Key 并返回明文(仅此一次)
|
||||
- 旧的 API Key 立即失效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
return success(data=response_data, msg="API Key 重新生成成功")
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_key_obj, api_key = ApiKeyService.regenerate_api_key(db, api_key_id, workspace_id)
|
||||
|
||||
# 返回包含明文 Key 的响应
|
||||
response_data = api_key_schema.ApiKeyResponse(
|
||||
**api_key_obj.__dict__,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
logger.info("API Key 重新生成成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=response_data, msg="API Key 重新生成成功")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "regenerate_api_key"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"重新生成API Key失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}/stats", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_stats(
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
api_key_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取 API Key 使用统计"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||
|
||||
return success(data=stats)
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats = ApiKeyService.get_stats(db, api_key_id, workspace_id)
|
||||
|
||||
logger.info("API Key stats retrieved successfully", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id)
|
||||
})
|
||||
|
||||
return success(data=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key_stats"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"获取API Key统计失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/{api_key_id}/logs", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_api_key_logs(
|
||||
api_key_id: uuid.UUID,
|
||||
start_date: Optional[datetime] = Query(None, description="开始日期"),
|
||||
end_date: Optional[datetime] = Query(None, description="结束日期"),
|
||||
status_code: Optional[int] = Query(None, description="HTTP状态码过滤"),
|
||||
endpoint: Optional[str] = Query(None, description="端点路径过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取 API Key 使用日志
|
||||
|
||||
- 支持时间范围过滤
|
||||
- 支持状态码和端点过滤
|
||||
- 按时间倒序返回
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证日期范围
|
||||
if start_date and end_date and start_date > end_date:
|
||||
logger.warning("开始日期晚于结束日期", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat()
|
||||
})
|
||||
raise BusinessException("开始日期不能晚于结束日期", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 验证状态码
|
||||
if status_code and (status_code < 100 or status_code > 599):
|
||||
logger.warning("查询无效的状态码", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"status_code": status_code
|
||||
})
|
||||
raise BusinessException("无效的HTTP状态码", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 构建过滤条件
|
||||
filters = {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"status_code": status_code,
|
||||
"endpoint": endpoint
|
||||
}
|
||||
|
||||
# 调用服务层获取日志
|
||||
result = ApiKeyService.get_logs(
|
||||
db, api_key_id, workspace_id, filters, page, pagesize
|
||||
)
|
||||
|
||||
logger.info("API Key 日志查询成功", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"filters": {k: str(v) if v else None for k, v in filters.items()}
|
||||
})
|
||||
|
||||
return success(data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}", extra={
|
||||
"api_key_id": str(api_key_id),
|
||||
"workspace_id": str(current_user.current_workspace_id),
|
||||
"user_id": str(current_user.id),
|
||||
"operation": "get_api_key_logs"
|
||||
}, exc_info=True)
|
||||
raise Exception(f"API Key 日志查询失败: {str(e)}")
|
||||
|
||||
@@ -121,7 +121,7 @@ def delete_app(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求删除应用",
|
||||
"用户请求删除应用",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
@@ -151,7 +151,7 @@ def copy_app(
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
logger.info(
|
||||
f"用户请求复制应用",
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
"source_app_id": str(app_id),
|
||||
"user_id": str(current_user.id),
|
||||
@@ -432,7 +432,7 @@ async def draft_run(
|
||||
|
||||
# 非流式返回
|
||||
logger.debug(
|
||||
f"开始非流式试运行",
|
||||
"开始非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -456,7 +456,7 @@ async def draft_run(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"试运行返回结果",
|
||||
"试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict"
|
||||
@@ -466,11 +466,11 @@ async def draft_run(
|
||||
# 验证结果
|
||||
try:
|
||||
validated_result = app_schema.DraftRunResponse.model_validate(result)
|
||||
logger.debug(f"结果验证成功")
|
||||
logger.debug("结果验证成功")
|
||||
return success(data=validated_result)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"结果验证失败",
|
||||
"结果验证失败",
|
||||
extra={
|
||||
"error": str(e),
|
||||
"error_type": str(type(e)),
|
||||
@@ -496,7 +496,7 @@ async def draft_run(
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
f"开始多智能体流式试运行",
|
||||
"开始多智能体流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -530,7 +530,7 @@ async def draft_run(
|
||||
|
||||
# 4. 非流式返回
|
||||
logger.debug(
|
||||
f"开始多智能体非流式试运行",
|
||||
"开始多智能体非流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
@@ -542,7 +542,7 @@ async def draft_run(
|
||||
result = await multiservice.run(app_id, multi_agent_request)
|
||||
|
||||
logger.debug(
|
||||
f"多智能体试运行返回结果",
|
||||
"多智能体试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
@@ -599,7 +599,7 @@ async def draft_run_compare(
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比试运行",
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(payload.models),
|
||||
@@ -705,7 +705,7 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"多模型对比完成",
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
|
||||
@@ -178,7 +178,7 @@ async def get_chunks(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document chunk query")
|
||||
api_logger.debug("Start executing document chunk query")
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
total, items = vector_service.search_by_segment(document_id=str(document_id), query=keywords, pagesize=pagesize, page=page, asc=True)
|
||||
api_logger.info(f"Document chunk query successful: total={total}, returned={len(items)} records")
|
||||
@@ -213,7 +213,9 @@ async def create_chunk(
|
||||
"""
|
||||
create chunk
|
||||
"""
|
||||
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={create_data.content}, username: {current_user.username}")
|
||||
# Obtain the actual content
|
||||
content = create_data.chunk_content
|
||||
api_logger.info(f"Create chunk request: kb_id={kb_id}, document_id={document_id}, content={content}, username: {current_user.username}")
|
||||
|
||||
# 1. Obtain knowledge base information
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
@@ -250,7 +252,7 @@ async def create_chunk(
|
||||
"sort_id": sort_id,
|
||||
"status": 1,
|
||||
}
|
||||
chunk = DocumentChunk(page_content=create_data.content, metadata=metadata)
|
||||
chunk = DocumentChunk(page_content=content, metadata=metadata)
|
||||
# 3. Segmented vector storage
|
||||
vector_service.add_chunks([chunk])
|
||||
|
||||
@@ -305,7 +307,9 @@ async def update_chunk(
|
||||
"""
|
||||
Update document chunk content
|
||||
"""
|
||||
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={update_data.content}, username: {current_user.username}")
|
||||
# Obtain the actual content
|
||||
content = update_data.chunk_content
|
||||
api_logger.info(f"Update document chunk content: kb_id={kb_id}, document_id={document_id}, doc_id={doc_id}, content={content}, username: {current_user.username}")
|
||||
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=kb_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
@@ -318,7 +322,7 @@ async def update_chunk(
|
||||
total, items = vector_service.get_by_segment(doc_id=doc_id)
|
||||
if total:
|
||||
chunk = items[0]
|
||||
chunk.page_content = update_data.content
|
||||
chunk.page_content = content
|
||||
vector_service.update_by_segment(chunk)
|
||||
return success(data=chunk, msg="The document chunk has been successfully updated")
|
||||
else:
|
||||
|
||||
@@ -78,7 +78,7 @@ async def get_documents(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing document paging query")
|
||||
api_logger.debug("Start executing document paging query")
|
||||
total, items = document_service.get_documents_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -66,7 +66,7 @@ async def get_files(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing file paging query")
|
||||
api_logger.debug("Start executing file paging query")
|
||||
total, items = file_service.get_files_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -74,8 +74,6 @@ async def get_knowledges(
|
||||
filters = [
|
||||
knowledge_model.Knowledge.workspace_id == current_user.current_workspace_id
|
||||
]
|
||||
if parent_id:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||
|
||||
# Keyword search (fuzzy matching of knowledge base name)
|
||||
if keywords:
|
||||
@@ -91,9 +89,14 @@ async def get_knowledges(
|
||||
filters.append(knowledge_model.Knowledge.id.in_(kb_ids.split(',')))
|
||||
else:
|
||||
filters.append(knowledge_model.Knowledge.status != 2)
|
||||
if parent_id:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == parent_id)
|
||||
else:
|
||||
filters.append(knowledge_model.Knowledge.parent_id == current_user.current_workspace_id)
|
||||
filters.append(knowledge_model.Knowledge.permission_id != knowledge_model.PermissionType.Memory)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base paging query")
|
||||
api_logger.debug("Start executing knowledge base paging query")
|
||||
total, items = knowledge_service.get_knowledges_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -58,7 +58,7 @@ async def get_knowledgeshares(
|
||||
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug(f"Start executing knowledge base sharing and paging query")
|
||||
api_logger.debug("Start executing knowledge base sharing and paging query")
|
||||
total, items = knowledgeshare_service.get_knowledgeshares_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
|
||||
@@ -54,7 +54,7 @@ def validate_config_id(config_id: int, db: Session) -> int:
|
||||
ValueError: If config_id is None, invalid, or doesn't exist in database
|
||||
"""
|
||||
if config_id is None:
|
||||
api_logger.info(f"config_id is required but was not provided")
|
||||
api_logger.info("config_id is required but was not provided")
|
||||
config_id = os.getenv('config_id')
|
||||
if config_id is None:
|
||||
raise ValueError("config_id is required but was not provided")
|
||||
@@ -257,7 +257,7 @@ async def write_server(
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning(f"workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
|
||||
@@ -2,11 +2,13 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_agent_schema import End_User_Information
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
|
||||
@@ -41,6 +43,56 @@ def get_workspace_total_end_users(
|
||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||
return success(data=total_end_users, msg="用户数量获取成功")
|
||||
|
||||
@router.post("/update/end_users", response_model=ApiResponse)
|
||||
async def update_workspace_end_users(
|
||||
user_input: End_User_Information,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
更新工作空间的宿主信息
|
||||
"""
|
||||
username = user_input.end_user_name # 要更新的用户名
|
||||
end_user_input_id = user_input.id # 宿主ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
|
||||
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
|
||||
|
||||
try:
|
||||
# 导入更新函数
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
|
||||
# 转换 end_user_id 为 UUID 类型
|
||||
end_user_uuid = uuid.UUID(end_user_input_id)
|
||||
|
||||
# 直接更新数据库中的 other_name 字段
|
||||
updated_count = update_end_user_other_name(
|
||||
db=db,
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=username
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"updated_count": updated_count,
|
||||
"end_user_id": end_user_input_id,
|
||||
"updated_other_name": username
|
||||
},
|
||||
msg=f"成功更新 {updated_count} 个宿主的信息"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新宿主信息失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新宿主信息失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
@@ -53,6 +105,8 @@ async def get_workspace_end_users(
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
@@ -61,14 +115,21 @@ async def get_workspace_end_users(
|
||||
)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
result.append(
|
||||
{
|
||||
'end_user':end_user,
|
||||
'memory_num':memory_num
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -203,7 +264,7 @@ def get_workspace_memory_list(
|
||||
current_user=current_user,
|
||||
limit=limit
|
||||
)
|
||||
api_logger.info(f"成功获取记忆列表")
|
||||
api_logger.info("成功获取记忆列表")
|
||||
return success(data=memory_list, msg="记忆列表获取成功")
|
||||
|
||||
|
||||
@@ -354,7 +415,7 @@ async def get_chunk_insight(
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk洞察")
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
|
||||
@@ -469,7 +530,7 @@ async def dashboard_data(
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info(f"成功获取neo4j_data")
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
# 如果 storage_type 为 'rag',获取 rag_data
|
||||
elif storage_type == 'rag':
|
||||
@@ -503,9 +564,9 @@ async def dashboard_data(
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info(f"成功获取rag_data")
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
api_logger.info(f"成功获取dashboard整合数据")
|
||||
api_logger.info("成功获取dashboard整合数据")
|
||||
return success(data=result, msg="Dashboard数据获取成功")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends, Query, UploadFile
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -62,7 +65,7 @@ async def get_storage_info(
|
||||
Returns:
|
||||
Storage information
|
||||
"""
|
||||
api_logger.info(f"Storage info requested ")
|
||||
api_logger.info("Storage info requested ")
|
||||
try:
|
||||
result = await memory_storage_service.get_storage_info()
|
||||
return success(data=result)
|
||||
@@ -139,6 +142,7 @@ def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -151,7 +155,7 @@ def create_config(
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
payload.workspace_id = workspace_id
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except Exception as e:
|
||||
@@ -163,6 +167,7 @@ def create_config(
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -173,7 +178,7 @@ def delete_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
except Exception as e:
|
||||
@@ -184,6 +189,7 @@ def delete_config(
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -194,7 +200,7 @@ def update_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -206,6 +212,7 @@ def update_config(
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -216,7 +223,7 @@ def update_config_extracted(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_extracted(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -229,6 +236,7 @@ def update_config_extracted(
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -239,7 +247,7 @@ def update_config_forget(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
@@ -251,6 +259,7 @@ def update_config_forget(
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -261,7 +270,7 @@ def read_config_extracted(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_extracted(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
@@ -272,6 +281,7 @@ def read_config_extracted(
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -282,7 +292,7 @@ def read_config_forget(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
@@ -292,6 +302,7 @@ def read_config_forget(
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -302,7 +313,7 @@ def read_all_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
# 传递 workspace_id 进行过滤(保持为 UUID 类型)
|
||||
result = svc.get_all(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
@@ -315,6 +326,7 @@ def read_all_config(
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
api_logger.info(f"Pilot run requested: config_id={payload.config_id}, dialogue_text_length={len(payload.dialogue_text)}")
|
||||
|
||||
@@ -330,7 +342,7 @@ async def pilot_run(
|
||||
return fail(BizCode.INTERNAL_ERROR, "配置加载异常", str(e))
|
||||
|
||||
try:
|
||||
svc = DataConfigService(get_db_conn())
|
||||
svc = DataConfigService(db)
|
||||
result = await svc.pilot_run(payload)
|
||||
return success(data=result, msg="试运行完成")
|
||||
except ValueError as e:
|
||||
@@ -475,13 +487,13 @@ async def search_for_entity_graph(
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for end_user_id: {end_user_id}")
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
try:
|
||||
result = await analytics_hot_memory_tags(end_user_id, limit)
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
|
||||
@@ -46,7 +46,8 @@ def get_model_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
@@ -55,7 +56,7 @@ def get_model_list(
|
||||
- 单个:?type=LLM
|
||||
- 多个:?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
query = model_schema.ModelConfigQuery(
|
||||
@@ -69,7 +70,7 @@ def get_model_list(
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query)
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
@@ -81,16 +82,17 @@ def get_model_list(
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据ID获取模型配置
|
||||
"""
|
||||
api_logger.info(f"获取模型配置请求: model_id={model_id}")
|
||||
api_logger.info(f"获取模型配置请求: model_id={model_id}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始获取模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置获取成功: {result_orm.name}")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -116,11 +118,11 @@ async def create_model(
|
||||
- 验证失败时会抛出异常,不会创建配置
|
||||
- 可通过 skip_validation=true 跳过验证
|
||||
"""
|
||||
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}")
|
||||
api_logger.info(f"创建模型配置请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始创建模型配置: {model_data.name}")
|
||||
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data)
|
||||
result_orm = await ModelConfigService.create_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -142,11 +144,11 @@ def update_model(
|
||||
"""
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data)
|
||||
result_orm = ModelConfigService.update_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
# 将ORM对象转换为Pydantic模型
|
||||
@@ -167,11 +169,11 @@ def delete_model(
|
||||
"""
|
||||
删除模型配置
|
||||
"""
|
||||
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
api_logger.info(f"删除模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始删除模型配置: model_id={model_id}")
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id)
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置删除成功: model_id={model_id}")
|
||||
return success(msg="模型配置删除成功")
|
||||
except Exception as e:
|
||||
|
||||
@@ -158,7 +158,7 @@ async def run_multi_agent(
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/test-routing",
|
||||
summary="测试智能路由"
|
||||
summary="测试智能路由(支持 Master Agent 模式)"
|
||||
)
|
||||
async def test_routing(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
@@ -166,19 +166,20 @@ async def test_routing(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""测试智能路由功能
|
||||
"""测试智能路由功能(重构版 - 支持 Master Agent)
|
||||
|
||||
支持三种路由模式:
|
||||
- keyword: 仅使用关键词路由
|
||||
- llm: 使用 LLM 路由(需要提供 routing_model_id)
|
||||
- hybrid: 混合路由(关键词 + LLM)
|
||||
- master_agent: 使用 Master Agent 决策(推荐)
|
||||
- llm_router: 使用旧 LLM 路由器(向后兼容)
|
||||
- rule_only: 仅使用规则路由(最快)
|
||||
|
||||
参数:
|
||||
- message: 测试消息
|
||||
- conversation_id: 会话 ID(可选)
|
||||
- routing_model_id: 路由模型 ID(可选,用于 LLM 路由)
|
||||
- routing_model_id: 路由模型 ID(可选)
|
||||
- use_llm: 是否启用 LLM(默认 False)
|
||||
- keyword_threshold: 关键词置信度阈值(默认 0.8)
|
||||
- force_new: 是否强制重新路由(默认 False)
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.llm_router import LLMRouter
|
||||
@@ -276,7 +277,161 @@ async def test_routing(
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/",
|
||||
"/{app_id}/multi-agent/test-master-agent",
|
||||
summary="测试 Master Agent 决策"
|
||||
)
|
||||
async def test_master_agent(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: multi_agent_schema.RoutingTestRequest = ...,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""测试 Master Agent 的路由决策能力
|
||||
|
||||
这个接口专门用于测试新的 Master Agent 路由器,
|
||||
可以看到 Master Agent 的完整决策过程。
|
||||
|
||||
返回信息包括:
|
||||
- 选中的 Agent
|
||||
- 置信度
|
||||
- 决策理由
|
||||
- 是否需要协作
|
||||
- 路由策略(master_agent / rule_fast_path / fallback)
|
||||
"""
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.services.master_agent_router import MasterAgentRouter
|
||||
from app.models import ModelConfig
|
||||
|
||||
# 1. 获取多 Agent 配置
|
||||
service = MultiAgentService(db)
|
||||
config = service.get_config(app_id)
|
||||
|
||||
if not config:
|
||||
return success(
|
||||
data=None,
|
||||
msg="应用未配置多 Agent,无法测试"
|
||||
)
|
||||
|
||||
# 2. 加载 Master Agent
|
||||
from app.models import AppRelease, App
|
||||
|
||||
master_release = db.get(AppRelease, config.master_agent_id)
|
||||
if not master_release:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"Master Agent 发布版本不存在: {config.master_agent_id}"
|
||||
)
|
||||
|
||||
# 获取应用信息
|
||||
app = db.get(App, master_release.app_id)
|
||||
if not app:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"应用不存在: {master_release.app_id}"
|
||||
)
|
||||
|
||||
# 创建 Master Agent 代理对象
|
||||
class AgentConfigProxy:
|
||||
def __init__(self, release, app, config_data):
|
||||
self.id = release.id
|
||||
self.app_id = release.app_id
|
||||
self.app = app
|
||||
self.name = release.name
|
||||
self.description = release.description
|
||||
self.system_prompt = config_data.get("system_prompt")
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
config_data = master_release.config or {}
|
||||
master_agent_config = AgentConfigProxy(master_release, app, config_data)
|
||||
|
||||
# 3. 获取 Master Agent 的模型配置
|
||||
master_model_config = db.get(ModelConfig, master_agent_config.default_model_config_id)
|
||||
if not master_model_config:
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"Master Agent 模型配置不存在: {master_agent_config.default_model_config_id}"
|
||||
)
|
||||
|
||||
# 4. 准备子 Agent 信息
|
||||
sub_agents = {}
|
||||
for sub_agent_info in config.sub_agents:
|
||||
agent_id = sub_agent_info["agent_id"]
|
||||
|
||||
# 加载子 Agent
|
||||
sub_release = db.get(AppRelease, uuid.UUID(agent_id))
|
||||
if sub_release:
|
||||
sub_app = db.get(App, sub_release.app_id)
|
||||
sub_config_data = sub_release.config or {}
|
||||
sub_agent_config = AgentConfigProxy(sub_release, sub_app, sub_config_data)
|
||||
|
||||
sub_agents[agent_id] = {
|
||||
"config": sub_agent_config,
|
||||
"info": sub_agent_info
|
||||
}
|
||||
|
||||
# 5. 初始化 Master Agent 路由器
|
||||
state_manager = ConversationStateManager()
|
||||
router = MasterAgentRouter(
|
||||
db=db,
|
||||
master_agent_config=master_agent_config,
|
||||
master_model_config=master_model_config,
|
||||
sub_agents=sub_agents,
|
||||
state_manager=state_manager,
|
||||
enable_rule_fast_path=True
|
||||
)
|
||||
|
||||
# 6. 执行路由决策
|
||||
try:
|
||||
decision = await router.route(
|
||||
message=request.message,
|
||||
conversation_id=str(request.conversation_id) if request.conversation_id else None,
|
||||
variables=None
|
||||
)
|
||||
|
||||
# 7. 获取选中的 Agent 信息
|
||||
agent_id = decision["selected_agent_id"]
|
||||
agent_info = sub_agents.get(agent_id, {}).get("info", {})
|
||||
|
||||
# 8. 构建响应
|
||||
response_data = {
|
||||
"message": request.message,
|
||||
"master_agent": {
|
||||
"name": master_agent_config.name,
|
||||
"model": master_model_config.name
|
||||
},
|
||||
"decision": {
|
||||
"selected_agent_id": agent_id,
|
||||
"selected_agent_name": agent_info.get("name", "未知"),
|
||||
"selected_agent_role": agent_info.get("role", ""),
|
||||
"confidence": decision["confidence"],
|
||||
"reasoning": decision.get("reasoning", ""),
|
||||
"topic": decision.get("topic", ""),
|
||||
"strategy": decision["strategy"],
|
||||
"routing_method": decision.get("routing_method", ""),
|
||||
"need_collaboration": decision.get("need_collaboration", False),
|
||||
"collaboration_agents": decision.get("collaboration_agents", [])
|
||||
},
|
||||
"config_info": {
|
||||
"total_sub_agents": len(sub_agents),
|
||||
"enable_rule_fast_path": True
|
||||
}
|
||||
}
|
||||
|
||||
return success(
|
||||
data=response_data,
|
||||
msg="Master Agent 决策测试成功"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Master Agent 决策测试失败: {str(e)}")
|
||||
return success(
|
||||
data=None,
|
||||
msg=f"测试失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{app_id}/multi-agent/batch-test-routing",
|
||||
summary="批量测试智能路由"
|
||||
)
|
||||
async def batch_test_routing(
|
||||
|
||||
@@ -5,9 +5,10 @@ import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import jwt
|
||||
from app.services import task_service, workspace_service
|
||||
from typing import Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
from app.dependencies import get_current_superuser, get_current_user, get_current_tenant, workspace_access_guard, cur_workspace_access_guard
|
||||
from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -21,8 +22,10 @@ from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.repositories.app_repository import AppRepository
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -95,7 +98,7 @@ def get_access_token(
|
||||
access_token = create_access_token(user_id, share_token)
|
||||
|
||||
logger.info(
|
||||
f"生成访问 token",
|
||||
"生成访问 token",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"user_id": user_id
|
||||
@@ -270,7 +273,7 @@ def get_conversation(
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -313,6 +316,45 @@ async def chat(
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
)
|
||||
|
||||
|
||||
appid=share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if storage_type is None:
|
||||
storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
# 获取应用类型
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
@@ -339,7 +381,7 @@ async def chat(
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"参数验证完成",
|
||||
"参数验证完成",
|
||||
extra={
|
||||
"share_token": share_token,
|
||||
"app_type": app_type,
|
||||
@@ -365,7 +407,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -388,7 +432,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -403,7 +449,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -426,7 +474,9 @@ async def chat(
|
||||
variables=payload.variables,
|
||||
password=password,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result))
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.db import get_db
|
||||
from app.core.response_utils import success
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
router = APIRouter(prefix="/v1/apps", tags=["V1 - App API"])
|
||||
router = APIRouter(prefix="/apps", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get(f"/llm/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
@@ -62,7 +62,7 @@ Answer: Let's think step by step."""
|
||||
raise
|
||||
|
||||
|
||||
@router.get(f"/embedding/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/embedding/{model_id}", response_model=ApiResponse)
|
||||
def test_embedding(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
@@ -96,7 +96,7 @@ def test_embedding(
|
||||
return success(msg="测试LLM成功")
|
||||
|
||||
|
||||
@router.get(f"/rerank/{{model_id}}", response_model=ApiResponse)
|
||||
@router.get("/rerank/{model_id}", response_model=ApiResponse)
|
||||
def test_rerank(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db)
|
||||
|
||||
@@ -73,7 +73,7 @@ def get_workspaces(
|
||||
if not include_current and current_user.current_workspace_id:
|
||||
workspaces = [w for w in workspaces if w.id != current_user.current_workspace_id]
|
||||
api_logger.debug(
|
||||
f"过滤掉当前工作空间",
|
||||
"过滤掉当前工作空间",
|
||||
extra={"current_workspace_id": str(current_user.current_workspace_id)}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user