File size: 4,340 Bytes
6551065 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import json
import re
import numpy as np
from PIL import Image
import torch.utils.data as data
from transformers import BertTokenizer, AutoImageProcessor
class FieldParser:
def __init__(
self,
args
):
super().__init__()
self.args = args
self.dataset = args.dataset
self.vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
def _parse_image(self, img):
pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
return pixel_values[0]
# from https://github.com/cuhksz-nlp/R2Gen/blob/main/modules/tokenizers.py
def clean_report(self, report):
# clean Iu-xray reports
if self.dataset == "iu_xray":
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
# clean MIMIC-CXR reports
else:
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ').replace(':', ' :') \
.strip().lower().split('. ')
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+()\[\]{}]', '', t.replace('"', '').replace('/', '')
.replace('\\', '').replace("'", '').strip().lower())
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
report = ' . '.join(tokens) + ' .'
# report = ' '.join(report.split()[:self.args.max_txt_len])
return report
def parse(self, features):
to_return = {'id': features['id']}
report = features.get("report", "")
report = self.clean_report(report)
to_return['input_text'] = report
# chest x-ray images
images = []
for image_path in features['image_path']:
with Image.open(os.path.join(self.args.base_dir, image_path)) as pil:
array = np.array(pil, dtype=np.uint8)
if array.shape[-1] != 3 or len(array.shape) != 3:
array = np.array(pil.convert("RGB"), dtype=np.uint8)
image = self._parse_image(array)
images.append(image)
to_return["image"] = images
return to_return
def transform_with_parse(self, inputs):
return self.parse(inputs)
class ParseDataset(data.Dataset):
def __init__(self, args, split='train'):
self.args = args
self.meta = json.load(open(args.annotation, 'r'))
self.meta = self.meta[split]
self.parser = FieldParser(args)
def __len__(self):
return len(self.meta)
def __getitem__(self, index):
return self.parser.transform_with_parse(self.meta[index])
def create_datasets(args):
train_dataset = ParseDataset(args, 'train')
dev_dataset = ParseDataset(args, 'val')
test_dataset = ParseDataset(args, 'test')
return train_dataset, dev_dataset, test_dataset
|