import torch from torchvision import models, transforms from PIL import Image import json def load_classes(): with open('utils/imagenet-simple-labels.json') as f: labels = json.load(f) return labels def class_id_to_label(i): labels = load_classes() return labels[i] def load_model(): model = models.mobilenet_v2(pretrained=True) model.eval() return model def transform_image(img): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) return transform(img)