File size: 1,294 Bytes
201ff2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torchvision
from torchvision import transforms
from torch import nn


def create_ResNetb34_model(num_classes:int=3,seed:int=42):
    """
    Creates an ResNetb34 feature extractor model and transforms.
    :param num_classes: number of classes in classifier head.
                        Defaults to 3.
    :param seed: random seed value.
                 Defaults to 42.
    :return: feature extractor model.
         transforms (torchvision.transforms): ResNetb34 image transforms.
    """
    # 1. Setup pretrained EffNetB1 weights
    weigts = torchvision.models.ResNet34_Weights.DEFAULT
    # 2. Get EffNetB2 transforms
    transform = transforms.Compose([
        weigts.transforms(),
    
    #transforms.RandomHorizontalFlip(),
    ])
    # 3. Setup pretrained model
    model=torchvision.models.resnet34(weights= "DEFAULT")

    # 4. Freeze the base layers in the model (this will freeze all layers to begin with)
    for param in model.parameters():
        param.requires_grad=True

    # 5. Change classifier head with random seed for reproducibility
    torch.manual_seed(seed)
    model.classifier=nn.Sequential(nn.Dropout(p=0.2,inplace=True),
                                   nn.Linear(in_features=612,out_features=num_classes))
    return model,transform