|
import os |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torch.nn.functional as F |
|
import torch.multiprocessing as mp |
|
|
|
class DEFAULTDataset(Dataset): |
|
def __init__(self, root_dir: str, internal_resolution: int): |
|
super().__init__() |
|
self.root_dir = root_dir |
|
self.internal_resolution = internal_resolution |
|
self.file_paths = self.get_data_files() |
|
|
|
def get_data_files(self): |
|
if not os.path.exists(self.root_dir): |
|
raise FileNotFoundError(f"Directory '{self.root_dir}' does not exist.") |
|
|
|
npy_file_names = os.listdir(self.root_dir) |
|
folder_names = [os.path.join(self.root_dir, npy_file_name) |
|
for npy_file_name in npy_file_names if npy_file_name.endswith('.npy')] |
|
return folder_names |
|
|
|
def __len__(self): |
|
return len(self.file_paths) |
|
|
|
def __getitem__(self, idx: int): |
|
filename = self.file_paths[idx] |
|
try: |
|
numpy_file = np.load(filename) |
|
torch_np = torch.from_numpy(numpy_file) |
|
torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() |
|
interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear') |
|
|
|
|
|
interpolated_data_tanh = torch.tanh(interpolated_data) |
|
interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) |
|
|
|
return {'data': interpolated_data_log} |
|
except Exception as e: |
|
print(f"Error loading file '{filename}': {e}") |
|
return None |
|
|