File size: 412 Bytes
019483f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
CLASSES = ["Healthy", "Resistant", "Susceptible"]
def get_model(model_name):
model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES))
config = resolve_data_config({}, model=model)
image_transform = create_transform(**config)
return model, image_transform
|