Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/fairseq/models/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fconv.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/lightconv.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/lstm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/masked_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/model_utils.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/transformer_align.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/text_to_speech/tts_transformer.py +454 -0
- fairseq/fairseq/models/transformer/__init__.py +50 -0
- fairseq/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc +0 -0
- fairseq/fairseq/models/transformer/transformer_base.py +193 -0
- fairseq/fairseq/models/transformer/transformer_config.py +341 -0
- fairseq/fairseq/models/transformer/transformer_decoder.py +474 -0
- fairseq/fairseq/models/transformer/transformer_decoder_aug.py +384 -0
- fairseq/fairseq/models/transformer/transformer_encoder.py +362 -0
- fairseq/fairseq/models/transformer/transformer_legacy.py +277 -0
- fairseq/fairseq/models/wav2vec/__init__.py +10 -0
- fairseq/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc +0 -0
- fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc +0 -0
fairseq/fairseq/models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (6.01 kB). View file
|
|
fairseq/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc
ADDED
Binary file (2.41 kB). View file
|
|
fairseq/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc
ADDED
Binary file (3.6 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc
ADDED
Binary file (3.74 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc
ADDED
Binary file (3.62 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc
ADDED
Binary file (4.85 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc
ADDED
Binary file (20.7 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fconv.cpython-310.pyc
ADDED
Binary file (19.1 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc
ADDED
Binary file (3.86 kB). View file
|
|
fairseq/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc
ADDED
Binary file (16.3 kB). View file
|
|
fairseq/fairseq/models/__pycache__/lightconv.cpython-310.pyc
ADDED
Binary file (27.5 kB). View file
|
|
fairseq/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc
ADDED
Binary file (7.03 kB). View file
|
|
fairseq/fairseq/models/__pycache__/lstm.cpython-310.pyc
ADDED
Binary file (18.7 kB). View file
|
|
fairseq/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc
ADDED
Binary file (4.38 kB). View file
|
|
fairseq/fairseq/models/__pycache__/masked_lm.cpython-310.pyc
ADDED
Binary file (10.1 kB). View file
|
|
fairseq/fairseq/models/__pycache__/model_utils.cpython-310.pyc
ADDED
Binary file (2.39 kB). View file
|
|
fairseq/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc
ADDED
Binary file (6.78 kB). View file
|
|
fairseq/fairseq/models/__pycache__/transformer_align.cpython-310.pyc
ADDED
Binary file (3.05 kB). View file
|
|
fairseq/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc
ADDED
Binary file (5.39 kB). View file
|
|
fairseq/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc
ADDED
Binary file (15.5 kB). View file
|
|
fairseq/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc
ADDED
Binary file (9.51 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc
ADDED
Binary file (2.92 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc
ADDED
Binary file (12.8 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc
ADDED
Binary file (3.84 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc
ADDED
Binary file (6.18 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc
ADDED
Binary file (12.4 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc
ADDED
Binary file (9.9 kB). View file
|
|
fairseq/fairseq/models/text_to_speech/tts_transformer.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.data.data_utils import lengths_to_padding_mask
|
14 |
+
from fairseq.models import (
|
15 |
+
FairseqEncoder,
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
FairseqIncrementalDecoder,
|
18 |
+
register_model,
|
19 |
+
register_model_architecture,
|
20 |
+
)
|
21 |
+
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
|
22 |
+
from fairseq.models.text_to_speech.tacotron2 import Postnet, Prenet
|
23 |
+
from fairseq.modules import (
|
24 |
+
FairseqDropout,
|
25 |
+
LayerNorm,
|
26 |
+
PositionalEmbedding,
|
27 |
+
TransformerDecoderLayer,
|
28 |
+
TransformerEncoderLayer,
|
29 |
+
)
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def encoder_init(m):
|
35 |
+
if isinstance(m, nn.Conv1d):
|
36 |
+
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
|
37 |
+
|
38 |
+
|
39 |
+
def Embedding(num_embeddings, embedding_dim):
|
40 |
+
m = nn.Embedding(num_embeddings, embedding_dim)
|
41 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
42 |
+
return m
|
43 |
+
|
44 |
+
|
45 |
+
class TTSTransformerEncoder(FairseqEncoder):
|
46 |
+
def __init__(self, args, src_dict, embed_speaker):
|
47 |
+
super().__init__(src_dict)
|
48 |
+
self.padding_idx = src_dict.pad()
|
49 |
+
self.embed_speaker = embed_speaker
|
50 |
+
self.spk_emb_proj = None
|
51 |
+
if embed_speaker is not None:
|
52 |
+
self.spk_emb_proj = nn.Linear(
|
53 |
+
args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
|
54 |
+
)
|
55 |
+
|
56 |
+
self.dropout_module = FairseqDropout(
|
57 |
+
p=args.dropout, module_name=self.__class__.__name__
|
58 |
+
)
|
59 |
+
self.embed_tokens = nn.Embedding(
|
60 |
+
len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
|
61 |
+
)
|
62 |
+
assert args.encoder_conv_kernel_size % 2 == 1
|
63 |
+
self.prenet = nn.ModuleList(
|
64 |
+
nn.Sequential(
|
65 |
+
nn.Conv1d(
|
66 |
+
args.encoder_embed_dim,
|
67 |
+
args.encoder_embed_dim,
|
68 |
+
kernel_size=args.encoder_conv_kernel_size,
|
69 |
+
padding=((args.encoder_conv_kernel_size - 1) // 2),
|
70 |
+
),
|
71 |
+
nn.BatchNorm1d(args.encoder_embed_dim),
|
72 |
+
nn.ReLU(),
|
73 |
+
nn.Dropout(args.encoder_dropout),
|
74 |
+
)
|
75 |
+
for _ in range(args.encoder_conv_layers)
|
76 |
+
)
|
77 |
+
self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
|
78 |
+
self.embed_positions = PositionalEmbedding(
|
79 |
+
args.max_source_positions, args.encoder_embed_dim, self.padding_idx
|
80 |
+
)
|
81 |
+
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
82 |
+
|
83 |
+
self.transformer_layers = nn.ModuleList(
|
84 |
+
TransformerEncoderLayer(args)
|
85 |
+
for _ in range(args.encoder_transformer_layers)
|
86 |
+
)
|
87 |
+
if args.encoder_normalize_before:
|
88 |
+
self.layer_norm = LayerNorm(args.encoder_embed_dim)
|
89 |
+
else:
|
90 |
+
self.layer_norm = None
|
91 |
+
|
92 |
+
self.apply(encoder_init)
|
93 |
+
|
94 |
+
def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
|
95 |
+
x = self.embed_tokens(src_tokens)
|
96 |
+
x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
|
97 |
+
for conv in self.prenet:
|
98 |
+
x = conv(x)
|
99 |
+
x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
|
100 |
+
x = self.prenet_proj(x)
|
101 |
+
|
102 |
+
padding_mask = src_tokens.eq(self.padding_idx)
|
103 |
+
positions = self.embed_positions(padding_mask)
|
104 |
+
x += self.pos_emb_alpha * positions
|
105 |
+
x = self.dropout_module(x)
|
106 |
+
|
107 |
+
# B x T x C -> T x B x C
|
108 |
+
x = x.transpose(0, 1)
|
109 |
+
|
110 |
+
for layer in self.transformer_layers:
|
111 |
+
x = layer(x, padding_mask)
|
112 |
+
|
113 |
+
if self.layer_norm is not None:
|
114 |
+
x = self.layer_norm(x)
|
115 |
+
|
116 |
+
if self.embed_speaker is not None:
|
117 |
+
seq_len, bsz, _ = x.size()
|
118 |
+
emb = self.embed_speaker(speaker).transpose(0, 1)
|
119 |
+
emb = emb.expand(seq_len, bsz, -1)
|
120 |
+
x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
|
121 |
+
|
122 |
+
return {
|
123 |
+
"encoder_out": [x], # T x B x C
|
124 |
+
"encoder_padding_mask": [padding_mask]
|
125 |
+
if padding_mask.any()
|
126 |
+
else [], # B x T
|
127 |
+
"encoder_embedding": [], # B x T x C
|
128 |
+
"encoder_states": [], # List[T x B x C]
|
129 |
+
"src_tokens": [],
|
130 |
+
"src_lengths": [],
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
def decoder_init(m):
|
135 |
+
if isinstance(m, torch.nn.Conv1d):
|
136 |
+
nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
|
137 |
+
|
138 |
+
|
139 |
+
class TTSTransformerDecoder(FairseqIncrementalDecoder):
|
140 |
+
def __init__(self, args, src_dict, padding_idx=1):
|
141 |
+
super().__init__(None)
|
142 |
+
self._future_mask = torch.empty(0)
|
143 |
+
|
144 |
+
self.args = args
|
145 |
+
self.padding_idx = src_dict.pad() if src_dict else padding_idx
|
146 |
+
self.n_frames_per_step = args.n_frames_per_step
|
147 |
+
self.out_dim = args.output_frame_dim * args.n_frames_per_step
|
148 |
+
|
149 |
+
self.dropout_module = FairseqDropout(
|
150 |
+
args.dropout, module_name=self.__class__.__name__
|
151 |
+
)
|
152 |
+
self.embed_positions = PositionalEmbedding(
|
153 |
+
args.max_target_positions, args.decoder_embed_dim, self.padding_idx
|
154 |
+
)
|
155 |
+
self.pos_emb_alpha = nn.Parameter(torch.ones(1))
|
156 |
+
self.prenet = nn.Sequential(
|
157 |
+
Prenet(
|
158 |
+
self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
|
159 |
+
),
|
160 |
+
nn.Linear(args.prenet_dim, args.decoder_embed_dim),
|
161 |
+
)
|
162 |
+
|
163 |
+
self.n_transformer_layers = args.decoder_transformer_layers
|
164 |
+
self.transformer_layers = nn.ModuleList(
|
165 |
+
TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
|
166 |
+
)
|
167 |
+
if args.decoder_normalize_before:
|
168 |
+
self.layer_norm = LayerNorm(args.decoder_embed_dim)
|
169 |
+
else:
|
170 |
+
self.layer_norm = None
|
171 |
+
|
172 |
+
self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
|
173 |
+
self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
|
174 |
+
|
175 |
+
self.postnet = Postnet(
|
176 |
+
self.out_dim,
|
177 |
+
args.postnet_conv_dim,
|
178 |
+
args.postnet_conv_kernel_size,
|
179 |
+
args.postnet_layers,
|
180 |
+
args.postnet_dropout,
|
181 |
+
)
|
182 |
+
|
183 |
+
self.ctc_proj = None
|
184 |
+
if getattr(args, "ctc_weight", 0.0) > 0.0:
|
185 |
+
self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
|
186 |
+
|
187 |
+
self.apply(decoder_init)
|
188 |
+
|
189 |
+
def extract_features(
|
190 |
+
self,
|
191 |
+
prev_outputs,
|
192 |
+
encoder_out=None,
|
193 |
+
incremental_state=None,
|
194 |
+
target_lengths=None,
|
195 |
+
speaker=None,
|
196 |
+
**kwargs,
|
197 |
+
):
|
198 |
+
alignment_layer = self.n_transformer_layers - 1
|
199 |
+
self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
|
200 |
+
positions = self.embed_positions(
|
201 |
+
self_attn_padding_mask, incremental_state=incremental_state
|
202 |
+
)
|
203 |
+
|
204 |
+
if incremental_state is not None:
|
205 |
+
prev_outputs = prev_outputs[:, -1:, :]
|
206 |
+
self_attn_padding_mask = self_attn_padding_mask[:, -1:]
|
207 |
+
if positions is not None:
|
208 |
+
positions = positions[:, -1:]
|
209 |
+
|
210 |
+
x = self.prenet(prev_outputs)
|
211 |
+
x += self.pos_emb_alpha * positions
|
212 |
+
x = self.dropout_module(x)
|
213 |
+
|
214 |
+
# B x T x C -> T x B x C
|
215 |
+
x = x.transpose(0, 1)
|
216 |
+
|
217 |
+
if not self_attn_padding_mask.any():
|
218 |
+
self_attn_padding_mask = None
|
219 |
+
|
220 |
+
attn: Optional[torch.Tensor] = None
|
221 |
+
inner_states: List[Optional[torch.Tensor]] = [x]
|
222 |
+
for idx, transformer_layer in enumerate(self.transformer_layers):
|
223 |
+
if incremental_state is None:
|
224 |
+
self_attn_mask = self.buffered_future_mask(x)
|
225 |
+
else:
|
226 |
+
self_attn_mask = None
|
227 |
+
|
228 |
+
x, layer_attn, _ = transformer_layer(
|
229 |
+
x,
|
230 |
+
encoder_out["encoder_out"][0]
|
231 |
+
if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
|
232 |
+
else None,
|
233 |
+
encoder_out["encoder_padding_mask"][0]
|
234 |
+
if (
|
235 |
+
encoder_out is not None
|
236 |
+
and len(encoder_out["encoder_padding_mask"]) > 0
|
237 |
+
)
|
238 |
+
else None,
|
239 |
+
incremental_state,
|
240 |
+
self_attn_mask=self_attn_mask,
|
241 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
242 |
+
need_attn=bool((idx == alignment_layer)),
|
243 |
+
need_head_weights=bool((idx == alignment_layer)),
|
244 |
+
)
|
245 |
+
inner_states.append(x)
|
246 |
+
if layer_attn is not None and idx == alignment_layer:
|
247 |
+
attn = layer_attn.float().to(x)
|
248 |
+
|
249 |
+
if attn is not None:
|
250 |
+
# average probabilities over heads, transpose to
|
251 |
+
# (B, src_len, tgt_len)
|
252 |
+
attn = attn.mean(dim=0).transpose(2, 1)
|
253 |
+
|
254 |
+
if self.layer_norm is not None:
|
255 |
+
x = self.layer_norm(x)
|
256 |
+
|
257 |
+
# T x B x C -> B x T x C
|
258 |
+
x = x.transpose(0, 1)
|
259 |
+
|
260 |
+
return x, {"attn": attn, "inner_states": inner_states}
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
prev_output_tokens,
|
265 |
+
encoder_out=None,
|
266 |
+
incremental_state=None,
|
267 |
+
target_lengths=None,
|
268 |
+
speaker=None,
|
269 |
+
**kwargs,
|
270 |
+
):
|
271 |
+
x, extra = self.extract_features(
|
272 |
+
prev_output_tokens,
|
273 |
+
encoder_out=encoder_out,
|
274 |
+
incremental_state=incremental_state,
|
275 |
+
target_lengths=target_lengths,
|
276 |
+
speaker=speaker,
|
277 |
+
**kwargs,
|
278 |
+
)
|
279 |
+
attn = extra["attn"]
|
280 |
+
feat_out = self.feat_proj(x)
|
281 |
+
bsz, seq_len, _ = x.size()
|
282 |
+
eos_out = self.eos_proj(x)
|
283 |
+
post_feat_out = feat_out + self.postnet(feat_out)
|
284 |
+
return (
|
285 |
+
post_feat_out,
|
286 |
+
eos_out,
|
287 |
+
{
|
288 |
+
"attn": attn,
|
289 |
+
"feature_out": feat_out,
|
290 |
+
"inner_states": extra["inner_states"],
|
291 |
+
},
|
292 |
+
)
|
293 |
+
|
294 |
+
def get_normalized_probs(self, net_output, log_probs, sample):
|
295 |
+
logits = self.ctc_proj(net_output[2]["feature_out"])
|
296 |
+
if log_probs:
|
297 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
298 |
+
else:
|
299 |
+
return utils.softmax(logits.float(), dim=-1)
|
300 |
+
|
301 |
+
def buffered_future_mask(self, tensor):
|
302 |
+
dim = tensor.size(0)
|
303 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
304 |
+
if (
|
305 |
+
self._future_mask.size(0) == 0
|
306 |
+
or (not self._future_mask.device == tensor.device)
|
307 |
+
or self._future_mask.size(0) < dim
|
308 |
+
):
|
309 |
+
self._future_mask = torch.triu(
|
310 |
+
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
311 |
+
)
|
312 |
+
self._future_mask = self._future_mask.to(tensor)
|
313 |
+
return self._future_mask[:dim, :dim]
|
314 |
+
|
315 |
+
|
316 |
+
@register_model("tts_transformer")
|
317 |
+
class TTSTransformerModel(FairseqEncoderDecoderModel):
|
318 |
+
"""
|
319 |
+
Implementation for https://arxiv.org/pdf/1809.08895.pdf
|
320 |
+
"""
|
321 |
+
|
322 |
+
@classmethod
|
323 |
+
def hub_models(cls):
|
324 |
+
base_url = "http://dl.fbaipublicfiles.com/fairseq/s2"
|
325 |
+
model_ids = [
|
326 |
+
"tts_transformer-en-ljspeech",
|
327 |
+
"tts_transformer-en-200_speaker-cv4",
|
328 |
+
"tts_transformer-es-css10",
|
329 |
+
"tts_transformer-fr-cv7_css10",
|
330 |
+
"tts_transformer-ru-cv7_css10",
|
331 |
+
"tts_transformer-zh-cv7_css10",
|
332 |
+
"tts_transformer-ar-cv7_css10",
|
333 |
+
"tts_transformer-tr-cv7_css10",
|
334 |
+
"tts_transformer-vi-cv7",
|
335 |
+
]
|
336 |
+
return {i: f"{base_url}/{i}.tar.gz" for i in model_ids}
|
337 |
+
|
338 |
+
@classmethod
|
339 |
+
def from_pretrained(
|
340 |
+
cls,
|
341 |
+
model_name_or_path,
|
342 |
+
checkpoint_file="model.pt",
|
343 |
+
data_name_or_path=".",
|
344 |
+
config_yaml="config.yaml",
|
345 |
+
vocoder: str = "griffin_lim",
|
346 |
+
fp16: bool = False,
|
347 |
+
**kwargs,
|
348 |
+
):
|
349 |
+
from fairseq import hub_utils
|
350 |
+
|
351 |
+
x = hub_utils.from_pretrained(
|
352 |
+
model_name_or_path,
|
353 |
+
checkpoint_file,
|
354 |
+
data_name_or_path,
|
355 |
+
archive_map=cls.hub_models(),
|
356 |
+
config_yaml=config_yaml,
|
357 |
+
vocoder=vocoder,
|
358 |
+
fp16=fp16,
|
359 |
+
**kwargs,
|
360 |
+
)
|
361 |
+
return TTSHubInterface(x["args"], x["task"], x["models"][0])
|
362 |
+
|
363 |
+
@staticmethod
|
364 |
+
def add_args(parser):
|
365 |
+
parser.add_argument("--dropout", type=float)
|
366 |
+
parser.add_argument("--output-frame-dim", type=int)
|
367 |
+
parser.add_argument("--speaker-embed-dim", type=int)
|
368 |
+
# encoder prenet
|
369 |
+
parser.add_argument("--encoder-dropout", type=float)
|
370 |
+
parser.add_argument("--encoder-conv-layers", type=int)
|
371 |
+
parser.add_argument("--encoder-conv-kernel-size", type=int)
|
372 |
+
# encoder transformer layers
|
373 |
+
parser.add_argument("--encoder-transformer-layers", type=int)
|
374 |
+
parser.add_argument("--encoder-embed-dim", type=int)
|
375 |
+
parser.add_argument("--encoder-ffn-embed-dim", type=int)
|
376 |
+
parser.add_argument("--encoder-normalize-before", action="store_true")
|
377 |
+
parser.add_argument("--encoder-attention-heads", type=int)
|
378 |
+
parser.add_argument("--attention-dropout", type=float)
|
379 |
+
parser.add_argument("--activation-dropout", "--relu-dropout", type=float)
|
380 |
+
parser.add_argument("--activation-fn", type=str, default="relu")
|
381 |
+
# decoder prenet
|
382 |
+
parser.add_argument("--prenet-dropout", type=float)
|
383 |
+
parser.add_argument("--prenet-layers", type=int)
|
384 |
+
parser.add_argument("--prenet-dim", type=int)
|
385 |
+
# decoder postnet
|
386 |
+
parser.add_argument("--postnet-dropout", type=float)
|
387 |
+
parser.add_argument("--postnet-layers", type=int)
|
388 |
+
parser.add_argument("--postnet-conv-dim", type=int)
|
389 |
+
parser.add_argument("--postnet-conv-kernel-size", type=int)
|
390 |
+
# decoder transformer layers
|
391 |
+
parser.add_argument("--decoder-transformer-layers", type=int)
|
392 |
+
parser.add_argument("--decoder-embed-dim", type=int)
|
393 |
+
parser.add_argument("--decoder-ffn-embed-dim", type=int)
|
394 |
+
parser.add_argument("--decoder-normalize-before", action="store_true")
|
395 |
+
parser.add_argument("--decoder-attention-heads", type=int)
|
396 |
+
|
397 |
+
def __init__(self, *args, **kwargs):
|
398 |
+
super().__init__(*args, **kwargs)
|
399 |
+
self._num_updates = 0
|
400 |
+
|
401 |
+
@classmethod
|
402 |
+
def build_model(cls, args, task):
|
403 |
+
embed_speaker = task.get_speaker_embeddings(args)
|
404 |
+
encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker)
|
405 |
+
decoder = TTSTransformerDecoder(args, task.src_dict)
|
406 |
+
return cls(encoder, decoder)
|
407 |
+
|
408 |
+
def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
|
409 |
+
return self.encoder(
|
410 |
+
src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs
|
411 |
+
)
|
412 |
+
|
413 |
+
def set_num_updates(self, num_updates):
|
414 |
+
super().set_num_updates(num_updates)
|
415 |
+
self._num_updates = num_updates
|
416 |
+
|
417 |
+
|
418 |
+
@register_model_architecture("tts_transformer", "tts_transformer")
|
419 |
+
def base_architecture(args):
|
420 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
421 |
+
args.output_frame_dim = getattr(args, "output_frame_dim", 80)
|
422 |
+
args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
|
423 |
+
# encoder prenet
|
424 |
+
args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
|
425 |
+
args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
|
426 |
+
args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
|
427 |
+
# encoder transformer layers
|
428 |
+
args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
|
429 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
430 |
+
args.encoder_ffn_embed_dim = getattr(
|
431 |
+
args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
|
432 |
+
)
|
433 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
434 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
435 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
436 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
437 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
438 |
+
# decoder prenet
|
439 |
+
args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
|
440 |
+
args.prenet_layers = getattr(args, "prenet_layers", 2)
|
441 |
+
args.prenet_dim = getattr(args, "prenet_dim", 256)
|
442 |
+
# decoder postnet
|
443 |
+
args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
|
444 |
+
args.postnet_layers = getattr(args, "postnet_layers", 5)
|
445 |
+
args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
|
446 |
+
args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
|
447 |
+
# decoder transformer layers
|
448 |
+
args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
|
449 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
450 |
+
args.decoder_ffn_embed_dim = getattr(
|
451 |
+
args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim
|
452 |
+
)
|
453 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
454 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
fairseq/fairseq/models/transformer/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""isort:skip_file"""
|
6 |
+
|
7 |
+
from .transformer_config import (
|
8 |
+
TransformerConfig,
|
9 |
+
DEFAULT_MAX_SOURCE_POSITIONS,
|
10 |
+
DEFAULT_MAX_TARGET_POSITIONS,
|
11 |
+
DEFAULT_MIN_PARAMS_TO_WRAP,
|
12 |
+
)
|
13 |
+
from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear
|
14 |
+
from .transformer_encoder import TransformerEncoder, TransformerEncoderBase
|
15 |
+
from .transformer_legacy import (
|
16 |
+
TransformerModel,
|
17 |
+
base_architecture,
|
18 |
+
tiny_architecture,
|
19 |
+
transformer_iwslt_de_en,
|
20 |
+
transformer_wmt_en_de,
|
21 |
+
transformer_vaswani_wmt_en_de_big,
|
22 |
+
transformer_vaswani_wmt_en_fr_big,
|
23 |
+
transformer_wmt_en_de_big,
|
24 |
+
transformer_wmt_en_de_big_t2t,
|
25 |
+
)
|
26 |
+
from .transformer_base import TransformerModelBase, Embedding
|
27 |
+
|
28 |
+
|
29 |
+
__all__ = [
|
30 |
+
"TransformerModelBase",
|
31 |
+
"TransformerConfig",
|
32 |
+
"TransformerDecoder",
|
33 |
+
"TransformerDecoderBase",
|
34 |
+
"TransformerEncoder",
|
35 |
+
"TransformerEncoderBase",
|
36 |
+
"TransformerModel",
|
37 |
+
"Embedding",
|
38 |
+
"Linear",
|
39 |
+
"base_architecture",
|
40 |
+
"tiny_architecture",
|
41 |
+
"transformer_iwslt_de_en",
|
42 |
+
"transformer_wmt_en_de",
|
43 |
+
"transformer_vaswani_wmt_en_de_big",
|
44 |
+
"transformer_vaswani_wmt_en_fr_big",
|
45 |
+
"transformer_wmt_en_de_big",
|
46 |
+
"transformer_wmt_en_de_big_t2t",
|
47 |
+
"DEFAULT_MAX_SOURCE_POSITIONS",
|
48 |
+
"DEFAULT_MAX_TARGET_POSITIONS",
|
49 |
+
"DEFAULT_MIN_PARAMS_TO_WRAP",
|
50 |
+
]
|
fairseq/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.18 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc
ADDED
Binary file (5.42 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc
ADDED
Binary file (8.86 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc
ADDED
Binary file (11.9 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc
ADDED
Binary file (9.72 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc
ADDED
Binary file (9 kB). View file
|
|
fairseq/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc
ADDED
Binary file (9.92 kB). View file
|
|
fairseq/fairseq/models/transformer/transformer_base.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Dict, List, Optional, Tuple
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
16 |
+
from fairseq.distributed import fsdp_wrap
|
17 |
+
from fairseq.models import FairseqEncoderDecoderModel
|
18 |
+
from fairseq.models.transformer import (
|
19 |
+
TransformerConfig,
|
20 |
+
TransformerDecoderBase,
|
21 |
+
TransformerEncoderBase,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
class TransformerModelBase(FairseqEncoderDecoderModel):
|
29 |
+
"""
|
30 |
+
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
|
31 |
+
<https://arxiv.org/abs/1706.03762>`_.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
encoder (TransformerEncoder): the encoder
|
35 |
+
decoder (TransformerDecoder): the decoder
|
36 |
+
|
37 |
+
The Transformer model provides the following named architectures and
|
38 |
+
command-line arguments:
|
39 |
+
|
40 |
+
.. argparse::
|
41 |
+
:ref: fairseq.models.transformer_parser
|
42 |
+
:prog:
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, cfg, encoder, decoder):
|
46 |
+
super().__init__(encoder, decoder)
|
47 |
+
self.cfg = cfg
|
48 |
+
self.supports_align_args = True
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def add_args(cls, parser):
|
52 |
+
"""Add model-specific arguments to the parser."""
|
53 |
+
# we want to build the args recursively in this case.
|
54 |
+
gen_parser_from_dataclass(
|
55 |
+
parser, TransformerConfig(), delete_default=False, with_prefix=""
|
56 |
+
)
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
def build_model(cls, cfg, task):
|
60 |
+
"""Build a new model instance."""
|
61 |
+
|
62 |
+
# -- TODO T96535332
|
63 |
+
# bug caused by interaction between OmegaConf II and argparsing
|
64 |
+
cfg.decoder.input_dim = int(cfg.decoder.input_dim)
|
65 |
+
cfg.decoder.output_dim = int(cfg.decoder.output_dim)
|
66 |
+
# --
|
67 |
+
|
68 |
+
if cfg.encoder.layers_to_keep:
|
69 |
+
cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))
|
70 |
+
if cfg.decoder.layers_to_keep:
|
71 |
+
cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(","))
|
72 |
+
|
73 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
74 |
+
|
75 |
+
if cfg.share_all_embeddings:
|
76 |
+
if src_dict != tgt_dict:
|
77 |
+
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
78 |
+
if cfg.encoder.embed_dim != cfg.decoder.embed_dim:
|
79 |
+
raise ValueError(
|
80 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
81 |
+
)
|
82 |
+
if cfg.decoder.embed_path and (
|
83 |
+
cfg.decoder.embed_path != cfg.encoder.embed_path
|
84 |
+
):
|
85 |
+
raise ValueError(
|
86 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
87 |
+
)
|
88 |
+
encoder_embed_tokens = cls.build_embedding(
|
89 |
+
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
|
90 |
+
)
|
91 |
+
decoder_embed_tokens = encoder_embed_tokens
|
92 |
+
cfg.share_decoder_input_output_embed = True
|
93 |
+
elif cfg.merge_src_tgt_embed:
|
94 |
+
logger.info(f"source dict size: {len(src_dict)}")
|
95 |
+
logger.info(f"target dict size: {len(tgt_dict)}")
|
96 |
+
src_dict.update(tgt_dict)
|
97 |
+
task.src_dict = src_dict
|
98 |
+
task.tgt_dict = src_dict
|
99 |
+
logger.info(f"merged dict size: {len(src_dict)}")
|
100 |
+
encoder_embed_tokens = cls.build_embedding(
|
101 |
+
cfg, src_dict, cfg.encoder.embed_dim
|
102 |
+
)
|
103 |
+
decoder_embed_tokens = encoder_embed_tokens
|
104 |
+
cfg.share_decoder_input_output_embed = True
|
105 |
+
else:
|
106 |
+
encoder_embed_tokens = cls.build_embedding(
|
107 |
+
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
|
108 |
+
)
|
109 |
+
decoder_embed_tokens = cls.build_embedding(
|
110 |
+
cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path
|
111 |
+
)
|
112 |
+
if cfg.offload_activations:
|
113 |
+
cfg.checkpoint_activations = True # offloading implies checkpointing
|
114 |
+
encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens)
|
115 |
+
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
|
116 |
+
return cls(cfg, encoder, decoder)
|
117 |
+
|
118 |
+
@classmethod
|
119 |
+
def build_embedding(cls, cfg, dictionary, embed_dim, path=None):
|
120 |
+
num_embeddings = len(dictionary)
|
121 |
+
padding_idx = dictionary.pad()
|
122 |
+
|
123 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx)
|
124 |
+
# if provided, load from preloaded dictionaries
|
125 |
+
if path:
|
126 |
+
embed_dict = utils.parse_embedding(path)
|
127 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
128 |
+
return emb
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def build_encoder(cls, cfg, src_dict, embed_tokens):
|
132 |
+
return TransformerEncoderBase(cfg, src_dict, embed_tokens)
|
133 |
+
|
134 |
+
@classmethod
|
135 |
+
def build_decoder(cls, cfg, tgt_dict, embed_tokens):
|
136 |
+
return TransformerDecoderBase(
|
137 |
+
cfg,
|
138 |
+
tgt_dict,
|
139 |
+
embed_tokens,
|
140 |
+
no_encoder_attn=cfg.no_cross_attention,
|
141 |
+
)
|
142 |
+
|
143 |
+
# TorchScript doesn't support optional arguments with variable length (**kwargs).
|
144 |
+
# Current workaround is to add union of all arguments in child classes.
|
145 |
+
def forward(
|
146 |
+
self,
|
147 |
+
src_tokens,
|
148 |
+
src_lengths,
|
149 |
+
prev_output_tokens,
|
150 |
+
return_all_hiddens: bool = True,
|
151 |
+
features_only: bool = False,
|
152 |
+
alignment_layer: Optional[int] = None,
|
153 |
+
alignment_heads: Optional[int] = None,
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
Run the forward pass for an encoder-decoder model.
|
157 |
+
|
158 |
+
Copied from the base class, but without ``**kwargs``,
|
159 |
+
which are not supported by TorchScript.
|
160 |
+
"""
|
161 |
+
encoder_out = self.encoder(
|
162 |
+
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
|
163 |
+
)
|
164 |
+
decoder_out = self.decoder(
|
165 |
+
prev_output_tokens,
|
166 |
+
encoder_out=encoder_out,
|
167 |
+
features_only=features_only,
|
168 |
+
alignment_layer=alignment_layer,
|
169 |
+
alignment_heads=alignment_heads,
|
170 |
+
src_lengths=src_lengths,
|
171 |
+
return_all_hiddens=return_all_hiddens,
|
172 |
+
)
|
173 |
+
return decoder_out
|
174 |
+
|
175 |
+
# Since get_normalized_probs is in the Fairseq Model which is not scriptable,
|
176 |
+
# I rewrite the get_normalized_probs from Base Class to call the
|
177 |
+
# helper function in the Base Class.
|
178 |
+
@torch.jit.export
|
179 |
+
def get_normalized_probs(
|
180 |
+
self,
|
181 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
182 |
+
log_probs: bool,
|
183 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
184 |
+
):
|
185 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
186 |
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
187 |
+
|
188 |
+
|
189 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
190 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
191 |
+
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
|
192 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
193 |
+
return m
|
fairseq/fairseq/models/transformer/transformer_config.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import re
|
8 |
+
from dataclasses import dataclass, field, fields
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
from omegaconf import II
|
12 |
+
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
15 |
+
from fairseq.utils import safe_getattr, safe_hasattr
|
16 |
+
|
17 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
18 |
+
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
19 |
+
|
20 |
+
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
|
21 |
+
|
22 |
+
_NAME_PARSER = r"(decoder|encoder|quant_noise)_(.*)"
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class EncDecBaseConfig(FairseqDataclass):
|
27 |
+
embed_path: Optional[str] = field(
|
28 |
+
default=None, metadata={"help": "path to pre-trained embedding"}
|
29 |
+
)
|
30 |
+
embed_dim: Optional[int] = field(
|
31 |
+
default=512, metadata={"help": "embedding dimension"}
|
32 |
+
)
|
33 |
+
ffn_embed_dim: int = field(
|
34 |
+
default=2048, metadata={"help": "embedding dimension for FFN"}
|
35 |
+
)
|
36 |
+
layers: int = field(default=6, metadata={"help": "number of layers"})
|
37 |
+
attention_heads: int = field(
|
38 |
+
default=8, metadata={"help": "number of attention heads"}
|
39 |
+
)
|
40 |
+
normalize_before: bool = field(
|
41 |
+
default=False, metadata={"help": "apply layernorm before each block"}
|
42 |
+
)
|
43 |
+
learned_pos: bool = field(
|
44 |
+
default=False, metadata={"help": "use learned positional embeddings"}
|
45 |
+
)
|
46 |
+
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
|
47 |
+
layerdrop: float = field(default=0, metadata={"help": "LayerDrop probability"})
|
48 |
+
layers_to_keep: Optional[List[int]] = field(
|
49 |
+
default=None, metadata={"help": "which layers to *keep* when pruning"}
|
50 |
+
)
|
51 |
+
|
52 |
+
xformers_att_config: Optional[str] = field(
|
53 |
+
default=None,
|
54 |
+
metadata={
|
55 |
+
"help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig"
|
56 |
+
},
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
@dataclass
|
61 |
+
class DecoderConfig(EncDecBaseConfig):
|
62 |
+
input_dim: int = II("model.decoder.embed_dim")
|
63 |
+
output_dim: int = field(
|
64 |
+
default=II("model.decoder.embed_dim"),
|
65 |
+
metadata={
|
66 |
+
"help": "decoder output dimension (extra linear layer if different from decoder embed dim)"
|
67 |
+
},
|
68 |
+
)
|
69 |
+
|
70 |
+
def __post_init__(self):
|
71 |
+
# II doesn't work if we are just creating the object outside of hydra so fix that
|
72 |
+
if self.input_dim == II("model.decoder.embed_dim"):
|
73 |
+
self.input_dim = self.embed_dim
|
74 |
+
if self.output_dim == II("model.decoder.embed_dim"):
|
75 |
+
self.output_dim = self.embed_dim
|
76 |
+
|
77 |
+
|
78 |
+
@dataclass
|
79 |
+
class QuantNoiseConfig(FairseqDataclass):
|
80 |
+
pq: float = field(
|
81 |
+
default=0.0,
|
82 |
+
metadata={"help": "iterative PQ quantization noise at training time"},
|
83 |
+
)
|
84 |
+
pq_block_size: int = field(
|
85 |
+
default=8,
|
86 |
+
metadata={"help": "block size of quantization noise at training time"},
|
87 |
+
)
|
88 |
+
scalar: float = field(
|
89 |
+
default=0.0,
|
90 |
+
metadata={
|
91 |
+
"help": "scalar quantization noise and scalar quantization at training time"
|
92 |
+
},
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
@dataclass
|
97 |
+
class TransformerConfig(FairseqDataclass):
|
98 |
+
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
|
99 |
+
default="relu",
|
100 |
+
metadata={"help": "activation function to use"},
|
101 |
+
)
|
102 |
+
dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
|
103 |
+
attention_dropout: float = field(
|
104 |
+
default=0.0, metadata={"help": "dropout probability for attention weights"}
|
105 |
+
)
|
106 |
+
activation_dropout: float = field(
|
107 |
+
default=0.0,
|
108 |
+
metadata={
|
109 |
+
"help": "dropout probability after activation in FFN.",
|
110 |
+
"alias": "--relu-dropout",
|
111 |
+
},
|
112 |
+
)
|
113 |
+
adaptive_input: bool = False
|
114 |
+
encoder: EncDecBaseConfig = EncDecBaseConfig()
|
115 |
+
# TODO should really be in the encoder config
|
116 |
+
max_source_positions: int = field(
|
117 |
+
default=DEFAULT_MAX_SOURCE_POSITIONS,
|
118 |
+
metadata={"help": "Maximum input length supported by the encoder"},
|
119 |
+
)
|
120 |
+
decoder: DecoderConfig = DecoderConfig()
|
121 |
+
# TODO should really be in the decoder config
|
122 |
+
max_target_positions: int = field(
|
123 |
+
default=DEFAULT_MAX_TARGET_POSITIONS,
|
124 |
+
metadata={"help": "Maximum output length supported by the decoder"},
|
125 |
+
)
|
126 |
+
share_decoder_input_output_embed: bool = field(
|
127 |
+
default=False, metadata={"help": "share decoder input and output embeddings"}
|
128 |
+
)
|
129 |
+
share_all_embeddings: bool = field(
|
130 |
+
default=False,
|
131 |
+
metadata={
|
132 |
+
"help": "share encoder, decoder and output embeddings (requires shared dictionary and embed dim)"
|
133 |
+
},
|
134 |
+
)
|
135 |
+
merge_src_tgt_embed: bool = field(
|
136 |
+
default=False,
|
137 |
+
metadata={
|
138 |
+
"help": "if true then the source and target embedding table is "
|
139 |
+
"merged into one table. This is going to make the model smaller but "
|
140 |
+
"it might hurt performance."
|
141 |
+
},
|
142 |
+
)
|
143 |
+
no_token_positional_embeddings: bool = field(
|
144 |
+
default=False,
|
145 |
+
metadata={
|
146 |
+
"help": "if True, disables positional embeddings (outside self attention)"
|
147 |
+
},
|
148 |
+
)
|
149 |
+
adaptive_softmax_cutoff: Optional[List[int]] = field(
|
150 |
+
default=None,
|
151 |
+
metadata={
|
152 |
+
"help": "list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion"
|
153 |
+
},
|
154 |
+
)
|
155 |
+
adaptive_softmax_dropout: float = field(
|
156 |
+
default=0.0,
|
157 |
+
metadata={"help": "sets adaptive softmax dropout for the tail projections"},
|
158 |
+
)
|
159 |
+
adaptive_softmax_factor: float = field(
|
160 |
+
default=4, metadata={"help": "adaptive input factor"}
|
161 |
+
)
|
162 |
+
layernorm_embedding: bool = field(
|
163 |
+
default=False, metadata={"help": "add layernorm to embedding"}
|
164 |
+
)
|
165 |
+
tie_adaptive_weights: bool = field(
|
166 |
+
default=False,
|
167 |
+
metadata={
|
168 |
+
"help": "if set, ties the weights of adaptive softmax and adaptive input"
|
169 |
+
},
|
170 |
+
)
|
171 |
+
tie_adaptive_proj: bool = field(
|
172 |
+
default=False,
|
173 |
+
metadata={
|
174 |
+
"help": "if set, ties the projection weights of adaptive softmax and adaptive input"
|
175 |
+
},
|
176 |
+
)
|
177 |
+
no_scale_embedding: bool = field(
|
178 |
+
default=False, metadata={"help": "if True, dont scale embeddings"}
|
179 |
+
)
|
180 |
+
checkpoint_activations: bool = field(
|
181 |
+
default=False,
|
182 |
+
metadata={
|
183 |
+
"help": "checkpoint activations at each layer, which saves GPU memory usage at the cost of some additional compute"
|
184 |
+
},
|
185 |
+
)
|
186 |
+
offload_activations: bool = field(
|
187 |
+
default=False,
|
188 |
+
metadata={
|
189 |
+
"help": "checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations."
|
190 |
+
},
|
191 |
+
)
|
192 |
+
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
|
193 |
+
no_cross_attention: bool = field(
|
194 |
+
default=False, metadata={"help": "do not perform cross-attention"}
|
195 |
+
)
|
196 |
+
cross_self_attention: bool = field(
|
197 |
+
default=False, metadata={"help": "perform cross+self-attention"}
|
198 |
+
)
|
199 |
+
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
|
200 |
+
quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig())
|
201 |
+
min_params_to_wrap: int = field(
|
202 |
+
default=DEFAULT_MIN_PARAMS_TO_WRAP,
|
203 |
+
metadata={
|
204 |
+
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
|
205 |
+
"training with --ddp-backend=fully_sharded. Smaller values will "
|
206 |
+
"improve memory efficiency, but may make torch.distributed "
|
207 |
+
"communication less efficient due to smaller input sizes. This option "
|
208 |
+
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
209 |
+
"--offload-activations are passed."
|
210 |
+
},
|
211 |
+
)
|
212 |
+
# DEPRECATED field, but some old checkpoints might have it
|
213 |
+
char_inputs: bool = field(
|
214 |
+
default=False, metadata={"help": "if set, model takes character ids as input"}
|
215 |
+
)
|
216 |
+
relu_dropout: float = 0.0
|
217 |
+
# config for "BASE Layers: Simplifying Training of Large, Sparse Models"
|
218 |
+
base_layers: Optional[int] = field(
|
219 |
+
default=0, metadata={"help": "number of BASE layers in total"}
|
220 |
+
)
|
221 |
+
base_sublayers: Optional[int] = field(
|
222 |
+
default=1, metadata={"help": "number of sublayers in each BASE layer"}
|
223 |
+
)
|
224 |
+
base_shuffle: Optional[int] = field(
|
225 |
+
default=1,
|
226 |
+
metadata={"help": "shuffle tokens between workers before computing assignment"},
|
227 |
+
)
|
228 |
+
|
229 |
+
export: bool = field(
|
230 |
+
default=False,
|
231 |
+
metadata={"help": "make the layernorm exportable with torchscript."},
|
232 |
+
)
|
233 |
+
|
234 |
+
# copied from transformer_lm but expected in transformer_decoder:
|
235 |
+
no_decoder_final_norm: bool = field(
|
236 |
+
default=False,
|
237 |
+
metadata={"help": "don't add an extra layernorm after the last decoder block"},
|
238 |
+
)
|
239 |
+
|
240 |
+
# We need to make this hierarchical dataclass like the flat namespace
|
241 |
+
# __getattr__ and __setattr__ here allow backward compatibility
|
242 |
+
# for subclasses of Transformer(Legacy) that depend on read/write on
|
243 |
+
# the flat namespace.
|
244 |
+
|
245 |
+
def __getattr__(self, name):
|
246 |
+
match = re.match(_NAME_PARSER, name)
|
247 |
+
if match:
|
248 |
+
sub = safe_getattr(self, match[1])
|
249 |
+
return safe_getattr(sub, match[2])
|
250 |
+
raise AttributeError(f"invalid argument {name}.")
|
251 |
+
|
252 |
+
def __setattr__(self, name, value):
|
253 |
+
match = re.match(_NAME_PARSER, name)
|
254 |
+
if match:
|
255 |
+
sub = safe_getattr(self, match[1])
|
256 |
+
setattr(sub, match[2], value)
|
257 |
+
else:
|
258 |
+
super().__setattr__(name, value)
|
259 |
+
|
260 |
+
@staticmethod
|
261 |
+
def _copy_keys(args, cls, prefix, seen):
|
262 |
+
"""
|
263 |
+
copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim
|
264 |
+
"""
|
265 |
+
cfg = cls()
|
266 |
+
for fld in fields(cls):
|
267 |
+
# for all the fields in the DC, find the fields (e.g. embed_dim)
|
268 |
+
# in the namespace with the prefix (e.g. decoder)
|
269 |
+
# and set it on the dc.
|
270 |
+
args_key = f"{prefix}_{fld.name}"
|
271 |
+
if safe_hasattr(args, args_key):
|
272 |
+
seen.add(args_key)
|
273 |
+
setattr(cfg, fld.name, safe_getattr(args, args_key))
|
274 |
+
if safe_hasattr(args, fld.name):
|
275 |
+
seen.add(fld.name)
|
276 |
+
setattr(cfg, fld.name, safe_getattr(args, fld.name))
|
277 |
+
return cfg
|
278 |
+
|
279 |
+
@classmethod
|
280 |
+
def from_namespace(cls, args):
|
281 |
+
if args is None:
|
282 |
+
return None
|
283 |
+
if not isinstance(args, cls):
|
284 |
+
seen = set()
|
285 |
+
config = cls()
|
286 |
+
# currently, we can go generically from DC fields to args hierarchically
|
287 |
+
# but we can't easily deconstruct a flat namespace to a hierarchical
|
288 |
+
# DC. Mostly because we could have a sub-dc called `decoder-foo` that should not
|
289 |
+
# go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple
|
290 |
+
# for now.
|
291 |
+
for fld in fields(cls):
|
292 |
+
# concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields
|
293 |
+
# and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()`
|
294 |
+
if fld.name == "decoder":
|
295 |
+
if safe_hasattr(args, "decoder"):
|
296 |
+
# in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC
|
297 |
+
seen.add("decoder")
|
298 |
+
config.decoder = DecoderConfig(**args.decoder)
|
299 |
+
else:
|
300 |
+
config.decoder = cls._copy_keys(
|
301 |
+
args, DecoderConfig, "decoder", seen
|
302 |
+
)
|
303 |
+
elif fld.name == "encoder":
|
304 |
+
# same but for encoder
|
305 |
+
if safe_hasattr(args, "encoder"):
|
306 |
+
seen.add("encoder")
|
307 |
+
config.encoder = EncDecBaseConfig(**args.encoder)
|
308 |
+
else:
|
309 |
+
config.encoder = cls._copy_keys(
|
310 |
+
args, EncDecBaseConfig, "encoder", seen
|
311 |
+
)
|
312 |
+
elif fld.name == "quant_noise":
|
313 |
+
# same but for quant_noise
|
314 |
+
if safe_hasattr(args, "quant_noise"):
|
315 |
+
seen.add("quant_noise")
|
316 |
+
config.quant_noise = QuantNoiseConfig(**args.quant_noise)
|
317 |
+
else:
|
318 |
+
config.quant_noise = cls._copy_keys(
|
319 |
+
args, QuantNoiseConfig, "quant_noise", seen
|
320 |
+
)
|
321 |
+
elif safe_hasattr(args, fld.name):
|
322 |
+
# if it's not a structure field, it's just a normal field, copy it over
|
323 |
+
seen.add(fld.name)
|
324 |
+
setattr(config, fld.name, safe_getattr(args, fld.name))
|
325 |
+
# we got all the fields defined in the dataclass, but
|
326 |
+
# the argparse namespace might have extra args for two reasons:
|
327 |
+
# - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this
|
328 |
+
# - some places expect args to be there but never define them
|
329 |
+
args_dict = (
|
330 |
+
args._asdict()
|
331 |
+
if safe_hasattr(args, "_asdict")
|
332 |
+
else vars(args)
|
333 |
+
if safe_hasattr(args, "__dict__")
|
334 |
+
else {}
|
335 |
+
) # namedtupled doesn't have __dict__ :-/
|
336 |
+
for key, value in args_dict.items():
|
337 |
+
if key not in seen:
|
338 |
+
setattr(config, key, value)
|
339 |
+
return config
|
340 |
+
else:
|
341 |
+
return args
|
fairseq/fairseq/models/transformer/transformer_decoder.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Any, Dict, List, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.distributed import fsdp_wrap
|
15 |
+
from fairseq.models import FairseqIncrementalDecoder
|
16 |
+
from fairseq.models.transformer import TransformerConfig
|
17 |
+
from fairseq.modules import (
|
18 |
+
AdaptiveSoftmax,
|
19 |
+
BaseLayer,
|
20 |
+
FairseqDropout,
|
21 |
+
LayerDropModuleList,
|
22 |
+
LayerNorm,
|
23 |
+
PositionalEmbedding,
|
24 |
+
SinusoidalPositionalEmbedding,
|
25 |
+
transformer_layer,
|
26 |
+
)
|
27 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
28 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
29 |
+
|
30 |
+
|
31 |
+
# rewrite name for backward compatibility in `make_generation_fast_`
|
32 |
+
def module_name_fordropout(module_name: str) -> str:
|
33 |
+
if module_name == "TransformerDecoderBase":
|
34 |
+
return "TransformerDecoder"
|
35 |
+
else:
|
36 |
+
return module_name
|
37 |
+
|
38 |
+
|
39 |
+
class TransformerDecoderBase(FairseqIncrementalDecoder):
|
40 |
+
"""
|
41 |
+
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
|
42 |
+
is a :class:`TransformerDecoderLayer`.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
cfg (argparse.Namespace): parsed command-line arguments
|
46 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
47 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
48 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
49 |
+
(default: False).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
cfg,
|
55 |
+
dictionary,
|
56 |
+
embed_tokens,
|
57 |
+
no_encoder_attn=False,
|
58 |
+
output_projection=None,
|
59 |
+
):
|
60 |
+
self.cfg = cfg
|
61 |
+
super().__init__(dictionary)
|
62 |
+
self.register_buffer("version", torch.Tensor([3]))
|
63 |
+
self._future_mask = torch.empty(0)
|
64 |
+
|
65 |
+
self.dropout_module = FairseqDropout(
|
66 |
+
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
|
67 |
+
)
|
68 |
+
self.decoder_layerdrop = cfg.decoder.layerdrop
|
69 |
+
self.share_input_output_embed = cfg.share_decoder_input_output_embed
|
70 |
+
|
71 |
+
input_embed_dim = embed_tokens.embedding_dim
|
72 |
+
embed_dim = cfg.decoder.embed_dim
|
73 |
+
self.embed_dim = embed_dim
|
74 |
+
self.output_embed_dim = cfg.decoder.output_dim
|
75 |
+
|
76 |
+
self.padding_idx = embed_tokens.padding_idx
|
77 |
+
self.max_target_positions = cfg.max_target_positions
|
78 |
+
|
79 |
+
self.embed_tokens = embed_tokens
|
80 |
+
|
81 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
82 |
+
|
83 |
+
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
|
84 |
+
self.quant_noise = apply_quant_noise_(
|
85 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
86 |
+
cfg.quant_noise.pq,
|
87 |
+
cfg.quant_noise.pq_block_size,
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
self.quant_noise = None
|
91 |
+
|
92 |
+
self.project_in_dim = (
|
93 |
+
Linear(input_embed_dim, embed_dim, bias=False)
|
94 |
+
if embed_dim != input_embed_dim
|
95 |
+
else None
|
96 |
+
)
|
97 |
+
self.embed_positions = (
|
98 |
+
PositionalEmbedding(
|
99 |
+
self.max_target_positions,
|
100 |
+
embed_dim,
|
101 |
+
self.padding_idx,
|
102 |
+
learned=cfg.decoder.learned_pos,
|
103 |
+
)
|
104 |
+
if not cfg.no_token_positional_embeddings
|
105 |
+
else None
|
106 |
+
)
|
107 |
+
if cfg.layernorm_embedding:
|
108 |
+
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
|
109 |
+
else:
|
110 |
+
self.layernorm_embedding = None
|
111 |
+
|
112 |
+
self.cross_self_attention = cfg.cross_self_attention
|
113 |
+
|
114 |
+
if self.decoder_layerdrop > 0.0:
|
115 |
+
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
116 |
+
else:
|
117 |
+
self.layers = nn.ModuleList([])
|
118 |
+
self.layers.extend(
|
119 |
+
[
|
120 |
+
self.build_decoder_layer(cfg, no_encoder_attn)
|
121 |
+
for _ in range(cfg.decoder.layers)
|
122 |
+
]
|
123 |
+
)
|
124 |
+
self.num_layers = len(self.layers)
|
125 |
+
|
126 |
+
if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
|
127 |
+
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
|
128 |
+
else:
|
129 |
+
self.layer_norm = None
|
130 |
+
|
131 |
+
self.project_out_dim = (
|
132 |
+
Linear(embed_dim, self.output_embed_dim, bias=False)
|
133 |
+
if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
|
134 |
+
else None
|
135 |
+
)
|
136 |
+
|
137 |
+
self.adaptive_softmax = None
|
138 |
+
self.output_projection = output_projection
|
139 |
+
if self.output_projection is None:
|
140 |
+
self.build_output_projection(cfg, dictionary, embed_tokens)
|
141 |
+
|
142 |
+
def build_output_projection(self, cfg, dictionary, embed_tokens):
|
143 |
+
if cfg.adaptive_softmax_cutoff is not None:
|
144 |
+
self.adaptive_softmax = AdaptiveSoftmax(
|
145 |
+
len(dictionary),
|
146 |
+
self.output_embed_dim,
|
147 |
+
utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
|
148 |
+
dropout=cfg.adaptive_softmax_dropout,
|
149 |
+
adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
|
150 |
+
factor=cfg.adaptive_softmax_factor,
|
151 |
+
tie_proj=cfg.tie_adaptive_proj,
|
152 |
+
)
|
153 |
+
elif self.share_input_output_embed:
|
154 |
+
self.output_projection = nn.Linear(
|
155 |
+
self.embed_tokens.weight.shape[1],
|
156 |
+
self.embed_tokens.weight.shape[0],
|
157 |
+
bias=False,
|
158 |
+
)
|
159 |
+
self.output_projection.weight = self.embed_tokens.weight
|
160 |
+
else:
|
161 |
+
self.output_projection = nn.Linear(
|
162 |
+
self.output_embed_dim, len(dictionary), bias=False
|
163 |
+
)
|
164 |
+
nn.init.normal_(
|
165 |
+
self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5
|
166 |
+
)
|
167 |
+
num_base_layers = cfg.base_layers
|
168 |
+
for i in range(num_base_layers):
|
169 |
+
self.layers.insert(
|
170 |
+
((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
|
171 |
+
BaseLayer(cfg),
|
172 |
+
)
|
173 |
+
|
174 |
+
def build_decoder_layer(self, cfg, no_encoder_attn=False):
|
175 |
+
layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn)
|
176 |
+
checkpoint = cfg.checkpoint_activations
|
177 |
+
if checkpoint:
|
178 |
+
offload_to_cpu = cfg.offload_activations
|
179 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
180 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
181 |
+
# checkpointed layer, regardless of layer size
|
182 |
+
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
|
183 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
184 |
+
return layer
|
185 |
+
|
186 |
+
def forward(
|
187 |
+
self,
|
188 |
+
prev_output_tokens,
|
189 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
190 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
191 |
+
features_only: bool = False,
|
192 |
+
full_context_alignment: bool = False,
|
193 |
+
alignment_layer: Optional[int] = None,
|
194 |
+
alignment_heads: Optional[int] = None,
|
195 |
+
src_lengths: Optional[Any] = None,
|
196 |
+
return_all_hiddens: bool = False,
|
197 |
+
):
|
198 |
+
"""
|
199 |
+
Args:
|
200 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
201 |
+
`(batch, tgt_len)`, for teacher forcing
|
202 |
+
encoder_out (optional): output from the encoder, used for
|
203 |
+
encoder-side attention, should be of size T x B x C
|
204 |
+
incremental_state (dict): dictionary used for storing state during
|
205 |
+
:ref:`Incremental decoding`
|
206 |
+
features_only (bool, optional): only return features without
|
207 |
+
applying output layer (default: False).
|
208 |
+
full_context_alignment (bool, optional): don't apply
|
209 |
+
auto-regressive mask to self-attention (default: False).
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
tuple:
|
213 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
214 |
+
- a dictionary with any model-specific outputs
|
215 |
+
"""
|
216 |
+
|
217 |
+
x, extra = self.extract_features(
|
218 |
+
prev_output_tokens,
|
219 |
+
encoder_out=encoder_out,
|
220 |
+
incremental_state=incremental_state,
|
221 |
+
full_context_alignment=full_context_alignment,
|
222 |
+
alignment_layer=alignment_layer,
|
223 |
+
alignment_heads=alignment_heads,
|
224 |
+
)
|
225 |
+
|
226 |
+
if not features_only:
|
227 |
+
x = self.output_layer(x)
|
228 |
+
return x, extra
|
229 |
+
|
230 |
+
def extract_features(
|
231 |
+
self,
|
232 |
+
prev_output_tokens,
|
233 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
234 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
235 |
+
full_context_alignment: bool = False,
|
236 |
+
alignment_layer: Optional[int] = None,
|
237 |
+
alignment_heads: Optional[int] = None,
|
238 |
+
):
|
239 |
+
return self.extract_features_scriptable(
|
240 |
+
prev_output_tokens,
|
241 |
+
encoder_out,
|
242 |
+
incremental_state,
|
243 |
+
full_context_alignment,
|
244 |
+
alignment_layer,
|
245 |
+
alignment_heads,
|
246 |
+
)
|
247 |
+
|
248 |
+
"""
|
249 |
+
A scriptable subclass of this class has an extract_features method and calls
|
250 |
+
super().extract_features, but super() is not supported in torchscript. A copy of
|
251 |
+
this function is made to be used in the subclass instead.
|
252 |
+
"""
|
253 |
+
|
254 |
+
def extract_features_scriptable(
|
255 |
+
self,
|
256 |
+
prev_output_tokens,
|
257 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
258 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
259 |
+
full_context_alignment: bool = False,
|
260 |
+
alignment_layer: Optional[int] = None,
|
261 |
+
alignment_heads: Optional[int] = None,
|
262 |
+
):
|
263 |
+
"""
|
264 |
+
Similar to *forward* but only return features.
|
265 |
+
|
266 |
+
Includes several features from "Jointly Learning to Align and
|
267 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
268 |
+
|
269 |
+
Args:
|
270 |
+
full_context_alignment (bool, optional): don't apply
|
271 |
+
auto-regressive mask to self-attention (default: False).
|
272 |
+
alignment_layer (int, optional): return mean alignment over
|
273 |
+
heads at this layer (default: last layer).
|
274 |
+
alignment_heads (int, optional): only average alignment over
|
275 |
+
this many heads (default: all heads).
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
tuple:
|
279 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
280 |
+
- a dictionary with any model-specific outputs
|
281 |
+
"""
|
282 |
+
bs, slen = prev_output_tokens.size()
|
283 |
+
if alignment_layer is None:
|
284 |
+
alignment_layer = self.num_layers - 1
|
285 |
+
|
286 |
+
enc: Optional[Tensor] = None
|
287 |
+
padding_mask: Optional[Tensor] = None
|
288 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
289 |
+
enc = encoder_out["encoder_out"][0]
|
290 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
291 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
292 |
+
|
293 |
+
# embed positions
|
294 |
+
positions = None
|
295 |
+
if self.embed_positions is not None:
|
296 |
+
positions = self.embed_positions(
|
297 |
+
prev_output_tokens, incremental_state=incremental_state
|
298 |
+
)
|
299 |
+
|
300 |
+
if incremental_state is not None:
|
301 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
302 |
+
if positions is not None:
|
303 |
+
positions = positions[:, -1:]
|
304 |
+
|
305 |
+
# Prevent torchscript exporting issue for dynamic quant embedding
|
306 |
+
prev_output_tokens = prev_output_tokens.contiguous()
|
307 |
+
# embed tokens and positions
|
308 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
309 |
+
|
310 |
+
if self.quant_noise is not None:
|
311 |
+
x = self.quant_noise(x)
|
312 |
+
|
313 |
+
if self.project_in_dim is not None:
|
314 |
+
x = self.project_in_dim(x)
|
315 |
+
|
316 |
+
if positions is not None:
|
317 |
+
x += positions
|
318 |
+
|
319 |
+
if self.layernorm_embedding is not None:
|
320 |
+
x = self.layernorm_embedding(x)
|
321 |
+
|
322 |
+
x = self.dropout_module(x)
|
323 |
+
|
324 |
+
# B x T x C -> T x B x C
|
325 |
+
x = x.transpose(0, 1)
|
326 |
+
|
327 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
328 |
+
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
329 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
330 |
+
|
331 |
+
# decoder layers
|
332 |
+
attn: Optional[Tensor] = None
|
333 |
+
inner_states: List[Optional[Tensor]] = [x]
|
334 |
+
for idx, layer in enumerate(self.layers):
|
335 |
+
if incremental_state is None and not full_context_alignment:
|
336 |
+
self_attn_mask = self.buffered_future_mask(x)
|
337 |
+
else:
|
338 |
+
self_attn_mask = None
|
339 |
+
|
340 |
+
x, layer_attn, _ = layer(
|
341 |
+
x,
|
342 |
+
enc,
|
343 |
+
padding_mask,
|
344 |
+
incremental_state,
|
345 |
+
self_attn_mask=self_attn_mask,
|
346 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
347 |
+
need_attn=bool((idx == alignment_layer)),
|
348 |
+
need_head_weights=bool((idx == alignment_layer)),
|
349 |
+
)
|
350 |
+
inner_states.append(x)
|
351 |
+
if layer_attn is not None and idx == alignment_layer:
|
352 |
+
attn = layer_attn.float().to(x)
|
353 |
+
|
354 |
+
if attn is not None:
|
355 |
+
if alignment_heads is not None:
|
356 |
+
attn = attn[:alignment_heads]
|
357 |
+
|
358 |
+
# average probabilities over heads
|
359 |
+
attn = attn.mean(dim=0)
|
360 |
+
|
361 |
+
if self.layer_norm is not None:
|
362 |
+
x = self.layer_norm(x)
|
363 |
+
|
364 |
+
# T x B x C -> B x T x C
|
365 |
+
x = x.transpose(0, 1)
|
366 |
+
|
367 |
+
if self.project_out_dim is not None:
|
368 |
+
x = self.project_out_dim(x)
|
369 |
+
|
370 |
+
return x, {"attn": [attn], "inner_states": inner_states}
|
371 |
+
|
372 |
+
def output_layer(self, features):
|
373 |
+
"""Project features to the vocabulary size."""
|
374 |
+
if self.adaptive_softmax is None:
|
375 |
+
# project back to size of vocabulary
|
376 |
+
return self.output_projection(features)
|
377 |
+
else:
|
378 |
+
return features
|
379 |
+
|
380 |
+
def max_positions(self):
|
381 |
+
"""Maximum output length supported by the decoder."""
|
382 |
+
if self.embed_positions is None:
|
383 |
+
return self.max_target_positions
|
384 |
+
return min(self.max_target_positions, self.embed_positions.max_positions)
|
385 |
+
|
386 |
+
def buffered_future_mask(self, tensor):
|
387 |
+
dim = tensor.size(0)
|
388 |
+
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
389 |
+
if (
|
390 |
+
self._future_mask.size(0) == 0
|
391 |
+
or (not self._future_mask.device == tensor.device)
|
392 |
+
or self._future_mask.size(0) < dim
|
393 |
+
):
|
394 |
+
self._future_mask = torch.triu(
|
395 |
+
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
396 |
+
)
|
397 |
+
self._future_mask = self._future_mask.to(tensor)
|
398 |
+
return self._future_mask[:dim, :dim]
|
399 |
+
|
400 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
401 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
402 |
+
if f"{name}.output_projection.weight" not in state_dict:
|
403 |
+
if self.share_input_output_embed:
|
404 |
+
embed_out_key = f"{name}.embed_tokens.weight"
|
405 |
+
else:
|
406 |
+
embed_out_key = f"{name}.embed_out"
|
407 |
+
if embed_out_key in state_dict:
|
408 |
+
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
409 |
+
embed_out_key
|
410 |
+
]
|
411 |
+
if not self.share_input_output_embed:
|
412 |
+
del state_dict[embed_out_key]
|
413 |
+
|
414 |
+
for i in range(self.num_layers):
|
415 |
+
# update layer norms
|
416 |
+
layer_norm_map = {
|
417 |
+
"0": "self_attn_layer_norm",
|
418 |
+
"1": "encoder_attn_layer_norm",
|
419 |
+
"2": "final_layer_norm",
|
420 |
+
}
|
421 |
+
for old, new in layer_norm_map.items():
|
422 |
+
for m in ("weight", "bias"):
|
423 |
+
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
424 |
+
if k in state_dict:
|
425 |
+
state_dict[
|
426 |
+
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
427 |
+
] = state_dict[k]
|
428 |
+
del state_dict[k]
|
429 |
+
|
430 |
+
version_key = "{}.version".format(name)
|
431 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
432 |
+
# earlier checkpoints did not normalize after the stack of layers
|
433 |
+
self.layer_norm = None
|
434 |
+
self.normalize = False
|
435 |
+
state_dict[version_key] = torch.Tensor([1])
|
436 |
+
|
437 |
+
return state_dict
|
438 |
+
|
439 |
+
|
440 |
+
def Linear(in_features, out_features, bias=True):
|
441 |
+
m = nn.Linear(in_features, out_features, bias)
|
442 |
+
nn.init.xavier_uniform_(m.weight)
|
443 |
+
if bias:
|
444 |
+
nn.init.constant_(m.bias, 0.0)
|
445 |
+
return m
|
446 |
+
|
447 |
+
|
448 |
+
class TransformerDecoder(TransformerDecoderBase):
|
449 |
+
def __init__(
|
450 |
+
self,
|
451 |
+
args,
|
452 |
+
dictionary,
|
453 |
+
embed_tokens,
|
454 |
+
no_encoder_attn=False,
|
455 |
+
output_projection=None,
|
456 |
+
):
|
457 |
+
self.args = args
|
458 |
+
super().__init__(
|
459 |
+
TransformerConfig.from_namespace(args),
|
460 |
+
dictionary,
|
461 |
+
embed_tokens,
|
462 |
+
no_encoder_attn=no_encoder_attn,
|
463 |
+
output_projection=output_projection,
|
464 |
+
)
|
465 |
+
|
466 |
+
def build_output_projection(self, args, dictionary, embed_tokens):
|
467 |
+
super().build_output_projection(
|
468 |
+
TransformerConfig.from_namespace(args), dictionary, embed_tokens
|
469 |
+
)
|
470 |
+
|
471 |
+
def build_decoder_layer(self, args, no_encoder_attn=False):
|
472 |
+
return super().build_decoder_layer(
|
473 |
+
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
|
474 |
+
)
|
fairseq/fairseq/models/transformer/transformer_decoder_aug.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Any, Dict, List, Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import Tensor
|
11 |
+
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.distributed import fsdp_wrap
|
14 |
+
from fairseq.models.transformer import TransformerConfig
|
15 |
+
from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase
|
16 |
+
from fairseq.modules import (
|
17 |
+
LayerDropModuleList,
|
18 |
+
SinusoidalPositionalEmbedding,
|
19 |
+
transformer_layer_aug,
|
20 |
+
)
|
21 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
22 |
+
|
23 |
+
|
24 |
+
class AugTransformerDecoderBase(TransformerDecoderBase):
|
25 |
+
"""
|
26 |
+
Transformer decoder augmented with an additional cross-attention. Each layer
|
27 |
+
is a :class:`AugTransformerDecoderLayerBase`.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
cfg (argparse.Namespace): parsed command-line arguments
|
31 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
32 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
33 |
+
encoder_attn_merge_type (str, optional): the way to combine outputs from
|
34 |
+
two cross-attention modules. If "sequential" is set, two cross-attention
|
35 |
+
modules are stacked sequentially. If "parallel" is set, they are processed
|
36 |
+
in parallel and combined before feeding it to FFN (default: sequential).
|
37 |
+
dropnet_ratio (float, optional): a probability to drop each cross-attention
|
38 |
+
module during training (default: 0.0).
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
cfg,
|
44 |
+
dictionary,
|
45 |
+
embed_tokens,
|
46 |
+
output_projection=None,
|
47 |
+
encoder_attn_merge_type="sequential",
|
48 |
+
dropnet_ratio=0.0,
|
49 |
+
):
|
50 |
+
super().__init__(
|
51 |
+
cfg,
|
52 |
+
dictionary,
|
53 |
+
embed_tokens,
|
54 |
+
no_encoder_attn=False,
|
55 |
+
output_projection=output_projection,
|
56 |
+
)
|
57 |
+
# assert cfg.cross_self_attention
|
58 |
+
self.cross_self_attention = cfg.cross_self_attention
|
59 |
+
|
60 |
+
if self.decoder_layerdrop > 0.0:
|
61 |
+
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
62 |
+
else:
|
63 |
+
self.layers = nn.ModuleList([])
|
64 |
+
self.layers.extend(
|
65 |
+
[
|
66 |
+
self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio)
|
67 |
+
for _ in range(cfg.decoder.layers)
|
68 |
+
]
|
69 |
+
)
|
70 |
+
|
71 |
+
def build_decoder_layer(
|
72 |
+
self,
|
73 |
+
cfg,
|
74 |
+
encoder_attn_merge_type="sequential",
|
75 |
+
dropnet_ratio=0,
|
76 |
+
):
|
77 |
+
layer = transformer_layer_aug.AugTransformerDecoderLayerBase(
|
78 |
+
cfg,
|
79 |
+
no_encoder_attn=False,
|
80 |
+
encoder_attn_merge_type=encoder_attn_merge_type,
|
81 |
+
dropnet_ratio=dropnet_ratio,
|
82 |
+
)
|
83 |
+
checkpoint = cfg.checkpoint_activations
|
84 |
+
if checkpoint:
|
85 |
+
offload_to_cpu = cfg.offload_activations
|
86 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
87 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
88 |
+
# checkpointed layer, regardless of layer size
|
89 |
+
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
|
90 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
91 |
+
return layer
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
prev_output_tokens,
|
96 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
97 |
+
encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
|
98 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
99 |
+
features_only: bool = False,
|
100 |
+
full_context_alignment: bool = False,
|
101 |
+
alignment_layer: Optional[int] = None,
|
102 |
+
alignment_heads: Optional[int] = None,
|
103 |
+
src_lengths: Optional[Any] = None,
|
104 |
+
return_all_hiddens: bool = False,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
109 |
+
`(batch, tgt_len)`, for teacher forcing
|
110 |
+
encoder_out (optional): output from the encoder, used for
|
111 |
+
encoder-side attention, should be of size T x B x C
|
112 |
+
incremental_state (dict): dictionary used for storing state during
|
113 |
+
:ref:`Incremental decoding`
|
114 |
+
features_only (bool, optional): only return features without
|
115 |
+
applying output layer (default: False).
|
116 |
+
full_context_alignment (bool, optional): don't apply
|
117 |
+
auto-regressive mask to self-attention (default: False).
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
tuple:
|
121 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
122 |
+
- a dictionary with any model-specific outputs
|
123 |
+
"""
|
124 |
+
|
125 |
+
x, extra = self.extract_features(
|
126 |
+
prev_output_tokens,
|
127 |
+
encoder_out=encoder_out,
|
128 |
+
encoder_out_aug=encoder_out_aug,
|
129 |
+
incremental_state=incremental_state,
|
130 |
+
full_context_alignment=full_context_alignment,
|
131 |
+
alignment_layer=alignment_layer,
|
132 |
+
alignment_heads=alignment_heads,
|
133 |
+
)
|
134 |
+
|
135 |
+
if not features_only:
|
136 |
+
x = self.output_layer(x)
|
137 |
+
return x, extra
|
138 |
+
|
139 |
+
def extract_features(
|
140 |
+
self,
|
141 |
+
prev_output_tokens,
|
142 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
143 |
+
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
|
144 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
145 |
+
full_context_alignment: bool = False,
|
146 |
+
alignment_layer: Optional[int] = None,
|
147 |
+
alignment_heads: Optional[int] = None,
|
148 |
+
):
|
149 |
+
return self.extract_features_scriptable(
|
150 |
+
prev_output_tokens,
|
151 |
+
encoder_out,
|
152 |
+
encoder_out_aug,
|
153 |
+
incremental_state,
|
154 |
+
full_context_alignment,
|
155 |
+
alignment_layer,
|
156 |
+
alignment_heads,
|
157 |
+
)
|
158 |
+
|
159 |
+
"""
|
160 |
+
A scriptable subclass of this class has an extract_features method and calls
|
161 |
+
super().extract_features, but super() is not supported in torchscript. A copy of
|
162 |
+
this function is made to be used in the subclass instead.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def extract_features_scriptable(
|
166 |
+
self,
|
167 |
+
prev_output_tokens,
|
168 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
169 |
+
encoder_out_aug: Optional[Dict[str, List[Tensor]]],
|
170 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
171 |
+
full_context_alignment: bool = False,
|
172 |
+
alignment_layer: Optional[int] = None,
|
173 |
+
alignment_heads: Optional[int] = None,
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Similar to *forward* but only return features.
|
177 |
+
|
178 |
+
Includes several features from "Jointly Learning to Align and
|
179 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
180 |
+
|
181 |
+
Args:
|
182 |
+
full_context_alignment (bool, optional): don't apply
|
183 |
+
auto-regressive mask to self-attention (default: False).
|
184 |
+
alignment_layer (int, optional): return mean alignment over
|
185 |
+
heads at this layer (default: last layer).
|
186 |
+
alignment_heads (int, optional): only average alignment over
|
187 |
+
this many heads (default: all heads).
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
tuple:
|
191 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
192 |
+
- a dictionary with any model-specific outputs
|
193 |
+
"""
|
194 |
+
bs, slen = prev_output_tokens.size()
|
195 |
+
if alignment_layer is None:
|
196 |
+
alignment_layer = self.num_layers - 1
|
197 |
+
|
198 |
+
enc: Optional[Tensor] = None
|
199 |
+
padding_mask: Optional[Tensor] = None
|
200 |
+
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
|
201 |
+
enc = encoder_out["encoder_out"][0]
|
202 |
+
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
|
203 |
+
padding_mask = encoder_out["encoder_padding_mask"][0]
|
204 |
+
|
205 |
+
enc_aug: Optional[Tensor] = None
|
206 |
+
padding_mask_aug: Optional[Tensor] = None
|
207 |
+
if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0:
|
208 |
+
enc_aug = encoder_out_aug["encoder_out"][0]
|
209 |
+
if (
|
210 |
+
encoder_out_aug is not None
|
211 |
+
and len(encoder_out_aug["encoder_padding_mask"]) > 0
|
212 |
+
):
|
213 |
+
padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0]
|
214 |
+
|
215 |
+
# embed positions
|
216 |
+
positions = None
|
217 |
+
if self.embed_positions is not None:
|
218 |
+
positions = self.embed_positions(
|
219 |
+
prev_output_tokens, incremental_state=incremental_state
|
220 |
+
)
|
221 |
+
|
222 |
+
if incremental_state is not None:
|
223 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
224 |
+
if positions is not None:
|
225 |
+
positions = positions[:, -1:]
|
226 |
+
|
227 |
+
# Prevent torchscript exporting issue for dynamic quant embedding
|
228 |
+
prev_output_tokens = prev_output_tokens.contiguous()
|
229 |
+
# embed tokens and positions
|
230 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
231 |
+
|
232 |
+
if self.quant_noise is not None:
|
233 |
+
x = self.quant_noise(x)
|
234 |
+
|
235 |
+
if self.project_in_dim is not None:
|
236 |
+
x = self.project_in_dim(x)
|
237 |
+
|
238 |
+
if positions is not None:
|
239 |
+
x += positions
|
240 |
+
|
241 |
+
if self.layernorm_embedding is not None:
|
242 |
+
x = self.layernorm_embedding(x)
|
243 |
+
|
244 |
+
x = self.dropout_module(x)
|
245 |
+
|
246 |
+
# B x T x C -> T x B x C
|
247 |
+
x = x.transpose(0, 1)
|
248 |
+
|
249 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
250 |
+
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
251 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
252 |
+
|
253 |
+
# decoder layers
|
254 |
+
attn: Optional[Tensor] = None
|
255 |
+
attn_aug: Optional[Tensor] = None
|
256 |
+
inner_states: List[Optional[Tensor]] = [x]
|
257 |
+
for idx, layer in enumerate(self.layers):
|
258 |
+
if incremental_state is None and not full_context_alignment:
|
259 |
+
self_attn_mask = self.buffered_future_mask(x)
|
260 |
+
else:
|
261 |
+
self_attn_mask = None
|
262 |
+
|
263 |
+
x, layer_attn, layer_attn_aug, _ = layer(
|
264 |
+
x,
|
265 |
+
enc,
|
266 |
+
padding_mask,
|
267 |
+
enc_aug,
|
268 |
+
padding_mask_aug,
|
269 |
+
incremental_state,
|
270 |
+
self_attn_mask=self_attn_mask,
|
271 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
272 |
+
need_attn=bool((idx == alignment_layer)),
|
273 |
+
need_head_weights=bool((idx == alignment_layer)),
|
274 |
+
)
|
275 |
+
inner_states.append(x)
|
276 |
+
if layer_attn is not None and idx == alignment_layer:
|
277 |
+
attn = layer_attn.float().to(x)
|
278 |
+
if layer_attn_aug is not None and idx == alignment_layer:
|
279 |
+
attn_aug = layer_attn_aug.float().to(x)
|
280 |
+
|
281 |
+
if attn is not None:
|
282 |
+
if alignment_heads is not None:
|
283 |
+
attn = attn[:alignment_heads]
|
284 |
+
|
285 |
+
# average probabilities over heads
|
286 |
+
attn = attn.mean(dim=0)
|
287 |
+
|
288 |
+
if attn_aug is not None:
|
289 |
+
if alignment_heads is not None:
|
290 |
+
attn_aug = attn_aug[:alignment_heads]
|
291 |
+
|
292 |
+
# average probabilities over heads
|
293 |
+
attn_aug = attn_aug.mean(dim=0)
|
294 |
+
|
295 |
+
if self.layer_norm is not None:
|
296 |
+
x = self.layer_norm(x)
|
297 |
+
|
298 |
+
# T x B x C -> B x T x C
|
299 |
+
x = x.transpose(0, 1)
|
300 |
+
|
301 |
+
if self.project_out_dim is not None:
|
302 |
+
x = self.project_out_dim(x)
|
303 |
+
|
304 |
+
return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states}
|
305 |
+
|
306 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
307 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
308 |
+
if f"{name}.output_projection.weight" not in state_dict:
|
309 |
+
if self.share_input_output_embed:
|
310 |
+
embed_out_key = f"{name}.embed_tokens.weight"
|
311 |
+
else:
|
312 |
+
embed_out_key = f"{name}.embed_out"
|
313 |
+
if embed_out_key in state_dict:
|
314 |
+
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
315 |
+
embed_out_key
|
316 |
+
]
|
317 |
+
if not self.share_input_output_embed:
|
318 |
+
del state_dict[embed_out_key]
|
319 |
+
|
320 |
+
for i in range(self.num_layers):
|
321 |
+
# update layer norms
|
322 |
+
layer_norm_map = {
|
323 |
+
"0": "self_attn_layer_norm",
|
324 |
+
"1": "encoder_attn_layer_norm",
|
325 |
+
"2": "encoder_attn_layer_norm2",
|
326 |
+
"3": "final_layer_norm",
|
327 |
+
}
|
328 |
+
for old, new in layer_norm_map.items():
|
329 |
+
for m in ("weight", "bias"):
|
330 |
+
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
331 |
+
if k in state_dict:
|
332 |
+
state_dict[
|
333 |
+
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
334 |
+
] = state_dict[k]
|
335 |
+
del state_dict[k]
|
336 |
+
|
337 |
+
version_key = "{}.version".format(name)
|
338 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
339 |
+
# earlier checkpoints did not normalize after the stack of layers
|
340 |
+
self.layer_norm = None
|
341 |
+
self.normalize = False
|
342 |
+
state_dict[version_key] = torch.Tensor([1])
|
343 |
+
|
344 |
+
return state_dict
|
345 |
+
|
346 |
+
|
347 |
+
class AugTransformerDecoder(AugTransformerDecoderBase):
|
348 |
+
def __init__(
|
349 |
+
self,
|
350 |
+
args,
|
351 |
+
dictionary,
|
352 |
+
embed_tokens,
|
353 |
+
output_projection=None,
|
354 |
+
):
|
355 |
+
self.args = args
|
356 |
+
super().__init__(
|
357 |
+
TransformerConfig.from_namespace(args),
|
358 |
+
dictionary,
|
359 |
+
embed_tokens,
|
360 |
+
no_encoder_attn=False,
|
361 |
+
output_projection=output_projection,
|
362 |
+
encoder_attn_merge_type=getattr(
|
363 |
+
args, "synthesizer_augmented_cross_attention_merge_type", "sequential"
|
364 |
+
),
|
365 |
+
dropnet_ratio=getattr(args, "dropnet_ratio", 0),
|
366 |
+
)
|
367 |
+
|
368 |
+
def build_output_projection(self, args, dictionary, embed_tokens):
|
369 |
+
super().build_output_projection(
|
370 |
+
TransformerConfig.from_namespace(args), dictionary, embed_tokens
|
371 |
+
)
|
372 |
+
|
373 |
+
def build_decoder_layer(
|
374 |
+
self,
|
375 |
+
args,
|
376 |
+
encoder_attn_merge_type="sequential",
|
377 |
+
dropnet_ratio=0,
|
378 |
+
):
|
379 |
+
return super().build_decoder_layer(
|
380 |
+
TransformerConfig.from_namespace(args),
|
381 |
+
no_encoder_attn=False,
|
382 |
+
encoder_attn_merge_type=encoder_attn_merge_type,
|
383 |
+
dropnet_ratio=dropnet_ratio,
|
384 |
+
)
|
fairseq/fairseq/models/transformer/transformer_encoder.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Dict, List, Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torch import Tensor
|
12 |
+
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.distributed import fsdp_wrap
|
15 |
+
from fairseq.models import FairseqEncoder
|
16 |
+
from fairseq.models.transformer import TransformerConfig
|
17 |
+
from fairseq.modules import (
|
18 |
+
FairseqDropout,
|
19 |
+
LayerDropModuleList,
|
20 |
+
LayerNorm,
|
21 |
+
PositionalEmbedding,
|
22 |
+
SinusoidalPositionalEmbedding,
|
23 |
+
transformer_layer,
|
24 |
+
)
|
25 |
+
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
|
26 |
+
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
27 |
+
|
28 |
+
|
29 |
+
# rewrite name for backward compatibility in `make_generation_fast_`
|
30 |
+
def module_name_fordropout(module_name: str) -> str:
|
31 |
+
if module_name == "TransformerEncoderBase":
|
32 |
+
return "TransformerEncoder"
|
33 |
+
else:
|
34 |
+
return module_name
|
35 |
+
|
36 |
+
|
37 |
+
class TransformerEncoderBase(FairseqEncoder):
|
38 |
+
"""
|
39 |
+
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
|
40 |
+
is a :class:`TransformerEncoderLayer`.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
args (argparse.Namespace): parsed command-line arguments
|
44 |
+
dictionary (~fairseq.data.Dictionary): encoding dictionary
|
45 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
|
49 |
+
self.cfg = cfg
|
50 |
+
super().__init__(dictionary)
|
51 |
+
self.register_buffer("version", torch.Tensor([3]))
|
52 |
+
|
53 |
+
self.dropout_module = FairseqDropout(
|
54 |
+
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
|
55 |
+
)
|
56 |
+
self.encoder_layerdrop = cfg.encoder.layerdrop
|
57 |
+
self.return_fc = return_fc
|
58 |
+
|
59 |
+
embed_dim = embed_tokens.embedding_dim
|
60 |
+
self.padding_idx = embed_tokens.padding_idx
|
61 |
+
self.max_source_positions = cfg.max_source_positions
|
62 |
+
|
63 |
+
self.embed_tokens = embed_tokens
|
64 |
+
|
65 |
+
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
|
66 |
+
|
67 |
+
self.embed_positions = (
|
68 |
+
PositionalEmbedding(
|
69 |
+
cfg.max_source_positions,
|
70 |
+
embed_dim,
|
71 |
+
self.padding_idx,
|
72 |
+
learned=cfg.encoder.learned_pos,
|
73 |
+
)
|
74 |
+
if not cfg.no_token_positional_embeddings
|
75 |
+
else None
|
76 |
+
)
|
77 |
+
if cfg.layernorm_embedding:
|
78 |
+
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
|
79 |
+
else:
|
80 |
+
self.layernorm_embedding = None
|
81 |
+
|
82 |
+
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
|
83 |
+
self.quant_noise = apply_quant_noise_(
|
84 |
+
nn.Linear(embed_dim, embed_dim, bias=False),
|
85 |
+
cfg.quant_noise.pq,
|
86 |
+
cfg.quant_noise.pq_block_size,
|
87 |
+
)
|
88 |
+
else:
|
89 |
+
self.quant_noise = None
|
90 |
+
|
91 |
+
if self.encoder_layerdrop > 0.0:
|
92 |
+
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
|
93 |
+
else:
|
94 |
+
self.layers = nn.ModuleList([])
|
95 |
+
self.layers.extend(
|
96 |
+
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
|
97 |
+
)
|
98 |
+
self.num_layers = len(self.layers)
|
99 |
+
|
100 |
+
if cfg.encoder.normalize_before:
|
101 |
+
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
|
102 |
+
else:
|
103 |
+
self.layer_norm = None
|
104 |
+
|
105 |
+
def build_encoder_layer(self, cfg):
|
106 |
+
layer = transformer_layer.TransformerEncoderLayerBase(
|
107 |
+
cfg, return_fc=self.return_fc
|
108 |
+
)
|
109 |
+
checkpoint = cfg.checkpoint_activations
|
110 |
+
if checkpoint:
|
111 |
+
offload_to_cpu = cfg.offload_activations
|
112 |
+
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
|
113 |
+
# if we are checkpointing, enforce that FSDP always wraps the
|
114 |
+
# checkpointed layer, regardless of layer size
|
115 |
+
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
|
116 |
+
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
|
117 |
+
return layer
|
118 |
+
|
119 |
+
def forward_embedding(
|
120 |
+
self, src_tokens, token_embedding: Optional[torch.Tensor] = None
|
121 |
+
):
|
122 |
+
# embed tokens and positions
|
123 |
+
if token_embedding is None:
|
124 |
+
token_embedding = self.embed_tokens(src_tokens)
|
125 |
+
x = embed = self.embed_scale * token_embedding
|
126 |
+
if self.embed_positions is not None:
|
127 |
+
x = embed + self.embed_positions(src_tokens)
|
128 |
+
if self.layernorm_embedding is not None:
|
129 |
+
x = self.layernorm_embedding(x)
|
130 |
+
x = self.dropout_module(x)
|
131 |
+
if self.quant_noise is not None:
|
132 |
+
x = self.quant_noise(x)
|
133 |
+
return x, embed
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
src_tokens,
|
138 |
+
src_lengths: Optional[torch.Tensor] = None,
|
139 |
+
return_all_hiddens: bool = False,
|
140 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
141 |
+
):
|
142 |
+
"""
|
143 |
+
Args:
|
144 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
145 |
+
`(batch, src_len)`
|
146 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
147 |
+
shape `(batch)`
|
148 |
+
return_all_hiddens (bool, optional): also return all of the
|
149 |
+
intermediate hidden states (default: False).
|
150 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
151 |
+
default `None` will recompute embeddings
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
dict:
|
155 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
156 |
+
shape `(src_len, batch, embed_dim)`
|
157 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
158 |
+
padding elements of shape `(batch, src_len)`
|
159 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
160 |
+
of shape `(batch, src_len, embed_dim)`
|
161 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
162 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
163 |
+
Only populated if *return_all_hiddens* is True.
|
164 |
+
"""
|
165 |
+
return self.forward_scriptable(
|
166 |
+
src_tokens, src_lengths, return_all_hiddens, token_embeddings
|
167 |
+
)
|
168 |
+
|
169 |
+
# TorchScript doesn't support super() method so that the scriptable Subclass
|
170 |
+
# can't access the base class model in Torchscript.
|
171 |
+
# Current workaround is to add a helper function with different name and
|
172 |
+
# call the helper function from scriptable Subclass.
|
173 |
+
def forward_scriptable(
|
174 |
+
self,
|
175 |
+
src_tokens,
|
176 |
+
src_lengths: Optional[torch.Tensor] = None,
|
177 |
+
return_all_hiddens: bool = False,
|
178 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
src_tokens (LongTensor): tokens in the source language of shape
|
183 |
+
`(batch, src_len)`
|
184 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
185 |
+
shape `(batch)`
|
186 |
+
return_all_hiddens (bool, optional): also return all of the
|
187 |
+
intermediate hidden states (default: False).
|
188 |
+
token_embeddings (torch.Tensor, optional): precomputed embeddings
|
189 |
+
default `None` will recompute embeddings
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
dict:
|
193 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
194 |
+
shape `(src_len, batch, embed_dim)`
|
195 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
196 |
+
padding elements of shape `(batch, src_len)`
|
197 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
198 |
+
of shape `(batch, src_len, embed_dim)`
|
199 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
200 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
201 |
+
Only populated if *return_all_hiddens* is True.
|
202 |
+
"""
|
203 |
+
# compute padding mask
|
204 |
+
encoder_padding_mask = src_tokens.eq(self.padding_idx)
|
205 |
+
has_pads = (
|
206 |
+
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
|
207 |
+
)
|
208 |
+
# Torchscript doesn't handle bool Tensor correctly, so we need to work around.
|
209 |
+
if torch.jit.is_scripting():
|
210 |
+
has_pads = torch.tensor(1) if has_pads else torch.tensor(0)
|
211 |
+
|
212 |
+
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
|
213 |
+
|
214 |
+
# account for padding while computing the representation
|
215 |
+
x = x * (
|
216 |
+
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
|
217 |
+
)
|
218 |
+
|
219 |
+
# B x T x C -> T x B x C
|
220 |
+
x = x.transpose(0, 1)
|
221 |
+
|
222 |
+
encoder_states = []
|
223 |
+
fc_results = []
|
224 |
+
|
225 |
+
if return_all_hiddens:
|
226 |
+
encoder_states.append(x)
|
227 |
+
|
228 |
+
# encoder layers
|
229 |
+
for layer in self.layers:
|
230 |
+
lr = layer(
|
231 |
+
x, encoder_padding_mask=encoder_padding_mask if has_pads else None
|
232 |
+
)
|
233 |
+
|
234 |
+
if isinstance(lr, tuple) and len(lr) == 2:
|
235 |
+
x, fc_result = lr
|
236 |
+
else:
|
237 |
+
x = lr
|
238 |
+
fc_result = None
|
239 |
+
|
240 |
+
if return_all_hiddens and not torch.jit.is_scripting():
|
241 |
+
assert encoder_states is not None
|
242 |
+
encoder_states.append(x)
|
243 |
+
fc_results.append(fc_result)
|
244 |
+
|
245 |
+
if self.layer_norm is not None:
|
246 |
+
x = self.layer_norm(x)
|
247 |
+
|
248 |
+
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
|
249 |
+
# `forward` so we use a dictionary instead.
|
250 |
+
# TorchScript does not support mixed values so the values are all lists.
|
251 |
+
# The empty list is equivalent to None.
|
252 |
+
src_lengths = (
|
253 |
+
src_tokens.ne(self.padding_idx)
|
254 |
+
.sum(dim=1, dtype=torch.int32)
|
255 |
+
.reshape(-1, 1)
|
256 |
+
.contiguous()
|
257 |
+
)
|
258 |
+
return {
|
259 |
+
"encoder_out": [x], # T x B x C
|
260 |
+
"encoder_padding_mask": [encoder_padding_mask], # B x T
|
261 |
+
"encoder_embedding": [encoder_embedding], # B x T x C
|
262 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
263 |
+
"fc_results": fc_results, # List[T x B x C]
|
264 |
+
"src_tokens": [],
|
265 |
+
"src_lengths": [src_lengths],
|
266 |
+
}
|
267 |
+
|
268 |
+
@torch.jit.export
|
269 |
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
270 |
+
"""
|
271 |
+
Reorder encoder output according to *new_order*.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
encoder_out: output from the ``forward()`` method
|
275 |
+
new_order (LongTensor): desired order
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
*encoder_out* rearranged according to *new_order*
|
279 |
+
"""
|
280 |
+
if len(encoder_out["encoder_out"]) == 0:
|
281 |
+
new_encoder_out = []
|
282 |
+
else:
|
283 |
+
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
|
284 |
+
if len(encoder_out["encoder_padding_mask"]) == 0:
|
285 |
+
new_encoder_padding_mask = []
|
286 |
+
else:
|
287 |
+
new_encoder_padding_mask = [
|
288 |
+
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
|
289 |
+
]
|
290 |
+
if len(encoder_out["encoder_embedding"]) == 0:
|
291 |
+
new_encoder_embedding = []
|
292 |
+
else:
|
293 |
+
new_encoder_embedding = [
|
294 |
+
encoder_out["encoder_embedding"][0].index_select(0, new_order)
|
295 |
+
]
|
296 |
+
|
297 |
+
if len(encoder_out["src_tokens"]) == 0:
|
298 |
+
src_tokens = []
|
299 |
+
else:
|
300 |
+
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
|
301 |
+
|
302 |
+
if len(encoder_out["src_lengths"]) == 0:
|
303 |
+
src_lengths = []
|
304 |
+
else:
|
305 |
+
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
|
306 |
+
|
307 |
+
encoder_states = encoder_out["encoder_states"]
|
308 |
+
if len(encoder_states) > 0:
|
309 |
+
for idx, state in enumerate(encoder_states):
|
310 |
+
encoder_states[idx] = state.index_select(1, new_order)
|
311 |
+
|
312 |
+
return {
|
313 |
+
"encoder_out": new_encoder_out, # T x B x C
|
314 |
+
"encoder_padding_mask": new_encoder_padding_mask, # B x T
|
315 |
+
"encoder_embedding": new_encoder_embedding, # B x T x C
|
316 |
+
"encoder_states": encoder_states, # List[T x B x C]
|
317 |
+
"src_tokens": src_tokens, # B x T
|
318 |
+
"src_lengths": src_lengths, # B x 1
|
319 |
+
}
|
320 |
+
|
321 |
+
@torch.jit.export
|
322 |
+
def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
323 |
+
"""Dummy re-order function for beamable enc-dec attention"""
|
324 |
+
return encoder_out
|
325 |
+
|
326 |
+
def max_positions(self):
|
327 |
+
"""Maximum input length supported by the encoder."""
|
328 |
+
if self.embed_positions is None:
|
329 |
+
return self.max_source_positions
|
330 |
+
return min(self.max_source_positions, self.embed_positions.max_positions)
|
331 |
+
|
332 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
333 |
+
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
334 |
+
for i in range(self.num_layers):
|
335 |
+
# update layer norms
|
336 |
+
self.layers[i].upgrade_state_dict_named(
|
337 |
+
state_dict, "{}.layers.{}".format(name, i)
|
338 |
+
)
|
339 |
+
|
340 |
+
version_key = "{}.version".format(name)
|
341 |
+
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
|
342 |
+
# earlier checkpoints did not normalize after the stack of layers
|
343 |
+
self.layer_norm = None
|
344 |
+
self.normalize = False
|
345 |
+
state_dict[version_key] = torch.Tensor([1])
|
346 |
+
return state_dict
|
347 |
+
|
348 |
+
|
349 |
+
class TransformerEncoder(TransformerEncoderBase):
|
350 |
+
def __init__(self, args, dictionary, embed_tokens, return_fc=False):
|
351 |
+
self.args = args
|
352 |
+
super().__init__(
|
353 |
+
TransformerConfig.from_namespace(args),
|
354 |
+
dictionary,
|
355 |
+
embed_tokens,
|
356 |
+
return_fc=return_fc,
|
357 |
+
)
|
358 |
+
|
359 |
+
def build_encoder_layer(self, args):
|
360 |
+
return super().build_encoder_layer(
|
361 |
+
TransformerConfig.from_namespace(args),
|
362 |
+
)
|
fairseq/fairseq/models/transformer/transformer_legacy.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
7 |
+
from fairseq.models import (
|
8 |
+
register_model,
|
9 |
+
register_model_architecture,
|
10 |
+
)
|
11 |
+
from fairseq.models.transformer.transformer_config import (
|
12 |
+
TransformerConfig,
|
13 |
+
DEFAULT_MAX_SOURCE_POSITIONS,
|
14 |
+
DEFAULT_MAX_TARGET_POSITIONS,
|
15 |
+
DEFAULT_MIN_PARAMS_TO_WRAP,
|
16 |
+
)
|
17 |
+
from fairseq.models.transformer.transformer_base import (
|
18 |
+
TransformerModelBase,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
@register_model("transformer")
|
23 |
+
class TransformerModel(TransformerModelBase):
|
24 |
+
"""
|
25 |
+
This is the legacy implementation of the transformer model that
|
26 |
+
uses argparse for configuration.
|
27 |
+
"""
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def hub_models(cls):
|
31 |
+
# fmt: off
|
32 |
+
|
33 |
+
def moses_subword(path):
|
34 |
+
return {
|
35 |
+
'path': path,
|
36 |
+
'tokenizer': 'moses',
|
37 |
+
'bpe': 'subword_nmt',
|
38 |
+
}
|
39 |
+
|
40 |
+
def moses_fastbpe(path):
|
41 |
+
return {
|
42 |
+
'path': path,
|
43 |
+
'tokenizer': 'moses',
|
44 |
+
'bpe': 'fastbpe',
|
45 |
+
}
|
46 |
+
|
47 |
+
def spm(path):
|
48 |
+
return {
|
49 |
+
'path': path,
|
50 |
+
'bpe': 'sentencepiece',
|
51 |
+
'tokenizer': 'space',
|
52 |
+
}
|
53 |
+
|
54 |
+
return {
|
55 |
+
'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
|
56 |
+
'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
|
57 |
+
'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
|
58 |
+
'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
|
59 |
+
'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
|
60 |
+
'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
|
61 |
+
'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
|
62 |
+
'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
|
63 |
+
'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
|
64 |
+
'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
|
65 |
+
'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
|
66 |
+
'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'),
|
67 |
+
'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'),
|
68 |
+
'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'),
|
69 |
+
'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'),
|
70 |
+
'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'),
|
71 |
+
'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'),
|
72 |
+
'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'),
|
73 |
+
'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'),
|
74 |
+
}
|
75 |
+
# fmt: on
|
76 |
+
|
77 |
+
def __init__(self, args, encoder, decoder):
|
78 |
+
cfg = TransformerConfig.from_namespace(args)
|
79 |
+
super().__init__(cfg, encoder, decoder)
|
80 |
+
self.args = args
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def add_args(cls, parser):
|
84 |
+
"""Add model-specific arguments to the parser."""
|
85 |
+
# we want to build the args recursively in this case.
|
86 |
+
# do not set defaults so that settings defaults from various architectures still works
|
87 |
+
gen_parser_from_dataclass(
|
88 |
+
parser, TransformerConfig(), delete_default=True, with_prefix=""
|
89 |
+
)
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def build_model(cls, args, task):
|
93 |
+
"""Build a new model instance."""
|
94 |
+
|
95 |
+
# make sure all arguments are present in older models
|
96 |
+
base_architecture(args)
|
97 |
+
|
98 |
+
if args.encoder_layers_to_keep:
|
99 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
100 |
+
if args.decoder_layers_to_keep:
|
101 |
+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
102 |
+
|
103 |
+
if getattr(args, "max_source_positions", None) is None:
|
104 |
+
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
105 |
+
if getattr(args, "max_target_positions", None) is None:
|
106 |
+
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
107 |
+
|
108 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
109 |
+
|
110 |
+
if args.share_all_embeddings:
|
111 |
+
if src_dict != tgt_dict:
|
112 |
+
raise ValueError("--share-all-embeddings requires a joined dictionary")
|
113 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
114 |
+
raise ValueError(
|
115 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
116 |
+
)
|
117 |
+
if args.decoder_embed_path and (
|
118 |
+
args.decoder_embed_path != args.encoder_embed_path
|
119 |
+
):
|
120 |
+
raise ValueError(
|
121 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
122 |
+
)
|
123 |
+
args.share_decoder_input_output_embed = True
|
124 |
+
|
125 |
+
if getattr(args, "offload_activations", False):
|
126 |
+
args.checkpoint_activations = True # offloading implies checkpointing
|
127 |
+
|
128 |
+
if not args.share_all_embeddings:
|
129 |
+
args.min_params_to_wrap = getattr(
|
130 |
+
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
|
131 |
+
)
|
132 |
+
cfg = TransformerConfig.from_namespace(args)
|
133 |
+
return super().build_model(cfg, task)
|
134 |
+
|
135 |
+
@classmethod
|
136 |
+
def build_embedding(cls, args, dictionary, embed_dim, path=None):
|
137 |
+
return super().build_embedding(
|
138 |
+
TransformerConfig.from_namespace(args), dictionary, embed_dim, path
|
139 |
+
)
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
143 |
+
return super().build_encoder(
|
144 |
+
TransformerConfig.from_namespace(args), src_dict, embed_tokens
|
145 |
+
)
|
146 |
+
|
147 |
+
@classmethod
|
148 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
149 |
+
return super().build_decoder(
|
150 |
+
TransformerConfig.from_namespace(args), tgt_dict, embed_tokens
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
# architectures
|
155 |
+
|
156 |
+
|
157 |
+
@register_model_architecture("transformer", "transformer_tiny")
|
158 |
+
def tiny_architecture(args):
|
159 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64)
|
160 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64)
|
161 |
+
args.encoder_layers = getattr(args, "encoder_layers", 2)
|
162 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
|
163 |
+
args.decoder_layers = getattr(args, "decoder_layers", 2)
|
164 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
|
165 |
+
return base_architecture(args)
|
166 |
+
|
167 |
+
|
168 |
+
@register_model_architecture("transformer", "transformer")
|
169 |
+
def base_architecture(args):
|
170 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
171 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
172 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
|
173 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
174 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
|
175 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
176 |
+
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
|
177 |
+
|
178 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
179 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
180 |
+
args.decoder_ffn_embed_dim = getattr(
|
181 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
182 |
+
)
|
183 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
184 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
185 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
186 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
187 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
188 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
189 |
+
args.activation_fn = getattr(args, "activation_fn", "relu")
|
190 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
191 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
192 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
193 |
+
args.share_decoder_input_output_embed = getattr(
|
194 |
+
args, "share_decoder_input_output_embed", False
|
195 |
+
)
|
196 |
+
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
|
197 |
+
args.merge_src_tgt_embed = getattr(args, "merge_src_tgt_embed", False)
|
198 |
+
args.no_token_positional_embeddings = getattr(
|
199 |
+
args, "no_token_positional_embeddings", False
|
200 |
+
)
|
201 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
202 |
+
args.no_cross_attention = getattr(args, "no_cross_attention", False)
|
203 |
+
args.cross_self_attention = getattr(args, "cross_self_attention", False)
|
204 |
+
|
205 |
+
args.decoder_output_dim = getattr(
|
206 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
207 |
+
)
|
208 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
209 |
+
|
210 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
211 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
212 |
+
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
|
213 |
+
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
|
214 |
+
args.offload_activations = getattr(args, "offload_activations", False)
|
215 |
+
if args.offload_activations:
|
216 |
+
args.checkpoint_activations = True
|
217 |
+
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
218 |
+
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
|
219 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
220 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
221 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
222 |
+
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
|
223 |
+
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
|
224 |
+
|
225 |
+
|
226 |
+
@register_model_architecture("transformer", "transformer_iwslt_de_en")
|
227 |
+
def transformer_iwslt_de_en(args):
|
228 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
229 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
230 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
231 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
232 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
233 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
234 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
235 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
236 |
+
base_architecture(args)
|
237 |
+
|
238 |
+
|
239 |
+
@register_model_architecture("transformer", "transformer_wmt_en_de")
|
240 |
+
def transformer_wmt_en_de(args):
|
241 |
+
base_architecture(args)
|
242 |
+
|
243 |
+
|
244 |
+
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
245 |
+
@register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big")
|
246 |
+
def transformer_vaswani_wmt_en_de_big(args):
|
247 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
248 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
249 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
250 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
251 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
252 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
253 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
254 |
+
args.dropout = getattr(args, "dropout", 0.3)
|
255 |
+
base_architecture(args)
|
256 |
+
|
257 |
+
|
258 |
+
@register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big")
|
259 |
+
def transformer_vaswani_wmt_en_fr_big(args):
|
260 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
261 |
+
transformer_vaswani_wmt_en_de_big(args)
|
262 |
+
|
263 |
+
|
264 |
+
@register_model_architecture("transformer", "transformer_wmt_en_de_big")
|
265 |
+
def transformer_wmt_en_de_big(args):
|
266 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
267 |
+
transformer_vaswani_wmt_en_de_big(args)
|
268 |
+
|
269 |
+
|
270 |
+
# default parameters used in tensor2tensor implementation
|
271 |
+
@register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t")
|
272 |
+
def transformer_wmt_en_de_big_t2t(args):
|
273 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
274 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
275 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
276 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
|
277 |
+
transformer_vaswani_wmt_en_de_big(args)
|
fairseq/fairseq/models/wav2vec/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .wav2vec import * # noqa
|
7 |
+
from .wav2vec2 import * # noqa
|
8 |
+
from .wav2vec2_asr import * # noqa
|
9 |
+
from .wav2vec2_laser import * # noqa
|
10 |
+
from .wav2vec2_classification import * # noqa
|
fairseq/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (345 Bytes). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (670 Bytes). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc
ADDED
Binary file (15.1 kB). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc
ADDED
Binary file (32.6 kB). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc
ADDED
Binary file (23.9 kB). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc
ADDED
Binary file (9.41 kB). View file
|
|
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|