import timm from timm.data import resolve_data_config from timm.data.transforms_factory import create_transform CLASSES = ["Healthy", "Resistant", "Susceptible"] def get_model(model_name): model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES)) config = resolve_data_config({}, model=model) image_transform = create_transform(**config) return model, image_transform