|
import numpy as np |
|
import pickle |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import torch |
|
from copy import deepcopy |
|
from blimpy import Waterfall |
|
from tqdm import tqdm |
|
from copy import deepcopy |
|
from sigpyproc.readers import FilReader |
|
from torch import nn |
|
|
|
|
|
def renorm_batched(data): |
|
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
|
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
|
standardized_data = (data - mean) / std |
|
return standardized_data |
|
|
|
def transform_batched(data): |
|
copy_data = data.detach().clone() |
|
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
|
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) |
|
masks_rms = [-1, 5] |
|
|
|
|
|
num_masks = len(masks_rms) + 1 |
|
new_data = torch.zeros((num_masks, *data.shape), device=data.device) |
|
|
|
|
|
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10)) |
|
for i, scale in enumerate(masks_rms, start=1): |
|
copy_data = data.detach().clone() |
|
|
|
|
|
if scale < 0: |
|
ind = copy_data < abs(scale) * rms + mean |
|
else: |
|
ind = copy_data > scale * rms + mean |
|
copy_data[ind] = 0 |
|
|
|
|
|
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10)) |
|
|
|
|
|
new_data = new_data.type(torch.float32) |
|
|
|
|
|
slices = torch.chunk(new_data, 8, dim=-1) |
|
new_data = torch.stack(slices, dim=2) |
|
new_data = torch.swapaxes(new_data, 0,1) |
|
|
|
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) |
|
return new_data |
|
|
|
class preproc_flip(nn.Module): |
|
def forward(self, x, flip=True): |
|
template = transform_batched(torch.flip(x, dims = (-2,))) |
|
return template |
|
|
|
class preproc(nn.Module): |
|
def forward(self, x, flip=True): |
|
template = transform_batched(x) |
|
return template |
|
|