Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from utils.misc import normalize | |
class build_lut_transform(nn.Module): | |
def __init__(self, input_dim, lut_dim, input_resolution, opt): | |
super().__init__() | |
self.lut_dim = lut_dim | |
self.opt = opt | |
# self.compress_layer = nn.Linear(input_resolution, 1) | |
self.transform_layers = nn.Sequential( | |
nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True), | |
# nn.BatchNorm1d(3 * lut_dim ** 3, affine=False), | |
nn.ReLU(inplace=True), | |
nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True), | |
) | |
self.transform_layers[-1].apply(lambda m: hyper_weight_init(m)) | |
def forward(self, composite_image, fg_appearance_features, bg_appearance_features): | |
composite_image = normalize(composite_image, self.opt, 'inv') | |
features = fg_appearance_features | |
lut_params = self.transform_layers(features) | |
fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim) | |
lut_transform_image = torch.stack( | |
[TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0) | |
return fit_3DLUT, normalize(lut_transform_image, self.opt) | |
def TrilinearInterpolation(LUT, img): | |
img = (img - 0.5) * 2. | |
img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1) | |
# Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW. | |
LUT = LUT[None] | |
# grid sample | |
result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True) | |
# drop added dimensions and permute back | |
result = result[:, :, 0] | |
return result | |
def hyper_weight_init(m): | |
if hasattr(m, 'weight'): | |
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in') | |
m.weight.data = m.weight.data / 1.e2 | |
if hasattr(m, 'bias'): | |
with torch.no_grad(): | |
m.bias.uniform_(0., 1.) | |