import timm | |
from timm.data import resolve_data_config | |
from timm.data.transforms_factory import create_transform | |
CLASSES = ["Healthy", "Resistant", "Susceptible"] | |
def get_model(model_name): | |
model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES)) | |
config = resolve_data_config({}, model=model) | |
image_transform = create_transform(**config) | |
return model, image_transform | |