import re import torch import config import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from collections import OrderedDict class DenseNet121(nn.Module): def __init__(self, out_size=14, checkpoint=None): super(DenseNet121, self).__init__() self.densenet121 = models.densenet121(weights='DEFAULT') num_classes = self.densenet121.classifier.in_features self.densenet121.classifier = nn.Sequential( nn.Linear(num_classes, out_size), nn.Sigmoid() ) if checkpoint is not None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') checkpoint = torch.load(checkpoint, map_location=device) state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if 'module' not in k: k = f'module.{k}' else: k = k.replace('module.densenet121.features', 'features') k = k.replace('module.densenet121.classifier', 'classifier') k = k.replace('.norm.1', '.norm1') k = k.replace('.conv.1', '.conv1') k = k.replace('.norm.2', '.norm2') k = k.replace('.conv.2', '.conv2') new_state_dict[k] = v self.densenet121.load_state_dict(new_state_dict) def forward(self, x): return self.densenet121(x) class EncoderCNN(nn.Module): def __init__(self, checkpoint=None): super(EncoderCNN, self).__init__() self.model = DenseNet121( checkpoint=checkpoint ) for param in self.model.densenet121.parameters(): param.requires_grad_(False) def forward(self, images): features = self.model.densenet121.features(images) batch, maps, size_1, size_2 = features.size() features = features.permute(0, 2, 3, 1) features = features.view(batch, size_1 * size_2, maps) return features class Attention(nn.Module): def __init__(self, features_size, hidden_size, output_size=1): super(Attention, self).__init__() self.W = nn.Linear(features_size, hidden_size) self.U = nn.Linear(hidden_size, hidden_size) self.v = nn.Linear(hidden_size, output_size) def forward(self, features, decoder_output): decoder_output = decoder_output.unsqueeze(1) w = self.W(features) u = self.U(decoder_output) scores = self.v(torch.tanh(w + u)) weights = F.softmax(scores, dim=1) context = torch.sum(weights * features, dim=1) weights = weights.squeeze(2) return context, weights class DecoderRNN(nn.Module): def __init__(self, features_size, embed_size, hidden_size, vocab_size): super(DecoderRNN, self).__init__() self.vocab_size = vocab_size self.embedding = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTMCell(embed_size + features_size, hidden_size) self.fc = nn.Linear(hidden_size, vocab_size) self.attention = Attention(features_size, hidden_size) self.init_h = nn.Linear(features_size, hidden_size) self.init_c = nn.Linear(features_size, hidden_size) def forward(self, features, captions): embeddings = self.embedding(captions) h, c = self.init_hidden(features) seq_len = len(captions[0]) - 1 features_size = features.size(1) batch_size = captions.size(0) outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(config.DEVICE) atten_weights = torch.zeros(batch_size, seq_len, features_size).to(config.DEVICE) for i in range(seq_len): context, attention = self.attention(features, h) inputs = torch.cat((embeddings[:, i, :], context), dim=1) h, c = self.lstm(inputs, (h, c)) h = F.dropout(h, p=0.5) output = self.fc(h) outputs[:, i, :] = output atten_weights[:, i, :] = attention return outputs, atten_weights def init_hidden(self, features): features = torch.mean(features, dim=1) h = self.init_h(features) c = self.init_c(features) return h, c class EncoderDecoderNet(nn.Module): def __init__(self, features_size, embed_size, hidden_size, vocabulary, encoder_checkpoint=None): super(EncoderDecoderNet, self).__init__() self.vocabulary = vocabulary self.encoder = EncoderCNN( checkpoint=encoder_checkpoint ) self.decoder = DecoderRNN( features_size=features_size, embed_size=embed_size, hidden_size=hidden_size, vocab_size=len(self.vocabulary) ) def forward(self, images, captions): features = self.encoder(images) outputs, _ = self.decoder(features, captions) return outputs def generate_caption(self, image, max_length=25): caption = [] with torch.no_grad(): features = self.encoder(image) h, c = self.decoder.init_hidden(features) word = torch.tensor(self.vocabulary.stoi['']).view(1, -1).to(config.DEVICE) embeddings = self.decoder.embedding(word).squeeze(0) for _ in range(max_length): context, _ = self.decoder.attention(features, h) inputs = torch.cat((embeddings, context), dim=1) h, c = self.decoder.lstm(inputs, (h, c)) output = self.decoder.fc(F.dropout(h, p=0.5)) output = output.view(1, -1) predicted = output.argmax(1) if self.vocabulary.itos[predicted.item()] == '': break caption.append(predicted.item()) embeddings = self.decoder.embedding(predicted) return [self.vocabulary.itos[idx] for idx in caption]