import sys import torch import torch.nn as nn import torch.nn.functional as F from src.utils import uncenter_l def find_local_patch(x, patch_size): """ > We take a tensor `x` and return a tensor `x_unfold` that contains all the patches of size `patch_size` in `x` Args: x: the input tensor patch_size: the size of the patch to be extracted. """ N, C, H, W = x.shape x_unfold = F.unfold(x, kernel_size=(patch_size, patch_size), padding=(patch_size // 2, patch_size // 2), stride=(1, 1)) return x_unfold.view(N, x_unfold.shape[1], H, W) class WeightedAverage(nn.Module): def __init__( self, ): super(WeightedAverage, self).__init__() def forward(self, x_lab, patch_size=3, alpha=1, scale_factor=1): """ It takes a 3-channel image (L, A, B) and returns a 2-channel image (A, B) where each pixel is a weighted average of the A and B values of the pixels in a 3x3 neighborhood around it Args: x_lab: the input image in LAB color space patch_size: the size of the patch to use for the local average. Defaults to 3 alpha: the higher the alpha, the smoother the output. Defaults to 1 scale_factor: the scale factor of the input image. Defaults to 1 Returns: The output of the forward function is a tensor of size (batch_size, 2, height, width) """ # alpha=0: less smooth; alpha=inf: smoother x_lab = F.interpolate(x_lab, scale_factor=scale_factor) l = x_lab[:, 0:1, :, :] a = x_lab[:, 1:2, :, :] b = x_lab[:, 2:3, :, :] local_l = find_local_patch(l, patch_size) local_a = find_local_patch(a, patch_size) local_b = find_local_patch(b, patch_size) local_difference_l = (local_l - l) ** 2 correlation = nn.functional.softmax(-1 * local_difference_l / alpha, dim=1) return torch.cat( ( torch.sum(correlation * local_a, dim=1, keepdim=True), torch.sum(correlation * local_b, dim=1, keepdim=True), ), 1, ) class WeightedAverage_color(nn.Module): """ smooth the image according to the color distance in the LAB space """ def __init__( self, ): super(WeightedAverage_color, self).__init__() def forward(self, x_lab, x_lab_predict, patch_size=3, alpha=1, scale_factor=1): """ It takes the predicted a and b channels, and the original a and b channels, and finds the weighted average of the predicted a and b channels based on the similarity of the original a and b channels to the predicted a and b channels Args: x_lab: the input image in LAB color space x_lab_predict: the predicted LAB image patch_size: the size of the patch to use for the local color correction. Defaults to 3 alpha: controls the smoothness of the output. Defaults to 1 scale_factor: the scale factor of the input image. Defaults to 1 Returns: The return is the weighted average of the local a and b channels. """ """ alpha=0: less smooth; alpha=inf: smoother """ x_lab = F.interpolate(x_lab, scale_factor=scale_factor) l = uncenter_l(x_lab[:, 0:1, :, :]) a = x_lab[:, 1:2, :, :] b = x_lab[:, 2:3, :, :] a_predict = x_lab_predict[:, 1:2, :, :] b_predict = x_lab_predict[:, 2:3, :, :] local_l = find_local_patch(l, patch_size) local_a = find_local_patch(a, patch_size) local_b = find_local_patch(b, patch_size) local_a_predict = find_local_patch(a_predict, patch_size) local_b_predict = find_local_patch(b_predict, patch_size) local_color_difference = (local_l - l) ** 2 + (local_a - a) ** 2 + (local_b - b) ** 2 # so that sum of weights equal to 1 correlation = nn.functional.softmax(-1 * local_color_difference / alpha, dim=1) return torch.cat( ( torch.sum(correlation * local_a_predict, dim=1, keepdim=True), torch.sum(correlation * local_b_predict, dim=1, keepdim=True), ), 1, ) class NonlocalWeightedAverage(nn.Module): def __init__( self, ): super(NonlocalWeightedAverage, self).__init__() def forward(self, x_lab, feature, patch_size=3, alpha=0.1, scale_factor=1): """ It takes in a feature map and a label map, and returns a smoothed label map Args: x_lab: the input image in LAB color space feature: the feature map of the input image patch_size: the size of the patch to be used for the correlation matrix. Defaults to 3 alpha: the higher the alpha, the smoother the output. scale_factor: the scale factor of the input image. Defaults to 1 Returns: weighted_ab is the weighted ab channel of the image. """ # alpha=0: less smooth; alpha=inf: smoother # input feature is normalized feature x_lab = F.interpolate(x_lab, scale_factor=scale_factor) batch_size, channel, height, width = x_lab.shape feature = F.interpolate(feature, size=(height, width)) batch_size = x_lab.shape[0] x_ab = x_lab[:, 1:3, :, :].view(batch_size, 2, -1) x_ab = x_ab.permute(0, 2, 1) local_feature = find_local_patch(feature, patch_size) local_feature = local_feature.view(batch_size, local_feature.shape[1], -1) correlation_matrix = torch.matmul(local_feature.permute(0, 2, 1), local_feature) correlation_matrix = nn.functional.softmax(correlation_matrix / alpha, dim=-1) weighted_ab = torch.matmul(correlation_matrix, x_ab) weighted_ab = weighted_ab.permute(0, 2, 1).contiguous() weighted_ab = weighted_ab.view(batch_size, 2, height, width) return weighted_ab class CorrelationLayer(nn.Module): def __init__(self, search_range): super(CorrelationLayer, self).__init__() self.search_range = search_range def forward(self, x1, x2, alpha=1, raw_output=False, metric="similarity"): """ It takes two tensors, x1 and x2, and returns a tensor of shape (batch_size, (search_range * 2 + 1) ** 2, height, width) where each element is the dot product of the corresponding patch in x1 and x2 Args: x1: the first image x2: the image to be warped alpha: the temperature parameter for the softmax function. Defaults to 1 raw_output: if True, return the raw output of the network, otherwise return the softmax output. Defaults to False metric: "similarity" or "subtraction". Defaults to similarity Returns: The output of the forward function is a softmax of the correlation volume. """ shape = list(x1.size()) shape[1] = (self.search_range * 2 + 1) ** 2 cv = torch.zeros(shape).to(torch.device("cuda")) for i in range(-self.search_range, self.search_range + 1): for j in range(-self.search_range, self.search_range + 1): if i < 0: slice_h, slice_h_r = slice(None, i), slice(-i, None) elif i > 0: slice_h, slice_h_r = slice(i, None), slice(None, -i) else: slice_h, slice_h_r = slice(None), slice(None) if j < 0: slice_w, slice_w_r = slice(None, j), slice(-j, None) elif j > 0: slice_w, slice_w_r = slice(j, None), slice(None, -j) else: slice_w, slice_w_r = slice(None), slice(None) if metric == "similarity": cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = ( x1[:, :, slice_h, slice_w] * x2[:, :, slice_h_r, slice_w_r] ).sum(1) else: # patchwise subtraction cv[:, (self.search_range * 2 + 1) * i + j, slice_h, slice_w] = -( (x1[:, :, slice_h, slice_w] - x2[:, :, slice_h_r, slice_w_r]) ** 2 ).sum(1) # TODO sigmoid? if raw_output: return cv else: return nn.functional.softmax(cv / alpha, dim=1) class WTA_scale(torch.autograd.Function): """ We can implement our own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. """ @staticmethod def forward(ctx, input, scale=1e-4): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. You can cache arbitrary Tensors for use in the backward pass using the save_for_backward method. """ activation_max, index_max = torch.max(input, -1, keepdim=True) input_scale = input * scale # default: 1e-4 # input_scale = input * scale # default: 1e-4 output_max_scale = torch.where(input == activation_max, input, input_scale) mask = (input == activation_max).type(torch.float) ctx.save_for_backward(input, mask) return output_max_scale @staticmethod def backward(ctx, grad_output): """ In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ input, mask = ctx.saved_tensors mask_ones = torch.ones_like(mask) mask_small_ones = torch.ones_like(mask) * 1e-4 # mask_small_ones = torch.ones_like(mask) * 1e-4 grad_scale = torch.where(mask == 1, mask_ones, mask_small_ones) grad_input = grad_output.clone() * grad_scale return grad_input, None class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1): super(ResidualBlock, self).__init__() self.padding1 = nn.ReflectionPad2d(padding) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride) self.bn1 = nn.InstanceNorm2d(out_channels) self.prelu = nn.PReLU() self.padding2 = nn.ReflectionPad2d(padding) self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride) self.bn2 = nn.InstanceNorm2d(out_channels) def forward(self, x): residual = x out = self.padding1(x) out = self.conv1(out) out = self.bn1(out) out = self.prelu(out) out = self.padding2(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.prelu(out) return out class WarpNet(nn.Module): """input is Al, Bl, channel = 1, range~[0,255]""" def __init__(self, feature_channel=128): super(WarpNet, self).__init__() self.feature_channel = feature_channel self.in_channels = self.feature_channel * 4 self.inter_channels = 256 # 44*44 self.layer2_1 = nn.Sequential( nn.ReflectionPad2d(1), # nn.Conv2d(128, 128, kernel_size=3, padding=0, stride=1), # nn.Conv2d(96, 128, kernel_size=3, padding=20, stride=1), nn.Conv2d(96, 128, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(128), nn.PReLU(), nn.ReflectionPad2d(1), nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=2), nn.InstanceNorm2d(self.feature_channel), nn.PReLU(), nn.Dropout(0.2), ) self.layer3_1 = nn.Sequential( nn.ReflectionPad2d(1), # nn.Conv2d(256, 128, kernel_size=3, padding=0, stride=1), # nn.Conv2d(192, 128, kernel_size=3, padding=10, stride=1), nn.Conv2d(192, 128, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(128), nn.PReLU(), nn.ReflectionPad2d(1), nn.Conv2d(128, self.feature_channel, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(self.feature_channel), nn.PReLU(), nn.Dropout(0.2), ) # 22*22->44*44 self.layer4_1 = nn.Sequential( nn.ReflectionPad2d(1), # nn.Conv2d(512, 256, kernel_size=3, padding=0, stride=1), # nn.Conv2d(384, 256, kernel_size=3, padding=5, stride=1), nn.Conv2d(384, 256, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(256), nn.PReLU(), nn.ReflectionPad2d(1), nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(self.feature_channel), nn.PReLU(), nn.Upsample(scale_factor=2), nn.Dropout(0.2), ) # 11*11->44*44 self.layer5_1 = nn.Sequential( nn.ReflectionPad2d(1), # nn.Conv2d(1024, 256, kernel_size=3, padding=0, stride=1), # nn.Conv2d(768, 256, kernel_size=2, padding=2, stride=1), nn.Conv2d(768, 256, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(256), nn.PReLU(), nn.Upsample(scale_factor=2), nn.ReflectionPad2d(1), nn.Conv2d(256, self.feature_channel, kernel_size=3, padding=0, stride=1), nn.InstanceNorm2d(self.feature_channel), nn.PReLU(), nn.Upsample(scale_factor=2), nn.Dropout(0.2), ) self.layer = nn.Sequential( ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), ResidualBlock(self.feature_channel * 4, self.feature_channel * 4, kernel_size=3, padding=1, stride=1), ) self.theta = nn.Conv2d( in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0 ) self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0) self.upsampling = nn.Upsample(scale_factor=4) def forward( self, B_lab_map, A_relu2_1, A_relu3_1, A_relu4_1, A_relu5_1, B_relu2_1, B_relu3_1, B_relu4_1, B_relu5_1, temperature=0.001 * 5, detach_flag=False, WTA_scale_weight=1, ): batch_size = B_lab_map.shape[0] channel = B_lab_map.shape[1] image_height = B_lab_map.shape[2] image_width = B_lab_map.shape[3] feature_height = int(image_height / 4) feature_width = int(image_width / 4) # scale feature size to 44*44 A_feature2_1 = self.layer2_1(A_relu2_1) B_feature2_1 = self.layer2_1(B_relu2_1) A_feature3_1 = self.layer3_1(A_relu3_1) B_feature3_1 = self.layer3_1(B_relu3_1) A_feature4_1 = self.layer4_1(A_relu4_1) B_feature4_1 = self.layer4_1(B_relu4_1) A_feature5_1 = self.layer5_1(A_relu5_1) B_feature5_1 = self.layer5_1(B_relu5_1) # concatenate features if A_feature5_1.shape[2] != A_feature2_1.shape[2] or A_feature5_1.shape[3] != A_feature2_1.shape[3]: A_feature5_1 = F.pad(A_feature5_1, (0, 0, 1, 1), "replicate") B_feature5_1 = F.pad(B_feature5_1, (0, 0, 1, 1), "replicate") A_features = self.layer(torch.cat((A_feature2_1, A_feature3_1, A_feature4_1, A_feature5_1), 1)) B_features = self.layer(torch.cat((B_feature2_1, B_feature3_1, B_feature4_1, B_feature5_1), 1)) # pairwise cosine similarity theta = self.theta(A_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width) theta = theta - theta.mean(dim=-1, keepdim=True) # center the feature theta_norm = torch.norm(theta, 2, 1, keepdim=True) + sys.float_info.epsilon theta = torch.div(theta, theta_norm) theta_permute = theta.permute(0, 2, 1) # 2*(feature_height*feature_width)*256 phi = self.phi(B_features).view(batch_size, self.inter_channels, -1) # 2*256*(feature_height*feature_width) phi = phi - phi.mean(dim=-1, keepdim=True) # center the feature phi_norm = torch.norm(phi, 2, 1, keepdim=True) + sys.float_info.epsilon phi = torch.div(phi, phi_norm) f = torch.matmul(theta_permute, phi) # 2*(feature_height*feature_width)*(feature_height*feature_width) if detach_flag: f = f.detach() f_similarity = f.unsqueeze_(dim=1) similarity_map = torch.max(f_similarity, -1, keepdim=True)[0] similarity_map = similarity_map.view(batch_size, 1, feature_height, feature_width) # f can be negative f_WTA = f if WTA_scale_weight == 1 else WTA_scale.apply(f, WTA_scale_weight) f_WTA = f_WTA / temperature f_div_C = F.softmax(f_WTA.squeeze_(), dim=-1) # 2*1936*1936; # downsample the reference color B_lab = F.avg_pool2d(B_lab_map, 4) B_lab = B_lab.view(batch_size, channel, -1) B_lab = B_lab.permute(0, 2, 1) # 2*1936*channel # multiply the corr map with color y = torch.matmul(f_div_C, B_lab) # 2*1936*channel y = y.permute(0, 2, 1).contiguous() y = y.view(batch_size, channel, feature_height, feature_width) # 2*3*44*44 y = self.upsampling(y) similarity_map = self.upsampling(similarity_map) return y, similarity_map