denoising / datasets.py
Yonuts's picture
gradio demo
12a4d59
raw
history blame
2.45 kB
from pathlib import Path
from typing import Callable, Optional
import os
import torch
from torch.utils.data import Dataset
class Preprocessed_fastMRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
preprocess: bool = False,
) -> None:
self.root = root
self.transform = transform
self.preprocess = preprocess
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
if not self.preprocess:
return img
else:
# remove extension and prefix from filename
fname = Path(fname).stem
return img, fname
class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
) -> None:
self.root = root
self.transform = transform
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # add channel dim