Spaces:
Sleeping
Sleeping
import torch | |
from models.blocks import AffineConv2d, downsample_strideconv, upsample_convtranspose | |
class InHead(torch.nn.Module): | |
def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False): | |
super(InHead, self).__init__() | |
self.in_channels_list = in_channels_list | |
self.input_layer = input_layer | |
for i, in_channels in enumerate(in_channels_list): | |
conv = AffineConv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
bias=bias, | |
mode=mode, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
padding_mode="zeros", | |
) | |
setattr(self, f"conv{i}", conv) | |
def forward(self, x): | |
in_channels = x.size(1) - 1 if self.input_layer else x.size(1) | |
# find index | |
i = self.in_channels_list.index(in_channels) | |
x = getattr(self, f"conv{i}")(x) | |
return x | |
class OutTail(torch.nn.Module): | |
def __init__(self, in_channels, out_channels_list, mode="", bias=False): | |
super(OutTail, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels_list = out_channels_list | |
for i, out_channels in enumerate(out_channels_list): | |
conv = AffineConv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
bias=bias, | |
mode=mode, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
padding_mode="zeros", | |
) | |
setattr(self, f"conv{i}", conv) | |
def forward(self, x, out_channels): | |
i = self.out_channels_list.index(out_channels) | |
x = getattr(self, f"conv{i}")(x) | |
return x | |
# TODO: check that the heads are compatible with the old implementation | |
class Heads(torch.nn.Module): | |
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False): | |
super(Heads, self).__init__() | |
self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list] | |
self.scale = scale | |
self.mode = mode | |
for i, in_channels in enumerate(self.in_channels_list): | |
setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) | |
if self.mode == "": | |
self.nl = torch.nn.ReLU(inplace=False) | |
if self.scale != 1: | |
for i, in_channels in enumerate(in_channels_list): | |
setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) | |
def forward(self, x): | |
in_channels = x.size(1) | |
i = self.in_channels_list.index(in_channels) | |
if self.scale != 1: | |
if self.mode == "bilinear": | |
x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) | |
else: | |
x = getattr(self, f"down{i}")(x) | |
x = self.nl(x) | |
# find index | |
x = getattr(self, f"head{i}")(x) | |
return x | |
class Tails(torch.nn.Module): | |
def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False): | |
super(Tails, self).__init__() | |
self.out_channels_list = out_channels_list | |
self.scale = scale | |
for i, out_channels in enumerate(out_channels_list): | |
setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) | |
self.mode = mode | |
if self.mode == "": | |
self.nl = torch.nn.ReLU(inplace=False) | |
if self.scale != 1: | |
# self.up = upsample_convtranspose(out_channels, out_channels, bias=True, mode=str(self.scale)) | |
for i, out_channels in enumerate(out_channels_list): | |
setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale))) | |
def forward(self, x, out_channels): | |
i = self.out_channels_list.index(out_channels) | |
x = getattr(self, f"tail{i}")(x) | |
# find index | |
if self.scale != 1: | |
if self.mode == "bilinear": | |
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) | |
else: | |
x = getattr(self, f"up{i}")(x) | |
return x | |
class ConvChannels(torch.nn.Module): | |
""" | |
TODO: remplace this with convconv | |
A method that only performs convolutional operations on the appropriate channels dim. | |
""" | |
def __init__(self, channels_list, depth=2, bias=False, residual=False): | |
super(ConvChannels, self).__init__() | |
self.channels_list = channels_list | |
self.residual = residual | |
for i, channels in enumerate(channels_list): | |
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) | |
setattr(self, f"nl{i}", torch.nn.ReLU()) | |
setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) | |
def forward(self, x): | |
i = self.channels_list.index(x.shape[1]) | |
u = getattr(self, f"conv{i}_1")(x) | |
u = getattr(self, f"nl{i}")(u) | |
u = getattr(self, f"conv{i}_2")(u) | |
if self.residual: | |
u = x + u | |
return u | |
class HeadBlock(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False): | |
super(HeadBlock, self).__init__() | |
padding = kernel_size // 2 | |
c = out_channels if depth < 2 else in_channels | |
self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias) | |
self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False) | |
self.depth = depth | |
self.nl_1 = torch.nn.ReLU(inplace=False) | |
self.nl_2 = torch.nn.ReLU(inplace=False) | |
self.relu_in = relu_in | |
self.skip_in = skip_in | |
for i in range(depth-1): | |
if i < depth - 2: | |
c_in, c = in_channels, in_channels | |
else: | |
c_in, c = in_channels, out_channels | |
setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias)) | |
setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias)) | |
setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False)) | |
def forward(self, x): | |
if self.skip_in and self.relu_in: | |
x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x) | |
elif self.skip_in and not self.relu_in: | |
x = self.convin(x) + self.zero_conv_skip(x) | |
else: | |
x = self.convin(x) | |
for i in range(self.depth-1): | |
aux = getattr(self, f"conv1{i}")(x) | |
aux = self.nl_2(aux) | |
aux_0 = getattr(self, f"conv2{i}")(aux) | |
aux_1 = getattr(self, f"skipconv{i}")(x) | |
x = aux_0 + aux_1 | |
return x | |
class SNRModule(torch.nn.Module): | |
""" | |
A method that only performs convolutional operations on the appropriate channels dim. | |
""" | |
def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64): | |
super(SNRModule, self).__init__() | |
self.channels_list = channels_list | |
self.residual = residual | |
for i, channels in enumerate(channels_list): | |
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels + 1, features, 3, bias=bias, padding=1)) | |
setattr(self, f"nl{i}", torch.nn.ReLU()) | |
setattr(self, f"conv{i}_2", torch.nn.Conv2d(features, out_channels, 3, bias=bias, padding=1)) | |
def forward(self, x0, sigma): | |
i = self.channels_list.index(x0.shape[1]) | |
noise_level_map = (torch.ones((x0.size(0), 1, x0.size(2), x0.size(3)), device=x0.device) * sigma) | |
x = torch.cat((x0, noise_level_map), 1) | |
u = getattr(self, f"conv{i}_1")(x) | |
u = getattr(self, f"nl{i}")(u) | |
u = getattr(self, f"conv{i}_2")(u) | |
den = u.pow(2).mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True).sqrt() | |
u = u.abs() / (den + 1e-8) | |
return u.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) | |
class EquivConvModule(torch.nn.Module): | |
""" | |
A method that only performs convolutional operations on the appropriate channels dim. | |
""" | |
def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64, N=1): | |
super(EquivConvModule, self).__init__() | |
self.channels_list = [c * N for c in channels_list] | |
self.residual = residual | |
for i, channels in enumerate(channels_list): | |
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels * N, channels * N, 3, bias=bias, padding=1)) | |
setattr(self, f"nl{i}", torch.nn.ReLU()) | |
setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels * N, out_channels, 3, bias=bias, padding=1)) | |
def forward(self, x): | |
i = self.channels_list.index(x.shape[1]) | |
u = getattr(self, f"conv{i}_1")(x) | |
u = getattr(self, f"nl{i}")(u) | |
u = getattr(self, f"conv{i}_2")(u) | |
return u | |
class EquivHeads(torch.nn.Module): | |
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear"): | |
super(EquivHeads, self).__init__() | |
self.in_channels_list = in_channels_list | |
self.scale = scale | |
self.mode = mode | |
for i, in_channels in enumerate(in_channels_list): | |
setattr(self, f"head{i}", HeadBlock(in_channels + 1, out_channels, depth=depth, bias=bias)) | |
if self.mode == "": | |
self.nl = torch.nn.ReLU(inplace=False) | |
if self.scale != 1: | |
for i, in_channels in enumerate(in_channels_list): | |
setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) | |
def forward(self, x, sigma): | |
in_channels = x.size(1) | |
i = self.in_channels_list.index(in_channels) | |
if self.scale != 1: | |
if self.mode == "bilinear": | |
x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) | |
else: | |
x = getattr(self, f"down{i}")(x) | |
x = self.nl(x) | |
# concat noise level map | |
noise_level_map = (torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * sigma) | |
x = torch.cat((x, noise_level_map), 1) | |
# find index | |
x = getattr(self, f"head{i}")(x) | |
return x | |