#!/usr/bin/python3 # -*- coding: utf-8 -*- import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import MultiheadAttention, GRU, Linear, LayerNorm, Dropout class FFN(nn.Module): def __init__(self, d_model, bidirectional=True, dropout=0): super(FFN, self).__init__() self.gru = GRU(d_model, d_model * 2, 1, bidirectional=bidirectional) if bidirectional: self.linear = Linear(d_model * 2 * 2, d_model) else: self.linear = Linear(d_model * 2, d_model) self.dropout = Dropout(dropout) def forward(self, x): self.gru.flatten_parameters() x, _ = self.gru(x) x = F.leaky_relu(x) x = self.dropout(x) x = self.linear(x) return x class TransformerBlock(nn.Module): def __init__(self, d_model, n_heads, bidirectional=True, dropout=0): super(TransformerBlock, self).__init__() self.norm1 = LayerNorm(d_model) self.attention = MultiheadAttention(d_model, n_heads, dropout=dropout) self.dropout1 = Dropout(dropout) self.norm2 = LayerNorm(d_model) self.ffn = FFN(d_model, bidirectional=bidirectional) self.dropout2 = Dropout(dropout) self.norm3 = LayerNorm(d_model) def forward(self, x, attn_mask=None, key_padding_mask=None): xt = self.norm1(x) xt, _ = self.attention(xt, xt, xt, attn_mask=attn_mask, key_padding_mask=key_padding_mask) x = x + self.dropout1(xt) xt = self.norm2(x) xt = self.ffn(xt) x = x + self.dropout2(xt) x = self.norm3(x) return x def main(): x = torch.randn(4, 64, 401, 201) b, c, t, f = x.size() x = x.permute(0, 3, 2, 1).contiguous().view(b, f * t, c) transformer = TransformerBlock(d_model=64, n_heads=4) x = transformer(x) x = x.view(b, f, t, c).permute(0, 3, 2, 1) print(x.size()) if __name__ == '__main__': main()