adithiyyha commited on
Commit
0faaa54
·
verified ·
1 Parent(s): 10b84e8

Upload 8 files

Browse files
Files changed (8) hide show
  1. config.py +36 -0
  2. dataset.py +165 -0
  3. eval.py +62 -0
  4. gui.py +92 -0
  5. inference.py +88 -0
  6. model.py +198 -0
  7. train.py +84 -0
  8. utils.py +111 -0
config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ import torch
3
+
4
+ from albumentations.pytorch import ToTensorV2
5
+
6
+
7
+ CHECKPOINT_FILE = './checkpoints/x_ray_model.pth.tar'
8
+ DATASET_PATH = './dataset'
9
+ IMAGES_DATASET = './dataset/images'
10
+
11
+ DEVICE = 'cpu'
12
+ BATCH_SIZE = 16
13
+ PIN_MEMORY = False
14
+ VOCAB_THRESHOLD = 2
15
+
16
+ FEATURES_SIZE = 1024
17
+ EMBED_SIZE = 300
18
+ HIDDEN_SIZE = 256
19
+
20
+ LEARNING_RATE = 4e-5
21
+ EPOCHS = 50
22
+
23
+ LOAD_MODEL = True
24
+ SAVE_MODEL = True
25
+
26
+ basic_transforms = A.Compose([
27
+ A.Resize(
28
+ height=256,
29
+ width=256
30
+ ),
31
+ A.Normalize(
32
+ mean=(0.485, 0.456, 0.406),
33
+ std=(0.229, 0.224, 0.225),
34
+ ),
35
+ ToTensorV2()
36
+ ])
dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spacy
3
+ import torch
4
+ import config
5
+ import utils
6
+ import numpy as np
7
+ import xml.etree.ElementTree as ET
8
+
9
+ from PIL import Image
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from torch.utils.data import Dataset, DataLoader
12
+
13
+
14
+ spacy_eng = spacy.load('en_core_web_sm')
15
+
16
+
17
+ class Vocabulary:
18
+ def __init__(self, freq_threshold):
19
+ self.itos = {
20
+ 0: '<PAD>',
21
+ 1: '<SOS>',
22
+ 2: '<EOS>',
23
+ 3: '<UNK>',
24
+ }
25
+ self.stoi = {
26
+ '<PAD>': 0,
27
+ '<SOS>': 1,
28
+ '<EOS>': 2,
29
+ '<UNK>': 3,
30
+ }
31
+ self.freq_threshold = freq_threshold
32
+
33
+ @staticmethod
34
+ def tokenizer(text):
35
+ return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
36
+
37
+ def build_vocabulary(self, sentence_list):
38
+ frequencies = {}
39
+ idx = 4
40
+
41
+ for sent in sentence_list:
42
+ for word in self.tokenizer(sent):
43
+ if word not in frequencies:
44
+ frequencies[word] = 1
45
+ else:
46
+ frequencies[word] += 1
47
+
48
+ if frequencies[word] == self.freq_threshold:
49
+ self.stoi[word] = idx
50
+ self.itos[idx] = word
51
+
52
+ idx += 1
53
+
54
+ def numericalize(self, text):
55
+ tokenized_text = self.tokenizer(text)
56
+
57
+ return [
58
+ self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
59
+ for token in tokenized_text
60
+ ]
61
+
62
+ def __len__(self):
63
+ return len(self.itos)
64
+
65
+
66
+ class XRayDataset(Dataset):
67
+ def __init__(self, root, transform=None, freq_threshold=3, raw_caption=False):
68
+ self.root = root
69
+ self.transform = transform
70
+ self.raw_caption = raw_caption
71
+
72
+ self.vocab = Vocabulary(freq_threshold=freq_threshold)
73
+
74
+ self.captions = []
75
+ self.imgs = []
76
+
77
+ for file in os.listdir(os.path.join(self.root, 'reports')):
78
+ if file.endswith('.xml'):
79
+ tree = ET.parse(os.path.join(self.root, 'reports', file))
80
+
81
+ frontal_img = ''
82
+ findings = tree.find(".//AbstractText[@Label='FINDINGS']").text
83
+
84
+ if findings is None:
85
+ continue
86
+
87
+ for x in tree.findall('parentImage'):
88
+ if frontal_img != '':
89
+ break
90
+
91
+ img = x.attrib['id']
92
+ img = os.path.join(config.IMAGES_DATASET, f'{img}.png')
93
+
94
+ frontal_img = img
95
+
96
+ if frontal_img == '':
97
+ continue
98
+
99
+ self.captions.append(findings)
100
+ self.imgs.append(frontal_img)
101
+
102
+
103
+ self.vocab.build_vocabulary(self.captions)
104
+
105
+ def __getitem__(self, item):
106
+ img = self.imgs[item]
107
+ caption = utils.normalize_text(self.captions[item])
108
+
109
+ img = np.array(Image.open(img).convert('L'))
110
+ img = np.expand_dims(img, axis=-1)
111
+ img = img.repeat(3, axis=-1)
112
+
113
+ if self.transform is not None:
114
+ img = self.transform(image=img)['image']
115
+
116
+ if self.raw_caption:
117
+ return img, caption
118
+
119
+ numericalized_caption = [self.vocab.stoi['<SOS>']]
120
+ numericalized_caption += self.vocab.numericalize(caption)
121
+ numericalized_caption.append(self.vocab.stoi['<EOS>'])
122
+
123
+ return img, torch.as_tensor(numericalized_caption, dtype=torch.long)
124
+
125
+ def __len__(self):
126
+ return len(self.captions)
127
+
128
+ def get_caption(self, item):
129
+ return self.captions[item].split(' ')
130
+
131
+
132
+ class CollateDataset:
133
+ def __init__(self, pad_idx):
134
+ self.pad_idx = pad_idx
135
+
136
+ def __call__(self, batch):
137
+ images, captions = zip(*batch)
138
+
139
+ images = torch.stack(images, 0)
140
+
141
+ targets = [item for item in captions]
142
+ targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
143
+
144
+ return images, targets
145
+
146
+
147
+ if __name__ == '__main__':
148
+ all_dataset = XRayDataset(
149
+ root=config.DATASET_PATH,
150
+ transform=config.basic_transforms,
151
+ freq_threshold=config.VOCAB_THRESHOLD,
152
+ )
153
+
154
+ train_loader = DataLoader(
155
+ dataset=all_dataset,
156
+ batch_size=config.BATCH_SIZE,
157
+ pin_memory=config.PIN_MEMORY,
158
+ drop_last=True,
159
+ shuffle=True,
160
+ collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['<PAD>']),
161
+ )
162
+
163
+ for img, caption in train_loader:
164
+ print(img.shape, caption.shape)
165
+ break
eval.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import utils
3
+ import numpy as np
4
+
5
+ from tqdm import tqdm
6
+ from nltk.translate.bleu_score import sentence_bleu
7
+
8
+
9
+ def check_accuracy(dataset, model):
10
+ print('=> Testing')
11
+
12
+ model.eval()
13
+
14
+ bleu1_score = []
15
+ bleu2_score = []
16
+ bleu3_score = []
17
+ bleu4_score = []
18
+
19
+ for image, caption in tqdm(dataset):
20
+ image = image.to(config.DEVICE)
21
+
22
+ generated = model.generate_caption(image.unsqueeze(0), max_length=len(caption.split(' ')))
23
+
24
+ bleu1_score.append(
25
+ sentence_bleu([caption.split()], generated, weights=(1, 0, 0, 0))
26
+ )
27
+
28
+ bleu2_score.append(
29
+ sentence_bleu([caption.split()], generated, weights=(0.5, 0.5, 0, 0))
30
+ )
31
+
32
+ bleu3_score.append(
33
+ sentence_bleu([caption.split()], generated, weights=(0.33, 0.33, 0.33, 0))
34
+ )
35
+
36
+ bleu4_score.append(
37
+ sentence_bleu([caption.split()], generated, weights=(0.25, 0.25, 0.25, 0.25))
38
+ )
39
+
40
+ print(f'=> BLEU 1: {np.mean(bleu1_score)}')
41
+ print(f'=> BLEU 2: {np.mean(bleu2_score)}')
42
+ print(f'=> BLEU 3: {np.mean(bleu3_score)}')
43
+ print(f'=> BLEU 4: {np.mean(bleu4_score)}')
44
+
45
+
46
+ def main():
47
+ all_dataset = utils.load_dataset(raw_caption=True)
48
+
49
+ model = utils.get_model_instance(all_dataset.vocab)
50
+
51
+ utils.load_checkpoint(model)
52
+
53
+ _, test_dataset = utils.train_test_split(dataset=all_dataset)
54
+
55
+ check_accuracy(
56
+ test_dataset,
57
+ model
58
+ )
59
+
60
+
61
+ if __name__ == '__main__':
62
+ main()
gui.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import utils
3
+ import numpy as np
4
+
5
+ from tkinter import *
6
+ from PIL import Image, ImageTk
7
+ from tkinter import filedialog
8
+
9
+
10
+ label = None
11
+ image = None
12
+ model = None
13
+
14
+ def choose_image():
15
+ global label, image
16
+
17
+ path = filedialog.askopenfilename(initialdir='images', title='Select Photo')
18
+
19
+ screen = Toplevel(root)
20
+ screen.title('Report Generator')
21
+
22
+ ff1 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
23
+ ff1.pack(side=TOP,fill=X)
24
+
25
+ ff2 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
26
+ ff2.pack(side=TOP, fill=X)
27
+
28
+ ff4 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
29
+ ff4.pack(side=TOP, fill=X)
30
+
31
+ ff3 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
32
+ ff3.pack(side=TOP, fill=X)
33
+
34
+ Label(ff1, text='Select X-Ray', fg='white', bg='grey', font='Helvetica 16 bold').pack()
35
+
36
+ original_img = Image.open(path).convert('L')
37
+
38
+ image = np.array(original_img)
39
+ image = np.expand_dims(image, axis=-1)
40
+ image = image.repeat(3, axis=-1)
41
+
42
+ image = config.basic_transforms(image=image)['image']
43
+
44
+ photo = ImageTk.PhotoImage(original_img)
45
+
46
+ Label(ff2, image=photo).pack()
47
+ label = Label(ff4, text='', fg='blue', bg='gray', font='Helvetica 16 bold')
48
+ label.pack()
49
+
50
+ Button(ff3, text='Generate Report', bg='violet', command=generate_report, height=2, width=20, font='Helvetica 16 bold').pack(side=LEFT)
51
+ Button(ff3, text='Quit', bg='red', command=quit_gui, height=2, width=20, font='Helvetica 16 bold').pack()
52
+
53
+ screen.bind('<Configure>', lambda event: label.configure(wraplength=label.winfo_width()))
54
+ screen.mainloop()
55
+
56
+ def generate_report():
57
+ global label, image, model
58
+
59
+ model.eval()
60
+
61
+ image = image.to(config.DEVICE)
62
+
63
+ report = model.generate_caption(image.unsqueeze(0), max_length=25)
64
+
65
+ label.config(text=report, fg='violet', bg='green', font='Helvetica 16 bold', width=40)
66
+ label.update_idletasks()
67
+
68
+ def quit_gui():
69
+ root.destroy()
70
+
71
+ root = Tk()
72
+ root.title('Chest X-Ray Report Generator')
73
+
74
+ f1 = Frame(root, bg='grey', borderwidth=6, relief=GROOVE)
75
+ f1.pack(side=TOP, fill=X)
76
+
77
+ f2 = Frame(root, bg='grey', borderwidth=6, relief=GROOVE)
78
+ f2.pack(side=TOP, fill=X)
79
+
80
+ Label(f1, text='Welcome to Chest X-Ray Report Generator', fg='white', bg='grey', font='Helvetica 16 bold').pack()
81
+
82
+ btn1 = Button(root, text='Choose Chest X-Ray', command=choose_image, height=2, width=20, bg='blue', font="Helvetica 16 bold", pady=10)
83
+ btn1.pack()
84
+
85
+ Button(root, text='Quit', command=quit_gui, height=2, width=20, bg='violet', font='Helvetica 16 bold', pady=10).pack()
86
+
87
+ if __name__ == '__main__':
88
+ model = utils.get_model_instance(utils.load_dataset().vocab)
89
+
90
+ utils.load_checkpoint(model)
91
+
92
+ root.mainloop()
inference.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import config
4
+ from utils import (
5
+ load_dataset,
6
+ get_model_instance,
7
+ load_checkpoint,
8
+ can_load_checkpoint,
9
+ normalize_text,
10
+ )
11
+ from PIL import Image
12
+ import torchvision.transforms as transforms
13
+
14
+ # Define device
15
+ DEVICE = 'cpu'
16
+
17
+ # Define image transformations (adjust based on training setup)
18
+ TRANSFORMS = transforms.Compose([
19
+ transforms.Resize((224, 224)), # Replace with your model's expected input size
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
+ ])
23
+
24
+
25
+ def load_model():
26
+ """
27
+ Loads the model with the vocabulary and checkpoint.
28
+ """
29
+ print("Loading dataset and vocabulary...")
30
+ dataset = load_dataset() # Load dataset to access vocabulary
31
+ vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
32
+
33
+ print("Initializing the model...")
34
+ model = get_model_instance(vocabulary) # Initialize the model
35
+
36
+ if can_load_checkpoint():
37
+ print("Loading checkpoint...")
38
+ load_checkpoint(model)
39
+ else:
40
+ print("No checkpoint found, starting with untrained model.")
41
+
42
+ model.eval() # Set the model to evaluation mode
43
+ print("Model is ready for inference.")
44
+ return model
45
+
46
+
47
+ def preprocess_image(image_path):
48
+ """
49
+ Preprocess the input image for the model.
50
+ """
51
+ print(f"Preprocessing image: {image_path}")
52
+ image = Image.open(image_path).convert("RGB") # Ensure RGB format
53
+ image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
54
+ return image.to(DEVICE)
55
+
56
+
57
+ def generate_report(model, image_path):
58
+ """
59
+ Generates a report for a given image using the model.
60
+ """
61
+ image = preprocess_image(image_path)
62
+
63
+ print("Generating report...")
64
+ with torch.no_grad():
65
+ # Assuming the model has a 'generate_caption' method
66
+ output = model.generate_caption(image, max_length=25)
67
+ report = " ".join(output)
68
+
69
+ print(f"Generated report: {report}")
70
+ return report
71
+
72
+
73
+ if __name__ == "__main__":
74
+ # Path to the checkpoint file
75
+ CHECKPOINT_PATH = config.CHECKPOINT_FILE # Ensure config.CHECKPOINT_FILE is correctly set
76
+
77
+ # Path to the input image
78
+ IMAGE_PATH = "./dataset/images/CXR1178_IM-0121-1001.png" # Replace with your image path
79
+
80
+ # Load the model
81
+ model = load_model()
82
+
83
+ # Ensure the image exists before inference
84
+ if os.path.exists(IMAGE_PATH):
85
+ report = generate_report(model, IMAGE_PATH)
86
+ print("Final Report:", report)
87
+ else:
88
+ print(f"Image not found at path: {IMAGE_PATH}")
model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import config
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+
8
+ from collections import OrderedDict
9
+
10
+
11
+ class DenseNet121(nn.Module):
12
+ def __init__(self, out_size=14, checkpoint=None):
13
+ super(DenseNet121, self).__init__()
14
+
15
+ self.densenet121 = models.densenet121(weights='DEFAULT')
16
+ num_classes = self.densenet121.classifier.in_features
17
+
18
+ self.densenet121.classifier = nn.Sequential(
19
+ nn.Linear(num_classes, out_size),
20
+ nn.Sigmoid()
21
+ )
22
+
23
+ if checkpoint is not None:
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ checkpoint = torch.load(checkpoint, map_location=device)
26
+
27
+ state_dict = checkpoint['state_dict']
28
+ new_state_dict = OrderedDict()
29
+
30
+ for k, v in state_dict.items():
31
+ if 'module' not in k:
32
+ k = f'module.{k}'
33
+ else:
34
+ k = k.replace('module.densenet121.features', 'features')
35
+ k = k.replace('module.densenet121.classifier', 'classifier')
36
+ k = k.replace('.norm.1', '.norm1')
37
+ k = k.replace('.conv.1', '.conv1')
38
+ k = k.replace('.norm.2', '.norm2')
39
+ k = k.replace('.conv.2', '.conv2')
40
+
41
+ new_state_dict[k] = v
42
+
43
+ self.densenet121.load_state_dict(new_state_dict)
44
+
45
+
46
+ def forward(self, x):
47
+ return self.densenet121(x)
48
+
49
+
50
+ class EncoderCNN(nn.Module):
51
+ def __init__(self, checkpoint=None):
52
+ super(EncoderCNN, self).__init__()
53
+
54
+ self.model = DenseNet121(
55
+ checkpoint=checkpoint
56
+ )
57
+
58
+ for param in self.model.densenet121.parameters():
59
+ param.requires_grad_(False)
60
+
61
+ def forward(self, images):
62
+ features = self.model.densenet121.features(images)
63
+
64
+ batch, maps, size_1, size_2 = features.size()
65
+
66
+ features = features.permute(0, 2, 3, 1)
67
+ features = features.view(batch, size_1 * size_2, maps)
68
+
69
+ return features
70
+
71
+
72
+ class Attention(nn.Module):
73
+ def __init__(self, features_size, hidden_size, output_size=1):
74
+ super(Attention, self).__init__()
75
+
76
+ self.W = nn.Linear(features_size, hidden_size)
77
+ self.U = nn.Linear(hidden_size, hidden_size)
78
+ self.v = nn.Linear(hidden_size, output_size)
79
+
80
+ def forward(self, features, decoder_output):
81
+ decoder_output = decoder_output.unsqueeze(1)
82
+
83
+ w = self.W(features)
84
+ u = self.U(decoder_output)
85
+
86
+ scores = self.v(torch.tanh(w + u))
87
+ weights = F.softmax(scores, dim=1)
88
+ context = torch.sum(weights * features, dim=1)
89
+
90
+ weights = weights.squeeze(2)
91
+
92
+ return context, weights
93
+
94
+
95
+ class DecoderRNN(nn.Module):
96
+ def __init__(self, features_size, embed_size, hidden_size, vocab_size):
97
+ super(DecoderRNN, self).__init__()
98
+
99
+ self.vocab_size = vocab_size
100
+
101
+ self.embedding = nn.Embedding(vocab_size, embed_size)
102
+ self.lstm = nn.LSTMCell(embed_size + features_size, hidden_size)
103
+
104
+ self.fc = nn.Linear(hidden_size, vocab_size)
105
+
106
+ self.attention = Attention(features_size, hidden_size)
107
+
108
+ self.init_h = nn.Linear(features_size, hidden_size)
109
+ self.init_c = nn.Linear(features_size, hidden_size)
110
+
111
+ def forward(self, features, captions):
112
+ embeddings = self.embedding(captions)
113
+
114
+ h, c = self.init_hidden(features)
115
+
116
+ seq_len = len(captions[0]) - 1
117
+ features_size = features.size(1)
118
+ batch_size = captions.size(0)
119
+
120
+ outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(config.DEVICE)
121
+ atten_weights = torch.zeros(batch_size, seq_len, features_size).to(config.DEVICE)
122
+
123
+ for i in range(seq_len):
124
+ context, attention = self.attention(features, h)
125
+
126
+ inputs = torch.cat((embeddings[:, i, :], context), dim=1)
127
+
128
+ h, c = self.lstm(inputs, (h, c))
129
+ h = F.dropout(h, p=0.5)
130
+
131
+ output = self.fc(h)
132
+
133
+ outputs[:, i, :] = output
134
+ atten_weights[:, i, :] = attention
135
+
136
+ return outputs, atten_weights
137
+
138
+ def init_hidden(self, features):
139
+ features = torch.mean(features, dim=1)
140
+
141
+ h = self.init_h(features)
142
+ c = self.init_c(features)
143
+
144
+ return h, c
145
+
146
+
147
+ class EncoderDecoderNet(nn.Module):
148
+ def __init__(self, features_size, embed_size, hidden_size, vocabulary, encoder_checkpoint=None):
149
+ super(EncoderDecoderNet, self).__init__()
150
+
151
+ self.vocabulary = vocabulary
152
+
153
+ self.encoder = EncoderCNN(
154
+ checkpoint=encoder_checkpoint
155
+ )
156
+ self.decoder = DecoderRNN(
157
+ features_size=features_size,
158
+ embed_size=embed_size,
159
+ hidden_size=hidden_size,
160
+ vocab_size=len(self.vocabulary)
161
+ )
162
+
163
+ def forward(self, images, captions):
164
+ features = self.encoder(images)
165
+ outputs, _ = self.decoder(features, captions)
166
+
167
+ return outputs
168
+
169
+ def generate_caption(self, image, max_length=25):
170
+ caption = []
171
+
172
+ with torch.no_grad():
173
+ features = self.encoder(image)
174
+ h, c = self.decoder.init_hidden(features)
175
+
176
+ word = torch.tensor(self.vocabulary.stoi['<SOS>']).view(1, -1).to(config.DEVICE)
177
+ embeddings = self.decoder.embedding(word).squeeze(0)
178
+
179
+ for _ in range(max_length):
180
+ context, _ = self.decoder.attention(features, h)
181
+
182
+ inputs = torch.cat((embeddings, context), dim=1)
183
+
184
+ h, c = self.decoder.lstm(inputs, (h, c))
185
+
186
+ output = self.decoder.fc(F.dropout(h, p=0.5))
187
+ output = output.view(1, -1)
188
+
189
+ predicted = output.argmax(1)
190
+
191
+ if self.vocabulary.itos[predicted.item()] == '<EOS>':
192
+ break
193
+
194
+ caption.append(predicted.item())
195
+
196
+ embeddings = self.decoder.embedding(predicted)
197
+
198
+ return [self.vocabulary.itos[idx] for idx in caption]
train.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import utils
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import numpy as np
6
+
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from dataset import CollateDataset
10
+
11
+
12
+ def train_epoch(loader, model, optimizer, loss_fn, epoch):
13
+ model.train()
14
+
15
+ losses = []
16
+
17
+ loader = tqdm(loader)
18
+
19
+ for img, captions in loader:
20
+ img = img.to(config.DEVICE)
21
+ captions = captions.to(config.DEVICE)
22
+
23
+ output = model(img, captions)
24
+
25
+ loss = loss_fn(
26
+ output.reshape(-1, output.shape[2]),
27
+ captions[:, 1:].reshape(-1)
28
+ )
29
+
30
+ optimizer.zero_grad()
31
+ loss.backward()
32
+ optimizer.step()
33
+
34
+ loader.set_postfix(loss=loss.item())
35
+
36
+ losses.append(loss.item())
37
+
38
+ if config.SAVE_MODEL:
39
+ utils.save_checkpoint({
40
+ 'state_dict': model.state_dict(),
41
+ 'optimizer': optimizer.state_dict(),
42
+ 'epoch': epoch,
43
+ 'loss': np.mean(losses)
44
+ })
45
+
46
+ print(f'Epoch[{epoch}]: Loss {np.mean(losses)}')
47
+
48
+
49
+ def main():
50
+ all_dataset = utils.load_dataset()
51
+
52
+ train_dataset, _ = utils.train_test_split(dataset=all_dataset)
53
+
54
+ train_loader = DataLoader(
55
+ dataset=train_dataset,
56
+ batch_size=config.BATCH_SIZE,
57
+ pin_memory=config.PIN_MEMORY,
58
+ drop_last=False,
59
+ shuffle=True,
60
+ collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['<PAD>']),
61
+ )
62
+
63
+ model = utils.get_model_instance(all_dataset.vocab)
64
+
65
+ optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
66
+ loss_fn = nn.CrossEntropyLoss(ignore_index=all_dataset.vocab.stoi['<PAD>'])
67
+
68
+ starting_epoch = 1
69
+
70
+ if utils.can_load_checkpoint():
71
+ starting_epoch = utils.load_checkpoint(model, optimizer)
72
+
73
+ for epoch in range(starting_epoch, config.EPOCHS):
74
+ train_epoch(
75
+ train_loader,
76
+ model,
77
+ optimizer,
78
+ loss_fn,
79
+ epoch
80
+ )
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import html
4
+ import string
5
+ import torch
6
+ import config
7
+ import unicodedata
8
+ from nltk.tokenize import word_tokenize
9
+
10
+ from dataset import XRayDataset
11
+ from model import EncoderDecoderNet
12
+ from torch.utils.data import Subset
13
+ from sklearn.model_selection import train_test_split as sklearn_train_test_split
14
+
15
+
16
+ def load_dataset(raw_caption=False):
17
+ return XRayDataset(
18
+ root=config.DATASET_PATH,
19
+ transform=config.basic_transforms,
20
+ freq_threshold=config.VOCAB_THRESHOLD,
21
+ raw_caption=raw_caption
22
+ )
23
+
24
+
25
+ def get_model_instance(vocabulary):
26
+ model = EncoderDecoderNet(
27
+ features_size=config.FEATURES_SIZE,
28
+ embed_size=config.EMBED_SIZE,
29
+ hidden_size=config.HIDDEN_SIZE,
30
+ vocabulary=vocabulary,
31
+ encoder_checkpoint='./weights/chexnet.pth.tar'
32
+ )
33
+ model = model.to(config.DEVICE)
34
+
35
+ return model
36
+
37
+ def train_test_split(dataset, test_size=0.25, random_state=44):
38
+ train_idx, test_idx = sklearn_train_test_split(
39
+ list(range(len(dataset))),
40
+ test_size=test_size,
41
+ random_state=random_state
42
+ )
43
+
44
+ return Subset(dataset, train_idx), Subset(dataset, test_idx)
45
+
46
+
47
+ def save_checkpoint(checkpoint):
48
+ print('=> Saving checkpoint')
49
+
50
+ torch.save(checkpoint, config.CHECKPOINT_FILE)
51
+
52
+
53
+ def load_checkpoint(model, optimizer=None):
54
+ print('=> Loading checkpoint')
55
+
56
+ checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=torch.device('cpu'))
57
+ model.load_state_dict(checkpoint['state_dict'])
58
+
59
+ if optimizer is not None:
60
+ optimizer.load_state_dict(checkpoint['optimizer'])
61
+
62
+ return checkpoint['epoch']
63
+
64
+
65
+ def can_load_checkpoint():
66
+ return os.path.exists(config.CHECKPOINT_FILE) and config.LOAD_MODEL
67
+
68
+
69
+ def remove_special_chars(text):
70
+ re1 = re.compile(r' +')
71
+ x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
72
+ 'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
73
+ '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace(
74
+ ' @-@ ', '-').replace('\\', ' \\ ')
75
+
76
+ return re1.sub(' ', html.unescape(x1))
77
+
78
+
79
+ def remove_non_ascii(text):
80
+ return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')
81
+
82
+
83
+ def to_lowercase(text):
84
+ return text.lower()
85
+
86
+
87
+ def remove_punctuation(text):
88
+ translator = str.maketrans('', '', string.punctuation)
89
+ return text.translate(translator)
90
+
91
+
92
+ def replace_numbers(text):
93
+ return re.sub(r'\d+', '', text)
94
+
95
+
96
+ def text2words(text):
97
+ return word_tokenize(text)
98
+
99
+
100
+ def normalize_text( text):
101
+ text = remove_special_chars(text)
102
+ text = remove_non_ascii(text)
103
+ text = remove_punctuation(text)
104
+ text = to_lowercase(text)
105
+ text = replace_numbers(text)
106
+
107
+ return text
108
+
109
+
110
+ def normalize_corpus(corpus):
111
+ return [normalize_text(t) for t in corpus]