# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
transpose last 2 dimensions of the input | |
""" | |
import torch.nn as nn | |
class TransposeLast(nn.Module): | |
def __init__(self, deconstruct_idx=None): | |
super().__init__() | |
self.deconstruct_idx = deconstruct_idx | |
def forward(self, x): | |
if self.deconstruct_idx is not None: | |
x = x[self.deconstruct_idx] | |
return x.transpose(-2, -1) | |