feat(memory): Add memory config API controller and end user info endpoints

- Create new memory_config_api_controller.py for dedicated memory configuration management
- Add /end_user/info GET endpoint to retrieve end user information (aliases, metadata)
- Add /end_user/info/update POST endpoint to update end user details
- Move /memory/configs endpoint from memory_api_controller to memory_config_api_controller
- Extract _get_current_user helper function to build user context from API key auth
- Support optional app_id parameter in end user creation with UUID validation
- Update service controller imports with alphabetical ordering and multi-line formatting
- Register memory_config_api_controller router in service module initialization
- Refactor memory_api_controller imports for consistency and clarity
This commit is contained in:
Ke Sun
2026-04-01 15:06:26 +08:00
parent 99ff07ccac
commit 7ce29019f7
7 changed files with 219 additions and 175 deletions

View File

@@ -4,7 +4,17 @@
认证方式: API Key 认证方式: API Key
""" """
from fastapi import APIRouter from fastapi import APIRouter
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
from . import (
app_api_controller,
end_user_api_controller,
memory_api_controller,
memory_config_api_controller,
rag_api_chunk_controller,
rag_api_document_controller,
rag_api_file_controller,
rag_api_knowledge_controller,
)
# 创建 V1 API 路由器 # 创建 V1 API 路由器
service_router = APIRouter() service_router = APIRouter()
@@ -17,5 +27,6 @@ service_router.include_router(rag_api_file_controller.router)
service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router) service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router) service_router.include_router(end_user_api_controller.router)
service_router.include_router(memory_config_api_controller.router)
__all__ = ["service_router"] __all__ = ["service_router"]

View File

@@ -5,6 +5,7 @@ import uuid
from fastapi import APIRouter, Body, Depends, Request from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.controllers import user_memory_controllers
from app.core.api_key_auth import require_api_key from app.core.api_key_auth import require_api_key
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
@@ -13,13 +14,31 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.end_user_info_schema import EndUserInfoUpdate
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
from app.services import api_key_service
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"]) router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
logger = get_business_logger() logger = get_business_logger()
def _get_current_user(api_key_auth: ApiKeyAuth, db: Session):
"""Build a current_user object from API key auth
Args:
api_key_auth: Validated API key auth info
db: Database session
Returns:
User object with current_workspace_id set
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return current_user
@router.post("/create") @router.post("/create")
@require_api_key(scopes=["memory"]) @require_api_key(scopes=["memory"])
async def create_end_user( async def create_end_user(
@@ -37,6 +56,7 @@ async def create_end_user(
Optionally accepts a memory_config_id to connect the end user to a specific Optionally accepts a memory_config_id to connect the end user to a specific
memory configuration. If not provided, falls back to the workspace default config. memory configuration. If not provided, falls back to the workspace default config.
Optionally accepts an app_id to bind the end user to a specific app.
""" """
body = await request.json() body = await request.json()
payload = CreateEndUserRequest(**body) payload = CreateEndUserRequest(**body)
@@ -71,9 +91,20 @@ async def create_end_user(
else: else:
logger.warning(f"No default memory config found for workspace: {workspace_id}") logger.warning(f"No default memory config found for workspace: {workspace_id}")
# Resolve app_id: explicit from payload, otherwise None
app_id = None
if payload.app_id:
try:
app_id = uuid.UUID(payload.app_id)
except ValueError:
raise BusinessException(
f"Invalid app_id format: {payload.app_id}",
BizCode.INVALID_PARAMETER
)
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user_with_config( end_user = end_user_repo.get_or_create_end_user_with_config(
app_id=api_key_auth.resource_id, app_id=app_id,
workspace_id=workspace_id, workspace_id=workspace_id,
other_id=payload.other_id, other_id=payload.other_id,
memory_config_id=memory_config_id, memory_config_id=memory_config_id,
@@ -90,3 +121,50 @@ async def create_end_user(
} }
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
@router.get("/info")
@require_api_key(scopes=["memory"])
async def get_end_user_info(
request: Request,
end_user_id: str,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Get end user info.
Retrieves the info record (aliases, meta_data, etc.) for the specified end user.
Delegates to the manager-side controller for shared logic.
"""
current_user = _get_current_user(api_key_auth, db)
return await user_memory_controllers.get_end_user_info(
end_user_id=end_user_id,
current_user=current_user,
db=db,
)
@router.post("/info/update")
@require_api_key(scopes=["memory"])
async def update_end_user_info(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(None, description="Request body"),
):
"""
Update end user info.
Updates the info record (other_name, aliases, meta_data) for the specified end user.
Delegates to the manager-side controller for shared logic.
"""
body = await request.json()
payload = EndUserInfoUpdate(**body)
current_user = _get_current_user(api_key_auth, db)
return await user_memory_controllers.update_end_user_info(
info_update=payload,
current_user=current_user,
db=db,
)

