BLADE_FRBNN / models /utils_batched_preproc.py
peterma02's picture
Upload folder using huggingface_hub
f3972ea verified
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
import os
import torch
from copy import deepcopy
from blimpy import Waterfall
from tqdm import tqdm
from copy import deepcopy
from sigpyproc.readers import FilReader
from torch import nn
def renorm_batched(data):
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
standardized_data = (data - mean) / std
return standardized_data
def transform_batched(data):
copy_data = data.detach().clone()
rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise std
mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True) # Batch-wise mean
masks_rms = [-1, 5]
# Prepare the new_data tensor
num_masks = len(masks_rms) + 1
new_data = torch.zeros((num_masks, *data.shape), device=data.device) # Shape: (num_masks, batch_size, ..., ...)
# First layer: Apply renorm(log10(copy_data + epsilon))
new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
for i, scale in enumerate(masks_rms, start=1):
copy_data = data.detach().clone()
# Apply masking based on the scale
if scale < 0:
ind = copy_data < abs(scale) * rms + mean
else:
ind = copy_data > scale * rms + mean
copy_data[ind] = 0
# Renormalize and log10 transform
new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
# Convert to float32
new_data = new_data.type(torch.float32)
# Chunk along the last dimension and stack
slices = torch.chunk(new_data, 8, dim=-1) # Adjust for batch-wise slicing
new_data = torch.stack(slices, dim=2) # Insert a new axis at dim=1
new_data = torch.swapaxes(new_data, 0,1)
# Reshape into final format
new_data = new_data.reshape( new_data.size(0), 24, new_data.size(3), new_data.size(4)) # Flatten batch and mask dimensions
return new_data
class preproc_flip(nn.Module):
def forward(self, x, flip=True):
template = transform_batched(torch.flip(x, dims = (-2,)))
return template
class preproc(nn.Module):
def forward(self, x, flip=True):
template = transform_batched(x)
return template