|
import os |
|
import json |
|
|
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class BaseDataset(Dataset): |
|
def __init__(self, args, tokenizer, split, transform=None): |
|
self.image_dir = args.image_dir |
|
self.ann_path = args.ann_path |
|
self.max_seq_length = args.max_seq_length |
|
self.split = split |
|
self.tokenizer = tokenizer |
|
self.transform = transform |
|
self.ann = json.loads(open(self.ann_path, 'r', encoding="utf_8_sig").read()) |
|
|
|
self.examples = self.ann[self.split] |
|
for i in range(len(self.examples)): |
|
self.examples[i]['ids'] = tokenizer(self.examples[i]['finding'])[:self.max_seq_length] |
|
self.examples[i]['mask'] = [1] * len(self.examples[i]['ids']) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
|
|
class MyDataset(BaseDataset): |
|
def __getitem__(self, idx): |
|
example = self.examples[idx] |
|
image_id = example['uid'] |
|
image_path = example['image_path'] |
|
image_1 = Image.open(os.path.join(self.image_dir, image_path[0])).convert('RGB') |
|
image_2 = Image.open(os.path.join(self.image_dir, image_path[1])).convert('RGB') |
|
if self.transform is not None: |
|
image_1 = self.transform(image_1) |
|
image_2 = self.transform(image_2) |
|
image = torch.stack((image_1, image_2), 0) |
|
report_ids = example['ids'] |
|
report_masks = example['mask'] |
|
mesh_label = example['labels'] |
|
seq_length = len(report_ids) |
|
sample = (image_id, image, report_ids, report_masks, seq_length, mesh_label) |
|
return sample |
|
|
|
|
|
|
|
|