View File

@@ -1,22 +1,20 @@
"""Memory 服务接口 - 基于 API Key 认证""" """Memory 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import ( from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
ListConfigsResponse,
MemoryReadRequest, MemoryReadRequest,
MemoryReadResponse, MemoryReadResponse,
MemoryWriteRequest, MemoryWriteRequest,
MemoryWriteResponse, MemoryWriteResponse,
) )
from app.services.memory_api_service import MemoryAPIService from app.services.memory_api_service import MemoryAPIService
from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session
router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) router = APIRouter(prefix="/memory", tags=["V1 - Memory API"])
logger = get_business_logger() logger = get_business_logger()
@@ -91,55 +89,3 @@ async def read_memory_api_service(
logger.info(f"Memory read successful for end_user: {payload.end_user_id}") logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully") return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
@router.get("/configs")
@require_api_key(scopes=["memory"])
async def list_memory_configs(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
List all memory configs for the workspace.
Returns all available memory configurations associated with the authorized workspace.
"""
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.list_memory_configs(
workspace_id=api_key_auth.workspace_id,
)
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
@router.post("/end_users")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the authorized workspace.
If an end user with the same other_id already exists, returns the existing one.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.create_end_user(
workspace_id=api_key_auth.workspace_id,
other_id=payload.other_id,
)
logger.info(f"End user ready: {result['id']}")
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -0,0 +1,39 @@
"""Memory Config 服务接口 - 基于 API Key 认证"""
from fastapi import APIRouter, Depends, Request
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import ListConfigsResponse
from app.services.memory_api_service import MemoryAPIService
router = APIRouter(prefix="/memory_config", tags=["V1 - Memory Config API"])
logger = get_business_logger()
@router.get("/configs")
@require_api_key(scopes=["memory"])
async def list_memory_configs(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
List all memory configs for the workspace.
Returns all available memory configurations associated with the authorized workspace.
"""
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.list_memory_configs(
workspace_id=api_key_auth.workspace_id,
)
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")

View File

@@ -4,11 +4,6 @@
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。 本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
""" """
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
from app.core.memory.storage_services.search.search_strategy import ( from app.core.memory.storage_services.search.search_strategy import (
@@ -29,115 +24,87 @@ __all__ = [
# ============================================================================ # ============================================================================
# 向后兼容的函数式API # 向后兼容的函数式API (DEPRECATED - 未被使用)
# ============================================================================ # ============================================================================
# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口 # 所有调用方均直接使用 app.core.memory.src.search.run_hybrid_search
# 保留注释以备参考
# async def run_hybrid_search(
async def run_hybrid_search( # query_text: str,
query_text: str, # search_type: str = "hybrid",
search_type: str = "hybrid", # end_user_id: str | None = None,
end_user_id: str | None = None, # apply_id: str | None = None,
apply_id: str | None = None, # user_id: str | None = None,
user_id: str | None = None, # limit: int = 50,
limit: int = 50, # include: list[str] | None = None,
include: list[str] | None = None, # alpha: float = 0.6,
alpha: float = 0.6, # use_forgetting_curve: bool = False,
use_forgetting_curve: bool = False, # memory_config: "MemoryConfig" = None,
memory_config: "MemoryConfig" = None, # **kwargs
**kwargs # ) -> dict:
) -> dict: # """运行混合搜索向后兼容的函数式API"""
"""运行混合搜索向后兼容的函数式API # from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
# from app.core.models.base import RedBearModelConfig
这是一个向后兼容的包装函数将旧的函数式API转换为新的基于类的API。 # from app.db import get_db_context
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
Args: # from app.services.memory_config_service import MemoryConfigService
query_text: 查询文本 #
search_type: 搜索类型("hybrid", "keyword", "semantic" # if not memory_config:
end_user_id: 组ID过滤 # raise ValueError("memory_config is required for search")
apply_id: 应用ID过滤 #
user_id: 用户ID过滤 # connector = Neo4jConnector()
limit: 每个类别的最大结果数 # with get_db_context() as db:
include: 要包含的搜索类别列表 # config_service = MemoryConfigService(db)
alpha: BM25分数权重0.0-1.0 # embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
use_forgetting_curve: 是否使用遗忘曲线 # embedder_config = RedBearModelConfig(**embedder_config_dict)
memory_config: MemoryConfig object containing embedding_model_id # embedder_client = OpenAIEmbedderClient(embedder_config)
**kwargs: 其他参数 #
# try:
Returns: # if search_type == "keyword":
dict: 搜索结果字典格式与旧API兼容 # strategy = KeywordSearchStrategy(connector=connector)
""" # elif search_type == "semantic":
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient # strategy = SemanticSearchStrategy(
from app.core.models.base import RedBearModelConfig # connector=connector,
from app.db import get_db_context # embedder_client=embedder_client
from app.repositories.neo4j.neo4j_connector import Neo4jConnector # )
from app.services.memory_config_service import MemoryConfigService # else:
# strategy = HybridSearchStrategy(
if not memory_config: # connector=connector,
raise ValueError("memory_config is required for search") # embedder_client=embedder_client,
# alpha=alpha,
# 初始化客户端 # use_forgetting_curve=use_forgetting_curve
connector = Neo4jConnector() # )
with get_db_context() as db: #
config_service = MemoryConfigService(db) # result = await strategy.search(
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) # query_text=query_text,
embedder_config = RedBearModelConfig(**embedder_config_dict) # end_user_id=end_user_id,
embedder_client = OpenAIEmbedderClient(embedder_config) # limit=limit,
# include=include,
try: # alpha=alpha,
# 根据搜索类型选择策略 # use_forgetting_curve=use_forgetting_curve,
if search_type == "keyword": # **kwargs
strategy = KeywordSearchStrategy(connector=connector) # )
elif search_type == "semantic": #
strategy = SemanticSearchStrategy( # result_dict = result.to_dict()
connector=connector, #
embedder_client=embedder_client # output_path = kwargs.get('output_path', 'search_results.json')
) # if output_path:
else: # hybrid # import json
strategy = HybridSearchStrategy( # import os
connector=connector, # from datetime import datetime
embedder_client=embedder_client, #
alpha=alpha, # try:
use_forgetting_curve=use_forgetting_curve # out_dir = os.path.dirname(output_path)
) # if out_dir:
# os.makedirs(out_dir, exist_ok=True)
# 执行搜索 # with open(output_path, "w", encoding="utf-8") as f:
result = await strategy.search( # json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
query_text=query_text, # print(f"Search results saved to {output_path}")
end_user_id=end_user_id, # except Exception as e:
limit=limit, # print(f"Error saving search results: {e}")
include=include, # return result_dict
alpha=alpha, #
use_forgetting_curve=use_forgetting_curve, # finally:
**kwargs # await connector.close()
) #
# __all__.append("run_hybrid_search")
# 转换为旧格式
result_dict = result.to_dict()
# 保存到文件如果指定了output_path
output_path = kwargs.get('output_path', 'search_results.json')
if output_path:
import json
import os
from datetime import datetime
try:
# 确保目录存在
out_dir = os.path.dirname(output_path)
if out_dir:
os.makedirs(out_dir, exist_ok=True)
# 保存结果
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
print(f"Search results saved to {output_path}")
except Exception as e:
print(f"Error saving search results: {e}")
return result_dict
finally:
await connector.close()
__all__.append("run_hybrid_search")

View File

@@ -141,10 +141,12 @@ class CreateEndUserRequest(BaseModel):
other_id: External user identifier (required) other_id: External user identifier (required)
other_name: Display name for the end user other_name: Display name for the end user
memory_config_id: Optional memory config ID. If not provided, uses workspace default. memory_config_id: Optional memory config ID. If not provided, uses workspace default.
app_id: Optional app ID to bind the end user to.
""" """
other_id: str = Field(..., description="External user identifier (required)") other_id: str = Field(..., description="External user identifier (required)")
other_name: Optional[str] = Field("", description="Display name") other_name: Optional[str] = Field("", description="Display name")
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.") memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
app_id: Optional[str] = Field(None, description="App ID to bind the end user to")
@field_validator("other_id") @field_validator("other_id")
@classmethod @classmethod

View File

@@ -8,6 +8,8 @@ This service validates inputs and delegates to MemoryAgentService for core memor
import uuid import uuid
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_logger from app.core.logging_config import get_logger
@@ -15,7 +17,6 @@ from app.models.app_model import App
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.schemas.memory_config_schema import ConfigurationError from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from sqlalchemy.orm import Session
logger = get_logger(__name__) logger = get_logger(__name__)