Spaces:
Sleeping
Sleeping
File size: 2,449 Bytes
12a4d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from pathlib import Path
from typing import Callable, Optional
import os
import torch
from torch.utils.data import Dataset
class Preprocessed_fastMRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
preprocess: bool = False,
) -> None:
self.root = root
self.transform = transform
self.preprocess = preprocess
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
if not self.preprocess:
return img
else:
# remove extension and prefix from filename
fname = Path(fname).stem
return img, fname
class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
) -> None:
self.root = root
self.transform = transform
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # add channel dim
|