File size: 729 Bytes
efc2ac4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import torch 
import torchvision
from torch import nn 

#Function that creates an effnetb2 

def create_effnet_b2(num_classes: int = 3, 
                     seed: int = 42): 
                      
  #Get weights 
  weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
  transforms = weights.transforms()
  model = torchvision.models.efficientnet_b2(weights = weights) 


  #Freeze parameters in features layer 
  for param in model.parameters(): 
    param.requires_grad = False

  #Change classification layer 
  torch.manual_seed(seed)
  model.classifier = nn.Sequential(
      nn.Dropout(p = .3), 
      nn.Linear(in_features = 1408, 
                out_features = num_classes))
      
  return model, transforms