File size: 986 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch


class SeparationModel(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def count_parameters(self) -> int:
        """Count the total number of parameters of the model

        Returns:
            int: total amount of parameters
        """
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def receptive_field(self) -> int:
        """Compute the receptive field of the model

        Returns:
            int: receptive field
        """
        input_tensor = torch.rand(1, 1, 4096, requires_grad=True)
        out, out_noise = self.forward(input_tensor)
        grad = torch.zeros_like(out)
        grad[..., out.shape[-1]//2] = torch.nan  # set NaN gradient at the middle of the output
        out.backward(gradient=grad)
        self.zero_grad()  # reset to avoid future problems
        return int(torch.sum(input_tensor.grad.isnan()).cpu())  # Count NaN in the input