Spaces:
Sleeping
Sleeping
File size: 962 Bytes
ec236ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
from torch import nn
import timm
class EfficientNet(nn.Module):
def __init__(self):
super().__init__()
self.efficientnet = timm.create_model(model_name="efficientnet_b0", pretrained=True, num_classes=25)
"""# Set requires_grad to False for all parameters except the output layer
for name, param in self.efficientnet.named_parameters():
if not name.startswith('classifier'):
param.requires_grad = False"""
# print number of parameters including final layer
trainable_params = sum(p.numel() for p in self.efficientnet.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in self.efficientnet.parameters())
"""print("Efficientnet_b0 with 25 classes initialized")
print(f"Trainable parameters: {trainable_params}")
print(f"Total parameters: {total_params}")"""
def forward(self, x):
return self.efficientnet(x)
|