refactor(knowledge_service): refactor model binding logic into generic function

- Extract duplicate model binding logic into `_get_model_by_name_or_fallback`.
- Implement logic to prioritize workspace default configuration, falling back to the tenant's first available model if not found.
- Simplify binding code for embedding, rerank, and LLM models.
This commit is contained in:
wwq
2026-04-20 19:01:06 +08:00
parent df6eb74b28
commit 8887600f7d

View File

@@ -77,29 +77,40 @@ def create_knowledge(
tenant_id = workspace.tenant_id
def _get_model_by_name_or_fallback(model_name: str | None, model_types: list, label: str):
"""优先按 workspace 指定的 model name 查,找不到再 fallback 到 tenant 下第一个"""
if model_name:
model = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id,
ModelConfig.name == model_name,
ModelConfig.type.in_([t.value for t in model_types]),
ModelConfig.is_active == True,
ModelConfig.is_composite == False
).first()
if model:
business_logger.debug(f"Auto-bind {label} model from workspace default: {model.id} ({model_name})")
return model
business_logger.debug(f"Workspace default {label} model '{model_name}' not found, falling back to tenant")
models = ModelConfigRepository.get_by_type(db=db, model_types=model_types, tenant_id=tenant_id, is_active=True)
if models:
business_logger.debug(f"Auto-bind {label} model from tenant fallback: {models[0].id}")
return models[0]
return None
if not knowledge.embedding_id:
embedding_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.EMBEDDING], tenant_id=tenant_id, is_active=True
)
if embedding_models:
knowledge.embedding_id = embedding_models[0].id
business_logger.debug(f"Auto-bind embedding model: {embedding_models[0].id}")
model = _get_model_by_name_or_fallback(workspace.embedding, [ModelType.EMBEDDING], "embedding")
if model:
knowledge.embedding_id = model.id
if not knowledge.reranker_id:
rerank_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.RERANK], tenant_id=tenant_id, is_active=True
)
if rerank_models:
knowledge.reranker_id = rerank_models[0].id
business_logger.debug(f"Auto-bind rerank model: {rerank_models[0].id}")
model = _get_model_by_name_or_fallback(workspace.rerank, [ModelType.RERANK], "rerank")
if model:
knowledge.reranker_id = model.id
if not knowledge.llm_id:
llm_models = ModelConfigRepository.get_by_type(
db=db, model_types=[ModelType.LLM, ModelType.CHAT], tenant_id=tenant_id, is_active=True
)
if llm_models:
knowledge.llm_id = llm_models[0].id
business_logger.debug(f"Auto-bind llm model: {llm_models[0].id}")
model = _get_model_by_name_or_fallback(workspace.llm, [ModelType.LLM, ModelType.CHAT], "llm")
if model:
knowledge.llm_id = model.id
if not knowledge.image2text_id:
image2text_models = db.query(ModelConfig).filter(