File size: 2,334 Bytes
f3972ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import pickle
from torch.utils.data import Dataset, DataLoader
import os
import torch
from copy import deepcopy
from blimpy import Waterfall
from tqdm import tqdm
from copy import deepcopy
from sigpyproc.readers import FilReader
from torch import nn


def renorm_batched(data):
    mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)
    std = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)
    standardized_data = (data - mean) / std
    return standardized_data

def transform_batched(data):
    copy_data = data.detach().clone()
    rms = torch.std(data, dim=tuple(range(1, data.ndim)), keepdim=True)  # Batch-wise std
    mean = torch.mean(data, dim=tuple(range(1, data.ndim)), keepdim=True)  # Batch-wise mean
    masks_rms = [-1, 5]
    
    # Prepare the new_data tensor
    num_masks = len(masks_rms) + 1
    new_data = torch.zeros((num_masks, *data.shape), device=data.device)  # Shape: (num_masks, batch_size, ..., ...)

    # First layer: Apply renorm(log10(copy_data + epsilon))
    new_data[0] = renorm_batched(torch.log10(copy_data + 1e-10))
    for i, scale in enumerate(masks_rms, start=1):
        copy_data = data.detach().clone()
        
        # Apply masking based on the scale
        if scale < 0:
            ind = copy_data < abs(scale) * rms + mean
        else:
            ind = copy_data > scale * rms + mean
        copy_data[ind] = 0
        
        # Renormalize and log10 transform
        new_data[i] = renorm_batched(torch.log10(copy_data + 1e-10))
    
    # Convert to float32
    new_data = new_data.type(torch.float32)

    # Chunk along the last dimension and stack
    slices = torch.chunk(new_data, 8, dim=-1)  # Adjust for batch-wise slicing
    new_data = torch.stack(slices, dim=2)  # Insert a new axis at dim=1
    new_data = torch.swapaxes(new_data, 0,1)
    # Reshape into final format
    new_data = new_data.reshape( new_data.size(0), 24,  new_data.size(3), new_data.size(4))  # Flatten batch and mask dimensions
    return new_data

class preproc_flip(nn.Module):
    def forward(self, x, flip=True):
        template = transform_batched(torch.flip(x, dims = (-2,)))
        return template 

class preproc(nn.Module):
    def forward(self, x, flip=True):
        template = transform_batched(x)
        return template