import numpy as np import pickle from torch.utils.data import Dataset, DataLoader import os import torch from copy import deepcopy from blimpy import Waterfall from tqdm import tqdm from copy import deepcopy from sigpyproc.readers import FilReader from torch import nn def renorm_batched(data): mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) standardized_data = (data - mean) / std return standardized_data def transform_batched(data): copy_data = data.detach().clone() rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean masks_rms = [-1, 5] # Prepare the new_data tensor num_masks = len(masks_rms) + 1 new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...) # First layer: Apply renorm(log10(copy_data + epsilon)) new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10)) for i, scale in enumerate(masks_rms, start=1): copy_data = data.detach().clone() # Apply masking based on the scale if scale < 0: ind = copy_data < abs(scale) * rms + mean else: ind = copy_data > scale * rms + mean copy_data[ind] = 0 # Renormalize and log10 transform new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10)) # Convert to float32 new_data = new_data.type(torch.float32) # Chunk along the last dimension and stack slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1 new_data = torch.swapaxes(new_data, 0,1) # Reshape into final format new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions return new_data class preproc_flip(nn.Module): def forward(self, x, flip=True): template = transform_batched(torch.flip(x, dims = (-2,))) return template class preproc(nn.Module): def forward(self, x, flip=True): template = transform_batched(x) return template