diff --git a/api/app/tasks.py b/api/app/tasks.py index 3a9da918..676ce71d 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -40,6 +40,7 @@ from app.models.file_model import File from app.models.knowledge_model import Knowledge from app.schemas import file_schema, document_schema from app.services.memory_agent_service import MemoryAgentService +from app.utils.config_utils import resolve_config_id @celery_app.task(name="tasks.process_item") @@ -905,7 +906,8 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s actual_config_id = None if config_id: try: - actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id + with get_db_context() as db: + actual_config_id = resolve_config_id(config_id, db) except (ValueError, AttributeError): # If conversion fails, leave as None and try to resolve pass @@ -981,14 +983,13 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: list[dict], config_id: str, storage_type: str, user_rag_memory_id: str, +def write_message_task(self, end_user_id: str, message: list[dict], config_id: str | int, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. - Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write - config_id: Configuration ID as string (will be converted to UUID) + config_id: Configuration ID (can be UUID string, integer, or config_id_old) storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID language: 语言类型 ("zh" 中文, "en" 英文) @@ -1002,24 +1003,28 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s from app.core.logging_config import get_logger logger = get_logger(__name__) - logger.info( - f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}, language={language}") + logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id} (type: {type(config_id).__name__}), storage_type={storage_type}, language={language}") start_time = time.time() - # Convert config_id string to UUID + # Convert config_id to UUID actual_config_id = None + if config_id: try: - actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id + with get_db_context() as db: + actual_config_id = resolve_config_id(config_id, db) + print(100*'-') + print(actual_config_id) + print(100*'-') logger.info( f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: - logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}") + logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id} (type: {type(config_id).__name__}), error: {e}") return { "status": "FAILURE", - "error": f"Invalid config_id format: {config_id}", + "error": f"Invalid config_id format: {config_id} - {str(e)}", "end_user_id": end_user_id, - "config_id": config_id, + "config_id": str(config_id), "elapsed_time": 0.0, "task_id": self.request.id } diff --git a/api/app/utils/config_utils.py b/api/app/utils/config_utils.py index 55cfe8a3..eee5c233 100644 --- a/api/app/utils/config_utils.py +++ b/api/app/utils/config_utils.py @@ -28,33 +28,7 @@ def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: if isinstance(config_id, UUID): return config_id - # 2. 如果是字符串类型 - if isinstance(config_id, str): - config_id_stripped = config_id.strip() - - # 2.1 尝试解析为 UUID(标准 UUID 字符串长度为 36) - try: - return uuid_module.UUID(config_id_stripped) - except ValueError: - pass - - # 2.2 尝试解析为整数(用于查询 config_id_old) - try: - old_id = int(config_id_stripped) - if old_id > 0: - memory_config = db.query(MemoryConfig).filter( - MemoryConfig.config_id_old == old_id - ).first() - if not memory_config: - raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") - return memory_config.config_id - except ValueError: - pass - - # 2.3 无法解析的字符串格式 - raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") - - # 3. 如果是整数类型,通过 config_id_old 查找 + # 2. 如果是整数类型,通过 config_id_old 查找 if isinstance(config_id, int): if config_id <= 0: raise ValueError(f"config_id 必须是正整数: {config_id}") @@ -67,6 +41,34 @@ def resolve_config_id(config_id: UUID | int | str, db: Session) -> UUID: raise ValueError(f"未找到 config_id_old={config_id} 对应的配置") return memory_config.config_id + + # 3. 如果是字符串类型 + if isinstance(config_id, str): + config_id_stripped = config_id.strip() + + # 3.1 先尝试解析为整数(用于查询 config_id_old) + # 这样可以处理 "17" 这样的字符串 + try: + old_id = int(config_id_stripped) + if old_id > 0: + memory_config = db.query(MemoryConfig).filter( + MemoryConfig.config_id_old == old_id + ).first() + if not memory_config: + raise ValueError(f"未找到 config_id_old={old_id} 对应的配置") + return memory_config.config_id + except ValueError: + # 不是整数,继续尝试 UUID + pass + + # 3.2 尝试解析为 UUID + try: + return uuid_module.UUID(config_id_stripped) + except ValueError: + pass + + # 3.3 无法解析的字符串格式 + raise ValueError(f"无效的 config_id 格式: '{config_id}'(必须是 UUID 或正整数)") # 4. 不支持的类型 raise ValueError(f"不支持的 config_id 类型: {type(config_id).__name__}")