denoising / models /heads.py
Yonuts's picture
gradio demo
12a4d59
raw
history blame
10.8 kB
import torch
from models.blocks import AffineConv2d, downsample_strideconv, upsample_convtranspose
class InHead(torch.nn.Module):
def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False):
super(InHead, self).__init__()
self.in_channels_list = in_channels_list
self.input_layer = input_layer
for i, in_channels in enumerate(in_channels_list):
conv = AffineConv2d(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
mode=mode,
kernel_size=3,
stride=1,
padding=1,
padding_mode="zeros",
)
setattr(self, f"conv{i}", conv)
def forward(self, x):
in_channels = x.size(1) - 1 if self.input_layer else x.size(1)
# find index
i = self.in_channels_list.index(in_channels)
x = getattr(self, f"conv{i}")(x)
return x
class OutTail(torch.nn.Module):
def __init__(self, in_channels, out_channels_list, mode="", bias=False):
super(OutTail, self).__init__()
self.in_channels = in_channels
self.out_channels_list = out_channels_list
for i, out_channels in enumerate(out_channels_list):
conv = AffineConv2d(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
mode=mode,
kernel_size=3,
stride=1,
padding=1,
padding_mode="zeros",
)
setattr(self, f"conv{i}", conv)
def forward(self, x, out_channels):
i = self.out_channels_list.index(out_channels)
x = getattr(self, f"conv{i}")(x)
return x
# TODO: check that the heads are compatible with the old implementation
class Heads(torch.nn.Module):
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False):
super(Heads, self).__init__()
self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list]
self.scale = scale
self.mode = mode
for i, in_channels in enumerate(self.in_channels_list):
setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
if self.mode == "":
self.nl = torch.nn.ReLU(inplace=False)
if self.scale != 1:
for i, in_channels in enumerate(in_channels_list):
setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
def forward(self, x):
in_channels = x.size(1)
i = self.in_channels_list.index(in_channels)
if self.scale != 1:
if self.mode == "bilinear":
x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False)
else:
x = getattr(self, f"down{i}")(x)
x = self.nl(x)
# find index
x = getattr(self, f"head{i}")(x)
return x
class Tails(torch.nn.Module):
def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False):
super(Tails, self).__init__()
self.out_channels_list = out_channels_list
self.scale = scale
for i, out_channels in enumerate(out_channels_list):
setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
self.mode = mode
if self.mode == "":
self.nl = torch.nn.ReLU(inplace=False)
if self.scale != 1:
# self.up = upsample_convtranspose(out_channels, out_channels, bias=True, mode=str(self.scale))
for i, out_channels in enumerate(out_channels_list):
setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale)))
def forward(self, x, out_channels):
i = self.out_channels_list.index(out_channels)
x = getattr(self, f"tail{i}")(x)
# find index
if self.scale != 1:
if self.mode == "bilinear":
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
else:
x = getattr(self, f"up{i}")(x)
return x
class ConvChannels(torch.nn.Module):
"""
TODO: remplace this with convconv
A method that only performs convolutional operations on the appropriate channels dim.
"""
def __init__(self, channels_list, depth=2, bias=False, residual=False):
super(ConvChannels, self).__init__()
self.channels_list = channels_list
self.residual = residual
for i, channels in enumerate(channels_list):
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1))
setattr(self, f"nl{i}", torch.nn.ReLU())
setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels, channels, 3, bias=bias, padding=1))
def forward(self, x):
i = self.channels_list.index(x.shape[1])
u = getattr(self, f"conv{i}_1")(x)
u = getattr(self, f"nl{i}")(u)
u = getattr(self, f"conv{i}_2")(u)
if self.residual:
u = x + u
return u
class HeadBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False):
super(HeadBlock, self).__init__()
padding = kernel_size // 2
c = out_channels if depth < 2 else in_channels
self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias)
self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False)
self.depth = depth
self.nl_1 = torch.nn.ReLU(inplace=False)
self.nl_2 = torch.nn.ReLU(inplace=False)
self.relu_in = relu_in
self.skip_in = skip_in
for i in range(depth-1):
if i < depth - 2:
c_in, c = in_channels, in_channels
else:
c_in, c = in_channels, out_channels
setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias))
setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias))
setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False))
def forward(self, x):
if self.skip_in and self.relu_in:
x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x)
elif self.skip_in and not self.relu_in:
x = self.convin(x) + self.zero_conv_skip(x)
else:
x = self.convin(x)
for i in range(self.depth-1):
aux = getattr(self, f"conv1{i}")(x)
aux = self.nl_2(aux)
aux_0 = getattr(self, f"conv2{i}")(aux)
aux_1 = getattr(self, f"skipconv{i}")(x)
x = aux_0 + aux_1
return x
class SNRModule(torch.nn.Module):
"""
A method that only performs convolutional operations on the appropriate channels dim.
"""
def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64):
super(SNRModule, self).__init__()
self.channels_list = channels_list
self.residual = residual
for i, channels in enumerate(channels_list):
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels + 1, features, 3, bias=bias, padding=1))
setattr(self, f"nl{i}", torch.nn.ReLU())
setattr(self, f"conv{i}_2", torch.nn.Conv2d(features, out_channels, 3, bias=bias, padding=1))
def forward(self, x0, sigma):
i = self.channels_list.index(x0.shape[1])
noise_level_map = (torch.ones((x0.size(0), 1, x0.size(2), x0.size(3)), device=x0.device) * sigma)
x = torch.cat((x0, noise_level_map), 1)
u = getattr(self, f"conv{i}_1")(x)
u = getattr(self, f"nl{i}")(u)
u = getattr(self, f"conv{i}_2")(u)
den = u.pow(2).mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True).sqrt()
u = u.abs() / (den + 1e-8)
return u.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
class EquivConvModule(torch.nn.Module):
"""
A method that only performs convolutional operations on the appropriate channels dim.
"""
def __init__(self, channels_list, out_channels, bias=False, residual=False, features=64, N=1):
super(EquivConvModule, self).__init__()
self.channels_list = [c * N for c in channels_list]
self.residual = residual
for i, channels in enumerate(channels_list):
setattr(self, f"conv{i}_1", torch.nn.Conv2d(channels * N, channels * N, 3, bias=bias, padding=1))
setattr(self, f"nl{i}", torch.nn.ReLU())
setattr(self, f"conv{i}_2", torch.nn.Conv2d(channels * N, out_channels, 3, bias=bias, padding=1))
def forward(self, x):
i = self.channels_list.index(x.shape[1])
u = getattr(self, f"conv{i}_1")(x)
u = getattr(self, f"nl{i}")(u)
u = getattr(self, f"conv{i}_2")(u)
return u
class EquivHeads(torch.nn.Module):
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear"):
super(EquivHeads, self).__init__()
self.in_channels_list = in_channels_list
self.scale = scale
self.mode = mode
for i, in_channels in enumerate(in_channels_list):
setattr(self, f"head{i}", HeadBlock(in_channels + 1, out_channels, depth=depth, bias=bias))
if self.mode == "":
self.nl = torch.nn.ReLU(inplace=False)
if self.scale != 1:
for i, in_channels in enumerate(in_channels_list):
setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
def forward(self, x, sigma):
in_channels = x.size(1)
i = self.in_channels_list.index(in_channels)
if self.scale != 1:
if self.mode == "bilinear":
x = torch.nn.functional.interpolate(x, scale_factor=1/self.scale, mode='bilinear', align_corners=False)
else:
x = getattr(self, f"down{i}")(x)
x = self.nl(x)
# concat noise level map
noise_level_map = (torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * sigma)
x = torch.cat((x, noise_level_map), 1)
# find index
x = getattr(self, f"head{i}")(x)
return x