File size: 539 Bytes
da43dfa
 
 
 
 
 
 
 
 
f2fab8d
da43dfa
f2fab8d
da43dfa
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torchvision
from torch import nn


class MobileNetV3Large(nn.Module):

    def __init__(self, ckpt, num_classes) -> None:
        super().__init__()
        self.model = torchvision.models.mobilenet_v3_large(pretrained=False)
        self.model.classifier[3] = nn.Linear(self.model.classifier[3].in_features, num_classes)
        self.model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu')))
        self.model.to('cpu')
        self.model.eval()

    def forward(self, x):
        return self.model(x)