import torch.nn as nn from typing import List, Tuple class SharedMLP(nn.Sequential): def __init__( self, args: List[int], *, bn: bool = False, activation=nn.ReLU(inplace=True), preact: bool = False, first: bool = False, name: str = "", instance_norm: bool = False, ): super().__init__() for i in range(len(args) - 1): self.add_module( name + 'layer{}'.format(i), Conv2d( args[i], args[i + 1], bn=(not first or not preact or (i != 0)) and bn, activation=activation if (not first or not preact or (i != 0)) else None, preact=preact, instance_norm=instance_norm ) ) class _ConvBase(nn.Sequential): def __init__( self, in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=None, batch_norm=None, bias=True, preact=False, name="", instance_norm=False, instance_norm_func=None ): super().__init__() bias = bias and (not bn) conv_unit = conv( in_size, out_size, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) init(conv_unit.weight) if bias: nn.init.constant_(conv_unit.bias, 0) if bn: if not preact: bn_unit = batch_norm(out_size) else: bn_unit = batch_norm(in_size) if instance_norm: if not preact: in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) else: in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) if preact: if bn: self.add_module(name + 'bn', bn_unit) if activation is not None: self.add_module(name + 'activation', activation) if not bn and instance_norm: self.add_module(name + 'in', in_unit) self.add_module(name + 'conv', conv_unit) if not preact: if bn: self.add_module(name + 'bn', bn_unit) if activation is not None: self.add_module(name + 'activation', activation) if not bn and instance_norm: self.add_module(name + 'in', in_unit) class _BNBase(nn.Sequential): def __init__(self, in_size, batch_norm=None, name=""): super().__init__() self.add_module(name + "bn", batch_norm(in_size)) nn.init.constant_(self[0].weight, 1.0) nn.init.constant_(self[0].bias, 0) class BatchNorm1d(_BNBase): def __init__(self, in_size: int, *, name: str = ""): super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) class BatchNorm2d(_BNBase): def __init__(self, in_size: int, name: str = ""): super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) class Conv1d(_ConvBase): def __init__( self, in_size: int, out_size: int, *, kernel_size: int = 1, stride: int = 1, padding: int = 0, activation=nn.ReLU(inplace=True), bn: bool = False, init=nn.init.kaiming_normal_, bias: bool = True, preact: bool = False, name: str = "", instance_norm=False ): super().__init__( in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=nn.Conv1d, batch_norm=BatchNorm1d, bias=bias, preact=preact, name=name, instance_norm=instance_norm, instance_norm_func=nn.InstanceNorm1d ) class Conv2d(_ConvBase): def __init__( self, in_size: int, out_size: int, *, kernel_size: Tuple[int, int] = (1, 1), stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), activation=nn.ReLU(inplace=True), bn: bool = False, init=nn.init.kaiming_normal_, bias: bool = True, preact: bool = False, name: str = "", instance_norm=False ): super().__init__( in_size, out_size, kernel_size, stride, padding, activation, bn, init, conv=nn.Conv2d, batch_norm=BatchNorm2d, bias=bias, preact=preact, name=name, instance_norm=instance_norm, instance_norm_func=nn.InstanceNorm2d ) class FC(nn.Sequential): def __init__( self, in_size: int, out_size: int, *, activation=nn.ReLU(inplace=True), bn: bool = False, init=None, preact: bool = False, name: str = "" ): super().__init__() fc = nn.Linear(in_size, out_size, bias=not bn) if init is not None: init(fc.weight) if not bn: nn.init.constant(fc.bias, 0) if preact: if bn: self.add_module(name + 'bn', BatchNorm1d(in_size)) if activation is not None: self.add_module(name + 'activation', activation) self.add_module(name + 'fc', fc) if not preact: if bn: self.add_module(name + 'bn', BatchNorm1d(out_size)) if activation is not None: self.add_module(name + 'activation', activation)