File size: 647 Bytes
dbfc835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from torchvision import models, transforms
from PIL import Image
import json
def load_classes():
with open('utils/imagenet-simple-labels.json') as f:
labels = json.load(f)
return labels
def class_id_to_label(i):
labels = load_classes()
return labels[i]
def load_model():
model = models.mobilenet_v2(pretrained=True)
model.eval()
return model
def transform_image(img):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(img)
|