|
""" Misc classes """ |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
|
|
class Elementwise(nn.ModuleList): |
|
""" |
|
A simple network container. |
|
Parameters are a list of modules. |
|
Inputs are a 3d Tensor whose last dimension is the same length |
|
as the list. |
|
Outputs are the result of applying modules to inputs elementwise. |
|
An optional merge parameter allows the outputs to be reduced to a |
|
single Tensor. |
|
""" |
|
|
|
def __init__(self, merge=None, *args): |
|
assert merge in [None, 'first', 'concat', 'sum', 'mlp'] |
|
self.merge = merge |
|
super(Elementwise, self).__init__(*args) |
|
|
|
def forward(self, inputs): |
|
inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] |
|
assert len(self) == len(inputs_) |
|
outputs = [f(x) for f, x in zip(self, inputs_)] |
|
if self.merge == 'first': |
|
return outputs[0] |
|
elif self.merge == 'concat' or self.merge == 'mlp': |
|
return torch.cat(outputs, 2) |
|
elif self.merge == 'sum': |
|
return sum(outputs) |
|
else: |
|
return outputs |
|
|
|
|
|
class Cast(nn.Module): |
|
""" |
|
Basic layer that casts its input to a specific data type. The same tensor |
|
is returned if the data type is already correct. |
|
""" |
|
|
|
def __init__(self, dtype): |
|
super(Cast, self).__init__() |
|
self._dtype = dtype |
|
|
|
def forward(self, x): |
|
return x.to(self._dtype) |
|
|