Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from typing import List, Tuple | |
class SharedMLP(nn.Sequential): | |
def __init__( | |
self, | |
args: List[int], | |
*, | |
bn: bool = False, | |
activation=nn.ReLU(inplace=True), | |
preact: bool = False, | |
first: bool = False, | |
name: str = "", | |
instance_norm: bool = False, | |
): | |
super().__init__() | |
for i in range(len(args) - 1): | |
self.add_module( | |
name + 'layer{}'.format(i), | |
Conv2d( | |
args[i], | |
args[i + 1], | |
bn=(not first or not preact or (i != 0)) and bn, | |
activation=activation | |
if (not first or not preact or (i != 0)) else None, | |
preact=preact, | |
instance_norm=instance_norm | |
) | |
) | |
class _ConvBase(nn.Sequential): | |
def __init__( | |
self, | |
in_size, | |
out_size, | |
kernel_size, | |
stride, | |
padding, | |
activation, | |
bn, | |
init, | |
conv=None, | |
batch_norm=None, | |
bias=True, | |
preact=False, | |
name="", | |
instance_norm=False, | |
instance_norm_func=None | |
): | |
super().__init__() | |
bias = bias and (not bn) | |
conv_unit = conv( | |
in_size, | |
out_size, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias | |
) | |
init(conv_unit.weight) | |
if bias: | |
nn.init.constant_(conv_unit.bias, 0) | |
if bn: | |
if not preact: | |
bn_unit = batch_norm(out_size) | |
else: | |
bn_unit = batch_norm(in_size) | |
if instance_norm: | |
if not preact: | |
in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) | |
else: | |
in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) | |
if preact: | |
if bn: | |
self.add_module(name + 'bn', bn_unit) | |
if activation is not None: | |
self.add_module(name + 'activation', activation) | |
if not bn and instance_norm: | |
self.add_module(name + 'in', in_unit) | |
self.add_module(name + 'conv', conv_unit) | |
if not preact: | |
if bn: | |
self.add_module(name + 'bn', bn_unit) | |
if activation is not None: | |
self.add_module(name + 'activation', activation) | |
if not bn and instance_norm: | |
self.add_module(name + 'in', in_unit) | |
class _BNBase(nn.Sequential): | |
def __init__(self, in_size, batch_norm=None, name=""): | |
super().__init__() | |
self.add_module(name + "bn", batch_norm(in_size)) | |
nn.init.constant_(self[0].weight, 1.0) | |
nn.init.constant_(self[0].bias, 0) | |
class BatchNorm1d(_BNBase): | |
def __init__(self, in_size: int, *, name: str = ""): | |
super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) | |
class BatchNorm2d(_BNBase): | |
def __init__(self, in_size: int, name: str = ""): | |
super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) | |
class Conv1d(_ConvBase): | |
def __init__( | |
self, | |
in_size: int, | |
out_size: int, | |
*, | |
kernel_size: int = 1, | |
stride: int = 1, | |
padding: int = 0, | |
activation=nn.ReLU(inplace=True), | |
bn: bool = False, | |
init=nn.init.kaiming_normal_, | |
bias: bool = True, | |
preact: bool = False, | |
name: str = "", | |
instance_norm=False | |
): | |
super().__init__( | |
in_size, | |
out_size, | |
kernel_size, | |
stride, | |
padding, | |
activation, | |
bn, | |
init, | |
conv=nn.Conv1d, | |
batch_norm=BatchNorm1d, | |
bias=bias, | |
preact=preact, | |
name=name, | |
instance_norm=instance_norm, | |
instance_norm_func=nn.InstanceNorm1d | |
) | |
class Conv2d(_ConvBase): | |
def __init__( | |
self, | |
in_size: int, | |
out_size: int, | |
*, | |
kernel_size: Tuple[int, int] = (1, 1), | |
stride: Tuple[int, int] = (1, 1), | |
padding: Tuple[int, int] = (0, 0), | |
activation=nn.ReLU(inplace=True), | |
bn: bool = False, | |
init=nn.init.kaiming_normal_, | |
bias: bool = True, | |
preact: bool = False, | |
name: str = "", | |
instance_norm=False | |
): | |
super().__init__( | |
in_size, | |
out_size, | |
kernel_size, | |
stride, | |
padding, | |
activation, | |
bn, | |
init, | |
conv=nn.Conv2d, | |
batch_norm=BatchNorm2d, | |
bias=bias, | |
preact=preact, | |
name=name, | |
instance_norm=instance_norm, | |
instance_norm_func=nn.InstanceNorm2d | |
) | |
class FC(nn.Sequential): | |
def __init__( | |
self, | |
in_size: int, | |
out_size: int, | |
*, | |
activation=nn.ReLU(inplace=True), | |
bn: bool = False, | |
init=None, | |
preact: bool = False, | |
name: str = "" | |
): | |
super().__init__() | |
fc = nn.Linear(in_size, out_size, bias=not bn) | |
if init is not None: | |
init(fc.weight) | |
if not bn: | |
nn.init.constant(fc.bias, 0) | |
if preact: | |
if bn: | |
self.add_module(name + 'bn', BatchNorm1d(in_size)) | |
if activation is not None: | |
self.add_module(name + 'activation', activation) | |
self.add_module(name + 'fc', fc) | |
if not preact: | |
if bn: | |
self.add_module(name + 'bn', BatchNorm1d(out_size)) | |
if activation is not None: | |
self.add_module(name + 'activation', activation) | |