File size: 962 Bytes
ec236ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch import nn
import timm


class EfficientNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.efficientnet = timm.create_model(model_name="efficientnet_b0", pretrained=True, num_classes=25)

        """# Set requires_grad to False for all parameters except the output layer
            for name, param in self.efficientnet.named_parameters():
                if not name.startswith('classifier'):
                    param.requires_grad = False"""
        # print number of parameters including final layer
        trainable_params = sum(p.numel() for p in self.efficientnet.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.efficientnet.parameters())
        """print("Efficientnet_b0 with 25 classes initialized")
        print(f"Trainable parameters: {trainable_params}")
        print(f"Total parameters: {total_params}")"""

    def forward(self, x):
        return self.efficientnet(x)