File size: 1,486 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
""" Misc classes """
import torch
import torch.nn as nn
# At the moment this class is only used by embeddings.Embeddings look-up tables
class Elementwise(nn.ModuleList):
"""
A simple network container.
Parameters are a list of modules.
Inputs are a 3d Tensor whose last dimension is the same length
as the list.
Outputs are the result of applying modules to inputs elementwise.
An optional merge parameter allows the outputs to be reduced to a
single Tensor.
"""
def __init__(self, merge=None, *args):
assert merge in [None, 'first', 'concat', 'sum', 'mlp']
self.merge = merge
super(Elementwise, self).__init__(*args)
def forward(self, inputs):
inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)]
assert len(self) == len(inputs_)
outputs = [f(x) for f, x in zip(self, inputs_)]
if self.merge == 'first':
return outputs[0]
elif self.merge == 'concat' or self.merge == 'mlp':
return torch.cat(outputs, 2)
elif self.merge == 'sum':
return sum(outputs)
else:
return outputs
class Cast(nn.Module):
"""
Basic layer that casts its input to a specific data type. The same tensor
is returned if the data type is already correct.
"""
def __init__(self, dtype):
super(Cast, self).__init__()
self._dtype = dtype
def forward(self, x):
return x.to(self._dtype)
|