|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as f |
|
from torch.nn import init |
|
import math |
|
|
|
|
|
class ConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None, |
|
BN_momentum=0.1): |
|
super(ConvLayer, self).__init__() |
|
|
|
bias = False if norm == 'BN' else True |
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) |
|
if activation is not None: |
|
self.activation = getattr(torch, activation) |
|
else: |
|
self.activation = None |
|
|
|
self.norm = norm |
|
if norm == 'BN': |
|
self.norm_layer = nn.BatchNorm2d(out_channels, momentum=BN_momentum) |
|
elif norm == 'IN': |
|
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) |
|
|
|
def forward(self, x): |
|
out = self.conv2d(x) |
|
|
|
if self.norm in ['BN', 'IN']: |
|
out = self.norm_layer(out) |
|
|
|
if self.activation is not None: |
|
out = self.activation(out) |
|
|
|
return out |
|
|
|
|
|
class TransposedConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): |
|
super(TransposedConvLayer, self).__init__() |
|
|
|
bias = False if norm == 'BN' else True |
|
self.transposed_conv2d = nn.ConvTranspose2d( |
|
in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=1, bias=bias) |
|
|
|
if activation is not None: |
|
self.activation = getattr(torch, activation) |
|
else: |
|
self.activation = None |
|
|
|
self.norm = norm |
|
if norm == 'BN': |
|
self.norm_layer = nn.BatchNorm2d(out_channels) |
|
elif norm == 'IN': |
|
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) |
|
|
|
def forward(self, x): |
|
out = self.transposed_conv2d(x) |
|
|
|
if self.norm in ['BN', 'IN']: |
|
out = self.norm_layer(out) |
|
|
|
if self.activation is not None: |
|
out = self.activation(out) |
|
|
|
return out |
|
|
|
|
|
class UpsampleConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation='relu', norm=None): |
|
super(UpsampleConvLayer, self).__init__() |
|
|
|
bias = False if norm == 'BN' else True |
|
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) |
|
|
|
if activation is not None: |
|
self.activation = getattr(torch, activation) |
|
else: |
|
self.activation = None |
|
|
|
self.norm = norm |
|
if norm == 'BN': |
|
self.norm_layer = nn.BatchNorm2d(out_channels) |
|
elif norm == 'IN': |
|
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) |
|
|
|
def forward(self, x): |
|
x_upsampled = f.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) |
|
out = self.conv2d(x_upsampled) |
|
|
|
if self.norm in ['BN', 'IN']: |
|
out = self.norm_layer(out) |
|
|
|
if self.activation is not None: |
|
out = self.activation(out) |
|
|
|
return out |
|
|
|
|
|
class RecurrentConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, |
|
recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1): |
|
super(RecurrentConvLayer, self).__init__() |
|
|
|
assert(recurrent_block_type in ['convlstm', 'convgru']) |
|
self.recurrent_block_type = recurrent_block_type |
|
if self.recurrent_block_type == 'convlstm': |
|
RecurrentBlock = ConvLSTM |
|
else: |
|
RecurrentBlock = ConvGRU |
|
|
|
|
|
|
|
self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) |
|
|
|
def forward(self, x, prev_state): |
|
|
|
state = self.recurrent_block(x, prev_state) |
|
x = state[0] if self.recurrent_block_type == 'convlstm' else state |
|
return x, state |
|
|
|
class Recurrent2ConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, |
|
recurrent_block_type='convlstm', activation='relu', norm=None, BN_momentum=0.1): |
|
super(Recurrent2ConvLayer, self).__init__() |
|
|
|
assert(recurrent_block_type in ['convlstm', 'convgru']) |
|
self.recurrent_block_type = recurrent_block_type |
|
if self.recurrent_block_type == 'convlstm': |
|
RecurrentBlock = ConvLSTM |
|
else: |
|
RecurrentBlock = ConvGRU |
|
|
|
self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm, |
|
BN_momentum=BN_momentum) |
|
self.recurrent_block = RecurrentBlock(input_size=out_channels, hidden_size=out_channels, kernel_size=3) |
|
|
|
def forward(self, x, prev_state): |
|
x = self.conv(x) |
|
state = self.recurrent_block(x, prev_state) |
|
x = state[0] if self.recurrent_block_type == 'convlstm' else state |
|
return x, state |
|
|
|
|
|
class RecurrentPhasedConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, |
|
activation='relu', norm=None, BN_momentum=0.1): |
|
super(RecurrentPhasedConvLayer, self).__init__() |
|
|
|
self.conv = ConvLayer(in_channels, out_channels, kernel_size, stride, padding, activation, norm, |
|
BN_momentum=BN_momentum) |
|
self.recurrent_block = PhasedConvLSTMCell(input_channels=out_channels, hidden_channels=out_channels, kernel_size=3) |
|
|
|
def forward(self, x, times, prev_state): |
|
x = self.conv(x) |
|
x, state = self.recurrent_block(x, times, prev_state) |
|
return x, state |
|
|
|
|
|
class DownsampleRecurrentConvLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size=3, recurrent_block_type='convlstm', padding=0, activation='relu'): |
|
super(DownsampleRecurrentConvLayer, self).__init__() |
|
|
|
self.activation = getattr(torch, activation) |
|
|
|
assert(recurrent_block_type in ['convlstm', 'convgru']) |
|
self.recurrent_block_type = recurrent_block_type |
|
if self.recurrent_block_type == 'convlstm': |
|
RecurrentBlock = ConvLSTM |
|
else: |
|
RecurrentBlock = ConvGRU |
|
self.recurrent_block = RecurrentBlock(input_size=in_channels, hidden_size=out_channels, kernel_size=kernel_size) |
|
|
|
def forward(self, x, prev_state): |
|
state = self.recurrent_block(x, prev_state) |
|
x = state[0] if self.recurrent_block_type == 'convlstm' else state |
|
x = f.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) |
|
return self.activation(x), state |
|
|
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, stride=1, downsample=None, norm=None, |
|
BN_momentum=0.1): |
|
super(ResidualBlock, self).__init__() |
|
bias = False if norm == 'BN' else True |
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=bias) |
|
self.norm = norm |
|
if norm == 'BN': |
|
self.bn1 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) |
|
self.bn2 = nn.BatchNorm2d(out_channels, momentum=BN_momentum) |
|
elif norm == 'IN': |
|
self.bn1 = nn.InstanceNorm2d(out_channels) |
|
self.bn2 = nn.InstanceNorm2d(out_channels) |
|
|
|
self.relu = nn.ReLU(inplace=False) |
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias) |
|
self.downsample = downsample |
|
|
|
def forward(self, x): |
|
residual = x |
|
out = self.conv1(x) |
|
if self.norm in ['BN', 'IN']: |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
out = self.conv2(out) |
|
if self.norm in ['BN', 'IN']: |
|
out = self.bn2(out) |
|
|
|
if self.downsample: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
return out |
|
|
|
|
|
class PhasedLSTMCell(nn.Module): |
|
"""Phased LSTM recurrent network cell. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size, |
|
leak=0.001, |
|
ratio_on=0.1, |
|
period_init_min=0.02, |
|
period_init_max=50.0 |
|
): |
|
""" |
|
Args: |
|
hidden_size: int, The number of units in the Phased LSTM cell. |
|
leak: float or scalar float Tensor with value in [0, 1]. Leak applied |
|
during training. |
|
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the |
|
period during which the gates are open. |
|
period_init_min: float or scalar float Tensor. With value > 0. |
|
Minimum value of the initialized period. |
|
The period values are initialized by drawing from the distribution: |
|
e^U(log(period_init_min), log(period_init_max)) |
|
Where U(.,.) is the uniform distribution. |
|
period_init_max: float or scalar float Tensor. |
|
With value > period_init_min. Maximum value of the initialized period. |
|
""" |
|
super().__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.ratio_on = ratio_on |
|
self.leak = leak |
|
|
|
|
|
period = torch.exp( |
|
torch.Tensor(hidden_size).uniform_( |
|
math.log(period_init_min), math.log(period_init_max) |
|
) |
|
) |
|
|
|
self.register_parameter("tau", nn.Parameter(period)) |
|
|
|
phase = torch.Tensor(hidden_size).uniform_() * period |
|
self.register_parameter("phase", nn.Parameter(phase)) |
|
self.phase.requires_grad = True |
|
self.tau.requires_grad = True |
|
|
|
|
|
def _compute_phi(self, t): |
|
t_ = t.view(-1, 1).repeat(1, self.hidden_size) |
|
phase_ = self.phase.view(1, -1).repeat(t.shape[0], 1) |
|
tau_ = self.tau.view(1, -1).repeat(t.shape[0], 1) |
|
tau_.to(t_.device) |
|
phase_.to(t_.device) |
|
phi = self._mod((t_ - phase_), tau_) |
|
phi = torch.abs(phi) / tau_ |
|
return phi |
|
|
|
def _mod(self, x, y): |
|
"""Modulo function that propagates x gradients.""" |
|
return x + (torch.fmod(x, y) - x).detach() |
|
|
|
def set_state(self, c, h): |
|
self.h0 = h |
|
self.c0 = c |
|
|
|
def forward(self, c_s, h_s, t): |
|
|
|
phi = self._compute_phi(t) |
|
|
|
|
|
k_up = 2 * phi / self.ratio_on |
|
k_down = 2 - k_up |
|
k_closed = self.leak * phi |
|
|
|
k = torch.where(phi < self.ratio_on, k_down, k_closed) |
|
k = torch.where(phi < 0.5 * self.ratio_on, k_up, k) |
|
k = k.view(c_s.shape[0], -1) |
|
c_s_new = k * c_s + (1 - k) * self.c0 |
|
h_s_new = k * h_s + (1 - k) * self.h0 |
|
|
|
return h_s_new, c_s_new |
|
|
|
|
|
class ConvLSTM(nn.Module): |
|
"""Adapted from: https://github.com/Atcold/pytorch-CortexNet/blob/master/model/ConvLSTMCell.py """ |
|
|
|
def __init__(self, input_size, hidden_size, kernel_size): |
|
super(ConvLSTM, self).__init__() |
|
|
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
pad = kernel_size // 2 |
|
|
|
|
|
self.zero_tensors = {} |
|
|
|
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size, padding=pad) |
|
|
|
def forward(self, input_, prev_state=None): |
|
|
|
batch_size = input_.data.size()[0] |
|
spatial_size = input_.data.size()[2:] |
|
|
|
|
|
if prev_state is None: |
|
|
|
|
|
state_size = tuple([batch_size, self.hidden_size] + list(spatial_size)) |
|
if state_size not in self.zero_tensors: |
|
|
|
self.zero_tensors[state_size] = ( |
|
torch.zeros(state_size, dtype=input_.dtype).to(input_.device), |
|
torch.zeros(state_size, dtype=input_.dtype).to(input_.device) |
|
) |
|
|
|
prev_state = self.zero_tensors[tuple(state_size)] |
|
|
|
prev_hidden, prev_cell = prev_state |
|
|
|
|
|
stacked_inputs = torch.cat((input_, prev_hidden), 1) |
|
gates = self.Gates(stacked_inputs) |
|
|
|
|
|
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1) |
|
|
|
|
|
in_gate = torch.sigmoid(in_gate) |
|
remember_gate = torch.sigmoid(remember_gate) |
|
out_gate = torch.sigmoid(out_gate) |
|
|
|
|
|
cell_gate = torch.tanh(cell_gate) |
|
|
|
|
|
cell = (remember_gate * prev_cell) + (in_gate * cell_gate) |
|
hidden = out_gate * torch.tanh(cell) |
|
|
|
return hidden, cell |
|
|
|
|
|
class PhasedConvLSTMCell(nn.Module): |
|
def __init__( |
|
self, |
|
input_channels, |
|
hidden_channels, |
|
kernel_size=3 |
|
): |
|
super().__init__() |
|
self.hidden_channels = hidden_channels |
|
|
|
self.lstm = ConvLSTM( |
|
input_size=input_channels, |
|
hidden_size=hidden_channels, |
|
kernel_size=kernel_size |
|
) |
|
|
|
|
|
self.phased_cell = None |
|
self.hidden_size = None |
|
|
|
def forward(self, input, times, prev_state=None): |
|
|
|
|
|
|
|
|
|
B, C, H, W = input.shape |
|
|
|
if self.hidden_size is None: |
|
self.hidden_size = self.hidden_channels * W * H |
|
self.phased_cell = PhasedLSTMCell(hidden_size=self.hidden_size) |
|
self.phased_cell = self.phased_cell.to(input.device) |
|
self.phased_cell.requires_grad = True |
|
|
|
if prev_state is None: |
|
h0 = input.new_zeros((B, self.hidden_channels, H, W)) |
|
c0 = input.new_zeros((B, self.hidden_channels, H, W)) |
|
else: |
|
c0, h0 = prev_state |
|
|
|
self.phased_cell.set_state(c0.view(B, -1), h0.view(B, -1)) |
|
|
|
c_t, h_t = self.lstm(input, (c0, h0)) |
|
|
|
|
|
(c_s, h_s) = self.phased_cell(c_t.view(B, -1), h_t.view(B, -1), times) |
|
|
|
|
|
c_s = c_s.view(B, -1, H, W) |
|
h_s = h_s.view(B, -1, H, W) |
|
|
|
return h_t, (c_s, h_s) |
|
|
|
|
|
class ConvGRU(nn.Module): |
|
""" |
|
Generate a convolutional GRU cell |
|
Adapted from: https://github.com/jacobkimmel/pytorch_convgru/blob/master/convgru.py |
|
""" |
|
|
|
def __init__(self, input_size, hidden_size, kernel_size): |
|
super().__init__() |
|
padding = kernel_size // 2 |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.reset_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) |
|
self.update_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) |
|
self.out_gate = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size, padding=padding) |
|
|
|
init.orthogonal_(self.reset_gate.weight) |
|
init.orthogonal_(self.update_gate.weight) |
|
init.orthogonal_(self.out_gate.weight) |
|
init.constant_(self.reset_gate.bias, 0.) |
|
init.constant_(self.update_gate.bias, 0.) |
|
init.constant_(self.out_gate.bias, 0.) |
|
|
|
def forward(self, input_, prev_state): |
|
|
|
|
|
batch_size = input_.data.size()[0] |
|
spatial_size = input_.data.size()[2:] |
|
|
|
|
|
if prev_state is None: |
|
state_size = [batch_size, self.hidden_size] + list(spatial_size) |
|
prev_state = torch.zeros(state_size, dtype=input_.dtype).to(input_.device) |
|
|
|
|
|
stacked_inputs = torch.cat([input_, prev_state], dim=1) |
|
update = torch.sigmoid(self.update_gate(stacked_inputs)) |
|
reset = torch.sigmoid(self.reset_gate(stacked_inputs)) |
|
out_inputs = torch.tanh(self.out_gate(torch.cat([input_, prev_state * reset], dim=1))) |
|
new_state = prev_state * (1 - update) + out_inputs * update |
|
|
|
return new_state |
|
|
|
|
|
class RecurrentResidualLayer(nn.Module): |
|
def __init__(self, in_channels, out_channels, |
|
recurrent_block_type='convlstm', norm=None, BN_momentum=0.1): |
|
super(RecurrentResidualLayer, self).__init__() |
|
|
|
assert(recurrent_block_type in ['convlstm', 'convgru']) |
|
self.recurrent_block_type = recurrent_block_type |
|
if self.recurrent_block_type == 'convlstm': |
|
RecurrentBlock = ConvLSTM |
|
else: |
|
RecurrentBlock = ConvGRU |
|
self.conv = ResidualBlock(in_channels=in_channels, |
|
out_channels=out_channels, |
|
norm=norm, |
|
BN_momentum=BN_momentum) |
|
self.recurrent_block = RecurrentBlock(input_size=out_channels, |
|
hidden_size=out_channels, |
|
kernel_size=3) |
|
|
|
def forward(self, x, prev_state): |
|
x = self.conv(x) |
|
state = self.recurrent_block(x, prev_state) |
|
x = state[0] if self.recurrent_block_type == 'convlstm' else state |
|
return x, state |
|
|