import torch import torch.nn as nn def count_params(model): total_params = sum(p.numel() for p in model.parameters()) return total_params class ActNorm(nn.Module): def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): assert affine super().__init__() self.logdet = logdet self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.allow_reverse_init = allow_reverse_init self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) def initialize(self, input): with torch.no_grad(): flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3) .permute(1, 0, 2, 3) ) self.loc.data.copy_(-mean) self.scale.data.copy_(1 / (std + 1e-6)) def forward(self, input, reverse=False): if reverse: return self.reverse(input) if len(input.shape) == 2: input = input[:,:,None,None] squeeze = True else: squeeze = False _, _, height, width = input.shape if self.training and self.initialized.item() == 0: self.initialize(input) self.initialized.fill_(1) h = self.scale * (input + self.loc) if squeeze: h = h.squeeze(-1).squeeze(-1) if self.logdet: log_abs = torch.log(torch.abs(self.scale)) logdet = height*width*torch.sum(log_abs) logdet = logdet * torch.ones(input.shape[0]).to(input) return h, logdet return h def reverse(self, output): if self.training and self.initialized.item() == 0: if not self.allow_reverse_init: raise RuntimeError( "Initializing ActNorm in reverse direction is " "disabled by default. Use allow_reverse_init=True to enable." ) else: self.initialize(output) self.initialized.fill_(1) if len(output.shape) == 2: output = output[:,:,None,None] squeeze = True else: squeeze = False h = output / self.scale - self.loc if squeeze: h = h.squeeze(-1).squeeze(-1) return h class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class Labelator(AbstractEncoder): """Net2Net Interface for Class-Conditional Model""" def __init__(self, n_classes, quantize_interface=True): super().__init__() self.n_classes = n_classes self.quantize_interface = quantize_interface def encode(self, c): c = c[:,None] if self.quantize_interface: return c, None, [None, None, c.long()] return c class SOSProvider(AbstractEncoder): # for unconditional training def __init__(self, sos_token, quantize_interface=True): super().__init__() self.sos_token = sos_token self.quantize_interface = quantize_interface def encode(self, x): # get batch size from data and replicate sos_token c = torch.ones(x.shape[0], 1)*self.sos_token c = c.long().to(x.device) if self.quantize_interface: return c, None, [None, None, c] return c def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag