adithiyyha commited on
Commit
10b84e8
·
verified ·
1 Parent(s): 1c369b2

Delete veeresh

Browse files
veeresh/config.py DELETED
@@ -1,36 +0,0 @@
1
- import albumentations as A
2
- import torch
3
-
4
- from albumentations.pytorch import ToTensorV2
5
-
6
-
7
- CHECKPOINT_FILE = './checkpoints/x_ray_model.pth.tar'
8
- DATASET_PATH = './dataset'
9
- IMAGES_DATASET = './dataset/images'
10
-
11
- DEVICE = 'cpu'
12
- BATCH_SIZE = 16
13
- PIN_MEMORY = False
14
- VOCAB_THRESHOLD = 2
15
-
16
- FEATURES_SIZE = 1024
17
- EMBED_SIZE = 300
18
- HIDDEN_SIZE = 256
19
-
20
- LEARNING_RATE = 4e-5
21
- EPOCHS = 50
22
-
23
- LOAD_MODEL = True
24
- SAVE_MODEL = True
25
-
26
- basic_transforms = A.Compose([
27
- A.Resize(
28
- height=256,
29
- width=256
30
- ),
31
- A.Normalize(
32
- mean=(0.485, 0.456, 0.406),
33
- std=(0.229, 0.224, 0.225),
34
- ),
35
- ToTensorV2()
36
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/dataset.py DELETED
@@ -1,165 +0,0 @@
1
- import os
2
- import spacy
3
- import torch
4
- import config
5
- import utils
6
- import numpy as np
7
- import xml.etree.ElementTree as ET
8
-
9
- from PIL import Image
10
- from torch.nn.utils.rnn import pad_sequence
11
- from torch.utils.data import Dataset, DataLoader
12
-
13
-
14
- spacy_eng = spacy.load('en_core_web_sm')
15
-
16
-
17
- class Vocabulary:
18
- def __init__(self, freq_threshold):
19
- self.itos = {
20
- 0: '<PAD>',
21
- 1: '<SOS>',
22
- 2: '<EOS>',
23
- 3: '<UNK>',
24
- }
25
- self.stoi = {
26
- '<PAD>': 0,
27
- '<SOS>': 1,
28
- '<EOS>': 2,
29
- '<UNK>': 3,
30
- }
31
- self.freq_threshold = freq_threshold
32
-
33
- @staticmethod
34
- def tokenizer(text):
35
- return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
36
-
37
- def build_vocabulary(self, sentence_list):
38
- frequencies = {}
39
- idx = 4
40
-
41
- for sent in sentence_list:
42
- for word in self.tokenizer(sent):
43
- if word not in frequencies:
44
- frequencies[word] = 1
45
- else:
46
- frequencies[word] += 1
47
-
48
- if frequencies[word] == self.freq_threshold:
49
- self.stoi[word] = idx
50
- self.itos[idx] = word
51
-
52
- idx += 1
53
-
54
- def numericalize(self, text):
55
- tokenized_text = self.tokenizer(text)
56
-
57
- return [
58
- self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
59
- for token in tokenized_text
60
- ]
61
-
62
- def __len__(self):
63
- return len(self.itos)
64
-
65
-
66
- class XRayDataset(Dataset):
67
- def __init__(self, root, transform=None, freq_threshold=3, raw_caption=False):
68
- self.root = root
69
- self.transform = transform
70
- self.raw_caption = raw_caption
71
-
72
- self.vocab = Vocabulary(freq_threshold=freq_threshold)
73
-
74
- self.captions = []
75
- self.imgs = []
76
-
77
- for file in os.listdir(os.path.join(self.root, 'reports')):
78
- if file.endswith('.xml'):
79
- tree = ET.parse(os.path.join(self.root, 'reports', file))
80
-
81
- frontal_img = ''
82
- findings = tree.find(".//AbstractText[@Label='FINDINGS']").text
83
-
84
- if findings is None:
85
- continue
86
-
87
- for x in tree.findall('parentImage'):
88
- if frontal_img != '':
89
- break
90
-
91
- img = x.attrib['id']
92
- img = os.path.join(config.IMAGES_DATASET, f'{img}.png')
93
-
94
- frontal_img = img
95
-
96
- if frontal_img == '':
97
- continue
98
-
99
- self.captions.append(findings)
100
- self.imgs.append(frontal_img)
101
-
102
-
103
- self.vocab.build_vocabulary(self.captions)
104
-
105
- def __getitem__(self, item):
106
- img = self.imgs[item]
107
- caption = utils.normalize_text(self.captions[item])
108
-
109
- img = np.array(Image.open(img).convert('L'))
110
- img = np.expand_dims(img, axis=-1)
111
- img = img.repeat(3, axis=-1)
112
-
113
- if self.transform is not None:
114
- img = self.transform(image=img)['image']
115
-
116
- if self.raw_caption:
117
- return img, caption
118
-
119
- numericalized_caption = [self.vocab.stoi['<SOS>']]
120
- numericalized_caption += self.vocab.numericalize(caption)
121
- numericalized_caption.append(self.vocab.stoi['<EOS>'])
122
-
123
- return img, torch.as_tensor(numericalized_caption, dtype=torch.long)
124
-
125
- def __len__(self):
126
- return len(self.captions)
127
-
128
- def get_caption(self, item):
129
- return self.captions[item].split(' ')
130
-
131
-
132
- class CollateDataset:
133
- def __init__(self, pad_idx):
134
- self.pad_idx = pad_idx
135
-
136
- def __call__(self, batch):
137
- images, captions = zip(*batch)
138
-
139
- images = torch.stack(images, 0)
140
-
141
- targets = [item for item in captions]
142
- targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx)
143
-
144
- return images, targets
145
-
146
-
147
- if __name__ == '__main__':
148
- all_dataset = XRayDataset(
149
- root=config.DATASET_PATH,
150
- transform=config.basic_transforms,
151
- freq_threshold=config.VOCAB_THRESHOLD,
152
- )
153
-
154
- train_loader = DataLoader(
155
- dataset=all_dataset,
156
- batch_size=config.BATCH_SIZE,
157
- pin_memory=config.PIN_MEMORY,
158
- drop_last=True,
159
- shuffle=True,
160
- collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['<PAD>']),
161
- )
162
-
163
- for img, caption in train_loader:
164
- print(img.shape, caption.shape)
165
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/eval.py DELETED
@@ -1,62 +0,0 @@
1
- import config
2
- import utils
3
- import numpy as np
4
-
5
- from tqdm import tqdm
6
- from nltk.translate.bleu_score import sentence_bleu
7
-
8
-
9
- def check_accuracy(dataset, model):
10
- print('=> Testing')
11
-
12
- model.eval()
13
-
14
- bleu1_score = []
15
- bleu2_score = []
16
- bleu3_score = []
17
- bleu4_score = []
18
-
19
- for image, caption in tqdm(dataset):
20
- image = image.to(config.DEVICE)
21
-
22
- generated = model.generate_caption(image.unsqueeze(0), max_length=len(caption.split(' ')))
23
-
24
- bleu1_score.append(
25
- sentence_bleu([caption.split()], generated, weights=(1, 0, 0, 0))
26
- )
27
-
28
- bleu2_score.append(
29
- sentence_bleu([caption.split()], generated, weights=(0.5, 0.5, 0, 0))
30
- )
31
-
32
- bleu3_score.append(
33
- sentence_bleu([caption.split()], generated, weights=(0.33, 0.33, 0.33, 0))
34
- )
35
-
36
- bleu4_score.append(
37
- sentence_bleu([caption.split()], generated, weights=(0.25, 0.25, 0.25, 0.25))
38
- )
39
-
40
- print(f'=> BLEU 1: {np.mean(bleu1_score)}')
41
- print(f'=> BLEU 2: {np.mean(bleu2_score)}')
42
- print(f'=> BLEU 3: {np.mean(bleu3_score)}')
43
- print(f'=> BLEU 4: {np.mean(bleu4_score)}')
44
-
45
-
46
- def main():
47
- all_dataset = utils.load_dataset(raw_caption=True)
48
-
49
- model = utils.get_model_instance(all_dataset.vocab)
50
-
51
- utils.load_checkpoint(model)
52
-
53
- _, test_dataset = utils.train_test_split(dataset=all_dataset)
54
-
55
- check_accuracy(
56
- test_dataset,
57
- model
58
- )
59
-
60
-
61
- if __name__ == '__main__':
62
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/gui.py DELETED
@@ -1,92 +0,0 @@
1
- import config
2
- import utils
3
- import numpy as np
4
-
5
- from tkinter import *
6
- from PIL import Image, ImageTk
7
- from tkinter import filedialog
8
-
9
-
10
- label = None
11
- image = None
12
- model = None
13
-
14
- def choose_image():
15
- global label, image
16
-
17
- path = filedialog.askopenfilename(initialdir='images', title='Select Photo')
18
-
19
- screen = Toplevel(root)
20
- screen.title('Report Generator')
21
-
22
- ff1 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
23
- ff1.pack(side=TOP,fill=X)
24
-
25
- ff2 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
26
- ff2.pack(side=TOP, fill=X)
27
-
28
- ff4 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
29
- ff4.pack(side=TOP, fill=X)
30
-
31
- ff3 = Frame(screen, bg='grey', borderwidth=6, relief=GROOVE)
32
- ff3.pack(side=TOP, fill=X)
33
-
34
- Label(ff1, text='Select X-Ray', fg='white', bg='grey', font='Helvetica 16 bold').pack()
35
-
36
- original_img = Image.open(path).convert('L')
37
-
38
- image = np.array(original_img)
39
- image = np.expand_dims(image, axis=-1)
40
- image = image.repeat(3, axis=-1)
41
-
42
- image = config.basic_transforms(image=image)['image']
43
-
44
- photo = ImageTk.PhotoImage(original_img)
45
-
46
- Label(ff2, image=photo).pack()
47
- label = Label(ff4, text='', fg='blue', bg='gray', font='Helvetica 16 bold')
48
- label.pack()
49
-
50
- Button(ff3, text='Generate Report', bg='violet', command=generate_report, height=2, width=20, font='Helvetica 16 bold').pack(side=LEFT)
51
- Button(ff3, text='Quit', bg='red', command=quit_gui, height=2, width=20, font='Helvetica 16 bold').pack()
52
-
53
- screen.bind('<Configure>', lambda event: label.configure(wraplength=label.winfo_width()))
54
- screen.mainloop()
55
-
56
- def generate_report():
57
- global label, image, model
58
-
59
- model.eval()
60
-
61
- image = image.to(config.DEVICE)
62
-
63
- report = model.generate_caption(image.unsqueeze(0), max_length=25)
64
-
65
- label.config(text=report, fg='violet', bg='green', font='Helvetica 16 bold', width=40)
66
- label.update_idletasks()
67
-
68
- def quit_gui():
69
- root.destroy()
70
-
71
- root = Tk()
72
- root.title('Chest X-Ray Report Generator')
73
-
74
- f1 = Frame(root, bg='grey', borderwidth=6, relief=GROOVE)
75
- f1.pack(side=TOP, fill=X)
76
-
77
- f2 = Frame(root, bg='grey', borderwidth=6, relief=GROOVE)
78
- f2.pack(side=TOP, fill=X)
79
-
80
- Label(f1, text='Welcome to Chest X-Ray Report Generator', fg='white', bg='grey', font='Helvetica 16 bold').pack()
81
-
82
- btn1 = Button(root, text='Choose Chest X-Ray', command=choose_image, height=2, width=20, bg='blue', font="Helvetica 16 bold", pady=10)
83
- btn1.pack()
84
-
85
- Button(root, text='Quit', command=quit_gui, height=2, width=20, bg='violet', font='Helvetica 16 bold', pady=10).pack()
86
-
87
- if __name__ == '__main__':
88
- model = utils.get_model_instance(utils.load_dataset().vocab)
89
-
90
- utils.load_checkpoint(model)
91
-
92
- root.mainloop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/inference.py DELETED
@@ -1,88 +0,0 @@
1
- import os
2
- import torch
3
- import config
4
- from utils import (
5
- load_dataset,
6
- get_model_instance,
7
- load_checkpoint,
8
- can_load_checkpoint,
9
- normalize_text,
10
- )
11
- from PIL import Image
12
- import torchvision.transforms as transforms
13
-
14
- # Define device
15
- DEVICE = 'cpu'
16
-
17
- # Define image transformations (adjust based on training setup)
18
- TRANSFORMS = transforms.Compose([
19
- transforms.Resize((224, 224)), # Replace with your model's expected input size
20
- transforms.ToTensor(),
21
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
22
- ])
23
-
24
-
25
- def load_model():
26
- """
27
- Loads the model with the vocabulary and checkpoint.
28
- """
29
- print("Loading dataset and vocabulary...")
30
- dataset = load_dataset() # Load dataset to access vocabulary
31
- vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
32
-
33
- print("Initializing the model...")
34
- model = get_model_instance(vocabulary) # Initialize the model
35
-
36
- if can_load_checkpoint():
37
- print("Loading checkpoint...")
38
- load_checkpoint(model)
39
- else:
40
- print("No checkpoint found, starting with untrained model.")
41
-
42
- model.eval() # Set the model to evaluation mode
43
- print("Model is ready for inference.")
44
- return model
45
-
46
-
47
- def preprocess_image(image_path):
48
- """
49
- Preprocess the input image for the model.
50
- """
51
- print(f"Preprocessing image: {image_path}")
52
- image = Image.open(image_path).convert("RGB") # Ensure RGB format
53
- image = TRANSFORMS(image).unsqueeze(0) # Add batch dimension
54
- return image.to(DEVICE)
55
-
56
-
57
- def generate_report(model, image_path):
58
- """
59
- Generates a report for a given image using the model.
60
- """
61
- image = preprocess_image(image_path)
62
-
63
- print("Generating report...")
64
- with torch.no_grad():
65
- # Assuming the model has a 'generate_caption' method
66
- output = model.generate_caption(image, max_length=25)
67
- report = " ".join(output)
68
-
69
- print(f"Generated report: {report}")
70
- return report
71
-
72
-
73
- if __name__ == "__main__":
74
- # Path to the checkpoint file
75
- CHECKPOINT_PATH = config.CHECKPOINT_FILE # Ensure config.CHECKPOINT_FILE is correctly set
76
-
77
- # Path to the input image
78
- IMAGE_PATH = "./dataset/images/CXR1178_IM-0121-1001.png" # Replace with your image path
79
-
80
- # Load the model
81
- model = load_model()
82
-
83
- # Ensure the image exists before inference
84
- if os.path.exists(IMAGE_PATH):
85
- report = generate_report(model, IMAGE_PATH)
86
- print("Final Report:", report)
87
- else:
88
- print(f"Image not found at path: {IMAGE_PATH}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/model.py DELETED
@@ -1,198 +0,0 @@
1
- import re
2
- import torch
3
- import config
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import torchvision.models as models
7
-
8
- from collections import OrderedDict
9
-
10
-
11
- class DenseNet121(nn.Module):
12
- def __init__(self, out_size=14, checkpoint=None):
13
- super(DenseNet121, self).__init__()
14
-
15
- self.densenet121 = models.densenet121(weights='DEFAULT')
16
- num_classes = self.densenet121.classifier.in_features
17
-
18
- self.densenet121.classifier = nn.Sequential(
19
- nn.Linear(num_classes, out_size),
20
- nn.Sigmoid()
21
- )
22
-
23
- if checkpoint is not None:
24
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
- checkpoint = torch.load(checkpoint, map_location=device)
26
-
27
- state_dict = checkpoint['state_dict']
28
- new_state_dict = OrderedDict()
29
-
30
- for k, v in state_dict.items():
31
- if 'module' not in k:
32
- k = f'module.{k}'
33
- else:
34
- k = k.replace('module.densenet121.features', 'features')
35
- k = k.replace('module.densenet121.classifier', 'classifier')
36
- k = k.replace('.norm.1', '.norm1')
37
- k = k.replace('.conv.1', '.conv1')
38
- k = k.replace('.norm.2', '.norm2')
39
- k = k.replace('.conv.2', '.conv2')
40
-
41
- new_state_dict[k] = v
42
-
43
- self.densenet121.load_state_dict(new_state_dict)
44
-
45
-
46
- def forward(self, x):
47
- return self.densenet121(x)
48
-
49
-
50
- class EncoderCNN(nn.Module):
51
- def __init__(self, checkpoint=None):
52
- super(EncoderCNN, self).__init__()
53
-
54
- self.model = DenseNet121(
55
- checkpoint=checkpoint
56
- )
57
-
58
- for param in self.model.densenet121.parameters():
59
- param.requires_grad_(False)
60
-
61
- def forward(self, images):
62
- features = self.model.densenet121.features(images)
63
-
64
- batch, maps, size_1, size_2 = features.size()
65
-
66
- features = features.permute(0, 2, 3, 1)
67
- features = features.view(batch, size_1 * size_2, maps)
68
-
69
- return features
70
-
71
-
72
- class Attention(nn.Module):
73
- def __init__(self, features_size, hidden_size, output_size=1):
74
- super(Attention, self).__init__()
75
-
76
- self.W = nn.Linear(features_size, hidden_size)
77
- self.U = nn.Linear(hidden_size, hidden_size)
78
- self.v = nn.Linear(hidden_size, output_size)
79
-
80
- def forward(self, features, decoder_output):
81
- decoder_output = decoder_output.unsqueeze(1)
82
-
83
- w = self.W(features)
84
- u = self.U(decoder_output)
85
-
86
- scores = self.v(torch.tanh(w + u))
87
- weights = F.softmax(scores, dim=1)
88
- context = torch.sum(weights * features, dim=1)
89
-
90
- weights = weights.squeeze(2)
91
-
92
- return context, weights
93
-
94
-
95
- class DecoderRNN(nn.Module):
96
- def __init__(self, features_size, embed_size, hidden_size, vocab_size):
97
- super(DecoderRNN, self).__init__()
98
-
99
- self.vocab_size = vocab_size
100
-
101
- self.embedding = nn.Embedding(vocab_size, embed_size)
102
- self.lstm = nn.LSTMCell(embed_size + features_size, hidden_size)
103
-
104
- self.fc = nn.Linear(hidden_size, vocab_size)
105
-
106
- self.attention = Attention(features_size, hidden_size)
107
-
108
- self.init_h = nn.Linear(features_size, hidden_size)
109
- self.init_c = nn.Linear(features_size, hidden_size)
110
-
111
- def forward(self, features, captions):
112
- embeddings = self.embedding(captions)
113
-
114
- h, c = self.init_hidden(features)
115
-
116
- seq_len = len(captions[0]) - 1
117
- features_size = features.size(1)
118
- batch_size = captions.size(0)
119
-
120
- outputs = torch.zeros(batch_size, seq_len, self.vocab_size).to(config.DEVICE)
121
- atten_weights = torch.zeros(batch_size, seq_len, features_size).to(config.DEVICE)
122
-
123
- for i in range(seq_len):
124
- context, attention = self.attention(features, h)
125
-
126
- inputs = torch.cat((embeddings[:, i, :], context), dim=1)
127
-
128
- h, c = self.lstm(inputs, (h, c))
129
- h = F.dropout(h, p=0.5)
130
-
131
- output = self.fc(h)
132
-
133
- outputs[:, i, :] = output
134
- atten_weights[:, i, :] = attention
135
-
136
- return outputs, atten_weights
137
-
138
- def init_hidden(self, features):
139
- features = torch.mean(features, dim=1)
140
-
141
- h = self.init_h(features)
142
- c = self.init_c(features)
143
-
144
- return h, c
145
-
146
-
147
- class EncoderDecoderNet(nn.Module):
148
- def __init__(self, features_size, embed_size, hidden_size, vocabulary, encoder_checkpoint=None):
149
- super(EncoderDecoderNet, self).__init__()
150
-
151
- self.vocabulary = vocabulary
152
-
153
- self.encoder = EncoderCNN(
154
- checkpoint=encoder_checkpoint
155
- )
156
- self.decoder = DecoderRNN(
157
- features_size=features_size,
158
- embed_size=embed_size,
159
- hidden_size=hidden_size,
160
- vocab_size=len(self.vocabulary)
161
- )
162
-
163
- def forward(self, images, captions):
164
- features = self.encoder(images)
165
- outputs, _ = self.decoder(features, captions)
166
-
167
- return outputs
168
-
169
- def generate_caption(self, image, max_length=25):
170
- caption = []
171
-
172
- with torch.no_grad():
173
- features = self.encoder(image)
174
- h, c = self.decoder.init_hidden(features)
175
-
176
- word = torch.tensor(self.vocabulary.stoi['<SOS>']).view(1, -1).to(config.DEVICE)
177
- embeddings = self.decoder.embedding(word).squeeze(0)
178
-
179
- for _ in range(max_length):
180
- context, _ = self.decoder.attention(features, h)
181
-
182
- inputs = torch.cat((embeddings, context), dim=1)
183
-
184
- h, c = self.decoder.lstm(inputs, (h, c))
185
-
186
- output = self.decoder.fc(F.dropout(h, p=0.5))
187
- output = output.view(1, -1)
188
-
189
- predicted = output.argmax(1)
190
-
191
- if self.vocabulary.itos[predicted.item()] == '<EOS>':
192
- break
193
-
194
- caption.append(predicted.item())
195
-
196
- embeddings = self.decoder.embedding(predicted)
197
-
198
- return [self.vocabulary.itos[idx] for idx in caption]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/train.py DELETED
@@ -1,84 +0,0 @@
1
- import config
2
- import utils
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- import numpy as np
6
-
7
- from tqdm import tqdm
8
- from torch.utils.data import DataLoader
9
- from dataset import CollateDataset
10
-
11
-
12
- def train_epoch(loader, model, optimizer, loss_fn, epoch):
13
- model.train()
14
-
15
- losses = []
16
-
17
- loader = tqdm(loader)
18
-
19
- for img, captions in loader:
20
- img = img.to(config.DEVICE)
21
- captions = captions.to(config.DEVICE)
22
-
23
- output = model(img, captions)
24
-
25
- loss = loss_fn(
26
- output.reshape(-1, output.shape[2]),
27
- captions[:, 1:].reshape(-1)
28
- )
29
-
30
- optimizer.zero_grad()
31
- loss.backward()
32
- optimizer.step()
33
-
34
- loader.set_postfix(loss=loss.item())
35
-
36
- losses.append(loss.item())
37
-
38
- if config.SAVE_MODEL:
39
- utils.save_checkpoint({
40
- 'state_dict': model.state_dict(),
41
- 'optimizer': optimizer.state_dict(),
42
- 'epoch': epoch,
43
- 'loss': np.mean(losses)
44
- })
45
-
46
- print(f'Epoch[{epoch}]: Loss {np.mean(losses)}')
47
-
48
-
49
- def main():
50
- all_dataset = utils.load_dataset()
51
-
52
- train_dataset, _ = utils.train_test_split(dataset=all_dataset)
53
-
54
- train_loader = DataLoader(
55
- dataset=train_dataset,
56
- batch_size=config.BATCH_SIZE,
57
- pin_memory=config.PIN_MEMORY,
58
- drop_last=False,
59
- shuffle=True,
60
- collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['<PAD>']),
61
- )
62
-
63
- model = utils.get_model_instance(all_dataset.vocab)
64
-
65
- optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
66
- loss_fn = nn.CrossEntropyLoss(ignore_index=all_dataset.vocab.stoi['<PAD>'])
67
-
68
- starting_epoch = 1
69
-
70
- if utils.can_load_checkpoint():
71
- starting_epoch = utils.load_checkpoint(model, optimizer)
72
-
73
- for epoch in range(starting_epoch, config.EPOCHS):
74
- train_epoch(
75
- train_loader,
76
- model,
77
- optimizer,
78
- loss_fn,
79
- epoch
80
- )
81
-
82
-
83
- if __name__ == '__main__':
84
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
veeresh/utils.py DELETED
@@ -1,111 +0,0 @@
1
- import os
2
- import re
3
- import html
4
- import string
5
- import torch
6
- import config
7
- import unicodedata
8
- from nltk.tokenize import word_tokenize
9
-
10
- from dataset import XRayDataset
11
- from model import EncoderDecoderNet
12
- from torch.utils.data import Subset
13
- from sklearn.model_selection import train_test_split as sklearn_train_test_split
14
-
15
-
16
- def load_dataset(raw_caption=False):
17
- return XRayDataset(
18
- root=config.DATASET_PATH,
19
- transform=config.basic_transforms,
20
- freq_threshold=config.VOCAB_THRESHOLD,
21
- raw_caption=raw_caption
22
- )
23
-
24
-
25
- def get_model_instance(vocabulary):
26
- model = EncoderDecoderNet(
27
- features_size=config.FEATURES_SIZE,
28
- embed_size=config.EMBED_SIZE,
29
- hidden_size=config.HIDDEN_SIZE,
30
- vocabulary=vocabulary,
31
- encoder_checkpoint='./weights/chexnet.pth.tar'
32
- )
33
- model = model.to(config.DEVICE)
34
-
35
- return model
36
-
37
- def train_test_split(dataset, test_size=0.25, random_state=44):
38
- train_idx, test_idx = sklearn_train_test_split(
39
- list(range(len(dataset))),
40
- test_size=test_size,
41
- random_state=random_state
42
- )
43
-
44
- return Subset(dataset, train_idx), Subset(dataset, test_idx)
45
-
46
-
47
- def save_checkpoint(checkpoint):
48
- print('=> Saving checkpoint')
49
-
50
- torch.save(checkpoint, config.CHECKPOINT_FILE)
51
-
52
-
53
- def load_checkpoint(model, optimizer=None):
54
- print('=> Loading checkpoint')
55
-
56
- checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=torch.device('cpu'))
57
- model.load_state_dict(checkpoint['state_dict'])
58
-
59
- if optimizer is not None:
60
- optimizer.load_state_dict(checkpoint['optimizer'])
61
-
62
- return checkpoint['epoch']
63
-
64
-
65
- def can_load_checkpoint():
66
- return os.path.exists(config.CHECKPOINT_FILE) and config.LOAD_MODEL
67
-
68
-
69
- def remove_special_chars(text):
70
- re1 = re.compile(r' +')
71
- x1 = text.lower().replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
72
- 'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
73
- '<br />', "\n").replace('\\"', '"').replace('<unk>', 'u_n').replace(' @.@ ', '.').replace(
74
- ' @-@ ', '-').replace('\\', ' \\ ')
75
-
76
- return re1.sub(' ', html.unescape(x1))
77
-
78
-
79
- def remove_non_ascii(text):
80
- return unicodedata.normalize('NFKD', text).encode('ascii', 'ignore').decode('utf-8', 'ignore')
81
-
82
-
83
- def to_lowercase(text):
84
- return text.lower()
85
-
86
-
87
- def remove_punctuation(text):
88
- translator = str.maketrans('', '', string.punctuation)
89
- return text.translate(translator)
90
-
91
-
92
- def replace_numbers(text):
93
- return re.sub(r'\d+', '', text)
94
-
95
-
96
- def text2words(text):
97
- return word_tokenize(text)
98
-
99
-
100
- def normalize_text( text):
101
- text = remove_special_chars(text)
102
- text = remove_non_ascii(text)
103
- text = remove_punctuation(text)
104
- text = to_lowercase(text)
105
- text = replace_numbers(text)
106
-
107
- return text
108
-
109
-
110
- def normalize_corpus(corpus):
111
- return [normalize_text(t) for t in corpus]