|
import numpy as np |
|
import pickle |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import torch |
|
from copy import deepcopy |
|
from blimpy import Waterfall |
|
from tqdm import tqdm |
|
from copy import deepcopy |
|
from sigpyproc.readers import FilReader |
|
from torch import nn |
|
|
|
|
|
def load_pickled_data(file_path): |
|
with open(file_path, 'rb') as f: |
|
data = pickle.load(f) |
|
return data |
|
|
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, data_dir, bit8=False, transform=None): |
|
self.data_dir = data_dir |
|
self.transform = transform |
|
self.images = [] |
|
self.labels = [] |
|
self.classes = os.listdir(data_dir) |
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
|
self.bit8 = bit8 |
|
|
|
for cls in self.classes: |
|
class_dir = os.path.join(data_dir, cls) |
|
for image_name in os.listdir(class_dir): |
|
image_path = os.path.join(class_dir, image_name) |
|
self.images.append(image_path) |
|
self.labels.append(self.class_to_idx[cls]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.images[idx] |
|
label = self.labels[idx] |
|
|
|
image = load_pickled_data(image_path) |
|
if self.transform is not None: |
|
if self.bit8 == True: |
|
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
|
else: |
|
new_image = self.transform(torch.from_numpy(image['data'])) |
|
|
|
return new_image, label |
|
|
|
|
|
class CustomDataset_Masked(Dataset): |
|
def __init__(self, data_dir, transform=None): |
|
self.data_dir = data_dir |
|
self.transform = transform |
|
self.images = [] |
|
self.labels = [] |
|
self.classes = os.listdir(data_dir) |
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
|
|
|
|
|
for cls in self.classes: |
|
class_dir = os.path.join(data_dir, cls) |
|
for image_name in os.listdir(class_dir): |
|
image_path = os.path.join(class_dir, image_name) |
|
self.images.append(image_path) |
|
self.labels.append(self.class_to_idx[cls]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.images[idx] |
|
|
|
label = self.labels[idx] |
|
|
|
image = load_pickled_data(image_path) |
|
if self.transform is not None: |
|
if image['burst'].max() ==0: |
|
new_burst = torch.from_numpy(image['burst']) |
|
else: |
|
new_burst = torch.from_numpy(image['burst']/image['burst'].max()) |
|
ind = new_burst > 0.1 |
|
ind_not = new_burst <= 0.1 |
|
new_burst[ind] = 1 |
|
new_burst[ind_not] = 0 |
|
new_image = self.transform(torch.from_numpy(image['data'].data)) |
|
new_burst_arr = torch.zeros_like(new_image) |
|
new_burst_arr[ 0, :,:] = new_burst |
|
new_burst_arr[ 1, :,:] = new_burst |
|
new_burst_arr[ 2, :,:] = new_burst |
|
return new_image, label, new_burst_arr |
|
|
|
|
|
class TestingDataset(Dataset): |
|
def __init__(self, data_dir, bit8=False, transform=None): |
|
self.data_dir = data_dir |
|
self.transform = transform |
|
self.images = [] |
|
self.labels = [] |
|
self.classes = os.listdir(data_dir) |
|
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} |
|
self.bit8 = bit8 |
|
|
|
for cls in self.classes: |
|
class_dir = os.path.join(data_dir, cls) |
|
for image_name in os.listdir(class_dir): |
|
image_path = os.path.join(class_dir, image_name) |
|
self.images.append(image_path) |
|
self.labels.append(self.class_to_idx[cls]) |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
image_path = self.images[idx] |
|
label = self.labels[idx] |
|
|
|
image = load_pickled_data(image_path) |
|
params = image['params'] |
|
if self.transform is not None: |
|
params = image['params'] |
|
if self.bit8 == True: |
|
new_image = self.transform(torch.from_numpy(image['8_data']).type(torch.float32)) |
|
else: |
|
new_image = self.transform(torch.from_numpy(image['data'])) |
|
params['labels'] = label |
|
return new_image, (label, params['dm'], params['freq_ref'], params['snr'], params['boxcard']) |
|
|
|
|
|
class SearchDataset(Dataset): |
|
def __init__(self, data_dir, transform=None, pickle_data=False): |
|
self.window_size = 2048 |
|
|
|
if pickle_data: |
|
with open(data_dir, 'rb') as f: |
|
self.d = pickle.load(f) |
|
self.header = self.d['header'] |
|
self.images = self.crop(self.d['data'][:,0,:], self.window_size) |
|
else: |
|
self.obs = Waterfall(data_dir, max_load = 50) |
|
self.header = self.obs.header |
|
self.images = self.crop(self.obs.data[:,0,:], self.window_size) |
|
self.transform = transform |
|
self.SEC_PER_DAY = 86400 |
|
|
|
def crop(self, data, window_size = 2048): |
|
n_samp = data.shape[0]//window_size |
|
new_data = np.zeros((n_samp, window_size, 192 )) |
|
for i in range(n_samp): |
|
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
|
return new_data |
|
|
|
def __len__(self): |
|
return self.images.shape[0] |
|
def __getitem__(self, idx): |
|
data = self.images[idx, :, :].T |
|
tindex = idx * self.window_size |
|
time = self.header['tsamp'] * tindex / self.SEC_PER_DAY + self.header['tstart'] |
|
if self.transform is not None: |
|
new_image = self.transform(data) |
|
return new_image, idx |
|
|
|
|
|
class SearchDataset_Sigproc(Dataset): |
|
def __init__(self, data_dir, transform=None): |
|
self.window_size = 2048 |
|
fil = FilReader(data_dir) |
|
self.header = fil.header |
|
|
|
read_data = fil.read_block(0, fil.header.nsamples)[:,1024:-1024] |
|
read_data = np.swapaxes(read_data, 0,-1) |
|
self.images = self.crop(read_data, self.window_size) |
|
self.transform = transform |
|
self.SEC_PER_DAY = 86400 |
|
|
|
def crop(self, data, window_size = 2048): |
|
n_samp = data.shape[0]//window_size |
|
new_data = np.zeros((n_samp, window_size, 192 )) |
|
for i in range(n_samp): |
|
new_data[i, :,:] = data[ i*window_size : (i+1)*window_size, :] |
|
return new_data |
|
|
|
def __len__(self): |
|
return self.images.shape[0] |
|
|
|
def __getitem__(self, idx): |
|
data = self.images[idx, :, :].T |
|
tindex = idx * self.window_size |
|
time = self.header.tsamp * tindex / self.SEC_PER_DAY + self.header.tstart |
|
if self.transform is not None: |
|
new_image = self.transform(torch.from_numpy(data)) |
|
return new_image, idx |
|
|
|
|
|
|
|
|
|
|
|
|
|
def renorm(data): |
|
mean = torch.mean(data) |
|
std = torch.std(data) |
|
|
|
standardized_data = (data - mean) / std |
|
return standardized_data |
|
|
|
def transform(data): |
|
copy_data = data.detach().clone() |
|
rms = torch.std(data) |
|
mean = torch.mean(data) |
|
masks_rms = [-1, 5] |
|
new_data = torch.zeros((len(masks_rms)+1, data.shape[0], data.shape[1])) |
|
new_data[0,:,:] = renorm(torch.log10(copy_data+1e-10)) |
|
for i in range(1, len(masks_rms)+1): |
|
scale = masks_rms[i-1] |
|
copy_data = data.detach().clone() |
|
if scale < 0: |
|
ind = copy_data < abs(scale) * rms + mean |
|
copy_data[ind] = 0 |
|
else: |
|
ind = copy_data > (scale) * rms + mean |
|
copy_data[ind] = 0 |
|
new_data[i,:,:] = renorm(torch.log10(copy_data+1e-10)) |
|
new_data = new_data.type(torch.float32) |
|
slices = torch.chunk(new_data, 8, dim=-1) |
|
new_data = torch.stack(slices, dim=1) |
|
new_data = new_data.view(-1, new_data.size(2), new_data.size(3)) |
|
return new_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def renorm_batched(data): |
|
mins = torch.amin(data, (-2, -1)) |
|
mins = mins.unsqueeze(1).unsqueeze(2) |
|
mins = mins.expand(data.shape[0], 192, 2048) |
|
shifted = data - mins |
|
maxs = torch.amax(shifted, (-2, -1)) |
|
maxs = maxs.unsqueeze(1).unsqueeze(2) |
|
maxs = maxs.expand(data.shape[0], 192, 2048) |
|
shifted = shifted/maxs |
|
return shifted |
|
|
|
|
|
def transform_mask(data): |
|
copy_data = deepcopy(data) |
|
shift = copy_data - copy_data.min() |
|
normalized_data = shift / shift.max() |
|
new_data = np.zeros((3, data.shape[0], data.shape[1])) |
|
for i in range(3): |
|
new_data[i,:,:] = normalized_data |
|
new_data = new_data.astype(np.float32) |
|
return new_data |
|
|
|
|
|
|
|
def Convert_ONNX(model, saveloc, input_data_mock): |
|
print("Saving to ONNX") |
|
|
|
model.eval() |
|
|
|
|
|
dummy_input = torch.autograd.Variable(input_data_mock) |
|
|
|
|
|
torch.onnx.export(model, |
|
dummy_input, |
|
saveloc, |
|
input_names = ['modelInput'], |
|
output_names = ['modelOutput'], |
|
dynamic_axes={'modelInput' : {0 : 'batch_size'}, |
|
'modelOutput' : {0 : 'batch_size'}} ) |
|
print(" ") |
|
print('Model has been converted to ONNX') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|