Mila_Global_Moth_Classifier / model_inference.py
adityajain07's picture
Upload folder using huggingface_hub
d6c6696 verified
import json
import PIL
import timm
import torch
from torchvision import transforms
class ModelInference:
"""Model inference class definition"""
def __init__(
self,
model_path: str,
model_type: str,
category_map_json: str,
categ_to_name_map_json: str,
device: str,
input_size: int = 128,
topk: int = 10,
):
self.device = device
self.topk = topk
self.input_size = input_size
self.model_type = model_type
self.image = None
self.id2categ = self._load_category_map(category_map_json)
self.categ2name = self._load_categ_to_name_map(categ_to_name_map_json)
self.model = self._load_model(model_path, num_classes=len(self.id2categ))
self.model.eval()
def _load_categ_to_name_map(self, categ_to_name_map_json: str):
with open(categ_to_name_map_json, "r") as f:
categ_to_name_map = json.load(f)
return categ_to_name_map
def _load_category_map(self, category_map_json: str):
with open(category_map_json, "r") as f:
categories_map = json.load(f)
id2categ = {categories_map[categ]: categ for categ in categories_map}
return id2categ
def _pad_to_square(self):
"""Padding transformation to make the image square"""
width, height = self.image.size
if height < width:
return transforms.Pad(padding=[0, 0, 0, width - height])
elif height > width:
return transforms.Pad(padding=[0, 0, height - width, 0])
else:
return transforms.Pad(padding=[0, 0, 0, 0])
def get_transforms(self):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
return transforms.Compose(
[
self._pad_to_square(),
transforms.ToTensor(),
transforms.Resize((self.input_size, self.input_size), antialias=True),
transforms.Normalize(mean, std),
]
)
def _load_model(self, model_path: str, num_classes: int, pretrained: bool = True):
if self.model_type == "resnet50":
model = timm.create_model(
"resnet50", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_resnet50":
model = timm.create_model(
"resnet50", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_convnext-t":
model = timm.create_model(
"convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_convnext-b":
model = timm.create_model(
"convnext_base_in22k", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "efficientnetv2-b3":
model = timm.create_model(
"tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_mobilenetv3large":
model = timm.create_model(
"mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_vit-b16-128":
model = timm.create_model(
"vit_base_patch16_224_in21k",
pretrained=pretrained,
img_size=128,
num_classes=num_classes,
)
else:
raise RuntimeError(f"Model {self.model_type} not implemented")
# Load model weights
model.load_state_dict(
torch.load(model_path, map_location=torch.device(self.device))
)
# Parallelize inference if multiple GPUs available
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
model = model.to(self.device)
return model
def predict(self, image: PIL.Image.Image):
with torch.no_grad():
# Process the image for prediction
self.image = image
transforms = self.get_transforms()
image = transforms(image)
image = image.to(self.device)
image = image.unsqueeze_(0)
# Model prediction on the image
predictions = self.model(image)
predictions = torch.nn.functional.softmax(predictions, dim=1)
predictions = predictions.cpu()
if self.topk == 0 or self.topk > len(
predictions[0]
): # topk=0 means get all predictions
predictions = torch.topk(predictions, len(predictions[0]))
else:
predictions = torch.topk(predictions, self.topk)
# Process the results
values, indices = (
predictions.values.numpy()[0],
predictions.indices.numpy()[0],
)
pred_results = {}
for i in range(len(indices)):
idx, value = indices[i], values[i]
categ = self.id2categ[idx]
sp_name = self.categ2name[categ]
pred_results[sp_name] = value
# pred_results.append([sp_name, round(value*100, 2)])
return pred_results