File size: 5,307 Bytes
d6c6696 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import json
import PIL
import timm
import torch
from torchvision import transforms
class ModelInference:
"""Model inference class definition"""
def __init__(
self,
model_path: str,
model_type: str,
category_map_json: str,
categ_to_name_map_json: str,
device: str,
input_size: int = 128,
topk: int = 10,
):
self.device = device
self.topk = topk
self.input_size = input_size
self.model_type = model_type
self.image = None
self.id2categ = self._load_category_map(category_map_json)
self.categ2name = self._load_categ_to_name_map(categ_to_name_map_json)
self.model = self._load_model(model_path, num_classes=len(self.id2categ))
self.model.eval()
def _load_categ_to_name_map(self, categ_to_name_map_json: str):
with open(categ_to_name_map_json, "r") as f:
categ_to_name_map = json.load(f)
return categ_to_name_map
def _load_category_map(self, category_map_json: str):
with open(category_map_json, "r") as f:
categories_map = json.load(f)
id2categ = {categories_map[categ]: categ for categ in categories_map}
return id2categ
def _pad_to_square(self):
"""Padding transformation to make the image square"""
width, height = self.image.size
if height < width:
return transforms.Pad(padding=[0, 0, 0, width - height])
elif height > width:
return transforms.Pad(padding=[0, 0, height - width, 0])
else:
return transforms.Pad(padding=[0, 0, 0, 0])
def get_transforms(self):
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
return transforms.Compose(
[
self._pad_to_square(),
transforms.ToTensor(),
transforms.Resize((self.input_size, self.input_size), antialias=True),
transforms.Normalize(mean, std),
]
)
def _load_model(self, model_path: str, num_classes: int, pretrained: bool = True):
if self.model_type == "resnet50":
model = timm.create_model(
"resnet50", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_resnet50":
model = timm.create_model(
"resnet50", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_convnext-t":
model = timm.create_model(
"convnext_tiny_in22k", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_convnext-b":
model = timm.create_model(
"convnext_base_in22k", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "efficientnetv2-b3":
model = timm.create_model(
"tf_efficientnetv2_b3", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_mobilenetv3large":
model = timm.create_model(
"mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes
)
elif self.model_type == "timm_vit-b16-128":
model = timm.create_model(
"vit_base_patch16_224_in21k",
pretrained=pretrained,
img_size=128,
num_classes=num_classes,
)
else:
raise RuntimeError(f"Model {self.model_type} not implemented")
# Load model weights
model.load_state_dict(
torch.load(model_path, map_location=torch.device(self.device))
)
# Parallelize inference if multiple GPUs available
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
model = model.to(self.device)
return model
def predict(self, image: PIL.Image.Image):
with torch.no_grad():
# Process the image for prediction
self.image = image
transforms = self.get_transforms()
image = transforms(image)
image = image.to(self.device)
image = image.unsqueeze_(0)
# Model prediction on the image
predictions = self.model(image)
predictions = torch.nn.functional.softmax(predictions, dim=1)
predictions = predictions.cpu()
if self.topk == 0 or self.topk > len(
predictions[0]
): # topk=0 means get all predictions
predictions = torch.topk(predictions, len(predictions[0]))
else:
predictions = torch.topk(predictions, self.topk)
# Process the results
values, indices = (
predictions.values.numpy()[0],
predictions.indices.numpy()[0],
)
pred_results = {}
for i in range(len(indices)):
idx, value = indices[i], values[i]
categ = self.id2categ[idx]
sp_name = self.categ2name[categ]
pred_results[sp_name] = value
# pred_results.append([sp_name, round(value*100, 2)])
return pred_results
|