def3 / scripts /deforum_helpers /src /adabins /unet_adaptive_bins.py
ddoc's picture
Upload 188 files
e61bb9a
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from pathlib import Path
from .miniViT import mViT
from modules.shared import opts
class UpSampleBN(nn.Module):
def __init__(self, skip_input, output_features):
super(UpSampleBN, self).__init__()
self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_features),
nn.LeakyReLU(),
nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(output_features),
nn.LeakyReLU())
def forward(self, x, concat_with):
up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
f = torch.cat([up_x, concat_with], dim=1)
return self._net(f)
class DecoderBN(nn.Module):
def __init__(self, num_features=2048, num_classes=1, bottleneck_features=2048):
super(DecoderBN, self).__init__()
features = int(num_features)
self.conv2 = nn.Conv2d(bottleneck_features, features, kernel_size=1, stride=1, padding=1)
self.up1 = UpSampleBN(skip_input=features // 1 + 112 + 64, output_features=features // 2)
self.up2 = UpSampleBN(skip_input=features // 2 + 40 + 24, output_features=features // 4)
self.up3 = UpSampleBN(skip_input=features // 4 + 24 + 16, output_features=features // 8)
self.up4 = UpSampleBN(skip_input=features // 8 + 16 + 8, output_features=features // 16)
# self.up5 = UpSample(skip_input=features // 16 + 3, output_features=features//16)
self.conv3 = nn.Conv2d(features // 16, num_classes, kernel_size=3, stride=1, padding=1)
# self.act_out = nn.Softmax(dim=1) if output_activation == 'softmax' else nn.Identity()
def forward(self, features):
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[
11]
x_d0 = self.conv2(x_block4)
x_d1 = self.up1(x_d0, x_block3)
x_d2 = self.up2(x_d1, x_block2)
x_d3 = self.up3(x_d2, x_block1)
x_d4 = self.up4(x_d3, x_block0)
# x_d5 = self.up5(x_d4, features[0])
out = self.conv3(x_d4)
# out = self.act_out(out)
# if with_features:
# return out, features[-1]
# elif with_intermediate:
# return out, [x_block0, x_block1, x_block2, x_block3, x_block4, x_d1, x_d2, x_d3, x_d4]
return out
class Encoder(nn.Module):
def __init__(self, backend):
super(Encoder, self).__init__()
self.original_model = backend
def forward(self, x):
features = [x]
for k, v in self.original_model._modules.items():
if (k == 'blocks'):
for ki, vi in v._modules.items():
features.append(vi(features[-1]))
else:
features.append(v(features[-1]))
return features
class UnetAdaptiveBins(nn.Module):
def __init__(self, backend, n_bins=100, min_val=0.1, max_val=10, norm='linear'):
super(UnetAdaptiveBins, self).__init__()
self.num_classes = n_bins
self.min_val = min_val
self.max_val = max_val
self.encoder = Encoder(backend)
self.adaptive_bins_layer = mViT(128, n_query_channels=128, patch_size=16,
dim_out=n_bins,
embedding_dim=128, norm=norm)
self.decoder = DecoderBN(num_classes=128)
self.conv_out = nn.Sequential(nn.Conv2d(128, n_bins, kernel_size=1, stride=1, padding=0),
nn.Softmax(dim=1))
def forward(self, x, **kwargs):
unet_out = self.decoder(self.encoder(x), **kwargs)
bin_widths_normed, range_attention_maps = self.adaptive_bins_layer(unet_out)
out = self.conv_out(range_attention_maps)
# Post process
# n, c, h, w = out.shape
# hist = torch.sum(out.view(n, c, h * w), dim=2) / (h * w) # not used for training
bin_widths = (self.max_val - self.min_val) * bin_widths_normed # .shape = N, dim_out
bin_widths = nn.functional.pad(bin_widths, (1, 0), mode='constant', value=self.min_val)
bin_edges = torch.cumsum(bin_widths, dim=1)
centers = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
n, dout = centers.size()
centers = centers.view(n, dout, 1, 1)
pred = torch.sum(out * centers, dim=1, keepdim=True)
return bin_edges, pred
def get_1x_lr_params(self): # lr/10 learning rate
return self.encoder.parameters()
def get_10x_lr_params(self): # lr learning rate
modules = [self.decoder, self.adaptive_bins_layer, self.conv_out]
for m in modules:
yield from m.parameters()
@classmethod
def build(cls, n_bins, **kwargs):
DEBUG_MODE = opts.data.get("deforum_debug_mode_enabled", False)
basemodel_name = 'tf_efficientnet_b5_ap'
print('Loading AdaBins model...')
predicted_torch_model_cache_path = str(Path.home()) + '\\.cache\\torch\\hub\\rwightman_gen-efficientnet-pytorch_master'
predicted_gep_cache_testilfe = Path(predicted_torch_model_cache_path + '\\hubconf.py')
#print(f"predicted_gep_cache_testilfe: {predicted_gep_cache_testilfe}")
# try to fetch the models from cache, and only if it can't be find, download from the internet (to enable offline usage)
if os.path.isfile(predicted_gep_cache_testilfe):
basemodel = torch.hub.load(predicted_torch_model_cache_path, basemodel_name, pretrained=True, source = 'local')
else:
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
if DEBUG_MODE:
print('Done.')
# Remove last layer
if DEBUG_MODE:
print('Removing last two layers (global_pool & classifier).')
basemodel.global_pool = nn.Identity()
basemodel.classifier = nn.Identity()
# Building Encoder-Decoder model
if DEBUG_MODE:
print('Building Encoder-Decoder model..', end='')
m = cls(basemodel, n_bins=n_bins, **kwargs)
if DEBUG_MODE:
print('Done.')
return m
if __name__ == '__main__':
model = UnetAdaptiveBins.build(100)
x = torch.rand(2, 3, 480, 640)
bins, pred = model(x)
print(bins.shape, pred.shape)