nabeelraza's picture
Add: initial code
019483f
raw
history blame contribute delete
412 Bytes
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
CLASSES = ["Healthy", "Resistant", "Susceptible"]
def get_model(model_name):
model = timm.create_model(model_name, pretrained=True, num_classes=len(CLASSES))
config = resolve_data_config({}, model=model)
image_transform = create_transform(**config)
return model, image_transform