Spaces:
Sleeping
Sleeping
import re | |
import torch | |
import config | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from collections import OrderedDict | |
class DenseNet121(nn.Module): | |
def __init__(self, out_size=14, checkpoint=None): | |
super(DenseNet121, self).__init__() | |
self.densenet121 = models.densenet121(weights='DEFAULT') | |
num_classes = self.densenet121.classifier.in_features | |
self.densenet121.classifier = nn.Sequential( | |
nn.Linear(num_classes, out_size), | |
nn.Sigmoid() | |
) | |
if checkpoint is not None: | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
checkpoint = torch.load(checkpoint, map_location=device) | |
state_dict = checkpoint['state_dict'] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if 'module' not in k: | |
k = f'module.{k}' | |
else: | |
k = k.replace('module.densenet121.features', 'features') | |
k = k.replace('module.densenet121.classifier', 'classifier') | |
k = k.replace('.norm.1', '.norm1') | |
k = k.replace('.conv.1', '.conv1') | |
k = k.replace('.norm.2', '.norm2') | |
k = k.replace('.conv.2', '.conv2') | |
new_state_dict[k] = v | |
self.densenet121.load_state_dict(new_state_dict) | |
def forward(self, x): | |
return self.densenet121(x) | |
class EncoderCNN(nn.Module): | |
def __init__(self, checkpoint=None): | |
super(EncoderCNN, self).__init__() | |
self.model = DenseNet121( | |
checkpoint=checkpoint | |
) | |
for param in self.model.densenet121.parameters(): | |
param.requires_grad_(False) | |
def forward(self, images): | |
features = self.model.densenet121.features(images) | |
batch, maps, size_1, size_2 = features.size() | |
features = features.permute(0, 2, 3, 1) | |
features = features.view(batch, size_1 * size_2, maps) | |
return features | |
class Attention(nn.Module): | |
def __init__(self, features_size, hidden_size, output_size=1): | |
super(Attention, self).__init__() | |
self.W = nn.Linear(features_size, hidden_size) | |
self.U = nn.Linear(hidden_size, hidden_size) | |
self.v = nn.Linear(hidden_size, output_size) | |
def forward(self, features, decoder_output): | |
decoder_output = decoder_output.unsqueeze(1) | |
w = self.W(features) | |
u = self.U(decoder_output) | |
scores = self.v(torch.tanh(w + u)) | |
weights = F.softmax(scores, dim=1) | |
context = torch.sum(weights * features, dim=1) | |
weights = weights.squeeze(2) | |
return context, weights | |
class DecoderRNN(nn.Module): | |
def __init__(self, features_size, embed_size, hidden_size, vocab_size): | |
super(DecoderRNN, self).__init__() | |
self.vocab_size = vocab_size | |
self.embedding = nn.Embedding(vocab_size, embed_size) | |
self.lstm = nn.LSTMCell(embed_size + features_size, hidden_size) | |
self.fc = nn.Linear(hidden_size, vocab_size) | |
self.attention = Attention(features_size, hidden_size) | |
self.init_h = nn.Linear(features_size, hidden_size) | |
self.init_c = nn.Linear(features_size, hidden_size) | |
def forward(self, features, captions): | |
embeddings = self.embedding(captions) | |
h, c = self.init_hidden(features) | |
seq_len = len(captions[0]) - 1 | |
features_size = features.size(1) | |
batch_size = captions.size(0) | |
outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(config.DEVICE) | |
atten_weights = torch.zeros(batch_size, seq_len, features_size).to(config.DEVICE) | |
for i in range(seq_len): | |
context, attention = self.attention(features, h) | |
inputs = torch.cat((embeddings[:, i, :], context), dim=1) | |
h, c = self.lstm(inputs, (h, c)) | |
h = F.dropout(h, p=0.5) | |
output = self.fc(h) | |
outputs[:, i, :] = output | |
atten_weights[:, i, :] = attention | |
return outputs, atten_weights | |
def init_hidden(self, features): | |
features = torch.mean(features, dim=1) | |
h = self.init_h(features) | |
c = self.init_c(features) | |
return h, c | |
class EncoderDecoderNet(nn.Module): | |
def __init__(self, features_size, embed_size, hidden_size, vocabulary, encoder_checkpoint=None): | |
super(EncoderDecoderNet, self).__init__() | |
self.vocabulary = vocabulary | |
self.encoder = EncoderCNN( | |
checkpoint=encoder_checkpoint | |
) | |
self.decoder = DecoderRNN( | |
features_size=features_size, | |
embed_size=embed_size, | |
hidden_size=hidden_size, | |
vocab_size=len(self.vocabulary) | |
) | |
def forward(self, images, captions): | |
features = self.encoder(images) | |
outputs, _ = self.decoder(features, captions) | |
return outputs | |
def generate_caption(self, image, max_length=25): | |
caption = [] | |
with torch.no_grad(): | |
features = self.encoder(image) | |
h, c = self.decoder.init_hidden(features) | |
word = torch.tensor(self.vocabulary.stoi['<SOS>']).view(1, -1).to(config.DEVICE) | |
embeddings = self.decoder.embedding(word).squeeze(0) | |
for _ in range(max_length): | |
context, _ = self.decoder.attention(features, h) | |
inputs = torch.cat((embeddings, context), dim=1) | |
h, c = self.decoder.lstm(inputs, (h, c)) | |
output = self.decoder.fc(F.dropout(h, p=0.5)) | |
output = output.view(1, -1) | |
predicted = output.argmax(1) | |
if self.vocabulary.itos[predicted.item()] == '<EOS>': | |
break | |
caption.append(predicted.item()) | |
embeddings = self.decoder.embedding(predicted) | |
return [self.vocabulary.itos[idx] for idx in caption] | |