Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from typing import Callable, Optional | |
import os | |
import torch | |
from torch.utils.data import Dataset | |
class Preprocessed_fastMRI(torch.utils.data.Dataset): | |
"""FastMRI from preprocessed data for faster lading.""" | |
def __init__( | |
self, | |
root: str, | |
transform: Optional[Callable] = None, | |
preprocess: bool = False, | |
) -> None: | |
self.root = root | |
self.transform = transform | |
self.preprocess = preprocess | |
# should contain all the information to load a data sample from the storage | |
self.sample_identifiers = [] | |
# append all filenames in self.root ending with .pt | |
for root, _, files in os.walk(self.root): | |
for file in files: | |
if file.endswith(".pt"): | |
self.sample_identifiers.append(file) | |
def __len__(self) -> int: | |
return len(self.sample_identifiers) | |
def __getitem__(self, idx: int): | |
fname = self.sample_identifiers[idx] | |
tensor = torch.load(os.path.join(self.root, fname), weights_only=True) | |
img = tensor['data'].float() | |
if self.transform is not None: | |
img = self.transform(img) | |
if not self.preprocess: | |
return img | |
else: | |
# remove extension and prefix from filename | |
fname = Path(fname).stem | |
return img, fname | |
class Preprocessed_LIDCIDRI(torch.utils.data.Dataset): | |
"""FastMRI from preprocessed data for faster lading.""" | |
def __init__( | |
self, | |
root: str, | |
transform: Optional[Callable] = None, | |
) -> None: | |
self.root = root | |
self.transform = transform | |
# should contain all the information to load a data sample from the storage | |
self.sample_identifiers = [] | |
# append all filenames in self.root ending with .pt | |
for root, _, files in os.walk(self.root): | |
for file in files: | |
if file.endswith(".pt"): | |
self.sample_identifiers.append(file) | |
def __len__(self) -> int: | |
return len(self.sample_identifiers) | |
def __getitem__(self, idx: int): | |
fname = self.sample_identifiers[idx] | |
tensor = torch.load(os.path.join(self.root, fname), weights_only=True) | |
img = tensor['data'].float() | |
if self.transform is not None: | |
img = self.transform(img) | |
img = img.unsqueeze(0) # add channel dim | |