Spaces:
Running
Running
import torch | |
from rstor.properties import LEAKY_RELU, RELU, SIMPLE_GATE | |
from typing import Optional, Tuple | |
class SimpleGate(torch.nn.Module): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1, x2 = x.chunk(2, dim=1) | |
return x1 * x2 | |
def get_non_linearity(activation: str): | |
if activation == LEAKY_RELU: | |
non_linearity = torch.nn.LeakyReLU() | |
elif activation == RELU: | |
non_linearity = torch.nn.ReLU() | |
elif activation is None: | |
non_linearity = torch.nn.Identity() | |
elif activation == SIMPLE_GATE: | |
non_linearity = SimpleGate() | |
else: | |
raise ValueError(f"Unknown activation {activation}") | |
return non_linearity | |
class BaseModel(torch.nn.Module): | |
"""Base class for all restoration models with additional useful methods""" | |
def count_parameters(self): | |
return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
def receptive_field( | |
self, | |
channels: Optional[int] = 3, | |
size: Optional[int] = 256, | |
device: Optional[str] = None | |
) -> Tuple[int, int]: | |
"""Compute the receptive field of the model | |
Returns: | |
int: receptive field | |
""" | |
input_tensor = torch.ones(1, channels, size, size, requires_grad=True) | |
if device is not None: | |
input_tensor = input_tensor.to(device) | |
out = self.forward(input_tensor) | |
grad = torch.zeros_like(out) | |
grad[..., out.shape[-2]//2, out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output | |
out.backward(gradient=grad) | |
self.zero_grad() | |
receptive_field_mask = input_tensor.grad.isnan()[0, 0] | |
receptive_field_indexes = torch.where(receptive_field_mask) | |
# Count NaN in the input | |
receptive_x = 1+receptive_field_indexes[-1].max() - receptive_field_indexes[-1].min() # Horizontal x | |
receptive_y = 1+receptive_field_indexes[-2].max() - receptive_field_indexes[-2].min() # Vertical y | |
return receptive_x.item(), receptive_y.item() | |