import numpy as np import pickle from torch.utils.data import Dataset, DataLoader import os import torch from copy import deepcopy from blimpy import Waterfall from tqdm import tqdm from copy import deepcopy from sigpyproc.readers import FilReader from torch import nn def load_pickled_data(file_path): with open(file_path, 'rb') as f: data = pickle.load(f) return data # Custom dataset class class CustomDataset(Dataset): def __init__(self, data_dir, bit8=False, transform=None): self.data_dir = data_dir self.transform = transform self.images = [] self.labels = [] self.classes = os.listdir(data_dir) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} self.bit8 = bit8 # Load images and labels for cls in self.classes: class_dir = os.path.join(data_dir, cls) for image_name in os.listdir(class_dir): image_path = os.path.join(class_dir, image_name) self.images.append(image_path) self.labels.append(self.class_to_idx[cls]) def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] label = self.labels[idx] # Load image image = load_pickled_data(image_path) if self.transform is not None: if self.bit8 == True: new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) else: new_image = self.transform(torch.from_numpy(image['data'])) # new_image = self.transform(image['data']) return new_image, label # Custom dataset class class CustomDataset_Masked(Dataset): def __init__(self, data_dir, transform=None): self.data_dir = data_dir self.transform = transform self.images = [] self.labels = [] self.classes = os.listdir(data_dir) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} # Load images and labels for cls in self.classes: class_dir = os.path.join(data_dir, cls) for image_name in os.listdir(class_dir): image_path = os.path.join(class_dir, image_name) self.images.append(image_path) self.labels.append(self.class_to_idx[cls]) def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] label = self.labels[idx] # Load image image = load_pickled_data(image_path) if self.transform is not None: if image['burst'].max() ==0: new_burst = torch.from_numpy(image['burst']) else: new_burst = torch.from_numpy(image['burst']/image['burst'].max()) ind = new_burst > 0.1 ind_not = new_burst <= 0.1 new_burst[ind] = 1 new_burst[ind_not] = 0 new_image = self.transform(torch.from_numpy(image['data'].data)) new_burst_arr = torch.zeros_like(new_image) new_burst_arr[ 0, :,:] = new_burst new_burst_arr[ 1, :,:] = new_burst new_burst_arr[ 2, :,:] = new_burst return new_image, label, new_burst_arr # Custom dataset class class TestingDataset(Dataset): def __init__(self, data_dir, bit8=False, transform=None): self.data_dir = data_dir self.transform = transform self.images = [] self.labels = [] self.classes = os.listdir(data_dir) self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} self.bit8 = bit8 # Load images and labels for cls in self.classes: class_dir = os.path.join(data_dir, cls) for image_name in os.listdir(class_dir): image_path = os.path.join(class_dir, image_name) self.images.append(image_path) self.labels.append(self.class_to_idx[cls]) def __len__(self): return len(self.images) def __getitem__(self, idx): image_path = self.images[idx] label = self.labels[idx] # Load image image = load_pickled_data(image_path) params = image['params'] if self.transform is not None: params = image['params'] if self.bit8 == True: new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) else: new_image = self.transform(torch.from_numpy(image['data'])) params['labels'] = label return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard']) # Custom dataset class class SearchDataset(Dataset): def __init__(self, data_dir, transform=None, pickle_data=False): self.window_size = 2048 if pickle_data: with open(data_dir, 'rb') as f: self.d = pickle.load(f) self.header = self.d['header'] self.images = self.crop(self.d['data'][:,0,:], self.window_size) else: self.obs = Waterfall(data_dir, max_load = 50) self.header = self.obs.header self.images = self.crop(self.obs.data[:,0,:], self.window_size) self.transform = transform self.SEC_PER_DAY = 86400 def crop(self, data, window_size = 2048): n_samp = data.shape[0]//window_size new_data = np.zeros((n_samp, window_size, 192 )) for i in range(n_samp): new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] return new_data def __len__(self): return self.images.shape[0] def __getitem__(self, idx): data = self.images[idx, :, :].T tindex = idx * self.window_size time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart'] if self.transform is not None: new_image = self.transform(data) return new_image, idx # Custom dataset class class SearchDataset_Sigproc(Dataset): def __init__(self, data_dir, transform=None): self.window_size = 2048 fil = FilReader(data_dir) self.header = fil.header # print("check shape ",fil.read_block(0, fil.header.nsamples).shape) read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024] read_data = np.swapaxes(read_data, 0,-1) self.images = self.crop(read_data, self.window_size) self.transform = transform self.SEC_PER_DAY = 86400 def crop(self, data, window_size = 2048): n_samp = data.shape[0]//window_size new_data = np.zeros((n_samp, window_size, 192 )) for i in range(n_samp): new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] return new_data def __len__(self): return self.images.shape[0] def __getitem__(self, idx): data = self.images[idx, :, :].T tindex = idx * self.window_size time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart if self.transform is not None: new_image = self.transform(torch.from_numpy(data)) return new_image, idx # def renorm(data): # shifted = data - data.min() # shifted = shifted/shifted.max() # return shifted def renorm(data): mean = torch.mean(data) std = torch.std(data) # Standardize the data standardized_data = (data - mean) / std return standardized_data def transform(data): copy_data = data.detach().clone() rms = torch.std(data) mean = torch.mean(data) masks_rms = [-1, 5] new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1])) new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10)) for i in range(1, len(masks_rms)+1): scale = masks_rms[i-1] copy_data = data.detach().clone() #deepcopy(data) if scale < 0: ind = copy_data < abs(scale) * rms + mean copy_data[ind] = 0 else: ind = copy_data > (scale) * rms + mean copy_data[ind] = 0 new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10)) new_data = new_data.type(torch.float32) slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1 new_data = new_data.view(-1, new_data.size(2), new_data.size(3)) return new_data # class preproc_debug(nn.Module): # def forward(self, x): # template = torch.zeros((32, 24, 192, 256)) # # for i in torch.arange(x.shape[0]): # Use a tensor-based range # template[0,:,:,:] = transform_debug(torch.flip(x[0,:,:], dims = (0,))) # template[1,:,:,:] = transform_debug(torch.flip(x[1,:,:], dims = (0,))) # template[2,:,:,:] = transform_debug(torch.flip(x[2,:,:], dims = (0,))) # template[3,:,:,:] = transform_debug(torch.flip(x[3,:,:], dims = (0,))) # template[4,:,:,:] = transform_debug(torch.flip(x[4,:,:], dims = (0,))) # template[5,:,:,:] = transform_debug(torch.flip(x[5,:,:], dims = (0,))) # template[6,:,:,:] = transform_debug(torch.flip(x[6,:,:], dims = (0,))) # template[7,:,:,:] = transform_debug(torch.flip(x[7,:,:], dims = (0,))) # template[8,:,:,:] = transform_debug(torch.flip(x[8,:,:], dims = (0,))) # template[9,:,:,:] = transform_debug(torch.flip(x[9,:,:], dims = (0,))) # template[10,:,:,:] = transform_debug(torch.flip(x[10,:,:], dims = (0,))) # template[11,:,:,:] = transform_debug(torch.flip(x[11,:,:], dims = (0,))) # template[12,:,:,:] = transform_debug(torch.flip(x[12,:,:], dims = (0,))) # template[13,:,:,:] = transform_debug(torch.flip(x[13,:,:], dims = (0,))) # template[14,:,:,:] = transform_debug(torch.flip(x[14,:,:], dims = (0,))) # template[15,:,:,:] = transform_debug(torch.flip(x[15,:,:], dims = (0,))) # template[16,:,:,:] = transform_debug(torch.flip(x[16,:,:], dims = (0,))) # template[17,:,:,:] = transform_debug(torch.flip(x[17,:,:], dims = (0,))) # template[18,:,:,:] = transform_debug(torch.flip(x[18,:,:], dims = (0,))) # template[19,:,:,:] = transform_debug(torch.flip(x[19,:,:], dims = (0,))) # template[20,:,:,:] = transform_debug(torch.flip(x[20,:,:], dims = (0,))) # template[21,:,:,:] = transform_debug(torch.flip(x[21,:,:], dims = (0,))) # template[22,:,:,:] = transform_debug(torch.flip(x[22,:,:], dims = (0,))) # template[23,:,:,:] = transform_debug(torch.flip(x[23,:,:], dims = (0,))) # template[24,:,:,:] = transform_debug(torch.flip(x[24,:,:], dims = (0,))) # template[25,:,:,:] = transform_debug(torch.flip(x[25,:,:], dims = (0,))) # template[26,:,:,:] = transform_debug(torch.flip(x[26,:,:], dims = (0,))) # template[27,:,:,:] = transform_debug(torch.flip(x[27,:,:], dims = (0,))) # template[28,:,:,:] = transform_debug(torch.flip(x[28,:,:], dims = (0,))) # template[29,:,:,:] = transform_debug(torch.flip(x[29,:,:], dims = (0,))) # template[30,:,:,:] = transform_debug(torch.flip(x[30,:,:], dims = (0,))) # template[31,:,:,:] = transform_debug(torch.flip(x[31,:,:], dims = (0,))) # return template # def transform_debug(data): # copy_data = data.detach().clone() # rms = torch.std(data) # mean = torch.mean(data) # masks_rms = [-1, 5] # new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1])) # new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10)) # for i in range(1, len(masks_rms)+1): # scale = masks_rms[i-1] # copy_data = data.detach().clone() # if scale < 0: # ind = copy_data < abs(scale) * rms + mean # copy_data[ind] = 0 # else: # ind = copy_data > (scale) * rms + mean # copy_data[ind] = 0 # new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10)) # new_data = new_data.type(torch.float32) # slices = torch.chunk(new_data, 8, dim=-1) # dim=1 is the height dimension # new_data = torch.stack(slices, dim=1) # New axis is inserted at dim=1 # new_data = new_data.view(-1, new_data.size(2), new_data.size(3)) # return new_data def renorm_batched(data): mins = torch.amin(data, (-2, -1)) mins = mins.unsqueeze(1).unsqueeze(2) mins = mins.expand(data.shape[0], 192, 2048) shifted = data - mins maxs = torch.amax(shifted, (-2, -1)) maxs = maxs.unsqueeze(1).unsqueeze(2) maxs = maxs.expand(data.shape[0], 192, 2048) shifted = shifted/maxs return shifted def transform_mask(data): copy_data = deepcopy(data) shift = copy_data - copy_data.min() normalized_data = shift / shift.max() new_data = np.zeros((3, data.shape[0], data.shape[1])) for i in range(3): new_data[i,:,:] = normalized_data new_data = new_data.astype(np.float32) return new_data #Function to Convert to ONNX def Convert_ONNX(model, saveloc, input_data_mock): print("Saving to ONNX") # set the model to inference mode model.eval() # Let's create a dummy input tensor dummy_input = torch.autograd.Variable(input_data_mock) # Export the model torch.onnx.export(model, # model being run dummy_input, # model input (or a tuple for multiple inputs) saveloc, # where to save the model input_names = ['modelInput'], # the model's input names output_names = ['modelOutput'], # the model's output names dynamic_axes={'modelInput' : {0 : 'batch_size'}, # variable length axes 'modelOutput' : {0 : 'batch_size'}} ) print(" ") print('Model has been converted to ONNX')