pytholic commited on
Commit
35437ac
·
1 Parent(s): e60475d

model.py added

Browse files
Files changed (1) hide show
  1. model.py +97 -0
model.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import torchmetrics
4
+ import torchvision.models as models
5
+ from simple_parsing import ArgumentParser
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from config.args import Args
10
+
11
+ parser = ArgumentParser()
12
+ parser.add_arguments(Args, dest="options")
13
+ args_namespace = parser.parse_args()
14
+ args = args_namespace.options
15
+
16
+ # Model class
17
+ class Model(nn.Module):
18
+ def __init__(self, input_shape, weights=args.weights):
19
+ super().__init__()
20
+
21
+ self.feature_extractor = models.resnet18(weights=weights)
22
+
23
+ if weights:
24
+ # layers are frozen by using eval()
25
+ self.feature_extractor.eval()
26
+ # freeze params
27
+ for param in self.feature_extractor.parameters():
28
+ param.requires_grad = False
29
+
30
+ n_size = self._get_conv_output(input_shape)
31
+
32
+ self.classifier = nn.Linear(n_size, args.num_classes)
33
+
34
+ # returns the size of the output tensor going into the Linear layer from the conv block.
35
+ def _get_conv_output(self, shape):
36
+ batch_size = 1
37
+ tmp_input = torch.autograd.Variable(torch.rand(batch_size, *shape))
38
+
39
+ output_feat = self.convs(tmp_input)
40
+ n_size = output_feat.data.view(batch_size, -1).size(1)
41
+ return n_size
42
+
43
+ def convs(self, x):
44
+ x = self.feature_extractor(x)
45
+ return x
46
+
47
+ def forward(self, x):
48
+
49
+ x = self.convs(x)
50
+ x = x.view(x.size(0), -1)
51
+ x = self.classifier(x)
52
+ return x
53
+
54
+
55
+ class Classifier(pl.LightningModule):
56
+ def __init__(self):
57
+ super().__init__()
58
+
59
+ self.model = Model(input_shape=args.input_shape)
60
+ self.accuracy = torchmetrics.Accuracy(
61
+ task="multiclass", num_classes=args.num_classes
62
+ )
63
+
64
+ def forward(self, x):
65
+ x = self.model(x)
66
+ return x
67
+
68
+ def ce_loss(self, logits, labels):
69
+ return F.cross_entropy(logits, labels)
70
+
71
+ def training_step(self, train_batch, batch_idx):
72
+ x, y = train_batch
73
+ logits = self.model(x)
74
+ loss = self.ce_loss(logits, y)
75
+ acc = self.accuracy(logits, y)
76
+ self.log("accuracy/train_accuracy", acc)
77
+ self.log("loss/train_loss", loss)
78
+ return loss
79
+
80
+ def validation_step(self, val_batch, batch_idx):
81
+ x, y = val_batch
82
+ logits = self.model(x)
83
+ loss = self.ce_loss(logits, y)
84
+ acc = self.accuracy(logits, y)
85
+ self.log("accuracy/val_accuracy", acc)
86
+ self.log("loss/val_loss", loss)
87
+
88
+ def configure_optimizers(self):
89
+ optimizer = torch.optim.Adam(self.parameters(), lr=args.learning_rate)
90
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
91
+ optimizer, mode="min", patience=7
92
+ )
93
+ return {
94
+ "optimizer": optimizer,
95
+ "lr_scheduler": scheduler,
96
+ "monitor": "loss/val_loss",
97
+ }