""" Misc classes """ import torch import torch.nn as nn # At the moment this class is only used by embeddings.Embeddings look-up tables class Elementwise(nn.ModuleList): """ A simple network container. Parameters are a list of modules. Inputs are a 3d Tensor whose last dimension is the same length as the list. Outputs are the result of applying modules to inputs elementwise. An optional merge parameter allows the outputs to be reduced to a single Tensor. """ def __init__(self, merge=None, *args): assert merge in [None, 'first', 'concat', 'sum', 'mlp'] self.merge = merge super(Elementwise, self).__init__(*args) def forward(self, inputs): inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] assert len(self) == len(inputs_) outputs = [f(x) for f, x in zip(self, inputs_)] if self.merge == 'first': return outputs[0] elif self.merge == 'concat' or self.merge == 'mlp': return torch.cat(outputs, 2) elif self.merge == 'sum': return sum(outputs) else: return outputs class Cast(nn.Module): """ Basic layer that casts its input to a specific data type. The same tensor is returned if the data type is already correct. """ def __init__(self, dtype): super(Cast, self).__init__() self._dtype = dtype def forward(self, x): return x.to(self._dtype)