import torch from models.blocks import AffineConv2d, downsample_strideconv, upsample_convtranspose class InHead(torch.nn.Module): def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False): super(InHead, self).__init__() self.in_channels_list = in_channels_list self.input_layer = input_layer for i, in_channels in enumerate(in_channels_list): conv = AffineConv2d( in_channels=in_channels, out_channels=out_channels, bias=bias, mode=mode, kernel_size=3, stride=1, padding=1, padding_mode="zeros", ) setattr(self, f"conv{i}", conv) def forward(self, x): in_channels = x.size(1) - 1 if self.input_layer else x.size(1) # find index i = self.in_channels_list.index(in_channels) x = getattr(self, f"conv{i}")(x) return x class OutTail(torch.nn.Module): def __init__(self, in_channels, out_channels_list, mode="", bias=False): super(OutTail, self).__init__() self.in_channels = in_channels self.out_channels_list = out_channels_list for i, out_channels in enumerate(out_channels_list): conv = AffineConv2d( in_channels=in_channels, out_channels=out_channels, bias=bias, mode=mode, kernel_size=3, stride=1, padding=1, padding_mode="zeros", ) setattr(self, f"conv{i}", conv) def forward(self, x, out_channels): i = self.out_channels_list.index(out_channels) x = getattr(self, f"conv{i}")(x) return x # TODO: check that the heads are compatible with the old implementation class Heads(torch.nn.Module): def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False): super(Heads, self).__init__() self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list] self.scale = scale self.mode = mode for i, in_channels in enumerate(self.in_channels_list): setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) if self.mode == "": self.nl = torch.nn.ReLU(inplace=False) if self.scale != 1: for i, in_channels in enumerate(in_channels_list): setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) def forward(self, x): in_channels = x.size(1) i = self.in_channels_list.index(in_channels) if self.scale != 1: if self.mode == "bilinear": x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) else: x = getattr(self, f"down{i}")(x) x = self.nl(x) # find index x = getattr(self, f"head{i}")(x) return x class Tails(torch.nn.Module): def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False): super(Tails, self).__init__() self.out_channels_list = out_channels_list self.scale = scale for i, out_channels in enumerate(out_channels_list): setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) self.mode = mode if self.mode == "": self.nl = torch.nn.ReLU(inplace=False) if self.scale != 1: # self.up = upsample_convtranspose(out_channels, out_channels, bias=True, mode=str(self.scale)) for i, out_channels in enumerate(out_channels_list): setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale))) def forward(self, x, out_channels): i = self.out_channels_list.index(out_channels) x = getattr(self, f"tail{i}")(x) # find index if self.scale != 1: if self.mode == "bilinear": x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) else: x = getattr(self, f"up{i}")(x) return x class ConvChannels(torch.nn.Module): """ TODO: remplace this with convconv A method that only performs convolutional operations on the appropriate channels dim. """ def __init__(self, channels_list, depth=2, bias=False, residual=False): super(ConvChannels, self).__init__() self.channels_list = channels_list self.residual = residual for i, channels in enumerate(channels_list): setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) setattr(self, f"nl{i}", torch.nn.ReLU()) setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1)) def forward(self, x): i = self.channels_list.index(x.shape[1]) u = getattr(self, f"conv{i}_1")(x) u = getattr(self, f"nl{i}")(u) u = getattr(self, f"conv{i}_2")(u) if self.residual: u = x + u return u class HeadBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False): super(HeadBlock, self).__init__() padding = kernel_size // 2 c = out_channels if depth < 2 else in_channels self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias) self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False) self.depth = depth self.nl_1 = torch.nn.ReLU(inplace=False) self.nl_2 = torch.nn.ReLU(inplace=False) self.relu_in = relu_in self.skip_in = skip_in for i in range(depth-1): if i < depth - 2: c_in, c = in_channels, in_channels else: c_in, c = in_channels, out_channels setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias)) setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias)) setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False)) def forward(self, x): if self.skip_in and self.relu_in: x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x) elif self.skip_in and not self.relu_in: x = self.convin(x) + self.zero_conv_skip(x) else: x = self.convin(x) for i in range(self.depth-1): aux = getattr(self, f"conv1{i}")(x) aux = self.nl_2(aux) aux_0 = getattr(self, f"conv2{i}")(aux) aux_1 = getattr(self, f"skipconv{i}")(x) x = aux_0 + aux_1 return x class SNRModule(torch.nn.Module): """ A method that only performs convolutional operations on the appropriate channels dim. """ def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64): super(SNRModule, self).__init__() self.channels_list = channels_list self.residual = residual for i, channels in enumerate(channels_list): setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels + 1, features, 3, bias=bias, padding=1)) setattr(self, f"nl{i}", torch.nn.ReLU()) setattr(self, f"conv{i}_2", torch.nn.Conv2d(features, out_channels, 3, bias=bias, padding=1)) def forward(self, x0, sigma): i = self.channels_list.index(x0.shape[1]) noise_level_map = (torch.ones((x0.size(0), 1, x0.size(2), x0.size(3)), device=x0.device) * sigma) x = torch.cat((x0, noise_level_map), 1) u = getattr(self, f"conv{i}_1")(x) u = getattr(self, f"nl{i}")(u) u = getattr(self, f"conv{i}_2")(u) den = u.pow(2).mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True).sqrt() u = u.abs() / (den + 1e-8) return u.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) class EquivConvModule(torch.nn.Module): """ A method that only performs convolutional operations on the appropriate channels dim. """ def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64, N=1): super(EquivConvModule, self).__init__() self.channels_list = [c * N for c in channels_list] self.residual = residual for i, channels in enumerate(channels_list): setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels * N, channels * N, 3, bias=bias, padding=1)) setattr(self, f"nl{i}", torch.nn.ReLU()) setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels * N, out_channels, 3, bias=bias, padding=1)) def forward(self, x): i = self.channels_list.index(x.shape[1]) u = getattr(self, f"conv{i}_1")(x) u = getattr(self, f"nl{i}")(u) u = getattr(self, f"conv{i}_2")(u) return u class EquivHeads(torch.nn.Module): def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear"): super(EquivHeads, self).__init__() self.in_channels_list = in_channels_list self.scale = scale self.mode = mode for i, in_channels in enumerate(in_channels_list): setattr(self, f"head{i}", HeadBlock(in_channels + 1, out_channels, depth=depth, bias=bias)) if self.mode == "": self.nl = torch.nn.ReLU(inplace=False) if self.scale != 1: for i, in_channels in enumerate(in_channels_list): setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) def forward(self, x, sigma): in_channels = x.size(1) i = self.in_channels_list.index(in_channels) if self.scale != 1: if self.mode == "bilinear": x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False) else: x = getattr(self, f"down{i}")(x) x = self.nl(x) # concat noise level map noise_level_map = (torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * sigma) x = torch.cat((x, noise_level_map), 1) # find index x = getattr(self, f"head{i}")(x) return x