From 8887600f7d6e2781cab4a2acfed4e8ceb5c870ad Mon Sep 17 00:00:00 2001 From: wwq Date: Mon, 20 Apr 2026 19:01:06 +0800 Subject: [PATCH] refactor(knowledge_service): refactor model binding logic into generic function - Extract duplicate model binding logic into `_get_model_by_name_or_fallback`. - Implement logic to prioritize workspace default configuration, falling back to the tenant's first available model if not found. - Simplify binding code for embedding, rerank, and LLM models. --- api/app/services/knowledge_service.py | 47 +++++++++++++++++---------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/api/app/services/knowledge_service.py b/api/app/services/knowledge_service.py index 94653db8..b1d0d77b 100644 --- a/api/app/services/knowledge_service.py +++ b/api/app/services/knowledge_service.py @@ -77,29 +77,40 @@ def create_knowledge( tenant_id = workspace.tenant_id + def _get_model_by_name_or_fallback(model_name: str | None, model_types: list, label: str): + """优先按 workspace 指定的 model name 查,找不到再 fallback 到 tenant 下第一个""" + if model_name: + model = db.query(ModelConfig).filter( + ModelConfig.tenant_id == tenant_id, + ModelConfig.name == model_name, + ModelConfig.type.in_([t.value for t in model_types]), + ModelConfig.is_active == True, + ModelConfig.is_composite == False + ).first() + if model: + business_logger.debug(f"Auto-bind {label} model from workspace default: {model.id} ({model_name})") + return model + business_logger.debug(f"Workspace default {label} model '{model_name}' not found, falling back to tenant") + models = ModelConfigRepository.get_by_type(db=db, model_types=model_types, tenant_id=tenant_id, is_active=True) + if models: + business_logger.debug(f"Auto-bind {label} model from tenant fallback: {models[0].id}") + return models[0] + return None + if not knowledge.embedding_id: - embedding_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True - ) - if embedding_models: - knowledge.embedding_id = embedding_models[0].id - business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}") + model = _get_model_by_name_or_fallback(workspace.embedding, [ModelType.EMBEDDING], "embedding") + if model: + knowledge.embedding_id = model.id if not knowledge.reranker_id: - rerank_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True - ) - if rerank_models: - knowledge.reranker_id = rerank_models[0].id - business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}") + model = _get_model_by_name_or_fallback(workspace.rerank, [ModelType.RERANK], "rerank") + if model: + knowledge.reranker_id = model.id if not knowledge.llm_id: - llm_models = ModelConfigRepository.get_by_type( - db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True - ) - if llm_models: - knowledge.llm_id = llm_models[0].id - business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}") + model = _get_model_by_name_or_fallback(workspace.llm, [ModelType.LLM, ModelType.CHAT], "llm") + if model: + knowledge.llm_id = model.id if not knowledge.image2text_id: image2text_models = db.query(ModelConfig).filter(