from pathlib import Path from typing import Callable, Optional import os import torch from torch.utils.data import Dataset from PIL import Image class Preprocessed_fastMRI(torch.utils.data.Dataset): """FastMRI from preprocessed data for faster lading.""" def __init__( self, root: str, transform: Optional[Callable] = None, preprocess: bool = False, ) -> None: self.root = root self.transform = transform self.preprocess = preprocess # should contain all the information to load a data sample from the storage self.sample_identifiers = [] # append all filenames in self.root ending with .pt for root, _, files in os.walk(self.root): for file in files: if file.endswith(".pt"): self.sample_identifiers.append(file) def __len__(self) -> int: return len(self.sample_identifiers) def __getitem__(self, idx: int): fname = self.sample_identifiers[idx] tensor = torch.load(os.path.join(self.root, fname), weights_only=True) img = tensor['data'].float() if self.transform is not None: img = self.transform(img) if not self.preprocess: return img else: # remove extension and prefix from filename fname = Path(fname).stem return img, fname class Preprocessed_LIDCIDRI(torch.utils.data.Dataset): """FastMRI from preprocessed data for faster lading.""" def __init__( self, root: str, transform: Optional[Callable] = None, ) -> None: self.root = root self.transform = transform # should contain all the information to load a data sample from the storage self.sample_identifiers = [] # append all filenames in self.root ending with .pt for root, _, files in os.walk(self.root): for file in files: if file.endswith(".pt"): self.sample_identifiers.append(file) def __len__(self) -> int: return len(self.sample_identifiers) def __getitem__(self, idx: int): fname = self.sample_identifiers[idx] tensor = torch.load(os.path.join(self.root, fname), weights_only=True) img = tensor['data'].float() if self.transform is not None: img = self.transform(img) img = img.unsqueeze(0) # add channel dim return img class LsdirMiniDataset(torch.utils.data.Dataset): def __init__( self, root: str, transform: Optional[Callable] = None, ) -> None: self.root = root self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.JPEG'))] self.transform = transform def __len__(self) -> int: return len(self.image_files) def __getitem__(self, idx): img_path = os.path.join(self.root, self.image_files[idx]) img = Image.open(img_path).convert("RGB") # Ensure consistent 3-channel format if self.transform: img = self.transform(img) return img