feat(memory): add protected memory config deletion with end-user safeguards
- Add force parameter to delete_config endpoint for controlled deletion of in-use configs - Implement MemoryConfigService.delete_config with protection against deleting default configs - Add validation to prevent deletion of configs with connected end-users unless force=True - Reorganize controller imports to remove duplicates and improve maintainability - Clean up unused database connection management code from memory_storage_controller - Add detailed docstring to delete_config endpoint explaining protection mechanisms - Update error handling with specific BizCode.RESOURCE_IN_USE for configs in active use - Add comprehensive logging for deletion attempts, warnings, and affected users - Refactor ConfigParamsDelete schema usage to use MemoryConfigService directly - Improve API response structure with affected_users count and force_required flag
This commit is contained in:
@@ -24,9 +24,11 @@ from . import (
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
memory_working_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
prompt_optimizer_controller,
|
||||
@@ -41,10 +43,6 @@ from . import (
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
@@ -11,7 +14,6 @@ from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
@@ -31,9 +33,6 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -70,68 +69,9 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@@ -139,7 +79,7 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -162,9 +102,20 @@ def create_config(
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
- 检查是否为默认配置,默认配置不允许删除
|
||||
- 检查是否有终端用户连接到该配置
|
||||
- 如果有连接且 force=False,返回警告
|
||||
- 如果 force=True,清除终端用户引用后删除配置
|
||||
|
||||
Query Parameters:
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -172,21 +123,62 @@ def delete_config(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.FORBIDDEN,
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
f"connected_count={result['connected_count']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.RESOURCE_IN_USE,
|
||||
msg=result["message"],
|
||||
data={
|
||||
"connected_count": result["connected_count"],
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
)
|
||||
return success(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -209,7 +201,7 @@ def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -236,7 +228,7 @@ def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -257,7 +249,7 @@ def read_config_extracted(
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -297,9 +289,8 @@ async def pilot_run(
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
# ==================== Search & Analytics ====================
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
@@ -439,8 +430,9 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
|
||||
@@ -46,6 +46,7 @@ class BizCode(IntEnum):
|
||||
RESOURCE_ALREADY_EXISTS = 5002
|
||||
VERSION_ALREADY_EXISTS = 5003
|
||||
STATE_CONFLICT = 5004
|
||||
RESOURCE_IN_USE = 5005
|
||||
|
||||
# 应用发布(6xxx)
|
||||
PUBLISH_FAILED = 6001
|
||||
@@ -125,6 +126,7 @@ HTTP_MAPPING = {
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.RESOURCE_IN_USE: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
|
||||
@@ -108,7 +108,6 @@ class DimensionAnalyzer:
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
@@ -220,7 +219,7 @@ class DimensionAnalyzer:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
@@ -228,7 +227,6 @@ class DimensionAnalyzer:
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
|
||||
@@ -7,7 +7,7 @@ providing percentage distribution that totals 100%.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
@@ -133,7 +133,6 @@ class InterestAnalyzer:
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
@@ -251,7 +250,7 @@ class InterestAnalyzer:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
@@ -259,15 +258,15 @@ class InterestAnalyzer:
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
def default_category(name: str) -> InterestCategory:
|
||||
return InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from app.db import Base
|
||||
from sqlalchemy import BigInteger, Column, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class EndUser(Base):
|
||||
__tablename__ = "end_users"
|
||||
@@ -21,7 +22,13 @@ class EndUser(Base):
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
# Memory config association - updated lazily during conversation
|
||||
memory_config_id = Column(Integer, ForeignKey("data_config.config_id"), nullable=True, index=True, comment="关联的记忆配置ID")
|
||||
memory_config_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("memory_config.config_id"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="关联的记忆配置ID"
|
||||
)
|
||||
|
||||
# 用户基本信息字段
|
||||
position = Column(String, nullable=True, comment="职位")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import datetime
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float
|
||||
|
||||
from sqlalchemy import Boolean, Column, DateTime, Float, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
@@ -38,6 +40,7 @@ class MemoryConfig(Base):
|
||||
|
||||
# 状态配置
|
||||
state = Column(Boolean, default=False, comment="配置使用状态")
|
||||
is_default = Column(Boolean, default=False, comment="是否为工作空间默认配置")
|
||||
|
||||
# 分块策略
|
||||
chunker_strategy = Column(String, default="RecursiveChunker", comment="分块策略")
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
@@ -264,6 +264,179 @@ class EndUserRepository:
|
||||
db_logger.error(f"查询活动工作空间时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_memory_config_id(self, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool:
|
||||
"""更新终端用户的 memory_config_id(懒更新)。
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config_id: 记忆配置ID
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{EndUser.memory_config_id: memory_config_id},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
db_logger.debug(f"成功更新终端用户 {end_user_id} 的 memory_config_id: {memory_config_id}")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新 memory_config_id")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_memory_config_id(self, end_user_id: uuid.UUID) -> Optional[uuid.UUID]:
|
||||
"""获取终端用户的 memory_config_id。
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
Optional[uuid.UUID]: memory_config_id 或 None
|
||||
"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.first()
|
||||
)
|
||||
if end_user and end_user.memory_config_id:
|
||||
db_logger.debug(f"获取终端用户 {end_user_id} 的 memory_config_id: {end_user.memory_config_id}")
|
||||
return end_user.memory_config_id
|
||||
return None
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def batch_update_memory_config_id(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 更新的行数
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.app_id == app_id)
|
||||
.values(memory_config_id=memory_config_id)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
db_logger.info(
|
||||
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(
|
||||
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def count_by_memory_config_id(
|
||||
self,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""统计使用指定记忆配置的终端用户数量
|
||||
|
||||
Args:
|
||||
memory_config_id: 记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 使用该配置的终端用户数量
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import func, select
|
||||
|
||||
stmt = (
|
||||
select(func.count(EndUser.id))
|
||||
.where(EndUser.memory_config_id == memory_config_id)
|
||||
)
|
||||
|
||||
count = self.db.execute(stmt).scalar() or 0
|
||||
|
||||
db_logger.debug(f"统计记忆配置使用数: memory_config_id={memory_config_id}, count={count}")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"统计记忆配置使用数时出错: memory_config_id={memory_config_id}, error={str(e)}")
|
||||
raise
|
||||
|
||||
def clear_memory_config_id(
|
||||
self,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""清除所有使用指定记忆配置的终端用户的 memory_config_id
|
||||
|
||||
将 memory_config_id 设置为 NULL
|
||||
|
||||
Args:
|
||||
memory_config_id: 要清除的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 清除的行数
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.memory_config_id == memory_config_id)
|
||||
.values(memory_config_id=None)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
cleared_count = result.rowcount
|
||||
|
||||
db_logger.warning(
|
||||
f"清除终端用户记忆配置引用: memory_config_id={memory_config_id}, "
|
||||
f"cleared_count={cleared_count}"
|
||||
)
|
||||
|
||||
return cleared_count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(
|
||||
f"清除终端用户记忆配置引用时出错: memory_config_id={memory_config_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
repo = EndUserRepository(db)
|
||||
@@ -315,3 +488,32 @@ def get_all_active_workspaces(db: Session) -> List[uuid.UUID]:
|
||||
"""获取所有活动工作空间的ID"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_all_active_workspaces()
|
||||
|
||||
|
||||
def update_memory_config_id(db: Session, end_user_id: uuid.UUID, memory_config_id: uuid.UUID) -> bool:
|
||||
"""更新终端用户的 memory_config_id(懒更新)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
memory_config_id: 记忆配置ID
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.update_memory_config_id(end_user_id, memory_config_id)
|
||||
|
||||
|
||||
def get_memory_config_id(db: Session, end_user_id: uuid.UUID) -> Optional[uuid.UUID]:
|
||||
"""获取终端用户的 memory_config_id。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
Optional[uuid.UUID]: memory_config_id 或 None
|
||||
"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_memory_config_id(end_user_id)
|
||||
|
||||
@@ -112,7 +112,6 @@ class DimensionPortraitResponse(BaseModel):
|
||||
"""Four-dimension personality portrait."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
creativity: DimensionScoreResponse
|
||||
aesthetic: DimensionScoreResponse
|
||||
technology: DimensionScoreResponse
|
||||
@@ -140,7 +139,6 @@ class InterestAreaDistributionResponse(BaseModel):
|
||||
"""Distribution of user interests across four areas."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
tech: InterestCategoryResponse
|
||||
lifestyle: InterestCategoryResponse
|
||||
music: InterestCategoryResponse
|
||||
@@ -184,7 +182,6 @@ class UserProfileResponse(BaseModel):
|
||||
"""Comprehensive user profile."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preference_tags: List[PreferenceTagResponse]
|
||||
dimension_portrait: DimensionPortraitResponse
|
||||
interest_area_distribution: InterestAreaDistributionResponse
|
||||
@@ -226,7 +223,6 @@ class UserMemorySummary(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
summary_id: str
|
||||
user_id: str
|
||||
user_content: str
|
||||
timestamp: datetime.datetime
|
||||
confidence_score: float = Field(ge=0.0, le=1.0)
|
||||
@@ -241,7 +237,6 @@ class SummaryAnalysisResult(BaseModel):
|
||||
"""Result of analyzing memory summaries."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preferences: List[PreferenceTagResponse]
|
||||
dimension_evidence: Dict[str, List[str]]
|
||||
interest_evidence: Dict[str, List[str]]
|
||||
@@ -273,7 +268,6 @@ class GenerateProfileRequest(BaseModel):
|
||||
|
||||
class CompleteProfileResponse(BaseModel):
|
||||
"""完整用户画像响应(包含所有模块)"""
|
||||
user_id: str
|
||||
preferences: List[PreferenceTagResponse]
|
||||
portrait: DimensionPortraitResponse
|
||||
interest_areas: InterestAreaDistributionResponse
|
||||
|
||||
@@ -9,28 +9,35 @@
|
||||
"""
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Optional, List, Dict, Any, Tuple, Annotated
|
||||
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import select, func, or_, and_
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import (
|
||||
ResourceNotFoundException,
|
||||
BusinessException,
|
||||
ResourceNotFoundException,
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.workflow.validator import WorkflowValidator
|
||||
from app.db import get_db
|
||||
from app.models import App, AgentConfig, AppRelease, MultiAgentConfig, WorkflowConfig
|
||||
from app.models import (
|
||||
AgentConfig,
|
||||
App,
|
||||
AppRelease,
|
||||
AppShare,
|
||||
MultiAgentConfig,
|
||||
WorkflowConfig,
|
||||
Workspace,
|
||||
)
|
||||
from app.models.app_model import AppStatus, AppType
|
||||
from app.repositories.app_repository import get_apps_by_id
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.services.agent_config_converter import AgentConfigConverter
|
||||
from app.models import AppShare, Workspace
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.utils.app_config_utils import model_parameters_to_dict
|
||||
@@ -136,9 +143,10 @@ class AppService:
|
||||
return app
|
||||
|
||||
def _check_workflow_config(self, app_id: uuid.UUID):
|
||||
from app.models import WorkflowConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models import ModelConfig, WorkflowConfig
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(WorkflowConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
@@ -154,9 +162,10 @@ class AppService:
|
||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
def _check_agent_config(self, app_id: uuid.UUID):
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
@@ -326,10 +335,10 @@ class AppService:
|
||||
"""
|
||||
# 将 Dict 转换为 MultiAgentConfigCreate
|
||||
from app.schemas.multi_agent_schema import (
|
||||
ExecutionConfig,
|
||||
MultiAgentConfigCreate,
|
||||
SubAgentConfig,
|
||||
RoutingRule,
|
||||
ExecutionConfig
|
||||
SubAgentConfig,
|
||||
)
|
||||
|
||||
# 转换 sub_agents
|
||||
@@ -1167,6 +1176,138 @@ class AppService:
|
||||
|
||||
return default_config
|
||||
|
||||
# ==================== 记忆配置提取方法 ====================
|
||||
|
||||
def _extract_memory_config_id(
|
||||
self,
|
||||
app_type: str,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[uuid.UUID]:
|
||||
"""从发布配置中提取 memory_config_id(根据应用类型分发)
|
||||
|
||||
Args:
|
||||
app_type: 应用类型 (agent, workflow, multi_agent)
|
||||
config: 发布配置字典
|
||||
|
||||
Returns:
|
||||
Optional[uuid.UUID]: 提取的 memory_config_id,如果不存在则返回 None
|
||||
"""
|
||||
if app_type == AppType.AGENT:
|
||||
return self._extract_memory_config_id_from_agent(config)
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
return self._extract_memory_config_id_from_workflow(config)
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-agent 暂不支持记忆配置提取
|
||||
logger.debug(f"多智能体应用暂不支持记忆配置提取: app_type={app_type}")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"不支持的应用类型,无法提取记忆配置: app_type={app_type}")
|
||||
return None
|
||||
|
||||
def _extract_memory_config_id_from_agent(
|
||||
self,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[uuid.UUID]:
|
||||
"""从 Agent 应用配置中提取 memory_config_id
|
||||
|
||||
路径: config.memory.memory_content
|
||||
|
||||
Args:
|
||||
config: Agent 配置字典
|
||||
|
||||
Returns:
|
||||
Optional[uuid.UUID]: 记忆配置ID,如果不存在则返回 None
|
||||
"""
|
||||
try:
|
||||
memory_content = config.get("memory", {}).get("memory_content")
|
||||
if memory_content:
|
||||
# 处理字符串和 UUID 两种情况
|
||||
if isinstance(memory_content, str):
|
||||
return uuid.UUID(memory_content)
|
||||
elif isinstance(memory_content, uuid.UUID):
|
||||
return memory_content
|
||||
else:
|
||||
logger.warning(
|
||||
f"Agent 配置中 memory_content 格式无效: type={type(memory_content)}, "
|
||||
f"value={memory_content}"
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"Agent 配置中 memory_content 格式无效: error={str(e)}, "
|
||||
f"memory_content={memory_content}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _extract_memory_config_id_from_workflow(
|
||||
self,
|
||||
config: Dict[str, Any]
|
||||
) -> Optional[uuid.UUID]:
|
||||
"""从 Workflow 应用配置中提取 memory_config_id
|
||||
|
||||
扫描工作流节点,查找 MemoryRead 或 MemoryWrite 节点。
|
||||
返回第一个找到的记忆节点的 config_id。
|
||||
|
||||
Args:
|
||||
config: Workflow 配置字典
|
||||
|
||||
Returns:
|
||||
Optional[uuid.UUID]: 记忆配置ID,如果不存在则返回 None
|
||||
"""
|
||||
nodes = config.get("nodes", [])
|
||||
|
||||
for node in nodes:
|
||||
node_type = node.get("type", "")
|
||||
|
||||
# 检查是否为记忆节点
|
||||
if node_type in ["MemoryRead", "MemoryWrite"]:
|
||||
config_id = node.get("config", {}).get("config_id")
|
||||
|
||||
if config_id:
|
||||
try:
|
||||
# 处理字符串和 UUID 两种情况
|
||||
if isinstance(config_id, str):
|
||||
return uuid.UUID(config_id)
|
||||
elif isinstance(config_id, uuid.UUID):
|
||||
return config_id
|
||||
else:
|
||||
logger.warning(
|
||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||
f"node_type={node_type}, type={type(config_id)}"
|
||||
)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(
|
||||
f"工作流记忆节点 config_id 格式无效: node_id={node.get('id')}, "
|
||||
f"node_type={node_type}, error={str(e)}"
|
||||
)
|
||||
|
||||
logger.debug("工作流配置中未找到记忆节点")
|
||||
return None
|
||||
|
||||
def _update_endusers_memory_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 更新的终端用户数量
|
||||
"""
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id(
|
||||
app_id=app_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
return updated_count
|
||||
|
||||
# ==================== 应用发布管理 ====================
|
||||
|
||||
def publish(
|
||||
@@ -1309,6 +1450,15 @@ class AppService:
|
||||
self.db.add(release)
|
||||
self.db.flush() # 先 flush,确保 release 已插入数据库
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id = self._extract_memory_config_id(app.type, config)
|
||||
if memory_config_id:
|
||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
# 更新当前发布版本指针
|
||||
app.current_release_id = release.id
|
||||
app.status = AppStatus.ACTIVE
|
||||
@@ -1424,6 +1574,15 @@ class AppService:
|
||||
)
|
||||
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
|
||||
|
||||
# 提取记忆配置ID并更新终端用户
|
||||
memory_config_id = self._extract_memory_config_id(release.type, release.config)
|
||||
if memory_config_id:
|
||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
app.current_release_id = release.id
|
||||
app.updated_at = datetime.datetime.now()
|
||||
|
||||
@@ -1839,8 +1998,8 @@ class AppService:
|
||||
Returns:
|
||||
Dict: 对比结果
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比试运行",
|
||||
@@ -1938,8 +2097,8 @@ class AppService:
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比流式试运行",
|
||||
|
||||
@@ -9,10 +9,15 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
@@ -37,11 +42,6 @@ from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
@@ -732,8 +732,12 @@ class MemoryAgentService:
|
||||
)
|
||||
|
||||
# 导入必要的模块
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
|
||||
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
summary_llm,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
)
|
||||
|
||||
# 构建状态对象
|
||||
state = {
|
||||
@@ -1144,9 +1148,8 @@ class MemoryAgentService:
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
|
||||
def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int]:
|
||||
"""
|
||||
快速获取终端用户的 memory_config_id(直接从 end_user 表读取)
|
||||
def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[uuid.UUID]:
|
||||
"""快速获取终端用户的 memory_config_id(直接从 end_user 表读取)。
|
||||
|
||||
如果 end_user 已有缓存的 memory_config_id,直接返回;
|
||||
否则返回 None,调用方应使用 get_end_user_connected_config 获取完整配置。
|
||||
@@ -1156,14 +1159,16 @@ def get_end_user_memory_config_id(end_user_id: str, db: Session) -> Optional[int
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
memory_config_id 或 None
|
||||
Optional[uuid.UUID]: memory_config_id 或 None
|
||||
"""
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.end_user_repository import get_memory_config_id
|
||||
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if end_user and end_user.memory_config_id:
|
||||
return end_user.memory_config_id
|
||||
return None
|
||||
try:
|
||||
end_user_uuid = uuid.UUID(end_user_id) if isinstance(end_user_id, str) else end_user_id
|
||||
return get_memory_config_id(db, end_user_uuid)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"Invalid end_user_id format: {end_user_id}, error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
@@ -1185,9 +1190,9 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_model import App
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
@@ -1200,22 +1205,25 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
app_id = end_user.app_id
|
||||
logger.debug(f"Found end_user app_id: {app_id}")
|
||||
|
||||
# 2. 获取该应用的最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||||
.order_by(AppRelease.version.desc())
|
||||
)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
# 2. 获取应用的当前发布版本(通过 apps.current_release_id)
|
||||
app = db.query(App).filter(App.id == app_id).first()
|
||||
if not app:
|
||||
logger.warning(f"App not found: {app_id}")
|
||||
raise ValueError(f"应用不存在: {app_id}")
|
||||
|
||||
if not latest_release:
|
||||
logger.warning(f"No active release found for app: {app_id}")
|
||||
if not app.current_release_id:
|
||||
logger.warning(f"No current release for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||||
current_release = db.query(AppRelease).filter(AppRelease.id == app.current_release_id).first()
|
||||
if not current_release:
|
||||
logger.warning(f"Current release not found: {app.current_release_id}")
|
||||
raise ValueError(f"应用发布版本不存在: {app.current_release_id}")
|
||||
|
||||
logger.debug(f"Found current release: version={current_release.version}, id={current_release.id}")
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
config = current_release.config or {}
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
@@ -1223,27 +1231,17 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
try:
|
||||
config = json.loads(config)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
logger.warning(f"Failed to parse config JSON for release {current_release.id}")
|
||||
config = {}
|
||||
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 4. 更新 end_user 的 memory_config_id(懒更新)
|
||||
if memory_config_id is not None and end_user.memory_config_id != memory_config_id:
|
||||
try:
|
||||
end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
logger.debug(f"Updated end_user memory_config_id: {end_user_id} -> {memory_config_id}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.warning(f"Failed to update end_user memory_config_id: {e}")
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"release_id": str(current_release.id),
|
||||
"release_version": current_release.version,
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
@@ -1272,10 +1270,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
...
|
||||
}
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||||
|
||||
|
||||
@@ -6,29 +6,33 @@ This service eliminates code duplication between MemoryAgentService and MemorySt
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
import uuid
|
||||
|
||||
|
||||
def _validate_config_id(config_id, db: Session = None):
|
||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||
@@ -320,11 +324,12 @@ class MemoryConfigService:
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
@@ -353,11 +358,12 @@ class MemoryConfigService:
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
@@ -438,3 +444,207 @@ class MemoryConfigService:
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
|
||||
def _get_workspace_default_config(
|
||||
self,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get workspace default memory config.
|
||||
|
||||
Returns the config marked as default for the workspace. If no explicit
|
||||
default exists, falls back to the first active config ordered by creation time.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace ID
|
||||
|
||||
Returns:
|
||||
Optional[MemoryConfigModel]: Default config or None if no configs exist
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
|
||||
# First, try to find the explicitly marked default config
|
||||
stmt = (
|
||||
select(MemoryConfigModel)
|
||||
.where(
|
||||
MemoryConfigModel.workspace_id == workspace_id,
|
||||
MemoryConfigModel.is_default.is_(True),
|
||||
MemoryConfigModel.state.is_(True),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
config = self.db.scalars(stmt).first()
|
||||
|
||||
if config:
|
||||
return config
|
||||
|
||||
# Fallback: get the oldest active config if no explicit default
|
||||
stmt = (
|
||||
select(MemoryConfigModel)
|
||||
.where(
|
||||
MemoryConfigModel.workspace_id == workspace_id,
|
||||
MemoryConfigModel.state.is_(True),
|
||||
)
|
||||
.order_by(MemoryConfigModel.created_at.asc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
config = self.db.scalars(stmt).first()
|
||||
|
||||
if not config:
|
||||
logger.warning(
|
||||
"No active memory config found for workspace fallback",
|
||||
extra={"workspace_id": str(workspace_id)}
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_config_with_fallback(
|
||||
self,
|
||||
end_user_id: UUID,
|
||||
workspace_id: UUID
|
||||
) -> Optional["MemoryConfigModel"]:
|
||||
"""Get memory config for end user with fallback to workspace default.
|
||||
|
||||
Implements graceful degradation: if the end user's assigned config
|
||||
doesn't exist, falls back to the workspace's default active config.
|
||||
|
||||
Args:
|
||||
end_user_id: End user ID
|
||||
workspace_id: Workspace ID for fallback lookup
|
||||
|
||||
Returns:
|
||||
Optional[MemoryConfigModel]: Memory config or None if no fallback available
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
end_user_repo = EndUserRepository(self.db)
|
||||
end_user = end_user_repo.get_by_id(end_user_id)
|
||||
|
||||
if not end_user or not end_user.memory_config_id:
|
||||
logger.debug(
|
||||
"End user has no memory config assigned",
|
||||
extra={"end_user_id": str(end_user_id)}
|
||||
)
|
||||
return self._get_workspace_default_config(workspace_id)
|
||||
|
||||
config = self.db.get(MemoryConfigModel, end_user.memory_config_id)
|
||||
|
||||
if config:
|
||||
return config
|
||||
|
||||
logger.warning(
|
||||
"Memory config not found, falling back to workspace default",
|
||||
extra={
|
||||
"end_user_id": str(end_user_id),
|
||||
"missing_config_id": str(end_user.memory_config_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
fallback_config = self._get_workspace_default_config(workspace_id)
|
||||
|
||||
if fallback_config:
|
||||
logger.info(
|
||||
"Using fallback memory config",
|
||||
extra={
|
||||
"end_user_id": str(end_user_id),
|
||||
"fallback_config_id": str(fallback_config.config_id)
|
||||
}
|
||||
)
|
||||
|
||||
return fallback_config
|
||||
|
||||
def delete_config(
|
||||
self,
|
||||
config_id: UUID,
|
||||
force: bool = False
|
||||
) -> dict:
|
||||
"""Delete memory config with protection against in-use configs.
|
||||
|
||||
Implements delete protection: prevents accidental deletion of configs
|
||||
that are actively being used by end users or marked as default.
|
||||
|
||||
Args:
|
||||
config_id: Memory config ID to delete
|
||||
force: If True, delete even if end users are connected
|
||||
|
||||
Returns:
|
||||
Dict with status, message, and affected_users count
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: If config doesn't exist
|
||||
"""
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
|
||||
config = self.db.get(MemoryConfigModel, config_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("MemoryConfig", str(config_id))
|
||||
|
||||
# Check if this is the default config - default configs cannot be deleted
|
||||
if config.is_default:
|
||||
logger.warning(
|
||||
"Attempted to delete default memory config",
|
||||
extra={"config_id": str(config_id)}
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": "默认配置不允许删除",
|
||||
"is_default": True
|
||||
}
|
||||
|
||||
# TODO: add back delete warning
|
||||
# # Count connected end users
|
||||
# end_user_repo = EndUserRepository(self.db)
|
||||
# connected_count = end_user_repo.count_by_memory_config_id(config_id)
|
||||
|
||||
# if connected_count > 0 and not force:
|
||||
# logger.warning(
|
||||
# "Attempted to delete memory config with connected end users",
|
||||
# extra={
|
||||
# "config_id": str(config_id),
|
||||
# "connected_count": connected_count
|
||||
# }
|
||||
# )
|
||||
|
||||
# return {
|
||||
# "status": "warning",
|
||||
# "message": f"Cannot delete memory config: {connected_count} end users are using it",
|
||||
# "connected_count": connected_count,
|
||||
# "force_required": True
|
||||
# }
|
||||
|
||||
# # Force delete: clear end user references first
|
||||
# if connected_count > 0 and force:
|
||||
# cleared_count = end_user_repo.clear_memory_config_id(config_id)
|
||||
|
||||
# logger.warning(
|
||||
# "Force deleting memory config",
|
||||
# extra={
|
||||
# "config_id": str(config_id),
|
||||
# "cleared_end_users": cleared_count
|
||||
# }
|
||||
# )
|
||||
connected_count = 0
|
||||
|
||||
self.db.delete(config)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"Memory config deleted",
|
||||
extra={
|
||||
"config_id": str(config_id),
|
||||
"force": force,
|
||||
"affected_users": connected_count
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Memory config deleted successfully",
|
||||
"affected_users": connected_count
|
||||
}
|
||||
|
||||
@@ -5,32 +5,36 @@ import uuid
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.schemas.workspace_schema import (
|
||||
WorkspaceModelsUpdate,
|
||||
from app.models.workspace_model import (
|
||||
InviteStatus,
|
||||
Workspace,
|
||||
WorkspaceMember,
|
||||
WorkspaceRole,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.schemas.workspace_schema import (
|
||||
InviteAcceptRequest,
|
||||
InviteValidateResponse,
|
||||
WorkspaceCreate,
|
||||
WorkspaceUpdate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceInviteResponse,
|
||||
InviteValidateResponse,
|
||||
InviteAcceptRequest,
|
||||
WorkspaceMemberUpdate
|
||||
WorkspaceMemberUpdate,
|
||||
WorkspaceModelsUpdate,
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
def switch_workspace(
|
||||
db: Session,
|
||||
@@ -127,6 +131,27 @@ def create_workspace(
|
||||
db.commit()
|
||||
db.refresh(db_workspace)
|
||||
|
||||
# Create default memory config for the workspace (only for neo4j storage types)
|
||||
if workspace.storage_type == 'neo4j':
|
||||
try:
|
||||
_create_default_memory_config(
|
||||
db=db,
|
||||
workspace_id=db_workspace.id,
|
||||
workspace_name=db_workspace.name,
|
||||
llm_id=llm,
|
||||
embedding_id=embedding,
|
||||
rerank_id=rerank,
|
||||
)
|
||||
business_logger.info(
|
||||
f"为工作空间 {db_workspace.id} 创建默认记忆配置成功"
|
||||
)
|
||||
except Exception as mc_error:
|
||||
business_logger.error(
|
||||
f"为工作空间 {db_workspace.id} 创建默认记忆配置失败: {str(mc_error)}"
|
||||
)
|
||||
# Don't fail workspace creation if memory config creation fails
|
||||
# The workspace can still function without a default memory config
|
||||
|
||||
# 如果 storage_type 是 "rag",自动创建知识库
|
||||
if workspace.storage_type == "rag":
|
||||
business_logger.info(
|
||||
@@ -286,7 +311,7 @@ def _check_workspace_member_permission(db: Session, workspace_id: uuid.UUID, use
|
||||
)
|
||||
|
||||
# 使用统一权限服务检查访问权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
@@ -324,7 +349,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
)
|
||||
|
||||
# 使用统一权限服务检查管理权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
@@ -852,3 +877,50 @@ def update_workspace_models_configs(
|
||||
business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_name: str,
|
||||
llm_id: Optional[uuid.UUID] = None,
|
||||
embedding_id: Optional[uuid.UUID] = None,
|
||||
rerank_id: Optional[uuid.UUID] = None,
|
||||
) -> None:
|
||||
"""Create a default memory config for a newly created workspace.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace_id: The workspace ID
|
||||
workspace_name: The workspace name (used for config naming)
|
||||
llm_id: Optional LLM model ID
|
||||
embedding_id: Optional embedding model ID
|
||||
rerank_id: Optional rerank model ID
|
||||
"""
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
config_id = uuid.uuid4()
|
||||
|
||||
default_config = MemoryConfig(
|
||||
config_id=config_id,
|
||||
config_name=f"{workspace_name} 默认配置",
|
||||
config_desc="工作空间创建时自动生成的默认记忆配置",
|
||||
workspace_id=workspace_id,
|
||||
llm_id=str(llm_id) if llm_id else None,
|
||||
embedding_id=str(embedding_id) if embedding_id else None,
|
||||
rerank_id=str(rerank_id) if rerank_id else None,
|
||||
state=True, # Active by default
|
||||
is_default=True, # Mark as workspace default
|
||||
)
|
||||
|
||||
db.add(default_config)
|
||||
db.commit()
|
||||
|
||||
business_logger.info(
|
||||
"Created default memory config for workspace",
|
||||
extra={
|
||||
"workspace_id": str(workspace_id),
|
||||
"config_id": str(config_id),
|
||||
"config_name": default_config.config_name,
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user