feat(memory): add protected memory config deletion with end-user safeguards

- Add force parameter to delete_config endpoint for controlled deletion of in-use configs
- Implement MemoryConfigService.delete_config with protection against deleting default configs
- Add validation to prevent deletion of configs with connected end-users unless force=True
- Reorganize controller imports to remove duplicates and improve maintainability
- Clean up unused database connection management code from memory_storage_controller
- Add detailed docstring to delete_config endpoint explaining protection mechanisms
- Update error handling with specific BizCode.RESOURCE_IN_USE for configs in active use
- Add comprehensive logging for deletion attempts, warnings, and affected users
- Refactor ConfigParamsDelete schema usage to use MemoryConfigService directly
- Improve API response structure with affected_users count and force_required flag
This commit is contained in:
Ke Sun
2026-01-28 12:02:35 +08:00
parent d9fa9039bb
commit 42b59a644d
13 changed files with 823 additions and 188 deletions

View File

@@ -24,9 +24,11 @@ from . import (
memory_episodic_controller,
memory_explicit_controller,
memory_forget_controller,
memory_perceptual_controller,
memory_reflection_controller,
memory_short_term_controller,
memory_storage_controller,
memory_working_controller,
model_controller,
multi_agent_controller,
prompt_optimizer_controller,
@@ -41,10 +43,6 @@ from . import (
user_memory_controllers,
workflow_controller,
workspace_controller,
memory_forget_controller,
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
)
# 创建管理端 API 路由器

View File

@@ -1,7 +1,10 @@
import os
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
@@ -11,7 +14,6 @@ from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
@@ -31,9 +33,6 @@ from app.services.memory_storage_service import (
search_entity,
search_statement,
)
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
# Get API logger
api_logger = get_api_logger()
@@ -70,68 +69,9 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
# --- DB connection dependency ---
_CONN: Optional[object] = None
"""PostgreSQL 连接生成与管理(使用 psycopg2"""
# 这个可以转移,可能是已经有的
# PostgreSQL 数据库连接
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
host = os.getenv("DB_HOST")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host or "localhost",
port=port,
user=user,
password=password,
dbname=database,
)
# 设置自动提交,避免显式事务管理
conn.autocommit = True
# 设置会话时区为中国标准时间Asia/Shanghai便于直接以本地时区展示
try:
cur = conn.cursor()
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
cur.close()
except Exception:
# 时区设置失败不影响连接,仅记录但不抛出
pass
return conn
except Exception as e:
try:
print(f"[PostgreSQL] 连接失败: {e}")
except Exception:
pass
return None
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
global _CONN
if _CONN is None:
_CONN = _make_pgsql_conn()
return _CONN
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
"""Close and recreate the global DB connection."""
global _CONN
try:
if _CONN:
try:
_CONN.close()
except Exception:
pass
_CONN = _make_pgsql_conn()
return _CONN is not None
except Exception:
_CONN = None
return False
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@@ -139,7 +79,7 @@ def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -162,9 +102,20 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: UUID,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
"""删除记忆配置(带终端用户保护)
- 检查是否为默认配置,默认配置不允许删除
- 检查是否有终端用户连接到该配置
- 如果有连接且 force=False返回警告
- 如果 force=True清除终端用户引用后删除配置
Query Parameters:
force: 设置为 true 可强制删除(即使有终端用户正在使用)
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -172,21 +123,62 @@ def delete_config(
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}"
)
try:
svc = DataConfigService(db)
result = svc.delete(ConfigParamsDelete(config_id=config_id))
return success(data=result, msg="删除成功")
# 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService
config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force)
if result["status"] == "error":
api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
)
return fail(
code=BizCode.FORBIDDEN,
msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
)
if result["status"] == "warning":
api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, "
f"connected_count={result['connected_count']}"
)
return fail(
code=BizCode.RESOURCE_IN_USE,
msg=result["message"],
data={
"connected_count": result["connected_count"],
"force_required": result["force_required"]
}
)
api_logger.info(
f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}"
)
return success(
msg=result["message"],
data={"affected_users": result["affected_users"]}
)
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}")
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
payload: ConfigUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -209,7 +201,7 @@ def update_config_extracted(
payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -236,7 +228,7 @@ def read_config_extracted(
config_id: UUID | int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -257,7 +249,7 @@ def read_config_extracted(
def read_all_config(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -297,9 +289,8 @@ async def pilot_run(
},
)
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。
"""
# ==================== Search & Analytics ====================
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
@@ -439,8 +430,9 @@ async def get_hot_memory_tags_api(
try:
# 尝试从Redis缓存获取
from app.aioRedis import aio_redis_get, aio_redis_set
import json
from app.aioRedis import aio_redis_get, aio_redis_set
cached_result = await aio_redis_get(cache_key)
if cached_result:

View File

@@ -46,6 +46,7 @@ class BizCode(IntEnum):
RESOURCE_ALREADY_EXISTS = 5002
VERSION_ALREADY_EXISTS = 5003
STATE_CONFLICT = 5004
RESOURCE_IN_USE = 5005
# 应用发布6xxx
PUBLISH_FAILED = 6001
@@ -125,6 +126,7 @@ HTTP_MAPPING = {
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.RESOURCE_IN_USE: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,

View File

@@ -108,7 +108,6 @@ class DimensionAnalyzer:
# Create dimension portrait
portrait = DimensionPortrait(
user_id=user_id,
creativity=dimension_scores["creativity"],
aesthetic=dimension_scores["aesthetic"],
technology=dimension_scores["technology"],
@@ -220,7 +219,7 @@ class DimensionAnalyzer:
"""Create an empty dimension portrait when no data is available.
Args:
user_id: Target user ID
user_id: Target user ID (used for logging only)
Returns:
Empty DimensionPortrait
@@ -228,7 +227,6 @@ class DimensionAnalyzer:
current_time = datetime.now()
return DimensionPortrait(
user_id=user_id,
creativity=self._create_default_dimension_score("creativity"),
aesthetic=self._create_default_dimension_score("aesthetic"),
technology=self._create_default_dimension_score("technology"),

View File

@@ -7,7 +7,7 @@ providing percentage distribution that totals 100%.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
@@ -133,7 +133,6 @@ class InterestAnalyzer:
# Create interest area distribution
distribution = InterestAreaDistribution(
user_id=user_id,
tech=interest_categories["tech"],
lifestyle=interest_categories["lifestyle"],
music=interest_categories["music"],
@@ -251,7 +250,7 @@ class InterestAnalyzer:
"""Create an empty interest distribution when no data is available.
Args:
user_id: Target user ID
user_id: Target user ID (used for logging only)
Returns:
Empty InterestAreaDistribution with equal percentages
@@ -259,15 +258,15 @@ class InterestAnalyzer:
current_time = datetime.now()
equal_percentage = 25.0 # 100% / 4 categories
default_category = lambda name: InterestCategory(
category_name=name,
percentage=equal_percentage,
evidence=["Insufficient data for analysis"],
trending_direction=None
)
def default_category(name: str) -> InterestCategory:
return InterestCategory(
category_name=name,
percentage=equal_percentage,
evidence=["Insufficient data for analysis"],
trending_direction=None
)
return InterestAreaDistribution(
user_id=user_id,
tech=default_category("tech"),
lifestyle=default_category("lifestyle"),
music=default_category("music"),

View File

@@ -1,11 +1,12 @@
import datetime
import uuid
from app.db import Base
from sqlalchemy import BigInteger, Column, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy import Column, DateTime, ForeignKey, String, Text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from app.db import Base
class EndUser(Base):
__tablename__ = "end_users"
@@ -21,7 +22,13 @@ class EndUser(Base):
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
# Memory config association - updated lazily during conversation
memory_config_id = Column(Integer, ForeignKey("data_config.config_id"), nullable=True, index=True, comment="关联的记忆配置ID")
memory_config_id = Column(
UUID(as_uuid=True),
ForeignKey("memory_config.config_id"),
nullable=True,
index=True,
comment="关联的记忆配置ID"
)
# 用户基本信息字段
position = Column(String, nullable=True, comment="职位")

View File

@@ -1,6 +1,8 @@
import datetime
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
@@ -38,6 +40,7 @@ class MemoryConfig(Base):
# 状态配置
state = Column(Boolean, default=False, comment="配置使用状态")
is_default = Column(Boolean, default=False, comment="是否为工作空间默认配置")
# 分块策略
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")

View File

@@ -1,13 +1,13 @@
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
import datetime
import uuid
from typing import List, Optional
from app.models.end_user_model import EndUser
from app.models.app_model import App
from app.models.workspace_model import Workspace
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
# 获取数据库专用日志器
db_logger = get_db_logger()
@@ -264,6 +264,179 @@ class EndUserRepository:
db_logger.error(f"查询活动工作空间时出错: {str(e)}")
raise
def update_memory_config_id(self, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool:
"""更新终端用户的 memory_config_id懒更新
Args:
end_user_id: 终端用户ID
memory_config_id: 记忆配置ID
Returns:
bool: 更新成功返回True否则返回False
"""
try:
updated_count = (
self.db.query(EndUser)
.filter(EndUser.id == end_user_id)
.update(
{EndUser.memory_config_id: memory_config_id},
synchronize_session=False
)
)
self.db.commit()
if updated_count > 0:
db_logger.debug(f"成功更新终端用户 {end_user_id} 的 memory_config_id: {memory_config_id}")
return True
else:
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新 memory_config_id")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"更新终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
raise
def get_memory_config_id(self, end_user_id: uuid.UUID) -> Optional[uuid.UUID]:
"""获取终端用户的 memory_config_id。
Args:
end_user_id: 终端用户ID
Returns:
Optional[uuid.UUID]: memory_config_id 或 None
"""
try:
end_user = (
self.db.query(EndUser)
.filter(EndUser.id == end_user_id)
.first()
)
if end_user and end_user.memory_config_id:
db_logger.debug(f"获取终端用户 {end_user_id} 的 memory_config_id: {end_user.memory_config_id}")
return end_user.memory_config_id
return None
except Exception as e:
self.db.rollback()
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
raise
def batch_update_memory_config_id(
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
int: 更新的行数
"""
try:
from sqlalchemy import update
stmt = (
update(EndUser)
.where(EndUser.app_id == app_id)
.values(memory_config_id=memory_config_id)
)
result = self.db.execute(stmt)
self.db.commit()
updated_count = result.rowcount
db_logger.info(
f"批量更新终端用户记忆配置: app_id={app_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
return updated_count
except Exception as e:
self.db.rollback()
db_logger.error(
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
f"memory_config_id={memory_config_id}, error={str(e)}"
)
raise
def count_by_memory_config_id(
self,
memory_config_id: uuid.UUID
) -> int:
"""统计使用指定记忆配置的终端用户数量
Args:
memory_config_id: 记忆配置ID
Returns:
int: 使用该配置的终端用户数量
"""
try:
from sqlalchemy import func, select
stmt = (
select(func.count(EndUser.id))
.where(EndUser.memory_config_id == memory_config_id)
)
count = self.db.execute(stmt).scalar() or 0
db_logger.debug(f"统计记忆配置使用数: memory_config_id={memory_config_id}, count={count}")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"统计记忆配置使用数时出错: memory_config_id={memory_config_id}, error={str(e)}")
raise
def clear_memory_config_id(
self,
memory_config_id: uuid.UUID
) -> int:
"""清除所有使用指定记忆配置的终端用户的 memory_config_id
将 memory_config_id 设置为 NULL
Args:
memory_config_id: 要清除的记忆配置ID
Returns:
int: 清除的行数
"""
try:
from sqlalchemy import update
stmt = (
update(EndUser)
.where(EndUser.memory_config_id == memory_config_id)
.values(memory_config_id=None)
)
result = self.db.execute(stmt)
self.db.commit()
cleared_count = result.rowcount
db_logger.warning(
f"清除终端用户记忆配置引用: memory_config_id={memory_config_id}, "
f"cleared_count={cleared_count}"
)
return cleared_count
except Exception as e:
self.db.rollback()
db_logger.error(
f"清除终端用户记忆配置引用时出错: memory_config_id={memory_config_id}, "
f"error={str(e)}"
)
raise
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
"""根据应用ID查询宿主返回 EndUser ORM 列表)"""
repo = EndUserRepository(db)
@@ -315,3 +488,32 @@ def get_all_active_workspaces(db: Session) -> List[uuid.UUID]:
"""获取所有活动工作空间的ID"""
repo = EndUserRepository(db)
return repo.get_all_active_workspaces()
def update_memory_config_id(db: Session, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool:
"""更新终端用户的 memory_config_id懒更新
Args:
db: 数据库会话
end_user_id: 终端用户ID
memory_config_id: 记忆配置ID
Returns:
bool: 更新成功返回True否则返回False
"""
repo = EndUserRepository(db)
return repo.update_memory_config_id(end_user_id, memory_config_id)
def get_memory_config_id(db: Session, end_user_id: uuid.UUID) -> Optional[uuid.UUID]:
"""获取终端用户的 memory_config_id。
Args:
db: 数据库会话
end_user_id: 终端用户ID
Returns:
Optional[uuid.UUID]: memory_config_id 或 None
"""
repo = EndUserRepository(db)
return repo.get_memory_config_id(end_user_id)

View File

@@ -112,7 +112,6 @@ class DimensionPortraitResponse(BaseModel):
"""Four-dimension personality portrait."""
model_config = ConfigDict(from_attributes=True)
user_id: str
creativity: DimensionScoreResponse
aesthetic: DimensionScoreResponse
technology: DimensionScoreResponse
@@ -140,7 +139,6 @@ class InterestAreaDistributionResponse(BaseModel):
"""Distribution of user interests across four areas."""
model_config = ConfigDict(from_attributes=True)
user_id: str
tech: InterestCategoryResponse
lifestyle: InterestCategoryResponse
music: InterestCategoryResponse
@@ -184,7 +182,6 @@ class UserProfileResponse(BaseModel):
"""Comprehensive user profile."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preference_tags: List[PreferenceTagResponse]
dimension_portrait: DimensionPortraitResponse
interest_area_distribution: InterestAreaDistributionResponse
@@ -226,7 +223,6 @@ class UserMemorySummary(BaseModel):
model_config = ConfigDict(from_attributes=True)
summary_id: str
user_id: str
user_content: str
timestamp: datetime.datetime
confidence_score: float = Field(ge=0.0, le=1.0)
@@ -241,7 +237,6 @@ class SummaryAnalysisResult(BaseModel):
"""Result of analyzing memory summaries."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preferences: List[PreferenceTagResponse]
dimension_evidence: Dict[str, List[str]]
interest_evidence: Dict[str, List[str]]
@@ -273,7 +268,6 @@ class GenerateProfileRequest(BaseModel):
class CompleteProfileResponse(BaseModel):
"""完整用户画像响应(包含所有模块)"""
user_id: str
preferences: List[PreferenceTagResponse]
portrait: DimensionPortraitResponse
interest_areas: InterestAreaDistributionResponse

View File

@@ -9,28 +9,35 @@
"""
import datetime
import uuid
from typing import Optional, List, Dict, Any, Tuple, Annotated
from typing import Annotated, Any, Dict, List, Optional, Tuple
from fastapi import Depends
from sqlalchemy import select, func, or_, and_
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import (
ResourceNotFoundException,
BusinessException,
ResourceNotFoundException,
)
from app.core.logging_config import get_business_logger
from app.core.workflow.validator import WorkflowValidator
from app.db import get_db
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
from app.models import (
AgentConfig,
App,
AppRelease,
AppShare,
MultiAgentConfig,
WorkflowConfig,
Workspace,
)
from app.models.app_model import AppStatus, AppType
from app.repositories.app_repository import get_apps_by_id
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import app_schema
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services.agent_config_converter import AgentConfigConverter
from app.models import AppShare, Workspace
from app.services.model_service import ModelApiKeyService
from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import model_parameters_to_dict
@@ -136,9 +143,10 @@ class AppService:
return app
def _check_workflow_config(self, app_id: uuid.UUID):
from app.models import WorkflowConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.models import ModelConfig, WorkflowConfig
# 2. 获取 Agent 配置
stmt = select(WorkflowConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
@@ -154,9 +162,10 @@ class AppService:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
def _check_agent_config(self, app_id: uuid.UUID):
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.models import AgentConfig, ModelConfig
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
@@ -326,10 +335,10 @@ class AppService:
"""
# 将 Dict 转换为 MultiAgentConfigCreate
from app.schemas.multi_agent_schema import (
ExecutionConfig,
MultiAgentConfigCreate,
SubAgentConfig,
RoutingRule,
ExecutionConfig
SubAgentConfig,
)
# 转换 sub_agents
@@ -1167,6 +1176,138 @@ class AppService:
return default_config
# ==================== 记忆配置提取方法 ====================
def _extract_memory_config_id(
self,
app_type: str,
config: Dict[str, Any]
) -> Optional[uuid.UUID]:
"""从发布配置中提取 memory_config_id根据应用类型分发
Args:
app_type: 应用类型 (agent, workflow, multi_agent)
config: 发布配置字典
Returns:
Optional[uuid.UUID]: 提取的 memory_config_id如果不存在则返回 None
"""
if app_type == AppType.AGENT:
return self._extract_memory_config_id_from_agent(config)
elif app_type == AppType.WORKFLOW:
return self._extract_memory_config_id_from_workflow(config)
elif app_type == AppType.MULTI_AGENT:
# Multi-agent 暂不支持记忆配置提取
logger.debug(f"多智能体应用暂不支持记忆配置提取: app_type={app_type}")
return None
else:
logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}")
return None
def _extract_memory_config_id_from_agent(
self,
config: Dict[str, Any]
) -> Optional[uuid.UUID]:
"""从 Agent 应用配置中提取 memory_config_id
路径: config.memory.memory_content
Args:
config: Agent 配置字典
Returns:
Optional[uuid.UUID]: 记忆配置ID如果不存在则返回 None
"""
try:
memory_content = config.get("memory", {}).get("memory_content")
if memory_content:
# 处理字符串和 UUID 两种情况
if isinstance(memory_content, str):
return uuid.UUID(memory_content)
elif isinstance(memory_content, uuid.UUID):
return memory_content
else:
logger.warning(
f"Agent 配置中 memory_content 格式无效: type={type(memory_content)}, "
f"value={memory_content}"
)
return None
except (ValueError, TypeError) as e:
logger.warning(
f"Agent 配置中 memory_content 格式无效: error={str(e)}, "
f"memory_content={memory_content}"
)
return None
def _extract_memory_config_id_from_workflow(
self,
config: Dict[str, Any]
) -> Optional[uuid.UUID]:
"""从 Workflow 应用配置中提取 memory_config_id
扫描工作流节点,查找 MemoryRead 或 MemoryWrite 节点。
返回第一个找到的记忆节点的 config_id。
Args:
config: Workflow 配置字典
Returns:
Optional[uuid.UUID]: 记忆配置ID如果不存在则返回 None
"""
nodes = config.get("nodes", [])
for node in nodes:
node_type = node.get("type", "")
# 检查是否为记忆节点
if node_type in ["MemoryRead", "MemoryWrite"]:
config_id = node.get("config", {}).get("config_id")
if config_id:
try:
# 处理字符串和 UUID 两种情况
if isinstance(config_id, str):
return uuid.UUID(config_id)
elif isinstance(config_id, uuid.UUID):
return config_id
else:
logger.warning(
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
f"node_type={node_type}, type={type(config_id)}"
)
except (ValueError, TypeError) as e:
logger.warning(
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
f"node_type={node_type}, error={str(e)}"
)
logger.debug("工作流配置中未找到记忆节点")
return None
def _update_endusers_memory_config(
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
int: 更新的终端用户数量
"""
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id(
app_id=app_id,
memory_config_id=memory_config_id
)
return updated_count
# ==================== 应用发布管理 ====================
def publish(
@@ -1309,6 +1450,15 @@ class AppService:
self.db.add(release)
self.db.flush() # 先 flush确保 release 已插入数据库
# 提取记忆配置ID并更新终端用户
memory_config_id = self._extract_memory_config_id(app.type, config)
if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
# 更新当前发布版本指针
app.current_release_id = release.id
app.status = AppStatus.ACTIVE
@@ -1424,6 +1574,15 @@ class AppService:
)
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
# 提取记忆配置ID并更新终端用户
memory_config_id = self._extract_memory_config_id(release.type, release.config)
if memory_config_id:
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
logger.info(
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
app.current_release_id = release.id
app.updated_at = datetime.datetime.now()
@@ -1839,8 +1998,8 @@ class AppService:
Returns:
Dict: 对比结果
"""
from app.services.draft_run_service import DraftRunService
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比试运行",
@@ -1938,8 +2097,8 @@ class AppService:
Yields:
str: SSE 格式的事件数据
"""
from app.services.draft_run_service import DraftRunService
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比流式试运行",

View File

@@ -9,10 +9,15 @@ import os
import re
import time
import uuid
from uuid import UUID
from typing import Any, AsyncGenerator, Dict, List, Optional
from uuid import UUID
import redis
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
@@ -37,11 +42,6 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
try:
from app.core.memory.utils.log.audit_logger import audit_logger
@@ -732,8 +732,12 @@ class MemoryAgentService:
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
)
# 构建状态对象
state = {
@@ -1144,9 +1148,8 @@ class MemoryAgentService:
# LogStreamer uses context manager for file handling, so cleanup is automatic
def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int]:
"""
快速获取终端用户的 memory_config_id直接从 end_user 表读取)
def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[uuid.UUID]:
"""快速获取终端用户的 memory_config_id直接从 end_user 表读取)。
如果 end_user 已有缓存的 memory_config_id直接返回
否则返回 None调用方应使用 get_end_user_connected_config 获取完整配置。
@@ -1156,14 +1159,16 @@ def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int
db: 数据库会话
Returns:
memory_config_id 或 None
Optional[uuid.UUID]: memory_config_id 或 None
"""
from app.models.end_user_model import EndUser
from app.repositories.end_user_repository import get_memory_config_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if end_user and end_user.memory_config_id:
return end_user.memory_config_id
return None
try:
end_user_uuid = uuid.UUID(end_user_id) if isinstance(end_user_id, str) else end_user_id
return get_memory_config_id(db, end_user_uuid)
except (ValueError, TypeError) as e:
logger.warning(f"Invalid end_user_id format: {end_user_id}, error: {e}")
return None
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
@@ -1185,9 +1190,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_model import App
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
@@ -1200,22 +1205,25 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
# 2. 获取应用的当前发布版本(通过 apps.current_release_id
app = db.query(App).filter(App.id == app_id).first()
if not app:
logger.warning(f"App not found: {app_id}")
raise ValueError(f"应用不存在: {app_id}")
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
if not app.current_release_id:
logger.warning(f"No current release for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
current_release = db.query(AppRelease).filter(AppRelease.id == app.current_release_id).first()
if not current_release:
logger.warning(f"Current release not found: {app.current_release_id}")
raise ValueError(f"应用发布版本不存在: {app.current_release_id}")
logger.debug(f"Found current release: version={current_release.version}, id={current_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
config = current_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
@@ -1223,27 +1231,17 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
try:
config = json.loads(config)
except json.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
logger.warning(f"Failed to parse config JSON for release {current_release.id}")
config = {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
# 4. 更新 end_user 的 memory_config_id懒更新
if memory_config_id is not None and end_user.memory_config_id != memory_config_id:
try:
end_user.memory_config_id = memory_config_id
db.commit()
logger.debug(f"Updated end_user memory_config_id: {end_user_id} -> {memory_config_id}")
except Exception as e:
db.rollback()
logger.warning(f"Failed to update end_user memory_config_id: {e}")
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"release_id": str(current_release.id),
"release_version": current_release.version,
"memory_config_id": memory_config_id
}
@@ -1272,10 +1270,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
...
}
"""
from sqlalchemy import select
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.memory_config_model import MemoryConfig
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")

View File

@@ -6,29 +6,33 @@ This service eliminates code duplication between MemoryAgentService and MemorySt
"""
import time
import uuid
from datetime import datetime
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from typing import TYPE_CHECKING, Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_model_exists_and_active,
)
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
from uuid import UUID
if TYPE_CHECKING:
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
logger = get_logger(__name__)
config_logger = get_config_logger()
import uuid
def _validate_config_id(config_id, db: Session = None):
"""Validate configuration ID format (supports both UUID and integer)."""
@@ -320,11 +324,12 @@ class MemoryConfigService:
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from fastapi import status
from fastapi.exceptions import HTTPException
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
@@ -353,11 +358,12 @@ class MemoryConfigService:
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
@@ -438,3 +444,207 @@ class MemoryConfigService:
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}
def _get_workspace_default_config(
self,
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get workspace default memory config.
Returns the config marked as default for the workspace. If no explicit
default exists, falls back to the first active config ordered by creation time.
Args:
workspace_id: Workspace ID
Returns:
Optional[MemoryConfigModel]: Default config or None if no configs exist
"""
from sqlalchemy import select
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
# First, try to find the explicitly marked default config
stmt = (
select(MemoryConfigModel)
.where(
MemoryConfigModel.workspace_id == workspace_id,
MemoryConfigModel.is_default.is_(True),
MemoryConfigModel.state.is_(True),
)
.limit(1)
)
config = self.db.scalars(stmt).first()
if config:
return config
# Fallback: get the oldest active config if no explicit default
stmt = (
select(MemoryConfigModel)
.where(
MemoryConfigModel.workspace_id == workspace_id,
MemoryConfigModel.state.is_(True),
)
.order_by(MemoryConfigModel.created_at.asc())
.limit(1)
)
config = self.db.scalars(stmt).first()
if not config:
logger.warning(
"No active memory config found for workspace fallback",
extra={"workspace_id": str(workspace_id)}
)
return config
def get_config_with_fallback(
self,
end_user_id: UUID,
workspace_id: UUID
) -> Optional["MemoryConfigModel"]:
"""Get memory config for end user with fallback to workspace default.
Implements graceful degradation: if the end user's assigned config
doesn't exist, falls back to the workspace's default active config.
Args:
end_user_id: End user ID
workspace_id: Workspace ID for fallback lookup
Returns:
Optional[MemoryConfigModel]: Memory config or None if no fallback available
"""
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(self.db)
end_user = end_user_repo.get_by_id(end_user_id)
if not end_user or not end_user.memory_config_id:
logger.debug(
"End user has no memory config assigned",
extra={"end_user_id": str(end_user_id)}
)
return self._get_workspace_default_config(workspace_id)
config = self.db.get(MemoryConfigModel, end_user.memory_config_id)
if config:
return config
logger.warning(
"Memory config not found, falling back to workspace default",
extra={
"end_user_id": str(end_user_id),
"missing_config_id": str(end_user.memory_config_id),
"workspace_id": str(workspace_id)
}
)
fallback_config = self._get_workspace_default_config(workspace_id)
if fallback_config:
logger.info(
"Using fallback memory config",
extra={
"end_user_id": str(end_user_id),
"fallback_config_id": str(fallback_config.config_id)
}
)
return fallback_config
def delete_config(
self,
config_id: UUID,
force: bool = False
) -> dict:
"""Delete memory config with protection against in-use configs.
Implements delete protection: prevents accidental deletion of configs
that are actively being used by end users or marked as default.
Args:
config_id: Memory config ID to delete
force: If True, delete even if end users are connected
Returns:
Dict with status, message, and affected_users count
Raises:
ResourceNotFoundException: If config doesn't exist
"""
from app.core.exceptions import ResourceNotFoundException
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
config = self.db.get(MemoryConfigModel, config_id)
if not config:
raise ResourceNotFoundException("MemoryConfig", str(config_id))
# Check if this is the default config - default configs cannot be deleted
if config.is_default:
logger.warning(
"Attempted to delete default memory config",
extra={"config_id": str(config_id)}
)
return {
"status": "error",
"message": "默认配置不允许删除",
"is_default": True
}
# TODO: add back delete warning
# # Count connected end users
# end_user_repo = EndUserRepository(self.db)
# connected_count = end_user_repo.count_by_memory_config_id(config_id)
# if connected_count > 0 and not force:
# logger.warning(
# "Attempted to delete memory config with connected end users",
# extra={
# "config_id": str(config_id),
# "connected_count": connected_count
# }
# )
# return {
# "status": "warning",
# "message": f"Cannot delete memory config: {connected_count} end users are using it",
# "connected_count": connected_count,
# "force_required": True
# }
# # Force delete: clear end user references first
# if connected_count > 0 and force:
# cleared_count = end_user_repo.clear_memory_config_id(config_id)
# logger.warning(
# "Force deleting memory config",
# extra={
# "config_id": str(config_id),
# "cleared_end_users": cleared_count
# }
# )
connected_count = 0
self.db.delete(config)
self.db.commit()
logger.info(
"Memory config deleted",
extra={
"config_id": str(config_id),
"force": force,
"affected_users": connected_count
}
)
return {
"status": "success",
"message": "Memory config deleted successfully",
"affected_users": connected_count
}

View File

@@ -5,32 +5,36 @@ import uuid
from os import getenv
from typing import List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.core.logging_config import get_business_logger
from app.models.user_model import User
from app.schemas.workspace_schema import (
WorkspaceModelsUpdate,
from app.models.workspace_model import (
InviteStatus,
Workspace,
WorkspaceMember,
WorkspaceRole,
)
from sqlalchemy.orm import Session
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.schemas.workspace_schema import (
InviteAcceptRequest,
InviteValidateResponse,
WorkspaceCreate,
WorkspaceUpdate,
WorkspaceInviteCreate,
WorkspaceInviteResponse,
InviteValidateResponse,
InviteAcceptRequest,
WorkspaceMemberUpdate
WorkspaceMemberUpdate,
WorkspaceModelsUpdate,
WorkspaceUpdate,
)
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
from dotenv import load_dotenv
load_dotenv()
def switch_workspace(
db: Session,
@@ -127,6 +131,27 @@ def create_workspace(
db.commit()
db.refresh(db_workspace)
# Create default memory config for the workspace (only for neo4j storage types)
if workspace.storage_type == 'neo4j':
try:
_create_default_memory_config(
db=db,
workspace_id=db_workspace.id,
workspace_name=db_workspace.name,
llm_id=llm,
embedding_id=embedding,
rerank_id=rerank,
)
business_logger.info(
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功"
)
except Exception as mc_error:
business_logger.error(
f"为工作空间 {db_workspace.id} 创建默认记忆配置失败: {str(mc_error)}"
)
# Don't fail workspace creation if memory config creation fails
# The workspace can still function without a default memory config
# 如果 storage_type 是 "rag",自动创建知识库
if workspace.storage_type == "rag":
business_logger.info(
@@ -286,7 +311,7 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
)
# 使用统一权限服务检查访问权限
from app.core.permissions import permission_service, Subject, Resource, Action
from app.core.permissions import Action, Resource, Subject, permission_service
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
@@ -324,7 +349,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
)
# 使用统一权限服务检查管理权限
from app.core.permissions import permission_service, Subject, Resource, Action
from app.core.permissions import Action, Resource, Subject, permission_service
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
@@ -852,3 +877,50 @@ def update_workspace_models_configs(
business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}")
db.rollback()
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
def _create_default_memory_config(
db: Session,
workspace_id: uuid.UUID,
workspace_name: str,
llm_id: Optional[uuid.UUID] = None,
embedding_id: Optional[uuid.UUID] = None,
rerank_id: Optional[uuid.UUID] = None,
) -> None:
"""Create a default memory config for a newly created workspace.
Args:
db: Database session
workspace_id: The workspace ID
workspace_name: The workspace name (used for config naming)
llm_id: Optional LLM model ID
embedding_id: Optional embedding model ID
rerank_id: Optional rerank model ID
"""
from app.models.memory_config_model import MemoryConfig
config_id = uuid.uuid4()
default_config = MemoryConfig(
config_id=config_id,
config_name=f"{workspace_name} 默认配置",
config_desc="工作空间创建时自动生成的默认记忆配置",
workspace_id=workspace_id,
llm_id=str(llm_id) if llm_id else None,
embedding_id=str(embedding_id) if embedding_id else None,
rerank_id=str(rerank_id) if rerank_id else None,
state=True, # Active by default
is_default=True, # Mark as workspace default
)
db.add(default_config)
db.commit()
business_logger.info(
"Created default memory config for workspace",
extra={
"workspace_id": str(workspace_id),
"config_id": str(config_id),
"config_name": default_config.config_name,
}
)