import json import wandb import torch import torchmetrics from torch import nn import pytorch_lightning as pl from torch.nn import functional as F from timm import create_model as create_timm_model from constants import INPUT_IMAGE_SIZE pl.seed_everything(hash("setting random seeds") % 2**32 - 1) class LitMLP(pl.LightningModule): def __init__(self, batch_size, n_classes): super().__init__() self.batch_size = batch_size self.feature_extractor, num_filters = get_feature_extractor() self.classifier = nn.Linear(num_filters, n_classes) self.save_hyperparameters() self.train_acc = torchmetrics.Accuracy() self.valid_acc = torchmetrics.Accuracy() self.test_acc = torchmetrics.Accuracy() self.img_class_map = get_img_class_map() def forward(self, x): self.feature_extractor.eval() with torch.no_grad(): representations = self.feature_extractor(x).flatten(1) x = self.classifier(representations) x = F.log_softmax(x, dim=1) return x def predict_app(self, x): self.eval() _, y_hat = self.forward(x).max(1) return {'class_id': y_hat.item(), 'class_name': self.img_class_map[str(y_hat.item())]} def loss(self, xs, ys): logits = self(xs) loss = F.nll_loss(logits, ys) return logits, loss def training_step(self, batch, batch_idx): xs, ys = batch logits, loss = self.loss(xs, ys) preds = torch.argmax(logits, 1) self.log('train/loss', loss, on_epoch=True) self.train_acc(preds, ys) self.log('train/acc', self.train_acc, on_epoch=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams["lr"]) def test_step(self, batch, batch_idx): xs, ys = batch logits, loss = self.loss(xs, ys) preds = torch.argmax(logits, 1) self.test_acc(preds, ys) self.log("test/loss_epoch", loss, on_step=False, on_epoch=True) self.log("test/acc_epoch", self.test_acc, on_step=False, on_epoch=True) def test_epoch_end(self, test_step_outputs): # args are defined as part of pl API dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)), device=self.device) model_filename = "model_final.onnx" self.to_onnx(model_filename, dummy_input, export_params=True) wandb.save(model_filename) def validation_step(self, batch, batch_idx): xs, ys = batch logits, loss = self.loss(xs, ys) preds = torch.argmax(logits, 1) self.valid_acc(preds, ys) self.log("valid/loss_epoch", loss) self.log('valid/acc_epoch', self.valid_acc) return logits def validation_epoch_end(self, validation_step_outputs): dummy_input = torch.zeros((self.batch_size, *(3, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)), device=self.device) model_filename = f"model_{str(self.global_step).zfill(5)}.onnx" torch.onnx.export(self, dummy_input, 'latest_run' + model_filename, opset_version=11, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} ) wandb.save(model_filename) flattened_logits = torch.flatten(torch.cat(validation_step_outputs)) self.logger.experiment.log( {"valid/logits": wandb.Histogram(flattened_logits.to("cpu")), "global_step": self.global_step}) def get_img_class_map(): with open('index_to_name.json') as f: img_class_map = json.load(f) return img_class_map def get_feature_extractor(): backbone = create_timm_model('resnet50d', pretrained=True) num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] return nn.Sequential(*layers), num_filters