import os import spacy import torch import config import utils import numpy as np import xml.etree.ElementTree as ET from PIL import Image from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset, DataLoader spacy_eng = spacy.load('en_core_web_sm') class Vocabulary: def __init__(self, freq_threshold): self.itos = { 0: '', 1: '', 2: '', 3: '', } self.stoi = { '': 0, '': 1, '': 2, '': 3, } self.freq_threshold = freq_threshold @staticmethod def tokenizer(text): return [tok.text.lower() for tok in spacy_eng.tokenizer(text)] def build_vocabulary(self, sentence_list): frequencies = {} idx = 4 for sent in sentence_list: for word in self.tokenizer(sent): if word not in frequencies: frequencies[word] = 1 else: frequencies[word] += 1 if frequencies[word] == self.freq_threshold: self.stoi[word] = idx self.itos[idx] = word idx += 1 def numericalize(self, text): tokenized_text = self.tokenizer(text) return [ self.stoi[token] if token in self.stoi else self.stoi[''] for token in tokenized_text ] def __len__(self): return len(self.itos) class XRayDataset(Dataset): def __init__(self, root, transform=None, freq_threshold=3, raw_caption=False): self.root = root self.transform = transform self.raw_caption = raw_caption self.vocab = Vocabulary(freq_threshold=freq_threshold) self.captions = [] self.imgs = [] for file in os.listdir(os.path.join(self.root, 'reports')): if file.endswith('.xml'): tree = ET.parse(os.path.join(self.root, 'reports', file)) frontal_img = '' findings = tree.find(".//AbstractText[@Label='FINDINGS']").text if findings is None: continue for x in tree.findall('parentImage'): if frontal_img != '': break img = x.attrib['id'] img = os.path.join(config.IMAGES_DATASET, f'{img}.png') frontal_img = img if frontal_img == '': continue self.captions.append(findings) self.imgs.append(frontal_img) self.vocab.build_vocabulary(self.captions) def __getitem__(self, item): img = self.imgs[item] caption = utils.normalize_text(self.captions[item]) img = np.array(Image.open(img).convert('L')) img = np.expand_dims(img, axis=-1) img = img.repeat(3, axis=-1) if self.transform is not None: img = self.transform(image=img)['image'] if self.raw_caption: return img, caption numericalized_caption = [self.vocab.stoi['']] numericalized_caption += self.vocab.numericalize(caption) numericalized_caption.append(self.vocab.stoi['']) return img, torch.as_tensor(numericalized_caption, dtype=torch.long) def __len__(self): return len(self.captions) def get_caption(self, item): return self.captions[item].split(' ') class CollateDataset: def __init__(self, pad_idx): self.pad_idx = pad_idx def __call__(self, batch): images, captions = zip(*batch) images = torch.stack(images, 0) targets = [item for item in captions] targets = pad_sequence(targets, batch_first=True, padding_value=self.pad_idx) return images, targets if __name__ == '__main__': all_dataset = XRayDataset( root=config.DATASET_PATH, transform=config.basic_transforms, freq_threshold=config.VOCAB_THRESHOLD, ) train_loader = DataLoader( dataset=all_dataset, batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY, drop_last=True, shuffle=True, collate_fn=CollateDataset(pad_idx=all_dataset.vocab.stoi['']), ) for img, caption in train_loader: print(img.shape, caption.shape) break