File size: 709 Bytes
afb90db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch, torchvision
from torch import nn

def build_effnetb1():
    # weight & model initialization
    effnetb1_weights = torchvision.models.EfficientNet_B1_Weights.DEFAULT
    effnetb1 = torchvision.models.efficientnet_b1(weights=effnetb1_weights)
    effnetb1_transforms = effnetb1_weights.transforms()
    effnetb1.name = "effnetb1"

    # Freeze params
    for params in effnetb1.parameters():
        params.requires_grad = False

    # Edit Classifiers
    effnetb1.classifier = torch.nn.Sequential(
        nn.Dropout(p=0, inplace=True),
        nn.Linear(in_features=1280,
                  out_features=4,
                  bias=True).to("cpu")
    )

    return effnetb1, effnetb1_transforms