xrayreport / model.py
adithiyyha's picture
Upload 8 files
0faaa54 verified
raw
history blame
6.05 kB
import re
import torch
import config
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from collections import OrderedDict
class DenseNet121(nn.Module):
def __init__(self, out_size=14, checkpoint=None):
super(DenseNet121, self).__init__()
self.densenet121 = models.densenet121(weights='DEFAULT')
num_classes = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
nn.Linear(num_classes, out_size),
nn.Sigmoid()
)
if checkpoint is not None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load(checkpoint, map_location=device)
state_dict = checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' not in k:
k = f'module.{k}'
else:
k = k.replace('module.densenet121.features', 'features')
k = k.replace('module.densenet121.classifier', 'classifier')
k = k.replace('.norm.1', '.norm1')
k = k.replace('.conv.1', '.conv1')
k = k.replace('.norm.2', '.norm2')
k = k.replace('.conv.2', '.conv2')
new_state_dict[k] = v
self.densenet121.load_state_dict(new_state_dict)
def forward(self, x):
return self.densenet121(x)
class EncoderCNN(nn.Module):
def __init__(self, checkpoint=None):
super(EncoderCNN, self).__init__()
self.model = DenseNet121(
checkpoint=checkpoint
)
for param in self.model.densenet121.parameters():
param.requires_grad_(False)
def forward(self, images):
features = self.model.densenet121.features(images)
batch, maps, size_1, size_2 = features.size()
features = features.permute(0, 2, 3, 1)
features = features.view(batch, size_1 * size_2, maps)
return features
class Attention(nn.Module):
def __init__(self, features_size, hidden_size, output_size=1):
super(Attention, self).__init__()
self.W = nn.Linear(features_size, hidden_size)
self.U = nn.Linear(hidden_size, hidden_size)
self.v = nn.Linear(hidden_size, output_size)
def forward(self, features, decoder_output):
decoder_output = decoder_output.unsqueeze(1)
w = self.W(features)
u = self.U(decoder_output)
scores = self.v(torch.tanh(w + u))
weights = F.softmax(scores, dim=1)
context = torch.sum(weights * features, dim=1)
weights = weights.squeeze(2)
return context, weights
class DecoderRNN(nn.Module):
def __init__(self, features_size, embed_size, hidden_size, vocab_size):
super(DecoderRNN, self).__init__()
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTMCell(embed_size + features_size, hidden_size)
self.fc = nn.Linear(hidden_size, vocab_size)
self.attention = Attention(features_size, hidden_size)
self.init_h = nn.Linear(features_size, hidden_size)
self.init_c = nn.Linear(features_size, hidden_size)
def forward(self, features, captions):
embeddings = self.embedding(captions)
h, c = self.init_hidden(features)
seq_len = len(captions[0]) - 1
features_size = features.size(1)
batch_size = captions.size(0)
outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(config.DEVICE)
atten_weights = torch.zeros(batch_size, seq_len, features_size).to(config.DEVICE)
for i in range(seq_len):
context, attention = self.attention(features, h)
inputs = torch.cat((embeddings[:, i, :], context), dim=1)
h, c = self.lstm(inputs, (h, c))
h = F.dropout(h, p=0.5)
output = self.fc(h)
outputs[:, i, :] = output
atten_weights[:, i, :] = attention
return outputs, atten_weights
def init_hidden(self, features):
features = torch.mean(features, dim=1)
h = self.init_h(features)
c = self.init_c(features)
return h, c
class EncoderDecoderNet(nn.Module):
def __init__(self, features_size, embed_size, hidden_size, vocabulary, encoder_checkpoint=None):
super(EncoderDecoderNet, self).__init__()
self.vocabulary = vocabulary
self.encoder = EncoderCNN(
checkpoint=encoder_checkpoint
)
self.decoder = DecoderRNN(
features_size=features_size,
embed_size=embed_size,
hidden_size=hidden_size,
vocab_size=len(self.vocabulary)
)
def forward(self, images, captions):
features = self.encoder(images)
outputs, _ = self.decoder(features, captions)
return outputs
def generate_caption(self, image, max_length=25):
caption = []
with torch.no_grad():
features = self.encoder(image)
h, c = self.decoder.init_hidden(features)
word = torch.tensor(self.vocabulary.stoi['<SOS>']).view(1, -1).to(config.DEVICE)
embeddings = self.decoder.embedding(word).squeeze(0)
for _ in range(max_length):
context, _ = self.decoder.attention(features, h)
inputs = torch.cat((embeddings, context), dim=1)
h, c = self.decoder.lstm(inputs, (h, c))
output = self.decoder.fc(F.dropout(h, p=0.5))
output = output.view(1, -1)
predicted = output.argmax(1)
if self.vocabulary.itos[predicted.item()] == '<EOS>':
break
caption.append(predicted.item())
embeddings = self.decoder.embedding(predicted)
return [self.vocabulary.itos[idx] for idx in caption]