File size: 412 Bytes
019483f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform


CLASSES = ["Healthy", "Resistant", "Susceptible"]


def get_model(model_name):
    model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES))
    config = resolve_data_config({}, model=model)
    image_transform = create_transform(**config)

    return model, image_transform