smishr-18 commited on
Commit
ce2eaae
·
verified ·
1 Parent(s): f5b6d4a

Upload resnet.py

Browse files
Files changed (1) hide show
  1. resnet.py +30 -0
resnet.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision
3
+
4
+ class Resnet50Flower102(nn.Module):
5
+ def __init__(self, device, pretrained=True, freeze_backbone=True):
6
+ super().__init__()
7
+ self.device = device
8
+
9
+ if pretrained:
10
+ weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
11
+ else:
12
+ weights = None
13
+
14
+ self.model = torchvision.models.resnet50(weights=weights)
15
+
16
+ self.model.fc = nn.Sequential(
17
+ nn.Linear(2048, 1024),
18
+ nn.BatchNorm1d(1024),
19
+ nn.ReLU(),
20
+ nn.Dropout(0.2),
21
+ nn.Linear(1024, 512),
22
+ nn.BatchNorm1d(512),
23
+ nn.ReLU(),
24
+ nn.Dropout(0.2),
25
+ nn.Linear(512, 102),
26
+ )
27
+ self.model.to(device)
28
+
29
+ def forward(self, x):
30
+ return self.model(x)