vshirasuna's picture
Move code to 3dgrid_vqgan folder
a4c759f
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torch.multiprocessing as mp
class DEFAULTDataset(Dataset):
def __init__(self, root_dir: str, internal_resolution: int):
super().__init__()
self.root_dir = root_dir
self.internal_resolution = internal_resolution
self.file_paths = self.get_data_files()
def get_data_files(self):
if not os.path.exists(self.root_dir):
raise FileNotFoundError(f"Directory '{self.root_dir}' does not exist.")
npy_file_names = os.listdir(self.root_dir)
folder_names = [os.path.join(self.root_dir, npy_file_name)
for npy_file_name in npy_file_names if npy_file_name.endswith('.npy')]
return folder_names
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx: int):
filename = self.file_paths[idx]
try:
numpy_file = np.load(filename)
torch_np = torch.from_numpy(numpy_file)
torch_np = torch_np.unsqueeze(0).unsqueeze(0).float() # Convert to float and move to appropriate device
interpolated_data = F.interpolate(input=torch_np, size=(self.internal_resolution, self.internal_resolution, self.internal_resolution), mode='trilinear')
# Apply tanh and log operations
interpolated_data_tanh = torch.tanh(interpolated_data)
interpolated_data_log = torch.log(interpolated_data + 1).squeeze(0) # Adding 1 to avoid log(0)
return {'data': interpolated_data_log}
except Exception as e:
print(f"Error loading file '{filename}': {e}")
return None