Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as f | |
from torch.nn import init | |
import math | |
class ConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None, | |
BN_momentum=0.1): | |
super(ConvLayer, self).__init__() | |
bias = False if norm == 'BN' else True | |
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) | |
if activation is not None: | |
self.activation = getattr(torch, activation) | |
else: | |
self.activation = None | |
self.norm = norm | |
if norm == 'BN': | |
self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum) | |
elif norm == 'IN': | |
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) | |
def forward(self, x): | |
out = self.conv2d(x) | |
if self.norm in ['BN', 'IN']: | |
out = self.norm_layer(out) | |
if self.activation is not None: | |
out = self.activation(out) | |
return out | |
class TransposedConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): | |
super(TransposedConvLayer, self).__init__() | |
bias = False if norm == 'BN' else True | |
self.transposed_conv2d = nn.ConvTranspose2d( | |
in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias) | |
if activation is not None: | |
self.activation = getattr(torch, activation) | |
else: | |
self.activation = None | |
self.norm = norm | |
if norm == 'BN': | |
self.norm_layer = nn.BatchNorm2d(out_channels) | |
elif norm == 'IN': | |
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) | |
def forward(self, x): | |
out = self.transposed_conv2d(x) | |
if self.norm in ['BN', 'IN']: | |
out = self.norm_layer(out) | |
if self.activation is not None: | |
out = self.activation(out) | |
return out | |
class UpsampleConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): | |
super(UpsampleConvLayer, self).__init__() | |
bias = False if norm == 'BN' else True | |
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) | |
if activation is not None: | |
self.activation = getattr(torch, activation) | |
else: | |
self.activation = None | |
self.norm = norm | |
if norm == 'BN': | |
self.norm_layer = nn.BatchNorm2d(out_channels) | |
elif norm == 'IN': | |
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) | |
def forward(self, x): | |
x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) | |
out = self.conv2d(x_upsampled) | |
if self.norm in ['BN', 'IN']: | |
out = self.norm_layer(out) | |
if self.activation is not None: | |
out = self.activation(out) | |
return out | |
class RecurrentConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, | |
recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1): | |
super(RecurrentConvLayer, self).__init__() | |
assert(recurrent_block_type in ['convlstm', 'convgru']) | |
self.recurrent_block_type = recurrent_block_type | |
if self.recurrent_block_type == 'convlstm': | |
RecurrentBlock = ConvLSTM | |
else: | |
RecurrentBlock = ConvGRU | |
# self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm, | |
# BN_momentum=BN_momentum) | |
self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) | |
def forward(self, x, prev_state): | |
# x = self.conv(x) | |
state = self.recurrent_block(x, prev_state) | |
x = state[0] if self.recurrent_block_type == 'convlstm' else state | |
return x, state | |
class Recurrent2ConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, | |
recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1): | |
super(Recurrent2ConvLayer, self).__init__() | |
assert(recurrent_block_type in ['convlstm', 'convgru']) | |
self.recurrent_block_type = recurrent_block_type | |
if self.recurrent_block_type == 'convlstm': | |
RecurrentBlock = ConvLSTM | |
else: | |
RecurrentBlock = ConvGRU | |
self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm, | |
BN_momentum=BN_momentum) | |
self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) | |
def forward(self, x, prev_state): | |
x = self.conv(x) | |
state = self.recurrent_block(x, prev_state) | |
x = state[0] if self.recurrent_block_type == 'convlstm' else state | |
return x, state | |
class RecurrentPhasedConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, | |
activation='relu', norm=None, BN_momentum=0.1): | |
super(RecurrentPhasedConvLayer, self).__init__() | |
self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm, | |
BN_momentum=BN_momentum) | |
self.recurrent_block = PhasedConvLSTMCell(input_channels=out_channels, hidden_channels=out_channels, kernel_size=3) | |
def forward(self, x, times, prev_state): | |
x = self.conv(x) | |
x, state = self.recurrent_block(x, times, prev_state) | |
return x, state | |
class DownsampleRecurrentConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'): | |
super(DownsampleRecurrentConvLayer, self).__init__() | |
self.activation = getattr(torch, activation) | |
assert(recurrent_block_type in ['convlstm', 'convgru']) | |
self.recurrent_block_type = recurrent_block_type | |
if self.recurrent_block_type == 'convlstm': | |
RecurrentBlock = ConvLSTM | |
else: | |
RecurrentBlock = ConvGRU | |
self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size) | |
def forward(self, x, prev_state): | |
state = self.recurrent_block(x, prev_state) | |
x = state[0] if self.recurrent_block_type == 'convlstm' else state | |
x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) | |
return self.activation(x), state | |
# Residual block | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None, | |
BN_momentum=0.1): | |
super(ResidualBlock, self).__init__() | |
bias = False if norm == 'BN' else True | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) | |
self.norm = norm | |
if norm == 'BN': | |
self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) | |
self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) | |
elif norm == 'IN': | |
self.bn1 = nn.InstanceNorm2d(out_channels) | |
self.bn2 = nn.InstanceNorm2d(out_channels) | |
self.relu = nn.ReLU(inplace=False) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) | |
self.downsample = downsample | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
if self.norm in ['BN', 'IN']: | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
if self.norm in ['BN', 'IN']: | |
out = self.bn2(out) | |
if self.downsample: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class PhasedLSTMCell(nn.Module): | |
"""Phased LSTM recurrent network cell. | |
""" | |
def __init__( | |
self, | |
hidden_size, | |
leak=0.001, | |
ratio_on=0.1, | |
period_init_min=0.02, | |
period_init_max=50.0 | |
): | |
""" | |
Args: | |
hidden_size: int, The number of units in the Phased LSTM cell. | |
leak: float or scalar float Tensor with value in [0, 1]. Leak applied | |
during training. | |
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the | |
period during which the gates are open. | |
period_init_min: float or scalar float Tensor. With value > 0. | |
Minimum value of the initialized period. | |
The period values are initialized by drawing from the distribution: | |
e^U(log(period_init_min), log(period_init_max)) | |
Where U(.,.) is the uniform distribution. | |
period_init_max: float or scalar float Tensor. | |
With value > period_init_min. Maximum value of the initialized period. | |
""" | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.ratio_on = ratio_on | |
self.leak = leak | |
# initialize time-gating parameters | |
period = torch.exp( | |
torch.Tensor(hidden_size).uniform_( | |
math.log(period_init_min), math.log(period_init_max) | |
) | |
) | |
#self.tau = nn.Parameter(period) | |
self.register_parameter("tau", nn.Parameter(period)) | |
phase = torch.Tensor(hidden_size).uniform_() * period | |
self.register_parameter("phase", nn.Parameter(phase)) | |
self.phase.requires_grad = True | |
self.tau.requires_grad = True | |
#self.phase = nn.Parameter(phase) | |
def _compute_phi(self, t): | |
t_ = t.view(-1, 1).repeat(1, self.hidden_size) | |
phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1) | |
tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1) | |
tau_.to(t_.device) | |
phase_.to(t_.device) | |
phi = self._mod((t_ - phase_), tau_) | |
phi = torch.abs(phi) / tau_ | |
return phi | |
def _mod(self, x, y): | |
"""Modulo function that propagates x gradients.""" | |
return x + (torch.fmod(x, y) - x).detach() | |
def set_state(self, c, h): | |
self.h0 = h | |
self.c0 = c | |
def forward(self, c_s, h_s, t): | |
# print(c_s.size(), h_s.size(), t.size()) | |
phi = self._compute_phi(t) | |
# Phase-related augmentations | |
k_up = 2 * phi / self.ratio_on | |
k_down = 2 - k_up | |
k_closed = self.leak * phi | |
k = torch.where(phi < self.ratio_on, k_down, k_closed) | |
k = torch.where(phi < 0.5 * self.ratio_on, k_up, k) | |
k = k.view(c_s.shape[0], -1) | |
c_s_new = k * c_s + (1 - k) * self.c0 | |
h_s_new = k * h_s + (1 - k) * self.h0 | |
return h_s_new, c_s_new | |
class ConvLSTM(nn.Module): | |
"""Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """ | |
def __init__(self, input_size, hidden_size, kernel_size): | |
super(ConvLSTM, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
pad = kernel_size // 2 | |
# cache a tensor filled with zeros to avoid reallocating memory at each inference step if --no-recurrent is enabled | |
self.zero_tensors = {} | |
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) | |
def forward(self, input_, prev_state=None): | |
# get batch and spatial sizes | |
batch_size = input_.data.size()[0] | |
spatial_size = input_.data.size()[2:] | |
# generate empty prev_state, if None is provided | |
if prev_state is None: | |
# create the zero tensor if it has not been created already | |
state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) | |
if state_size not in self.zero_tensors: | |
# allocate a tensor with size `spatial_size`, filled with zero (if it has not been allocated already) | |
self.zero_tensors[state_size] = ( | |
torch.zeros(state_size, dtype=input_.dtype).to(input_.device), | |
torch.zeros(state_size, dtype=input_.dtype).to(input_.device) | |
) | |
prev_state = self.zero_tensors[tuple(state_size)] | |
prev_hidden, prev_cell = prev_state | |
# data size is [batch, channel, height, width] | |
stacked_inputs = torch.cat((input_, prev_hidden), 1) | |
gates = self.Gates(stacked_inputs) | |
# chunk across channel dimension | |
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) | |
# apply sigmoid non linearity | |
in_gate = torch.sigmoid(in_gate) | |
remember_gate = torch.sigmoid(remember_gate) | |
out_gate = torch.sigmoid(out_gate) | |
# apply tanh non linearity | |
cell_gate = torch.tanh(cell_gate) | |
# compute current cell and hidden state | |
cell = (remember_gate * prev_cell) + (in_gate * cell_gate) | |
hidden = out_gate * torch.tanh(cell) | |
return hidden, cell | |
class PhasedConvLSTMCell(nn.Module): | |
def __init__( | |
self, | |
input_channels, | |
hidden_channels, | |
kernel_size=3 | |
): | |
super().__init__() | |
self.hidden_channels = hidden_channels | |
self.lstm = ConvLSTM( | |
input_size=input_channels, | |
hidden_size=hidden_channels, | |
kernel_size=kernel_size | |
) | |
# as soon as spatial dimension is known, phased lstm cell is instantiated | |
self.phased_cell = None | |
self.hidden_size = None | |
def forward(self, input, times, prev_state=None): | |
# input: B x C x H x W | |
# times: B | |
# returns: output: B x C_out x H x W, prev_state: (B x C_out x H x W, B x C_out x H x W) | |
B, C, H, W = input.shape | |
if self.hidden_size is None: | |
self.hidden_size = self.hidden_channels * W * H | |
self.phased_cell = PhasedLSTMCell(hidden_size=self.hidden_size) | |
self.phased_cell = self.phased_cell.to(input.device) | |
self.phased_cell.requires_grad = True | |
if prev_state is None: | |
h0 = input.new_zeros((B, self.hidden_channels, H, W)) | |
c0 = input.new_zeros((B, self.hidden_channels, H, W)) | |
else: | |
c0, h0 = prev_state | |
self.phased_cell.set_state(c0.view(B, -1), h0.view(B, -1)) | |
c_t, h_t = self.lstm(input, (c0, h0)) | |
# reshape activation maps such that phased lstm can use them | |
(c_s, h_s) = self.phased_cell(c_t.view(B, -1), h_t.view(B, -1), times) | |
# reshape to feed to conv lstm | |
c_s = c_s.view(B, -1, H, W) | |
h_s = h_s.view(B, -1, H, W) | |
return h_t, (c_s, h_s) | |
class ConvGRU(nn.Module): | |
""" | |
Generate a convolutional GRU cell | |
Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py | |
""" | |
def __init__(self, input_size, hidden_size, kernel_size): | |
super().__init__() | |
padding = kernel_size // 2 | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) | |
init.orthogonal_(self.reset_gate.weight) | |
init.orthogonal_(self.update_gate.weight) | |
init.orthogonal_(self.out_gate.weight) | |
init.constant_(self.reset_gate.bias, 0.) | |
init.constant_(self.update_gate.bias, 0.) | |
init.constant_(self.out_gate.bias, 0.) | |
def forward(self, input_, prev_state): | |
# get batch and spatial sizes | |
batch_size = input_.data.size()[0] | |
spatial_size = input_.data.size()[2:] | |
# generate empty prev_state, if None is provided | |
if prev_state is None: | |
state_size = [batch_size, self.hidden_size] + list(spatial_size) | |
prev_state = torch.zeros(state_size, dtype=input_.dtype).to(input_.device) | |
# data size is [batch, channel, height, width] | |
stacked_inputs = torch.cat([input_, prev_state], dim=1) | |
update = torch.sigmoid(self.update_gate(stacked_inputs)) | |
reset = torch.sigmoid(self.reset_gate(stacked_inputs)) | |
out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) | |
new_state = prev_state * (1 - update) + out_inputs * update | |
return new_state | |
class RecurrentResidualLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, | |
recurrent_block_type='convlstm', norm=None, BN_momentum=0.1): | |
super(RecurrentResidualLayer, self).__init__() | |
assert(recurrent_block_type in ['convlstm', 'convgru']) | |
self.recurrent_block_type = recurrent_block_type | |
if self.recurrent_block_type == 'convlstm': | |
RecurrentBlock = ConvLSTM | |
else: | |
RecurrentBlock = ConvGRU | |
self.conv = ResidualBlock(in_channels=in_channels, | |
out_channels=out_channels, | |
norm=norm, | |
BN_momentum=BN_momentum) | |
self.recurrent_block = RecurrentBlock(input_size=out_channels, | |
hidden_size=out_channels, | |
kernel_size=3) | |
def forward(self, x, prev_state): | |
x = self.conv(x) | |
state = self.recurrent_block(x, prev_state) | |
x = state[0] if self.recurrent_block_type == 'convlstm' else state | |
return x, state |