From 42b59a644d667f763d034e6786495d47c336d7c9 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Wed, 28 Jan 2026 12:02:35 +0800 Subject: [PATCH] feat(memory): add protected memory config deletion with end-user safeguards - Add force parameter to delete_config endpoint for controlled deletion of in-use configs - Implement MemoryConfigService.delete_config with protection against deleting default configs - Add validation to prevent deletion of configs with connected end-users unless force=True - Reorganize controller imports to remove duplicates and improve maintainability - Clean up unused database connection management code from memory_storage_controller - Add detailed docstring to delete_config endpoint explaining protection mechanisms - Update error handling with specific BizCode.RESOURCE_IN_USE for configs in active use - Add comprehensive logging for deletion attempts, warnings, and affected users - Refactor ConfigParamsDelete schema usage to use MemoryConfigService directly - Improve API response structure with affected_users count and force_required flag --- api/app/controllers/__init__.py | 6 +- .../controllers/memory_storage_controller.py | 150 ++++++----- api/app/core/error_codes.py | 2 + .../analyzers/dimension_analyzer.py | 4 +- .../analyzers/interest_analyzer.py | 19 +- api/app/models/end_user_model.py | 13 +- api/app/models/memory_config_model.py | 5 +- api/app/repositories/end_user_repository.py | 214 +++++++++++++++- api/app/schemas/implicit_memory_schema.py | 6 - api/app/services/app_service.py | 181 +++++++++++++- api/app/services/memory_agent_service.py | 85 ++++--- api/app/services/memory_config_service.py | 232 +++++++++++++++++- api/app/services/workspace_service.py | 94 ++++++- 13 files changed, 823 insertions(+), 188 deletions(-) diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 3701f14d..f96d0b7e 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -24,9 +24,11 @@ from . import ( memory_episodic_controller, memory_explicit_controller, memory_forget_controller, + memory_perceptual_controller, memory_reflection_controller, memory_short_term_controller, memory_storage_controller, + memory_working_controller, model_controller, multi_agent_controller, prompt_optimizer_controller, @@ -41,10 +43,6 @@ from . import ( user_memory_controllers, workflow_controller, workspace_controller, - memory_forget_controller, - home_page_controller, - memory_perceptual_controller, - memory_working_controller, ) # 创建管理端 API 路由器 diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index f24d2f70..f4773bc3 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,7 +1,10 @@ -import os from typing import Optional from uuid import UUID +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + from app.core.error_codes import BizCode from app.core.logging_config import get_api_logger from app.core.response_utils import fail, success @@ -11,7 +14,6 @@ from app.models.user_model import User from app.schemas.memory_storage_schema import ( ConfigKey, ConfigParamsCreate, - ConfigParamsDelete, ConfigPilotRun, ConfigUpdate, ConfigUpdateExtracted, @@ -31,9 +33,6 @@ from app.services.memory_storage_service import ( search_entity, search_statement, ) -from fastapi import APIRouter, Depends -from fastapi.responses import StreamingResponse -from sqlalchemy.orm import Session # Get API logger api_logger = get_api_logger() @@ -70,68 +69,9 @@ async def get_storage_info( return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) -# --- DB connection dependency --- -_CONN: Optional[object] = None -"""PostgreSQL 连接生成与管理(使用 psycopg2)。""" -# 这个可以转移,可能是已经有的 -# PostgreSQL 数据库连接 -def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接 - host = os.getenv("DB_HOST") - user = os.getenv("DB_USER") - password = os.getenv("DB_PASSWORD") - database = os.getenv("DB_NAME") - port_str = os.getenv("DB_PORT") - try: - import psycopg2 # type: ignore - port = int(port_str) if port_str else 5432 - conn = psycopg2.connect( - host=host or "localhost", - port=port, - user=user, - password=password, - dbname=database, - ) - # 设置自动提交,避免显式事务管理 - conn.autocommit = True - # 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示 - try: - cur = conn.cursor() - cur.execute("SET TIME ZONE 'Asia/Shanghai'") - cur.close() - except Exception: - # 时区设置失败不影响连接,仅记录但不抛出 - pass - return conn - except Exception as e: - try: - print(f"[PostgreSQL] 连接失败: {e}") - except Exception: - pass - return None -def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接 - global _CONN - if _CONN is None: - _CONN = _make_pgsql_conn() - return _CONN - - -def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接 - """Close and recreate the global DB connection.""" - global _CONN - try: - if _CONN: - try: - _CONN.close() - except Exception: - pass - _CONN = _make_pgsql_conn() - return _CONN is not None - except Exception: - _CONN = None - return False @router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认 @@ -139,7 +79,7 @@ def create_config( payload: ConfigParamsCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -162,9 +102,20 @@ def create_config( @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) def delete_config( config_id: UUID, + force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: + """删除记忆配置(带终端用户保护) + + - 检查是否为默认配置,默认配置不允许删除 + - 检查是否有终端用户连接到该配置 + - 如果有连接且 force=False,返回警告 + - 如果 force=True,清除终端用户引用后删除配置 + + Query Parameters: + force: 设置为 true 可强制删除(即使有终端用户正在使用) + """ workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -172,21 +123,62 @@ def delete_config( api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}") + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " + f"config_id={config_id}, force={force}" + ) + try: - svc = DataConfigService(db) - result = svc.delete(ConfigParamsDelete(config_id=config_id)) - return success(data=result, msg="删除成功") + # 使用带保护的删除服务 + from app.services.memory_config_service import MemoryConfigService + + config_service = MemoryConfigService(db) + result = config_service.delete_config(config_id=config_id, force=force) + + if result["status"] == "error": + api_logger.warning( + f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" + ) + return fail( + code=BizCode.FORBIDDEN, + msg=result["message"], + data={"config_id": str(config_id), "is_default": result.get("is_default", False)} + ) + + if result["status"] == "warning": + api_logger.warning( + f"记忆配置正在使用,无法删除: config_id={config_id}, " + f"connected_count={result['connected_count']}" + ) + return fail( + code=BizCode.RESOURCE_IN_USE, + msg=result["message"], + data={ + "connected_count": result["connected_count"], + "force_required": result["force_required"] + } + ) + + api_logger.info( + f"记忆配置删除成功: config_id={config_id}, " + f"affected_users={result['affected_users']}" + ) + return success( + msg=result["message"], + data={"affected_users": result["affected_users"]} + ) + except Exception as e: - api_logger.error(f"Delete config failed: {str(e)}") + api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) + @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc def update_config( payload: ConfigUpdate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -209,7 +201,7 @@ def update_config_extracted( payload: ConfigUpdateExtracted, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -236,7 +228,7 @@ def read_config_extracted( config_id: UUID | int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -257,7 +249,7 @@ def read_config_extracted( def read_all_config( current_user: User = Depends(get_current_user), db: Session = Depends(get_db), - ) -> dict: +) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -297,9 +289,8 @@ async def pilot_run( }, ) -""" -以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。 -""" + +# ==================== Search & Analytics ==================== @router.get("/search/kb_type_distribution", response_model=ApiResponse) async def get_kb_type_distribution( @@ -439,8 +430,9 @@ async def get_hot_memory_tags_api( try: # 尝试从Redis缓存获取 - from app.aioRedis import aio_redis_get, aio_redis_set import json + + from app.aioRedis import aio_redis_get, aio_redis_set cached_result = await aio_redis_get(cache_key) if cached_result: diff --git a/api/app/core/error_codes.py b/api/app/core/error_codes.py index cb0084b7..3feae4f6 100644 --- a/api/app/core/error_codes.py +++ b/api/app/core/error_codes.py @@ -46,6 +46,7 @@ class BizCode(IntEnum): RESOURCE_ALREADY_EXISTS = 5002 VERSION_ALREADY_EXISTS = 5003 STATE_CONFLICT = 5004 + RESOURCE_IN_USE = 5005 # 应用发布(6xxx) PUBLISH_FAILED = 6001 @@ -125,6 +126,7 @@ HTTP_MAPPING = { BizCode.RESOURCE_ALREADY_EXISTS: 409, BizCode.VERSION_ALREADY_EXISTS: 409, BizCode.STATE_CONFLICT: 409, + BizCode.RESOURCE_IN_USE: 409, BizCode.PUBLISH_FAILED: 500, BizCode.NO_DRAFT_TO_PUBLISH: 400, BizCode.ROLLBACK_TARGET_NOT_FOUND: 400, diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py index 521ac383..e8d728dc 100644 --- a/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/dimension_analyzer.py @@ -108,7 +108,6 @@ class DimensionAnalyzer: # Create dimension portrait portrait = DimensionPortrait( - user_id=user_id, creativity=dimension_scores["creativity"], aesthetic=dimension_scores["aesthetic"], technology=dimension_scores["technology"], @@ -220,7 +219,7 @@ class DimensionAnalyzer: """Create an empty dimension portrait when no data is available. Args: - user_id: Target user ID + user_id: Target user ID (used for logging only) Returns: Empty DimensionPortrait @@ -228,7 +227,6 @@ class DimensionAnalyzer: current_time = datetime.now() return DimensionPortrait( - user_id=user_id, creativity=self._create_default_dimension_score("creativity"), aesthetic=self._create_default_dimension_score("aesthetic"), technology=self._create_default_dimension_score("technology"), diff --git a/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py b/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py index dc65d740..3b52b372 100644 --- a/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py +++ b/api/app/core/memory/analytics/implicit_memory/analyzers/interest_analyzer.py @@ -7,7 +7,7 @@ providing percentage distribution that totals 100%. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient from app.core.memory.llm_tools.llm_client import LLMClientException @@ -133,7 +133,6 @@ class InterestAnalyzer: # Create interest area distribution distribution = InterestAreaDistribution( - user_id=user_id, tech=interest_categories["tech"], lifestyle=interest_categories["lifestyle"], music=interest_categories["music"], @@ -251,7 +250,7 @@ class InterestAnalyzer: """Create an empty interest distribution when no data is available. Args: - user_id: Target user ID + user_id: Target user ID (used for logging only) Returns: Empty InterestAreaDistribution with equal percentages @@ -259,15 +258,15 @@ class InterestAnalyzer: current_time = datetime.now() equal_percentage = 25.0 # 100% / 4 categories - default_category = lambda name: InterestCategory( - category_name=name, - percentage=equal_percentage, - evidence=["Insufficient data for analysis"], - trending_direction=None - ) + def default_category(name: str) -> InterestCategory: + return InterestCategory( + category_name=name, + percentage=equal_percentage, + evidence=["Insufficient data for analysis"], + trending_direction=None + ) return InterestAreaDistribution( - user_id=user_id, tech=default_category("tech"), lifestyle=default_category("lifestyle"), music=default_category("music"), diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index 2839b74b..eff17f7e 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -1,11 +1,12 @@ import datetime import uuid -from app.db import Base -from sqlalchemy import BigInteger, Column, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy import Column, DateTime, ForeignKey, String, Text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship +from app.db import Base + class EndUser(Base): __tablename__ = "end_users" @@ -21,7 +22,13 @@ class EndUser(Base): updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) # Memory config association - updated lazily during conversation - memory_config_id = Column(Integer, ForeignKey("data_config.config_id"), nullable=True, index=True, comment="关联的记忆配置ID") + memory_config_id = Column( + UUID(as_uuid=True), + ForeignKey("memory_config.config_id"), + nullable=True, + index=True, + comment="关联的记忆配置ID" + ) # 用户基本信息字段 position = Column(String, nullable=True, comment="职位") diff --git a/api/app/models/memory_config_model.py b/api/app/models/memory_config_model.py index 454b1b48..816ece79 100644 --- a/api/app/models/memory_config_model.py +++ b/api/app/models/memory_config_model.py @@ -1,6 +1,8 @@ import datetime -from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float + +from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String from sqlalchemy.dialects.postgresql import UUID + from app.db import Base @@ -38,6 +40,7 @@ class MemoryConfig(Base): # 状态配置 state = Column(Boolean, default=False, comment="配置使用状态") + is_default = Column(Boolean, default=False, comment="是否为工作空间默认配置") # 分块策略 chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略") diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index c7d13f8f..522a43b3 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -1,13 +1,13 @@ -from sqlalchemy.orm import Session -from typing import List, Optional -import uuid import datetime +import uuid +from typing import List, Optional -from app.models.end_user_model import EndUser -from app.models.app_model import App -from app.models.workspace_model import Workspace +from sqlalchemy.orm import Session from app.core.logging_config import get_db_logger +from app.models.app_model import App +from app.models.end_user_model import EndUser +from app.models.workspace_model import Workspace # 获取数据库专用日志器 db_logger = get_db_logger() @@ -264,6 +264,179 @@ class EndUserRepository: db_logger.error(f"查询活动工作空间时出错: {str(e)}") raise + def update_memory_config_id(self, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool: + """更新终端用户的 memory_config_id(懒更新)。 + + Args: + end_user_id: 终端用户ID + memory_config_id: 记忆配置ID + + Returns: + bool: 更新成功返回True,否则返回False + """ + try: + updated_count = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .update( + {EndUser.memory_config_id: memory_config_id}, + synchronize_session=False + ) + ) + self.db.commit() + + if updated_count > 0: + db_logger.debug(f"成功更新终端用户 {end_user_id} 的 memory_config_id: {memory_config_id}") + return True + else: + db_logger.warning(f"未找到终端用户 {end_user_id},无法更新 memory_config_id") + return False + except Exception as e: + self.db.rollback() + db_logger.error(f"更新终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}") + raise + + def get_memory_config_id(self, end_user_id: uuid.UUID) -> Optional[uuid.UUID]: + """获取终端用户的 memory_config_id。 + + Args: + end_user_id: 终端用户ID + + Returns: + Optional[uuid.UUID]: memory_config_id 或 None + """ + try: + end_user = ( + self.db.query(EndUser) + .filter(EndUser.id == end_user_id) + .first() + ) + if end_user and end_user.memory_config_id: + db_logger.debug(f"获取终端用户 {end_user_id} 的 memory_config_id: {end_user.memory_config_id}") + return end_user.memory_config_id + return None + except Exception as e: + self.db.rollback() + db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}") + raise + + def batch_update_memory_config_id( + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID + ) -> int: + """批量更新应用下所有终端用户的 memory_config_id + + Args: + app_id: 应用ID + memory_config_id: 新的记忆配置ID + + Returns: + int: 更新的行数 + """ + try: + from sqlalchemy import update + + stmt = ( + update(EndUser) + .where(EndUser.app_id == app_id) + .values(memory_config_id=memory_config_id) + ) + + result = self.db.execute(stmt) + self.db.commit() + + updated_count = result.rowcount + + db_logger.info( + f"批量更新终端用户记忆配置: app_id={app_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + + return updated_count + + except Exception as e: + self.db.rollback() + db_logger.error( + f"批量更新终端用户记忆配置时出错: app_id={app_id}, " + f"memory_config_id={memory_config_id}, error={str(e)}" + ) + raise + + def count_by_memory_config_id( + self, + memory_config_id: uuid.UUID + ) -> int: + """统计使用指定记忆配置的终端用户数量 + + Args: + memory_config_id: 记忆配置ID + + Returns: + int: 使用该配置的终端用户数量 + """ + try: + from sqlalchemy import func, select + + stmt = ( + select(func.count(EndUser.id)) + .where(EndUser.memory_config_id == memory_config_id) + ) + + count = self.db.execute(stmt).scalar() or 0 + + db_logger.debug(f"统计记忆配置使用数: memory_config_id={memory_config_id}, count={count}") + + return count + + except Exception as e: + self.db.rollback() + db_logger.error(f"统计记忆配置使用数时出错: memory_config_id={memory_config_id}, error={str(e)}") + raise + + def clear_memory_config_id( + self, + memory_config_id: uuid.UUID + ) -> int: + """清除所有使用指定记忆配置的终端用户的 memory_config_id + + 将 memory_config_id 设置为 NULL + + Args: + memory_config_id: 要清除的记忆配置ID + + Returns: + int: 清除的行数 + """ + try: + from sqlalchemy import update + + stmt = ( + update(EndUser) + .where(EndUser.memory_config_id == memory_config_id) + .values(memory_config_id=None) + ) + + result = self.db.execute(stmt) + self.db.commit() + + cleared_count = result.rowcount + + db_logger.warning( + f"清除终端用户记忆配置引用: memory_config_id={memory_config_id}, " + f"cleared_count={cleared_count}" + ) + + return cleared_count + + except Exception as e: + self.db.rollback() + db_logger.error( + f"清除终端用户记忆配置引用时出错: memory_config_id={memory_config_id}, " + f"error={str(e)}" + ) + raise + def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]: """根据应用ID查询宿主(返回 EndUser ORM 列表)""" repo = EndUserRepository(db) @@ -315,3 +488,32 @@ def get_all_active_workspaces(db: Session) -> List[uuid.UUID]: """获取所有活动工作空间的ID""" repo = EndUserRepository(db) return repo.get_all_active_workspaces() + + +def update_memory_config_id(db: Session, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool: + """更新终端用户的 memory_config_id(懒更新)。 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + memory_config_id: 记忆配置ID + + Returns: + bool: 更新成功返回True,否则返回False + """ + repo = EndUserRepository(db) + return repo.update_memory_config_id(end_user_id, memory_config_id) + + +def get_memory_config_id(db: Session, end_user_id: uuid.UUID) -> Optional[uuid.UUID]: + """获取终端用户的 memory_config_id。 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID + + Returns: + Optional[uuid.UUID]: memory_config_id 或 None + """ + repo = EndUserRepository(db) + return repo.get_memory_config_id(end_user_id) diff --git a/api/app/schemas/implicit_memory_schema.py b/api/app/schemas/implicit_memory_schema.py index ced50b92..7ff7824d 100644 --- a/api/app/schemas/implicit_memory_schema.py +++ b/api/app/schemas/implicit_memory_schema.py @@ -112,7 +112,6 @@ class DimensionPortraitResponse(BaseModel): """Four-dimension personality portrait.""" model_config = ConfigDict(from_attributes=True) - user_id: str creativity: DimensionScoreResponse aesthetic: DimensionScoreResponse technology: DimensionScoreResponse @@ -140,7 +139,6 @@ class InterestAreaDistributionResponse(BaseModel): """Distribution of user interests across four areas.""" model_config = ConfigDict(from_attributes=True) - user_id: str tech: InterestCategoryResponse lifestyle: InterestCategoryResponse music: InterestCategoryResponse @@ -184,7 +182,6 @@ class UserProfileResponse(BaseModel): """Comprehensive user profile.""" model_config = ConfigDict(from_attributes=True) - user_id: str preference_tags: List[PreferenceTagResponse] dimension_portrait: DimensionPortraitResponse interest_area_distribution: InterestAreaDistributionResponse @@ -226,7 +223,6 @@ class UserMemorySummary(BaseModel): model_config = ConfigDict(from_attributes=True) summary_id: str - user_id: str user_content: str timestamp: datetime.datetime confidence_score: float = Field(ge=0.0, le=1.0) @@ -241,7 +237,6 @@ class SummaryAnalysisResult(BaseModel): """Result of analyzing memory summaries.""" model_config = ConfigDict(from_attributes=True) - user_id: str preferences: List[PreferenceTagResponse] dimension_evidence: Dict[str, List[str]] interest_evidence: Dict[str, List[str]] @@ -273,7 +268,6 @@ class GenerateProfileRequest(BaseModel): class CompleteProfileResponse(BaseModel): """完整用户画像响应(包含所有模块)""" - user_id: str preferences: List[PreferenceTagResponse] portrait: DimensionPortraitResponse interest_areas: InterestAreaDistributionResponse diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 7ec4bc0e..eaa9bfd6 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -9,28 +9,35 @@ """ import datetime import uuid -from typing import Optional, List, Dict, Any, Tuple, Annotated +from typing import Annotated, Any, Dict, List, Optional, Tuple from fastapi import Depends -from sqlalchemy import select, func, or_, and_ +from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import ( - ResourceNotFoundException, BusinessException, + ResourceNotFoundException, ) from app.core.logging_config import get_business_logger from app.core.workflow.validator import WorkflowValidator from app.db import get_db -from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig +from app.models import ( + AgentConfig, + App, + AppRelease, + AppShare, + MultiAgentConfig, + WorkflowConfig, + Workspace, +) from app.models.app_model import AppStatus, AppType from app.repositories.app_repository import get_apps_by_id from app.repositories.workflow_repository import WorkflowConfigRepository from app.schemas import app_schema from app.schemas.workflow_schema import WorkflowConfigUpdate from app.services.agent_config_converter import AgentConfigConverter -from app.models import AppShare, Workspace from app.services.model_service import ModelApiKeyService from app.services.workflow_service import WorkflowService from app.utils.app_config_utils import model_parameters_to_dict @@ -136,9 +143,10 @@ class AppService: return app def _check_workflow_config(self, app_id: uuid.UUID): - from app.models import WorkflowConfig, ModelConfig from sqlalchemy import select + from app.core.exceptions import BusinessException + from app.models import ModelConfig, WorkflowConfig # 2. 获取 Agent 配置 stmt = select(WorkflowConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() @@ -154,9 +162,10 @@ class AppService: raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) def _check_agent_config(self, app_id: uuid.UUID): - from app.models import AgentConfig, ModelConfig from sqlalchemy import select + from app.core.exceptions import BusinessException + from app.models import AgentConfig, ModelConfig # 2. 获取 Agent 配置 stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) agent_cfg = self.db.scalars(stmt).first() @@ -326,10 +335,10 @@ class AppService: """ # 将 Dict 转换为 MultiAgentConfigCreate from app.schemas.multi_agent_schema import ( + ExecutionConfig, MultiAgentConfigCreate, - SubAgentConfig, RoutingRule, - ExecutionConfig + SubAgentConfig, ) # 转换 sub_agents @@ -1167,6 +1176,138 @@ class AppService: return default_config + # ==================== 记忆配置提取方法 ==================== + + def _extract_memory_config_id( + self, + app_type: str, + config: Dict[str, Any] + ) -> Optional[uuid.UUID]: + """从发布配置中提取 memory_config_id(根据应用类型分发) + + Args: + app_type: 应用类型 (agent, workflow, multi_agent) + config: 发布配置字典 + + Returns: + Optional[uuid.UUID]: 提取的 memory_config_id,如果不存在则返回 None + """ + if app_type == AppType.AGENT: + return self._extract_memory_config_id_from_agent(config) + elif app_type == AppType.WORKFLOW: + return self._extract_memory_config_id_from_workflow(config) + elif app_type == AppType.MULTI_AGENT: + # Multi-agent 暂不支持记忆配置提取 + logger.debug(f"多智能体应用暂不支持记忆配置提取: app_type={app_type}") + return None + else: + logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}") + return None + + def _extract_memory_config_id_from_agent( + self, + config: Dict[str, Any] + ) -> Optional[uuid.UUID]: + """从 Agent 应用配置中提取 memory_config_id + + 路径: config.memory.memory_content + + Args: + config: Agent 配置字典 + + Returns: + Optional[uuid.UUID]: 记忆配置ID,如果不存在则返回 None + """ + try: + memory_content = config.get("memory", {}).get("memory_content") + if memory_content: + # 处理字符串和 UUID 两种情况 + if isinstance(memory_content, str): + return uuid.UUID(memory_content) + elif isinstance(memory_content, uuid.UUID): + return memory_content + else: + logger.warning( + f"Agent 配置中 memory_content 格式无效: type={type(memory_content)}, " + f"value={memory_content}" + ) + return None + except (ValueError, TypeError) as e: + logger.warning( + f"Agent 配置中 memory_content 格式无效: error={str(e)}, " + f"memory_content={memory_content}" + ) + return None + + def _extract_memory_config_id_from_workflow( + self, + config: Dict[str, Any] + ) -> Optional[uuid.UUID]: + """从 Workflow 应用配置中提取 memory_config_id + + 扫描工作流节点,查找 MemoryRead 或 MemoryWrite 节点。 + 返回第一个找到的记忆节点的 config_id。 + + Args: + config: Workflow 配置字典 + + Returns: + Optional[uuid.UUID]: 记忆配置ID,如果不存在则返回 None + """ + nodes = config.get("nodes", []) + + for node in nodes: + node_type = node.get("type", "") + + # 检查是否为记忆节点 + if node_type in ["MemoryRead", "MemoryWrite"]: + config_id = node.get("config", {}).get("config_id") + + if config_id: + try: + # 处理字符串和 UUID 两种情况 + if isinstance(config_id, str): + return uuid.UUID(config_id) + elif isinstance(config_id, uuid.UUID): + return config_id + else: + logger.warning( + f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " + f"node_type={node_type}, type={type(config_id)}" + ) + except (ValueError, TypeError) as e: + logger.warning( + f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, " + f"node_type={node_type}, error={str(e)}" + ) + + logger.debug("工作流配置中未找到记忆节点") + return None + + def _update_endusers_memory_config( + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID + ) -> int: + """批量更新应用下所有终端用户的 memory_config_id + + Args: + app_id: 应用ID + memory_config_id: 新的记忆配置ID + + Returns: + int: 更新的终端用户数量 + """ + from app.repositories.end_user_repository import EndUserRepository + + repo = EndUserRepository(self.db) + updated_count = repo.batch_update_memory_config_id( + app_id=app_id, + memory_config_id=memory_config_id + ) + + return updated_count + # ==================== 应用发布管理 ==================== def publish( @@ -1309,6 +1450,15 @@ class AppService: self.db.add(release) self.db.flush() # 先 flush,确保 release 已插入数据库 + # 提取记忆配置ID并更新终端用户 + memory_config_id = self._extract_memory_config_id(app.type, config) + if memory_config_id: + updated_count = self._update_endusers_memory_config(app_id, memory_config_id) + logger.info( + f"发布时更新终端用户记忆配置: app_id={app_id}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + # 更新当前发布版本指针 app.current_release_id = release.id app.status = AppStatus.ACTIVE @@ -1424,6 +1574,15 @@ class AppService: ) raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}") + # 提取记忆配置ID并更新终端用户 + memory_config_id = self._extract_memory_config_id(release.type, release.config) + if memory_config_id: + updated_count = self._update_endusers_memory_config(app_id, memory_config_id) + logger.info( + f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " + f"memory_config_id={memory_config_id}, updated_count={updated_count}" + ) + app.current_release_id = release.id app.updated_at = datetime.datetime.now() @@ -1839,8 +1998,8 @@ class AppService: Returns: Dict: 对比结果 """ - from app.services.draft_run_service import DraftRunService from app.models import ModelConfig + from app.services.draft_run_service import DraftRunService logger.info( "多模型对比试运行", @@ -1938,8 +2097,8 @@ class AppService: Yields: str: SSE 格式的事件数据 """ - from app.services.draft_run_service import DraftRunService from app.models import ModelConfig + from app.services.draft_run_service import DraftRunService logger.info( "多模型对比流式试运行", diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1d6a1cc6..610f804c 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,10 +9,15 @@ import os import re import time import uuid -from uuid import UUID from typing import Any, AsyncGenerator, Dict, List, Optional +from uuid import UUID import redis +from langchain_core.messages import AIMessage, HumanMessage +from pydantic import BaseModel, Field +from sqlalchemy import func +from sqlalchemy.orm import Session + from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph @@ -37,11 +42,6 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) -from langchain_core.messages import AIMessage -from langchain_core.messages import HumanMessage -from pydantic import BaseModel, Field -from sqlalchemy import func -from sqlalchemy.orm import Session try: from app.core.memory.utils.log.audit_logger import audit_logger @@ -732,8 +732,12 @@ class MemoryAgentService: ) # 导入必要的模块 - from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm - from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse + from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( + summary_llm, + ) + from app.core.memory.agent.models.summary_models import ( + RetrieveSummaryResponse, + ) # 构建状态对象 state = { @@ -1144,9 +1148,8 @@ class MemoryAgentService: # LogStreamer uses context manager for file handling, so cleanup is automatic -def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int]: - """ - 快速获取终端用户的 memory_config_id(直接从 end_user 表读取) +def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[uuid.UUID]: + """快速获取终端用户的 memory_config_id(直接从 end_user 表读取)。 如果 end_user 已有缓存的 memory_config_id,直接返回; 否则返回 None,调用方应使用 get_end_user_connected_config 获取完整配置。 @@ -1156,14 +1159,16 @@ def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int db: 数据库会话 Returns: - memory_config_id 或 None + Optional[uuid.UUID]: memory_config_id 或 None """ - from app.models.end_user_model import EndUser + from app.repositories.end_user_repository import get_memory_config_id - end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() - if end_user and end_user.memory_config_id: - return end_user.memory_config_id - return None + try: + end_user_uuid = uuid.UUID(end_user_id) if isinstance(end_user_id, str) else end_user_id + return get_memory_config_id(db, end_user_uuid) + except (ValueError, TypeError) as e: + logger.warning(f"Invalid end_user_id format: {end_user_id}, error: {e}") + return None def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: @@ -1185,9 +1190,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An Raises: ValueError: 当终端用户不存在或应用未发布时 """ + from app.models.app_model import App from app.models.app_release_model import AppRelease from app.models.end_user_model import EndUser - from sqlalchemy import select logger.info(f"Getting connected config for end_user: {end_user_id}") @@ -1200,22 +1205,25 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An app_id = end_user.app_id logger.debug(f"Found end_user app_id: {app_id}") - # 2. 获取该应用的最新发布版本 - stmt = ( - select(AppRelease) - .where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True)) - .order_by(AppRelease.version.desc()) - ) - latest_release = db.scalars(stmt).first() + # 2. 获取应用的当前发布版本(通过 apps.current_release_id) + app = db.query(App).filter(App.id == app_id).first() + if not app: + logger.warning(f"App not found: {app_id}") + raise ValueError(f"应用不存在: {app_id}") - if not latest_release: - logger.warning(f"No active release found for app: {app_id}") + if not app.current_release_id: + logger.warning(f"No current release for app: {app_id}") raise ValueError(f"应用未发布: {app_id}") - logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}") + current_release = db.query(AppRelease).filter(AppRelease.id == app.current_release_id).first() + if not current_release: + logger.warning(f"Current release not found: {app.current_release_id}") + raise ValueError(f"应用发布版本不存在: {app.current_release_id}") + + logger.debug(f"Found current release: version={current_release.version}, id={current_release.id}") # 3. 从 config 中提取 memory_config_id - config = latest_release.config or {} + config = current_release.config or {} # 如果 config 是字符串,解析为字典 if isinstance(config, str): @@ -1223,27 +1231,17 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An try: config = json.loads(config) except json.JSONDecodeError: - logger.warning(f"Failed to parse config JSON for release {latest_release.id}") + logger.warning(f"Failed to parse config JSON for release {current_release.id}") config = {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - # 4. 更新 end_user 的 memory_config_id(懒更新) - if memory_config_id is not None and end_user.memory_config_id != memory_config_id: - try: - end_user.memory_config_id = memory_config_id - db.commit() - logger.debug(f"Updated end_user memory_config_id: {end_user_id} -> {memory_config_id}") - except Exception as e: - db.rollback() - logger.warning(f"Failed to update end_user memory_config_id: {e}") - result = { "end_user_id": str(end_user_id), "app_id": str(app_id), - "release_id": str(latest_release.id), - "release_version": latest_release.version, + "release_id": str(current_release.id), + "release_version": current_release.version, "memory_config_id": memory_config_id } @@ -1272,10 +1270,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) ... } """ + from sqlalchemy import select + from app.models.app_release_model import AppRelease from app.models.end_user_model import EndUser from app.models.memory_config_model import MemoryConfig - from sqlalchemy import select logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index e09cf67f..561bffce 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -6,29 +6,33 @@ This service eliminates code duplication between MemoryAgentService and MemorySt """ import time +import uuid from datetime import datetime -from app.models.memory_config_model import MemoryConfig as MemoryConfigModel +from typing import TYPE_CHECKING, Optional +from uuid import UUID + from sqlalchemy import select +from sqlalchemy.orm import Session + from app.core.logging_config import get_config_logger, get_logger from app.core.validators.memory_config_validators import ( validate_and_resolve_model_id, validate_embedding_model, - validate_model_exists_and_active, ) +from app.models.memory_config_model import MemoryConfig as MemoryConfigModel from app.repositories.memory_config_repository import MemoryConfigRepository from app.schemas.memory_config_schema import ( ConfigurationError, InvalidConfigError, MemoryConfig, - ModelInactiveError, - ModelNotFoundError, ) -from sqlalchemy.orm import Session -from uuid import UUID + +if TYPE_CHECKING: + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel logger = get_logger(__name__) config_logger = get_config_logger() -import uuid + def _validate_config_id(config_id, db: Session = None): """Validate configuration ID format (supports both UUID and integer).""" @@ -320,11 +324,12 @@ class MemoryConfigService: Returns: Dict with model configuration including api_key, base_url, etc. """ + from fastapi import status + from fastapi.exceptions import HTTPException + from app.core.config import settings from app.models.models_model import ModelApiKey from app.services.model_service import ModelConfigService as ModelSvc - from fastapi import status - from fastapi.exceptions import HTTPException config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id) if not config: @@ -353,11 +358,12 @@ class MemoryConfigService: Returns: Dict with embedder configuration including api_key, base_url, etc. """ - from app.models.models_model import ModelApiKey - from app.services.model_service import ModelConfigService as ModelSvc from fastapi import status from fastapi.exceptions import HTTPException + from app.models.models_model import ModelApiKey + from app.services.model_service import ModelConfigService as ModelSvc + config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id) if not config: logger.warning(f"Embedding model ID {embedding_id} not found") @@ -438,3 +444,207 @@ class MemoryConfigService: "pruning_scene": memory_config.pruning_scene, "pruning_threshold": memory_config.pruning_threshold, } + + def _get_workspace_default_config( + self, + workspace_id: UUID + ) -> Optional["MemoryConfigModel"]: + """Get workspace default memory config. + + Returns the config marked as default for the workspace. If no explicit + default exists, falls back to the first active config ordered by creation time. + + Args: + workspace_id: Workspace ID + + Returns: + Optional[MemoryConfigModel]: Default config or None if no configs exist + """ + from sqlalchemy import select + + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + + # First, try to find the explicitly marked default config + stmt = ( + select(MemoryConfigModel) + .where( + MemoryConfigModel.workspace_id == workspace_id, + MemoryConfigModel.is_default.is_(True), + MemoryConfigModel.state.is_(True), + ) + .limit(1) + ) + + config = self.db.scalars(stmt).first() + + if config: + return config + + # Fallback: get the oldest active config if no explicit default + stmt = ( + select(MemoryConfigModel) + .where( + MemoryConfigModel.workspace_id == workspace_id, + MemoryConfigModel.state.is_(True), + ) + .order_by(MemoryConfigModel.created_at.asc()) + .limit(1) + ) + + config = self.db.scalars(stmt).first() + + if not config: + logger.warning( + "No active memory config found for workspace fallback", + extra={"workspace_id": str(workspace_id)} + ) + + return config + + def get_config_with_fallback( + self, + end_user_id: UUID, + workspace_id: UUID + ) -> Optional["MemoryConfigModel"]: + """Get memory config for end user with fallback to workspace default. + + Implements graceful degradation: if the end user's assigned config + doesn't exist, falls back to the workspace's default active config. + + Args: + end_user_id: End user ID + workspace_id: Workspace ID for fallback lookup + + Returns: + Optional[MemoryConfigModel]: Memory config or None if no fallback available + """ + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + from app.repositories.end_user_repository import EndUserRepository + + end_user_repo = EndUserRepository(self.db) + end_user = end_user_repo.get_by_id(end_user_id) + + if not end_user or not end_user.memory_config_id: + logger.debug( + "End user has no memory config assigned", + extra={"end_user_id": str(end_user_id)} + ) + return self._get_workspace_default_config(workspace_id) + + config = self.db.get(MemoryConfigModel, end_user.memory_config_id) + + if config: + return config + + logger.warning( + "Memory config not found, falling back to workspace default", + extra={ + "end_user_id": str(end_user_id), + "missing_config_id": str(end_user.memory_config_id), + "workspace_id": str(workspace_id) + } + ) + + fallback_config = self._get_workspace_default_config(workspace_id) + + if fallback_config: + logger.info( + "Using fallback memory config", + extra={ + "end_user_id": str(end_user_id), + "fallback_config_id": str(fallback_config.config_id) + } + ) + + return fallback_config + + def delete_config( + self, + config_id: UUID, + force: bool = False + ) -> dict: + """Delete memory config with protection against in-use configs. + + Implements delete protection: prevents accidental deletion of configs + that are actively being used by end users or marked as default. + + Args: + config_id: Memory config ID to delete + force: If True, delete even if end users are connected + + Returns: + Dict with status, message, and affected_users count + + Raises: + ResourceNotFoundException: If config doesn't exist + """ + from app.core.exceptions import ResourceNotFoundException + from app.models.memory_config_model import MemoryConfig as MemoryConfigModel + + config = self.db.get(MemoryConfigModel, config_id) + if not config: + raise ResourceNotFoundException("MemoryConfig", str(config_id)) + + # Check if this is the default config - default configs cannot be deleted + if config.is_default: + logger.warning( + "Attempted to delete default memory config", + extra={"config_id": str(config_id)} + ) + return { + "status": "error", + "message": "默认配置不允许删除", + "is_default": True + } + + # TODO: add back delete warning + # # Count connected end users + # end_user_repo = EndUserRepository(self.db) + # connected_count = end_user_repo.count_by_memory_config_id(config_id) + + # if connected_count > 0 and not force: + # logger.warning( + # "Attempted to delete memory config with connected end users", + # extra={ + # "config_id": str(config_id), + # "connected_count": connected_count + # } + # ) + + # return { + # "status": "warning", + # "message": f"Cannot delete memory config: {connected_count} end users are using it", + # "connected_count": connected_count, + # "force_required": True + # } + + # # Force delete: clear end user references first + # if connected_count > 0 and force: + # cleared_count = end_user_repo.clear_memory_config_id(config_id) + + # logger.warning( + # "Force deleting memory config", + # extra={ + # "config_id": str(config_id), + # "cleared_end_users": cleared_count + # } + # ) + connected_count = 0 + + self.db.delete(config) + self.db.commit() + + logger.info( + "Memory config deleted", + extra={ + "config_id": str(config_id), + "force": force, + "affected_users": connected_count + } + ) + + return { + "status": "success", + "message": "Memory config deleted successfully", + "affected_users": connected_count + } diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 34ce0610..eb0fe46b 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -5,32 +5,36 @@ import uuid from os import getenv from typing import List, Optional +from sqlalchemy.orm import Session + from app.core.config import settings from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, PermissionDeniedException from app.core.logging_config import get_business_logger from app.models.user_model import User - -from app.schemas.workspace_schema import ( - WorkspaceModelsUpdate, +from app.models.workspace_model import ( + InviteStatus, + Workspace, + WorkspaceMember, + WorkspaceRole, ) -from sqlalchemy.orm import Session -from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember from app.repositories import workspace_repository from app.repositories.workspace_invite_repository import WorkspaceInviteRepository from app.schemas.workspace_schema import ( + InviteAcceptRequest, + InviteValidateResponse, WorkspaceCreate, - WorkspaceUpdate, WorkspaceInviteCreate, WorkspaceInviteResponse, - InviteValidateResponse, - InviteAcceptRequest, - WorkspaceMemberUpdate + WorkspaceMemberUpdate, + WorkspaceModelsUpdate, + WorkspaceUpdate, ) # 获取业务逻辑专用日志器 business_logger = get_business_logger() from dotenv import load_dotenv + load_dotenv() def switch_workspace( db: Session, @@ -127,6 +131,27 @@ def create_workspace( db.commit() db.refresh(db_workspace) + # Create default memory config for the workspace (only for neo4j storage types) + if workspace.storage_type == 'neo4j': + try: + _create_default_memory_config( + db=db, + workspace_id=db_workspace.id, + workspace_name=db_workspace.name, + llm_id=llm, + embedding_id=embedding, + rerank_id=rerank, + ) + business_logger.info( + f"为工作空间 {db_workspace.id} 创建默认记忆配置成功" + ) + except Exception as mc_error: + business_logger.error( + f"为工作空间 {db_workspace.id} 创建默认记忆配置失败: {str(mc_error)}" + ) + # Don't fail workspace creation if memory config creation fails + # The workspace can still function without a default memory config + # 如果 storage_type 是 "rag",自动创建知识库 if workspace.storage_type == "rag": business_logger.info( @@ -286,7 +311,7 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use ) # 使用统一权限服务检查访问权限 - from app.core.permissions import permission_service, Subject, Resource, Action + from app.core.permissions import Action, Resource, Subject, permission_service # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( @@ -324,7 +349,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user ) # 使用统一权限服务检查管理权限 - from app.core.permissions import permission_service, Subject, Resource, Action + from app.core.permissions import Action, Resource, Subject, permission_service # 获取用户的工作空间成员关系 member = workspace_repository.get_member_in_workspace( @@ -852,3 +877,50 @@ def update_workspace_models_configs( business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}") db.rollback() raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR) + + +def _create_default_memory_config( + db: Session, + workspace_id: uuid.UUID, + workspace_name: str, + llm_id: Optional[uuid.UUID] = None, + embedding_id: Optional[uuid.UUID] = None, + rerank_id: Optional[uuid.UUID] = None, +) -> None: + """Create a default memory config for a newly created workspace. + + Args: + db: Database session + workspace_id: The workspace ID + workspace_name: The workspace name (used for config naming) + llm_id: Optional LLM model ID + embedding_id: Optional embedding model ID + rerank_id: Optional rerank model ID + """ + from app.models.memory_config_model import MemoryConfig + + config_id = uuid.uuid4() + + default_config = MemoryConfig( + config_id=config_id, + config_name=f"{workspace_name} 默认配置", + config_desc="工作空间创建时自动生成的默认记忆配置", + workspace_id=workspace_id, + llm_id=str(llm_id) if llm_id else None, + embedding_id=str(embedding_id) if embedding_id else None, + rerank_id=str(rerank_id) if rerank_id else None, + state=True, # Active by default + is_default=True, # Mark as workspace default + ) + + db.add(default_config) + db.commit() + + business_logger.info( + "Created default memory config for workspace", + extra={ + "workspace_id": str(workspace_id), + "config_id": str(config_id), + "config_name": default_config.config_name, + } + )