File size: 1,301 Bytes
d022305
 
 
91d6eae
d022305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f809395
d022305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torchvision
import torch
import functools
import os

from torch import nn
from pathlib import Path

@functools.cache
def create_effnetb2_model(num_class: int):
    """Create a pytorch model for EfficientNetB2.

    Making a EfficientNetB2 as Feature Extractor and also
    can custom output class as need it.

    Args:
        num_class: A number of class, that will be output (head) of model.

    Returns:
        a tuple of (model, transforms) of EfficientNetB2.

    """
    
    # Get weights of ResNet50
    weights_effnetb2 = torchvision.models.EfficientNet_B2_Weights.IMAGENET1K_V1

    # Get transforms used in resnet
    transforms = weights_effnetb2.transforms()

    # making model
    model = torchvision.models.efficientnet_b2(weights=weights_effnetb2)

    # Freeze All layer
    for param in model.parameters():
        param.requires_grad = False

    # Custom Output class
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.3, inplace=True),
        nn.Linear(in_features=1408,
                  out_features=num_class)
    )

    # Load trained weights
    path_model_weights = Path("./src/data") / "44158_3_efficientnet_b2.pth"
    model.load_state_dict(
        torch.load(path_model_weights, map_location=torch.device("cpu"))
    )

    return model, transforms