PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
e3b406d
·
verified ·
1 Parent(s): be46eed

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/fairseq/models/__pycache__/__init__.cpython-310.pyc +0 -0
  2. fairseq/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc +0 -0
  3. fairseq/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc +0 -0
  4. fairseq/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc +0 -0
  5. fairseq/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc +0 -0
  6. fairseq/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc +0 -0
  7. fairseq/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc +0 -0
  8. fairseq/fairseq/models/__pycache__/fconv.cpython-310.pyc +0 -0
  9. fairseq/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc +0 -0
  10. fairseq/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc +0 -0
  11. fairseq/fairseq/models/__pycache__/lightconv.cpython-310.pyc +0 -0
  12. fairseq/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc +0 -0
  13. fairseq/fairseq/models/__pycache__/lstm.cpython-310.pyc +0 -0
  14. fairseq/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc +0 -0
  15. fairseq/fairseq/models/__pycache__/masked_lm.cpython-310.pyc +0 -0
  16. fairseq/fairseq/models/__pycache__/model_utils.cpython-310.pyc +0 -0
  17. fairseq/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc +0 -0
  18. fairseq/fairseq/models/__pycache__/transformer_align.cpython-310.pyc +0 -0
  19. fairseq/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc +0 -0
  20. fairseq/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc +0 -0
  21. fairseq/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc +0 -0
  22. fairseq/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc +0 -0
  23. fairseq/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc +0 -0
  24. fairseq/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc +0 -0
  25. fairseq/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc +0 -0
  26. fairseq/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc +0 -0
  27. fairseq/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc +0 -0
  28. fairseq/fairseq/models/text_to_speech/tts_transformer.py +454 -0
  29. fairseq/fairseq/models/transformer/__init__.py +50 -0
  30. fairseq/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc +0 -0
  31. fairseq/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc +0 -0
  32. fairseq/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc +0 -0
  33. fairseq/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc +0 -0
  34. fairseq/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc +0 -0
  35. fairseq/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc +0 -0
  36. fairseq/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc +0 -0
  37. fairseq/fairseq/models/transformer/transformer_base.py +193 -0
  38. fairseq/fairseq/models/transformer/transformer_config.py +341 -0
  39. fairseq/fairseq/models/transformer/transformer_decoder.py +474 -0
  40. fairseq/fairseq/models/transformer/transformer_decoder_aug.py +384 -0
  41. fairseq/fairseq/models/transformer/transformer_encoder.py +362 -0
  42. fairseq/fairseq/models/transformer/transformer_legacy.py +277 -0
  43. fairseq/fairseq/models/wav2vec/__init__.py +10 -0
  44. fairseq/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc +0 -0
  45. fairseq/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc +0 -0
  46. fairseq/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc +0 -0
  47. fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc +0 -0
  48. fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc +0 -0
  49. fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc +0 -0
  50. fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc +0 -0
fairseq/fairseq/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (6.01 kB). View file
 
fairseq/fairseq/models/__pycache__/composite_encoder.cpython-310.pyc ADDED
Binary file (2.41 kB). View file
 
fairseq/fairseq/models/__pycache__/distributed_fairseq_model.cpython-310.pyc ADDED
Binary file (3.6 kB). View file
 
fairseq/fairseq/models/__pycache__/fairseq_decoder.cpython-310.pyc ADDED
Binary file (3.74 kB). View file
 
fairseq/fairseq/models/__pycache__/fairseq_encoder.cpython-310.pyc ADDED
Binary file (3.62 kB). View file
 
fairseq/fairseq/models/__pycache__/fairseq_incremental_decoder.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
fairseq/fairseq/models/__pycache__/fairseq_model.cpython-310.pyc ADDED
Binary file (20.7 kB). View file
 
fairseq/fairseq/models/__pycache__/fconv.cpython-310.pyc ADDED
Binary file (19.1 kB). View file
 
fairseq/fairseq/models/__pycache__/fconv_lm.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
fairseq/fairseq/models/__pycache__/fconv_self_att.cpython-310.pyc ADDED
Binary file (16.3 kB). View file
 
fairseq/fairseq/models/__pycache__/lightconv.cpython-310.pyc ADDED
Binary file (27.5 kB). View file
 
fairseq/fairseq/models/__pycache__/lightconv_lm.cpython-310.pyc ADDED
Binary file (7.03 kB). View file
 
fairseq/fairseq/models/__pycache__/lstm.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
fairseq/fairseq/models/__pycache__/lstm_lm.cpython-310.pyc ADDED
Binary file (4.38 kB). View file
 
fairseq/fairseq/models/__pycache__/masked_lm.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
fairseq/fairseq/models/__pycache__/model_utils.cpython-310.pyc ADDED
Binary file (2.39 kB). View file
 
fairseq/fairseq/models/__pycache__/multilingual_transformer.cpython-310.pyc ADDED
Binary file (6.78 kB). View file
 
fairseq/fairseq/models/__pycache__/transformer_align.cpython-310.pyc ADDED
Binary file (3.05 kB). View file
 
fairseq/fairseq/models/__pycache__/transformer_from_pretrained_xlm.cpython-310.pyc ADDED
Binary file (5.39 kB). View file
 
fairseq/fairseq/models/__pycache__/transformer_lm.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
fairseq/fairseq/models/__pycache__/transformer_ulm.cpython-310.pyc ADDED
Binary file (9.51 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/codehifigan.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/fastspeech2.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/hifigan.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/hub_interface.cpython-310.pyc ADDED
Binary file (6.18 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/tts_transformer.cpython-310.pyc ADDED
Binary file (12.4 kB). View file
 
fairseq/fairseq/models/text_to_speech/__pycache__/vocoder.cpython-310.pyc ADDED
Binary file (9.9 kB). View file
 
fairseq/fairseq/models/text_to_speech/tts_transformer.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from typing import List, Optional
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from fairseq import utils
13
+ from fairseq.data.data_utils import lengths_to_padding_mask
14
+ from fairseq.models import (
15
+ FairseqEncoder,
16
+ FairseqEncoderDecoderModel,
17
+ FairseqIncrementalDecoder,
18
+ register_model,
19
+ register_model_architecture,
20
+ )
21
+ from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
22
+ from fairseq.models.text_to_speech.tacotron2 import Postnet, Prenet
23
+ from fairseq.modules import (
24
+ FairseqDropout,
25
+ LayerNorm,
26
+ PositionalEmbedding,
27
+ TransformerDecoderLayer,
28
+ TransformerEncoderLayer,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def encoder_init(m):
35
+ if isinstance(m, nn.Conv1d):
36
+ nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))
37
+
38
+
39
+ def Embedding(num_embeddings, embedding_dim):
40
+ m = nn.Embedding(num_embeddings, embedding_dim)
41
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
42
+ return m
43
+
44
+
45
+ class TTSTransformerEncoder(FairseqEncoder):
46
+ def __init__(self, args, src_dict, embed_speaker):
47
+ super().__init__(src_dict)
48
+ self.padding_idx = src_dict.pad()
49
+ self.embed_speaker = embed_speaker
50
+ self.spk_emb_proj = None
51
+ if embed_speaker is not None:
52
+ self.spk_emb_proj = nn.Linear(
53
+ args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
54
+ )
55
+
56
+ self.dropout_module = FairseqDropout(
57
+ p=args.dropout, module_name=self.__class__.__name__
58
+ )
59
+ self.embed_tokens = nn.Embedding(
60
+ len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
61
+ )
62
+ assert args.encoder_conv_kernel_size % 2 == 1
63
+ self.prenet = nn.ModuleList(
64
+ nn.Sequential(
65
+ nn.Conv1d(
66
+ args.encoder_embed_dim,
67
+ args.encoder_embed_dim,
68
+ kernel_size=args.encoder_conv_kernel_size,
69
+ padding=((args.encoder_conv_kernel_size - 1) // 2),
70
+ ),
71
+ nn.BatchNorm1d(args.encoder_embed_dim),
72
+ nn.ReLU(),
73
+ nn.Dropout(args.encoder_dropout),
74
+ )
75
+ for _ in range(args.encoder_conv_layers)
76
+ )
77
+ self.prenet_proj = nn.Linear(args.encoder_embed_dim, args.encoder_embed_dim)
78
+ self.embed_positions = PositionalEmbedding(
79
+ args.max_source_positions, args.encoder_embed_dim, self.padding_idx
80
+ )
81
+ self.pos_emb_alpha = nn.Parameter(torch.ones(1))
82
+
83
+ self.transformer_layers = nn.ModuleList(
84
+ TransformerEncoderLayer(args)
85
+ for _ in range(args.encoder_transformer_layers)
86
+ )
87
+ if args.encoder_normalize_before:
88
+ self.layer_norm = LayerNorm(args.encoder_embed_dim)
89
+ else:
90
+ self.layer_norm = None
91
+
92
+ self.apply(encoder_init)
93
+
94
+ def forward(self, src_tokens, src_lengths=None, speaker=None, **kwargs):
95
+ x = self.embed_tokens(src_tokens)
96
+ x = x.transpose(1, 2).contiguous() # B x T x C -> B x C x T
97
+ for conv in self.prenet:
98
+ x = conv(x)
99
+ x = x.transpose(1, 2).contiguous() # B x C x T -> B x T x C
100
+ x = self.prenet_proj(x)
101
+
102
+ padding_mask = src_tokens.eq(self.padding_idx)
103
+ positions = self.embed_positions(padding_mask)
104
+ x += self.pos_emb_alpha * positions
105
+ x = self.dropout_module(x)
106
+
107
+ # B x T x C -> T x B x C
108
+ x = x.transpose(0, 1)
109
+
110
+ for layer in self.transformer_layers:
111
+ x = layer(x, padding_mask)
112
+
113
+ if self.layer_norm is not None:
114
+ x = self.layer_norm(x)
115
+
116
+ if self.embed_speaker is not None:
117
+ seq_len, bsz, _ = x.size()
118
+ emb = self.embed_speaker(speaker).transpose(0, 1)
119
+ emb = emb.expand(seq_len, bsz, -1)
120
+ x = self.spk_emb_proj(torch.cat([x, emb], dim=2))
121
+
122
+ return {
123
+ "encoder_out": [x], # T x B x C
124
+ "encoder_padding_mask": [padding_mask]
125
+ if padding_mask.any()
126
+ else [], # B x T
127
+ "encoder_embedding": [], # B x T x C
128
+ "encoder_states": [], # List[T x B x C]
129
+ "src_tokens": [],
130
+ "src_lengths": [],
131
+ }
132
+
133
+
134
+ def decoder_init(m):
135
+ if isinstance(m, torch.nn.Conv1d):
136
+ nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh"))
137
+
138
+
139
+ class TTSTransformerDecoder(FairseqIncrementalDecoder):
140
+ def __init__(self, args, src_dict, padding_idx=1):
141
+ super().__init__(None)
142
+ self._future_mask = torch.empty(0)
143
+
144
+ self.args = args
145
+ self.padding_idx = src_dict.pad() if src_dict else padding_idx
146
+ self.n_frames_per_step = args.n_frames_per_step
147
+ self.out_dim = args.output_frame_dim * args.n_frames_per_step
148
+
149
+ self.dropout_module = FairseqDropout(
150
+ args.dropout, module_name=self.__class__.__name__
151
+ )
152
+ self.embed_positions = PositionalEmbedding(
153
+ args.max_target_positions, args.decoder_embed_dim, self.padding_idx
154
+ )
155
+ self.pos_emb_alpha = nn.Parameter(torch.ones(1))
156
+ self.prenet = nn.Sequential(
157
+ Prenet(
158
+ self.out_dim, args.prenet_layers, args.prenet_dim, args.prenet_dropout
159
+ ),
160
+ nn.Linear(args.prenet_dim, args.decoder_embed_dim),
161
+ )
162
+
163
+ self.n_transformer_layers = args.decoder_transformer_layers
164
+ self.transformer_layers = nn.ModuleList(
165
+ TransformerDecoderLayer(args) for _ in range(self.n_transformer_layers)
166
+ )
167
+ if args.decoder_normalize_before:
168
+ self.layer_norm = LayerNorm(args.decoder_embed_dim)
169
+ else:
170
+ self.layer_norm = None
171
+
172
+ self.feat_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)
173
+ self.eos_proj = nn.Linear(args.decoder_embed_dim, 1)
174
+
175
+ self.postnet = Postnet(
176
+ self.out_dim,
177
+ args.postnet_conv_dim,
178
+ args.postnet_conv_kernel_size,
179
+ args.postnet_layers,
180
+ args.postnet_dropout,
181
+ )
182
+
183
+ self.ctc_proj = None
184
+ if getattr(args, "ctc_weight", 0.0) > 0.0:
185
+ self.ctc_proj = nn.Linear(self.out_dim, len(src_dict))
186
+
187
+ self.apply(decoder_init)
188
+
189
+ def extract_features(
190
+ self,
191
+ prev_outputs,
192
+ encoder_out=None,
193
+ incremental_state=None,
194
+ target_lengths=None,
195
+ speaker=None,
196
+ **kwargs,
197
+ ):
198
+ alignment_layer = self.n_transformer_layers - 1
199
+ self_attn_padding_mask = lengths_to_padding_mask(target_lengths)
200
+ positions = self.embed_positions(
201
+ self_attn_padding_mask, incremental_state=incremental_state
202
+ )
203
+
204
+ if incremental_state is not None:
205
+ prev_outputs = prev_outputs[:, -1:, :]
206
+ self_attn_padding_mask = self_attn_padding_mask[:, -1:]
207
+ if positions is not None:
208
+ positions = positions[:, -1:]
209
+
210
+ x = self.prenet(prev_outputs)
211
+ x += self.pos_emb_alpha * positions
212
+ x = self.dropout_module(x)
213
+
214
+ # B x T x C -> T x B x C
215
+ x = x.transpose(0, 1)
216
+
217
+ if not self_attn_padding_mask.any():
218
+ self_attn_padding_mask = None
219
+
220
+ attn: Optional[torch.Tensor] = None
221
+ inner_states: List[Optional[torch.Tensor]] = [x]
222
+ for idx, transformer_layer in enumerate(self.transformer_layers):
223
+ if incremental_state is None:
224
+ self_attn_mask = self.buffered_future_mask(x)
225
+ else:
226
+ self_attn_mask = None
227
+
228
+ x, layer_attn, _ = transformer_layer(
229
+ x,
230
+ encoder_out["encoder_out"][0]
231
+ if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0)
232
+ else None,
233
+ encoder_out["encoder_padding_mask"][0]
234
+ if (
235
+ encoder_out is not None
236
+ and len(encoder_out["encoder_padding_mask"]) > 0
237
+ )
238
+ else None,
239
+ incremental_state,
240
+ self_attn_mask=self_attn_mask,
241
+ self_attn_padding_mask=self_attn_padding_mask,
242
+ need_attn=bool((idx == alignment_layer)),
243
+ need_head_weights=bool((idx == alignment_layer)),
244
+ )
245
+ inner_states.append(x)
246
+ if layer_attn is not None and idx == alignment_layer:
247
+ attn = layer_attn.float().to(x)
248
+
249
+ if attn is not None:
250
+ # average probabilities over heads, transpose to
251
+ # (B, src_len, tgt_len)
252
+ attn = attn.mean(dim=0).transpose(2, 1)
253
+
254
+ if self.layer_norm is not None:
255
+ x = self.layer_norm(x)
256
+
257
+ # T x B x C -> B x T x C
258
+ x = x.transpose(0, 1)
259
+
260
+ return x, {"attn": attn, "inner_states": inner_states}
261
+
262
+ def forward(
263
+ self,
264
+ prev_output_tokens,
265
+ encoder_out=None,
266
+ incremental_state=None,
267
+ target_lengths=None,
268
+ speaker=None,
269
+ **kwargs,
270
+ ):
271
+ x, extra = self.extract_features(
272
+ prev_output_tokens,
273
+ encoder_out=encoder_out,
274
+ incremental_state=incremental_state,
275
+ target_lengths=target_lengths,
276
+ speaker=speaker,
277
+ **kwargs,
278
+ )
279
+ attn = extra["attn"]
280
+ feat_out = self.feat_proj(x)
281
+ bsz, seq_len, _ = x.size()
282
+ eos_out = self.eos_proj(x)
283
+ post_feat_out = feat_out + self.postnet(feat_out)
284
+ return (
285
+ post_feat_out,
286
+ eos_out,
287
+ {
288
+ "attn": attn,
289
+ "feature_out": feat_out,
290
+ "inner_states": extra["inner_states"],
291
+ },
292
+ )
293
+
294
+ def get_normalized_probs(self, net_output, log_probs, sample):
295
+ logits = self.ctc_proj(net_output[2]["feature_out"])
296
+ if log_probs:
297
+ return utils.log_softmax(logits.float(), dim=-1)
298
+ else:
299
+ return utils.softmax(logits.float(), dim=-1)
300
+
301
+ def buffered_future_mask(self, tensor):
302
+ dim = tensor.size(0)
303
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
304
+ if (
305
+ self._future_mask.size(0) == 0
306
+ or (not self._future_mask.device == tensor.device)
307
+ or self._future_mask.size(0) < dim
308
+ ):
309
+ self._future_mask = torch.triu(
310
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
311
+ )
312
+ self._future_mask = self._future_mask.to(tensor)
313
+ return self._future_mask[:dim, :dim]
314
+
315
+
316
+ @register_model("tts_transformer")
317
+ class TTSTransformerModel(FairseqEncoderDecoderModel):
318
+ """
319
+ Implementation for https://arxiv.org/pdf/1809.08895.pdf
320
+ """
321
+
322
+ @classmethod
323
+ def hub_models(cls):
324
+ base_url = "http://dl.fbaipublicfiles.com/fairseq/s2"
325
+ model_ids = [
326
+ "tts_transformer-en-ljspeech",
327
+ "tts_transformer-en-200_speaker-cv4",
328
+ "tts_transformer-es-css10",
329
+ "tts_transformer-fr-cv7_css10",
330
+ "tts_transformer-ru-cv7_css10",
331
+ "tts_transformer-zh-cv7_css10",
332
+ "tts_transformer-ar-cv7_css10",
333
+ "tts_transformer-tr-cv7_css10",
334
+ "tts_transformer-vi-cv7",
335
+ ]
336
+ return {i: f"{base_url}/{i}.tar.gz" for i in model_ids}
337
+
338
+ @classmethod
339
+ def from_pretrained(
340
+ cls,
341
+ model_name_or_path,
342
+ checkpoint_file="model.pt",
343
+ data_name_or_path=".",
344
+ config_yaml="config.yaml",
345
+ vocoder: str = "griffin_lim",
346
+ fp16: bool = False,
347
+ **kwargs,
348
+ ):
349
+ from fairseq import hub_utils
350
+
351
+ x = hub_utils.from_pretrained(
352
+ model_name_or_path,
353
+ checkpoint_file,
354
+ data_name_or_path,
355
+ archive_map=cls.hub_models(),
356
+ config_yaml=config_yaml,
357
+ vocoder=vocoder,
358
+ fp16=fp16,
359
+ **kwargs,
360
+ )
361
+ return TTSHubInterface(x["args"], x["task"], x["models"][0])
362
+
363
+ @staticmethod
364
+ def add_args(parser):
365
+ parser.add_argument("--dropout", type=float)
366
+ parser.add_argument("--output-frame-dim", type=int)
367
+ parser.add_argument("--speaker-embed-dim", type=int)
368
+ # encoder prenet
369
+ parser.add_argument("--encoder-dropout", type=float)
370
+ parser.add_argument("--encoder-conv-layers", type=int)
371
+ parser.add_argument("--encoder-conv-kernel-size", type=int)
372
+ # encoder transformer layers
373
+ parser.add_argument("--encoder-transformer-layers", type=int)
374
+ parser.add_argument("--encoder-embed-dim", type=int)
375
+ parser.add_argument("--encoder-ffn-embed-dim", type=int)
376
+ parser.add_argument("--encoder-normalize-before", action="store_true")
377
+ parser.add_argument("--encoder-attention-heads", type=int)
378
+ parser.add_argument("--attention-dropout", type=float)
379
+ parser.add_argument("--activation-dropout", "--relu-dropout", type=float)
380
+ parser.add_argument("--activation-fn", type=str, default="relu")
381
+ # decoder prenet
382
+ parser.add_argument("--prenet-dropout", type=float)
383
+ parser.add_argument("--prenet-layers", type=int)
384
+ parser.add_argument("--prenet-dim", type=int)
385
+ # decoder postnet
386
+ parser.add_argument("--postnet-dropout", type=float)
387
+ parser.add_argument("--postnet-layers", type=int)
388
+ parser.add_argument("--postnet-conv-dim", type=int)
389
+ parser.add_argument("--postnet-conv-kernel-size", type=int)
390
+ # decoder transformer layers
391
+ parser.add_argument("--decoder-transformer-layers", type=int)
392
+ parser.add_argument("--decoder-embed-dim", type=int)
393
+ parser.add_argument("--decoder-ffn-embed-dim", type=int)
394
+ parser.add_argument("--decoder-normalize-before", action="store_true")
395
+ parser.add_argument("--decoder-attention-heads", type=int)
396
+
397
+ def __init__(self, *args, **kwargs):
398
+ super().__init__(*args, **kwargs)
399
+ self._num_updates = 0
400
+
401
+ @classmethod
402
+ def build_model(cls, args, task):
403
+ embed_speaker = task.get_speaker_embeddings(args)
404
+ encoder = TTSTransformerEncoder(args, task.src_dict, embed_speaker)
405
+ decoder = TTSTransformerDecoder(args, task.src_dict)
406
+ return cls(encoder, decoder)
407
+
408
+ def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs):
409
+ return self.encoder(
410
+ src_tokens, src_lengths=src_lengths, speaker=speaker, **kwargs
411
+ )
412
+
413
+ def set_num_updates(self, num_updates):
414
+ super().set_num_updates(num_updates)
415
+ self._num_updates = num_updates
416
+
417
+
418
+ @register_model_architecture("tts_transformer", "tts_transformer")
419
+ def base_architecture(args):
420
+ args.dropout = getattr(args, "dropout", 0.1)
421
+ args.output_frame_dim = getattr(args, "output_frame_dim", 80)
422
+ args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
423
+ # encoder prenet
424
+ args.encoder_dropout = getattr(args, "encoder_dropout", 0.5)
425
+ args.encoder_conv_layers = getattr(args, "encoder_conv_layers", 3)
426
+ args.encoder_conv_kernel_size = getattr(args, "encoder_conv_kernel_size", 5)
427
+ # encoder transformer layers
428
+ args.encoder_transformer_layers = getattr(args, "encoder_transformer_layers", 6)
429
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
430
+ args.encoder_ffn_embed_dim = getattr(
431
+ args, "encoder_ffn_embed_dim", 4 * args.encoder_embed_dim
432
+ )
433
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
434
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
435
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
436
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
437
+ args.activation_fn = getattr(args, "activation_fn", "relu")
438
+ # decoder prenet
439
+ args.prenet_dropout = getattr(args, "prenet_dropout", 0.5)
440
+ args.prenet_layers = getattr(args, "prenet_layers", 2)
441
+ args.prenet_dim = getattr(args, "prenet_dim", 256)
442
+ # decoder postnet
443
+ args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
444
+ args.postnet_layers = getattr(args, "postnet_layers", 5)
445
+ args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
446
+ args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
447
+ # decoder transformer layers
448
+ args.decoder_transformer_layers = getattr(args, "decoder_transformer_layers", 6)
449
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
450
+ args.decoder_ffn_embed_dim = getattr(
451
+ args, "decoder_ffn_embed_dim", 4 * args.decoder_embed_dim
452
+ )
453
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
454
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
fairseq/fairseq/models/transformer/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ """isort:skip_file"""
6
+
7
+ from .transformer_config import (
8
+ TransformerConfig,
9
+ DEFAULT_MAX_SOURCE_POSITIONS,
10
+ DEFAULT_MAX_TARGET_POSITIONS,
11
+ DEFAULT_MIN_PARAMS_TO_WRAP,
12
+ )
13
+ from .transformer_decoder import TransformerDecoder, TransformerDecoderBase, Linear
14
+ from .transformer_encoder import TransformerEncoder, TransformerEncoderBase
15
+ from .transformer_legacy import (
16
+ TransformerModel,
17
+ base_architecture,
18
+ tiny_architecture,
19
+ transformer_iwslt_de_en,
20
+ transformer_wmt_en_de,
21
+ transformer_vaswani_wmt_en_de_big,
22
+ transformer_vaswani_wmt_en_fr_big,
23
+ transformer_wmt_en_de_big,
24
+ transformer_wmt_en_de_big_t2t,
25
+ )
26
+ from .transformer_base import TransformerModelBase, Embedding
27
+
28
+
29
+ __all__ = [
30
+ "TransformerModelBase",
31
+ "TransformerConfig",
32
+ "TransformerDecoder",
33
+ "TransformerDecoderBase",
34
+ "TransformerEncoder",
35
+ "TransformerEncoderBase",
36
+ "TransformerModel",
37
+ "Embedding",
38
+ "Linear",
39
+ "base_architecture",
40
+ "tiny_architecture",
41
+ "transformer_iwslt_de_en",
42
+ "transformer_wmt_en_de",
43
+ "transformer_vaswani_wmt_en_de_big",
44
+ "transformer_vaswani_wmt_en_fr_big",
45
+ "transformer_wmt_en_de_big",
46
+ "transformer_wmt_en_de_big_t2t",
47
+ "DEFAULT_MAX_SOURCE_POSITIONS",
48
+ "DEFAULT_MAX_TARGET_POSITIONS",
49
+ "DEFAULT_MIN_PARAMS_TO_WRAP",
50
+ ]
fairseq/fairseq/models/transformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_base.cpython-310.pyc ADDED
Binary file (5.42 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_config.cpython-310.pyc ADDED
Binary file (8.86 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_decoder.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_decoder_aug.cpython-310.pyc ADDED
Binary file (9.72 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_encoder.cpython-310.pyc ADDED
Binary file (9 kB). View file
 
fairseq/fairseq/models/transformer/__pycache__/transformer_legacy.cpython-310.pyc ADDED
Binary file (9.92 kB). View file
 
fairseq/fairseq/models/transformer/transformer_base.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+
12
+ import logging
13
+
14
+ from fairseq import utils
15
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
16
+ from fairseq.distributed import fsdp_wrap
17
+ from fairseq.models import FairseqEncoderDecoderModel
18
+ from fairseq.models.transformer import (
19
+ TransformerConfig,
20
+ TransformerDecoderBase,
21
+ TransformerEncoderBase,
22
+ )
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class TransformerModelBase(FairseqEncoderDecoderModel):
29
+ """
30
+ Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
31
+ <https://arxiv.org/abs/1706.03762>`_.
32
+
33
+ Args:
34
+ encoder (TransformerEncoder): the encoder
35
+ decoder (TransformerDecoder): the decoder
36
+
37
+ The Transformer model provides the following named architectures and
38
+ command-line arguments:
39
+
40
+ .. argparse::
41
+ :ref: fairseq.models.transformer_parser
42
+ :prog:
43
+ """
44
+
45
+ def __init__(self, cfg, encoder, decoder):
46
+ super().__init__(encoder, decoder)
47
+ self.cfg = cfg
48
+ self.supports_align_args = True
49
+
50
+ @classmethod
51
+ def add_args(cls, parser):
52
+ """Add model-specific arguments to the parser."""
53
+ # we want to build the args recursively in this case.
54
+ gen_parser_from_dataclass(
55
+ parser, TransformerConfig(), delete_default=False, with_prefix=""
56
+ )
57
+
58
+ @classmethod
59
+ def build_model(cls, cfg, task):
60
+ """Build a new model instance."""
61
+
62
+ # -- TODO T96535332
63
+ # bug caused by interaction between OmegaConf II and argparsing
64
+ cfg.decoder.input_dim = int(cfg.decoder.input_dim)
65
+ cfg.decoder.output_dim = int(cfg.decoder.output_dim)
66
+ # --
67
+
68
+ if cfg.encoder.layers_to_keep:
69
+ cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(","))
70
+ if cfg.decoder.layers_to_keep:
71
+ cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(","))
72
+
73
+ src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
74
+
75
+ if cfg.share_all_embeddings:
76
+ if src_dict != tgt_dict:
77
+ raise ValueError("--share-all-embeddings requires a joined dictionary")
78
+ if cfg.encoder.embed_dim != cfg.decoder.embed_dim:
79
+ raise ValueError(
80
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
81
+ )
82
+ if cfg.decoder.embed_path and (
83
+ cfg.decoder.embed_path != cfg.encoder.embed_path
84
+ ):
85
+ raise ValueError(
86
+ "--share-all-embeddings not compatible with --decoder-embed-path"
87
+ )
88
+ encoder_embed_tokens = cls.build_embedding(
89
+ cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
90
+ )
91
+ decoder_embed_tokens = encoder_embed_tokens
92
+ cfg.share_decoder_input_output_embed = True
93
+ elif cfg.merge_src_tgt_embed:
94
+ logger.info(f"source dict size: {len(src_dict)}")
95
+ logger.info(f"target dict size: {len(tgt_dict)}")
96
+ src_dict.update(tgt_dict)
97
+ task.src_dict = src_dict
98
+ task.tgt_dict = src_dict
99
+ logger.info(f"merged dict size: {len(src_dict)}")
100
+ encoder_embed_tokens = cls.build_embedding(
101
+ cfg, src_dict, cfg.encoder.embed_dim
102
+ )
103
+ decoder_embed_tokens = encoder_embed_tokens
104
+ cfg.share_decoder_input_output_embed = True
105
+ else:
106
+ encoder_embed_tokens = cls.build_embedding(
107
+ cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path
108
+ )
109
+ decoder_embed_tokens = cls.build_embedding(
110
+ cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path
111
+ )
112
+ if cfg.offload_activations:
113
+ cfg.checkpoint_activations = True # offloading implies checkpointing
114
+ encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens)
115
+ decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens)
116
+ return cls(cfg, encoder, decoder)
117
+
118
+ @classmethod
119
+ def build_embedding(cls, cfg, dictionary, embed_dim, path=None):
120
+ num_embeddings = len(dictionary)
121
+ padding_idx = dictionary.pad()
122
+
123
+ emb = Embedding(num_embeddings, embed_dim, padding_idx)
124
+ # if provided, load from preloaded dictionaries
125
+ if path:
126
+ embed_dict = utils.parse_embedding(path)
127
+ utils.load_embedding(embed_dict, dictionary, emb)
128
+ return emb
129
+
130
+ @classmethod
131
+ def build_encoder(cls, cfg, src_dict, embed_tokens):
132
+ return TransformerEncoderBase(cfg, src_dict, embed_tokens)
133
+
134
+ @classmethod
135
+ def build_decoder(cls, cfg, tgt_dict, embed_tokens):
136
+ return TransformerDecoderBase(
137
+ cfg,
138
+ tgt_dict,
139
+ embed_tokens,
140
+ no_encoder_attn=cfg.no_cross_attention,
141
+ )
142
+
143
+ # TorchScript doesn't support optional arguments with variable length (**kwargs).
144
+ # Current workaround is to add union of all arguments in child classes.
145
+ def forward(
146
+ self,
147
+ src_tokens,
148
+ src_lengths,
149
+ prev_output_tokens,
150
+ return_all_hiddens: bool = True,
151
+ features_only: bool = False,
152
+ alignment_layer: Optional[int] = None,
153
+ alignment_heads: Optional[int] = None,
154
+ ):
155
+ """
156
+ Run the forward pass for an encoder-decoder model.
157
+
158
+ Copied from the base class, but without ``**kwargs``,
159
+ which are not supported by TorchScript.
160
+ """
161
+ encoder_out = self.encoder(
162
+ src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens
163
+ )
164
+ decoder_out = self.decoder(
165
+ prev_output_tokens,
166
+ encoder_out=encoder_out,
167
+ features_only=features_only,
168
+ alignment_layer=alignment_layer,
169
+ alignment_heads=alignment_heads,
170
+ src_lengths=src_lengths,
171
+ return_all_hiddens=return_all_hiddens,
172
+ )
173
+ return decoder_out
174
+
175
+ # Since get_normalized_probs is in the Fairseq Model which is not scriptable,
176
+ # I rewrite the get_normalized_probs from Base Class to call the
177
+ # helper function in the Base Class.
178
+ @torch.jit.export
179
+ def get_normalized_probs(
180
+ self,
181
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
182
+ log_probs: bool,
183
+ sample: Optional[Dict[str, Tensor]] = None,
184
+ ):
185
+ """Get normalized probabilities (or log probs) from a net's output."""
186
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
187
+
188
+
189
+ def Embedding(num_embeddings, embedding_dim, padding_idx):
190
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
191
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
192
+ nn.init.constant_(m.weight[padding_idx], 0)
193
+ return m
fairseq/fairseq/models/transformer/transformer_config.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import re
8
+ from dataclasses import dataclass, field, fields
9
+ from typing import List, Optional
10
+
11
+ from omegaconf import II
12
+
13
+ from fairseq import utils
14
+ from fairseq.dataclass import ChoiceEnum, FairseqDataclass
15
+ from fairseq.utils import safe_getattr, safe_hasattr
16
+
17
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
18
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
19
+
20
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
21
+
22
+ _NAME_PARSER = r"(decoder|encoder|quant_noise)_(.*)"
23
+
24
+
25
+ @dataclass
26
+ class EncDecBaseConfig(FairseqDataclass):
27
+ embed_path: Optional[str] = field(
28
+ default=None, metadata={"help": "path to pre-trained embedding"}
29
+ )
30
+ embed_dim: Optional[int] = field(
31
+ default=512, metadata={"help": "embedding dimension"}
32
+ )
33
+ ffn_embed_dim: int = field(
34
+ default=2048, metadata={"help": "embedding dimension for FFN"}
35
+ )
36
+ layers: int = field(default=6, metadata={"help": "number of layers"})
37
+ attention_heads: int = field(
38
+ default=8, metadata={"help": "number of attention heads"}
39
+ )
40
+ normalize_before: bool = field(
41
+ default=False, metadata={"help": "apply layernorm before each block"}
42
+ )
43
+ learned_pos: bool = field(
44
+ default=False, metadata={"help": "use learned positional embeddings"}
45
+ )
46
+ # args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
47
+ layerdrop: float = field(default=0, metadata={"help": "LayerDrop probability"})
48
+ layers_to_keep: Optional[List[int]] = field(
49
+ default=None, metadata={"help": "which layers to *keep* when pruning"}
50
+ )
51
+
52
+ xformers_att_config: Optional[str] = field(
53
+ default=None,
54
+ metadata={
55
+ "help": "config for xFormers attention, defined in xformers.components.attention.AttentionConfig"
56
+ },
57
+ )
58
+
59
+
60
+ @dataclass
61
+ class DecoderConfig(EncDecBaseConfig):
62
+ input_dim: int = II("model.decoder.embed_dim")
63
+ output_dim: int = field(
64
+ default=II("model.decoder.embed_dim"),
65
+ metadata={
66
+ "help": "decoder output dimension (extra linear layer if different from decoder embed dim)"
67
+ },
68
+ )
69
+
70
+ def __post_init__(self):
71
+ # II doesn't work if we are just creating the object outside of hydra so fix that
72
+ if self.input_dim == II("model.decoder.embed_dim"):
73
+ self.input_dim = self.embed_dim
74
+ if self.output_dim == II("model.decoder.embed_dim"):
75
+ self.output_dim = self.embed_dim
76
+
77
+
78
+ @dataclass
79
+ class QuantNoiseConfig(FairseqDataclass):
80
+ pq: float = field(
81
+ default=0.0,
82
+ metadata={"help": "iterative PQ quantization noise at training time"},
83
+ )
84
+ pq_block_size: int = field(
85
+ default=8,
86
+ metadata={"help": "block size of quantization noise at training time"},
87
+ )
88
+ scalar: float = field(
89
+ default=0.0,
90
+ metadata={
91
+ "help": "scalar quantization noise and scalar quantization at training time"
92
+ },
93
+ )
94
+
95
+
96
+ @dataclass
97
+ class TransformerConfig(FairseqDataclass):
98
+ activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
99
+ default="relu",
100
+ metadata={"help": "activation function to use"},
101
+ )
102
+ dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
103
+ attention_dropout: float = field(
104
+ default=0.0, metadata={"help": "dropout probability for attention weights"}
105
+ )
106
+ activation_dropout: float = field(
107
+ default=0.0,
108
+ metadata={
109
+ "help": "dropout probability after activation in FFN.",
110
+ "alias": "--relu-dropout",
111
+ },
112
+ )
113
+ adaptive_input: bool = False
114
+ encoder: EncDecBaseConfig = EncDecBaseConfig()
115
+ # TODO should really be in the encoder config
116
+ max_source_positions: int = field(
117
+ default=DEFAULT_MAX_SOURCE_POSITIONS,
118
+ metadata={"help": "Maximum input length supported by the encoder"},
119
+ )
120
+ decoder: DecoderConfig = DecoderConfig()
121
+ # TODO should really be in the decoder config
122
+ max_target_positions: int = field(
123
+ default=DEFAULT_MAX_TARGET_POSITIONS,
124
+ metadata={"help": "Maximum output length supported by the decoder"},
125
+ )
126
+ share_decoder_input_output_embed: bool = field(
127
+ default=False, metadata={"help": "share decoder input and output embeddings"}
128
+ )
129
+ share_all_embeddings: bool = field(
130
+ default=False,
131
+ metadata={
132
+ "help": "share encoder, decoder and output embeddings (requires shared dictionary and embed dim)"
133
+ },
134
+ )
135
+ merge_src_tgt_embed: bool = field(
136
+ default=False,
137
+ metadata={
138
+ "help": "if true then the source and target embedding table is "
139
+ "merged into one table. This is going to make the model smaller but "
140
+ "it might hurt performance."
141
+ },
142
+ )
143
+ no_token_positional_embeddings: bool = field(
144
+ default=False,
145
+ metadata={
146
+ "help": "if True, disables positional embeddings (outside self attention)"
147
+ },
148
+ )
149
+ adaptive_softmax_cutoff: Optional[List[int]] = field(
150
+ default=None,
151
+ metadata={
152
+ "help": "list of adaptive softmax cutoff points. Must be used with adaptive_loss criterion"
153
+ },
154
+ )
155
+ adaptive_softmax_dropout: float = field(
156
+ default=0.0,
157
+ metadata={"help": "sets adaptive softmax dropout for the tail projections"},
158
+ )
159
+ adaptive_softmax_factor: float = field(
160
+ default=4, metadata={"help": "adaptive input factor"}
161
+ )
162
+ layernorm_embedding: bool = field(
163
+ default=False, metadata={"help": "add layernorm to embedding"}
164
+ )
165
+ tie_adaptive_weights: bool = field(
166
+ default=False,
167
+ metadata={
168
+ "help": "if set, ties the weights of adaptive softmax and adaptive input"
169
+ },
170
+ )
171
+ tie_adaptive_proj: bool = field(
172
+ default=False,
173
+ metadata={
174
+ "help": "if set, ties the projection weights of adaptive softmax and adaptive input"
175
+ },
176
+ )
177
+ no_scale_embedding: bool = field(
178
+ default=False, metadata={"help": "if True, dont scale embeddings"}
179
+ )
180
+ checkpoint_activations: bool = field(
181
+ default=False,
182
+ metadata={
183
+ "help": "checkpoint activations at each layer, which saves GPU memory usage at the cost of some additional compute"
184
+ },
185
+ )
186
+ offload_activations: bool = field(
187
+ default=False,
188
+ metadata={
189
+ "help": "checkpoint activations at each layer, then save to gpu. Sets --checkpoint-activations."
190
+ },
191
+ )
192
+ # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
193
+ no_cross_attention: bool = field(
194
+ default=False, metadata={"help": "do not perform cross-attention"}
195
+ )
196
+ cross_self_attention: bool = field(
197
+ default=False, metadata={"help": "perform cross+self-attention"}
198
+ )
199
+ # args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
200
+ quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig())
201
+ min_params_to_wrap: int = field(
202
+ default=DEFAULT_MIN_PARAMS_TO_WRAP,
203
+ metadata={
204
+ "help": "minimum number of params for a layer to be wrapped with FSDP() when "
205
+ "training with --ddp-backend=fully_sharded. Smaller values will "
206
+ "improve memory efficiency, but may make torch.distributed "
207
+ "communication less efficient due to smaller input sizes. This option "
208
+ "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
209
+ "--offload-activations are passed."
210
+ },
211
+ )
212
+ # DEPRECATED field, but some old checkpoints might have it
213
+ char_inputs: bool = field(
214
+ default=False, metadata={"help": "if set, model takes character ids as input"}
215
+ )
216
+ relu_dropout: float = 0.0
217
+ # config for "BASE Layers: Simplifying Training of Large, Sparse Models"
218
+ base_layers: Optional[int] = field(
219
+ default=0, metadata={"help": "number of BASE layers in total"}
220
+ )
221
+ base_sublayers: Optional[int] = field(
222
+ default=1, metadata={"help": "number of sublayers in each BASE layer"}
223
+ )
224
+ base_shuffle: Optional[int] = field(
225
+ default=1,
226
+ metadata={"help": "shuffle tokens between workers before computing assignment"},
227
+ )
228
+
229
+ export: bool = field(
230
+ default=False,
231
+ metadata={"help": "make the layernorm exportable with torchscript."},
232
+ )
233
+
234
+ # copied from transformer_lm but expected in transformer_decoder:
235
+ no_decoder_final_norm: bool = field(
236
+ default=False,
237
+ metadata={"help": "don't add an extra layernorm after the last decoder block"},
238
+ )
239
+
240
+ # We need to make this hierarchical dataclass like the flat namespace
241
+ # __getattr__ and __setattr__ here allow backward compatibility
242
+ # for subclasses of Transformer(Legacy) that depend on read/write on
243
+ # the flat namespace.
244
+
245
+ def __getattr__(self, name):
246
+ match = re.match(_NAME_PARSER, name)
247
+ if match:
248
+ sub = safe_getattr(self, match[1])
249
+ return safe_getattr(sub, match[2])
250
+ raise AttributeError(f"invalid argument {name}.")
251
+
252
+ def __setattr__(self, name, value):
253
+ match = re.match(_NAME_PARSER, name)
254
+ if match:
255
+ sub = safe_getattr(self, match[1])
256
+ setattr(sub, match[2], value)
257
+ else:
258
+ super().__setattr__(name, value)
259
+
260
+ @staticmethod
261
+ def _copy_keys(args, cls, prefix, seen):
262
+ """
263
+ copy the prefixed keys (decoder_embed_dim) to the DC fields: decoder.embed_dim
264
+ """
265
+ cfg = cls()
266
+ for fld in fields(cls):
267
+ # for all the fields in the DC, find the fields (e.g. embed_dim)
268
+ # in the namespace with the prefix (e.g. decoder)
269
+ # and set it on the dc.
270
+ args_key = f"{prefix}_{fld.name}"
271
+ if safe_hasattr(args, args_key):
272
+ seen.add(args_key)
273
+ setattr(cfg, fld.name, safe_getattr(args, args_key))
274
+ if safe_hasattr(args, fld.name):
275
+ seen.add(fld.name)
276
+ setattr(cfg, fld.name, safe_getattr(args, fld.name))
277
+ return cfg
278
+
279
+ @classmethod
280
+ def from_namespace(cls, args):
281
+ if args is None:
282
+ return None
283
+ if not isinstance(args, cls):
284
+ seen = set()
285
+ config = cls()
286
+ # currently, we can go generically from DC fields to args hierarchically
287
+ # but we can't easily deconstruct a flat namespace to a hierarchical
288
+ # DC. Mostly because we could have a sub-dc called `decoder-foo` that should not
289
+ # go to the sub struct called `decoder`. There are ways to go around this, but let's keep it simple
290
+ # for now.
291
+ for fld in fields(cls):
292
+ # concretelly, the transformer_config know what sub-dc it has, so we go through all the dc fields
293
+ # and if it's one that has a sub-dc, we build that sub-dc with `copy_keys()`
294
+ if fld.name == "decoder":
295
+ if safe_hasattr(args, "decoder"):
296
+ # in some cases, the args we receive is already structured (as DictConfigs), so let's just build the correct DC
297
+ seen.add("decoder")
298
+ config.decoder = DecoderConfig(**args.decoder)
299
+ else:
300
+ config.decoder = cls._copy_keys(
301
+ args, DecoderConfig, "decoder", seen
302
+ )
303
+ elif fld.name == "encoder":
304
+ # same but for encoder
305
+ if safe_hasattr(args, "encoder"):
306
+ seen.add("encoder")
307
+ config.encoder = EncDecBaseConfig(**args.encoder)
308
+ else:
309
+ config.encoder = cls._copy_keys(
310
+ args, EncDecBaseConfig, "encoder", seen
311
+ )
312
+ elif fld.name == "quant_noise":
313
+ # same but for quant_noise
314
+ if safe_hasattr(args, "quant_noise"):
315
+ seen.add("quant_noise")
316
+ config.quant_noise = QuantNoiseConfig(**args.quant_noise)
317
+ else:
318
+ config.quant_noise = cls._copy_keys(
319
+ args, QuantNoiseConfig, "quant_noise", seen
320
+ )
321
+ elif safe_hasattr(args, fld.name):
322
+ # if it's not a structure field, it's just a normal field, copy it over
323
+ seen.add(fld.name)
324
+ setattr(config, fld.name, safe_getattr(args, fld.name))
325
+ # we got all the fields defined in the dataclass, but
326
+ # the argparse namespace might have extra args for two reasons:
327
+ # - we are in a legacy class so all the args are not declared in the dataclass. Ideally once everyone has defined a dataclass for their model, we won't need this
328
+ # - some places expect args to be there but never define them
329
+ args_dict = (
330
+ args._asdict()
331
+ if safe_hasattr(args, "_asdict")
332
+ else vars(args)
333
+ if safe_hasattr(args, "__dict__")
334
+ else {}
335
+ ) # namedtupled doesn't have __dict__ :-/
336
+ for key, value in args_dict.items():
337
+ if key not in seen:
338
+ setattr(config, key, value)
339
+ return config
340
+ else:
341
+ return args
fairseq/fairseq/models/transformer/transformer_decoder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from fairseq import utils
14
+ from fairseq.distributed import fsdp_wrap
15
+ from fairseq.models import FairseqIncrementalDecoder
16
+ from fairseq.models.transformer import TransformerConfig
17
+ from fairseq.modules import (
18
+ AdaptiveSoftmax,
19
+ BaseLayer,
20
+ FairseqDropout,
21
+ LayerDropModuleList,
22
+ LayerNorm,
23
+ PositionalEmbedding,
24
+ SinusoidalPositionalEmbedding,
25
+ transformer_layer,
26
+ )
27
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
28
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
29
+
30
+
31
+ # rewrite name for backward compatibility in `make_generation_fast_`
32
+ def module_name_fordropout(module_name: str) -> str:
33
+ if module_name == "TransformerDecoderBase":
34
+ return "TransformerDecoder"
35
+ else:
36
+ return module_name
37
+
38
+
39
+ class TransformerDecoderBase(FairseqIncrementalDecoder):
40
+ """
41
+ Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
42
+ is a :class:`TransformerDecoderLayer`.
43
+
44
+ Args:
45
+ cfg (argparse.Namespace): parsed command-line arguments
46
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
47
+ embed_tokens (torch.nn.Embedding): output embedding
48
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
49
+ (default: False).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ cfg,
55
+ dictionary,
56
+ embed_tokens,
57
+ no_encoder_attn=False,
58
+ output_projection=None,
59
+ ):
60
+ self.cfg = cfg
61
+ super().__init__(dictionary)
62
+ self.register_buffer("version", torch.Tensor([3]))
63
+ self._future_mask = torch.empty(0)
64
+
65
+ self.dropout_module = FairseqDropout(
66
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
67
+ )
68
+ self.decoder_layerdrop = cfg.decoder.layerdrop
69
+ self.share_input_output_embed = cfg.share_decoder_input_output_embed
70
+
71
+ input_embed_dim = embed_tokens.embedding_dim
72
+ embed_dim = cfg.decoder.embed_dim
73
+ self.embed_dim = embed_dim
74
+ self.output_embed_dim = cfg.decoder.output_dim
75
+
76
+ self.padding_idx = embed_tokens.padding_idx
77
+ self.max_target_positions = cfg.max_target_positions
78
+
79
+ self.embed_tokens = embed_tokens
80
+
81
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
82
+
83
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
84
+ self.quant_noise = apply_quant_noise_(
85
+ nn.Linear(embed_dim, embed_dim, bias=False),
86
+ cfg.quant_noise.pq,
87
+ cfg.quant_noise.pq_block_size,
88
+ )
89
+ else:
90
+ self.quant_noise = None
91
+
92
+ self.project_in_dim = (
93
+ Linear(input_embed_dim, embed_dim, bias=False)
94
+ if embed_dim != input_embed_dim
95
+ else None
96
+ )
97
+ self.embed_positions = (
98
+ PositionalEmbedding(
99
+ self.max_target_positions,
100
+ embed_dim,
101
+ self.padding_idx,
102
+ learned=cfg.decoder.learned_pos,
103
+ )
104
+ if not cfg.no_token_positional_embeddings
105
+ else None
106
+ )
107
+ if cfg.layernorm_embedding:
108
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
109
+ else:
110
+ self.layernorm_embedding = None
111
+
112
+ self.cross_self_attention = cfg.cross_self_attention
113
+
114
+ if self.decoder_layerdrop > 0.0:
115
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
116
+ else:
117
+ self.layers = nn.ModuleList([])
118
+ self.layers.extend(
119
+ [
120
+ self.build_decoder_layer(cfg, no_encoder_attn)
121
+ for _ in range(cfg.decoder.layers)
122
+ ]
123
+ )
124
+ self.num_layers = len(self.layers)
125
+
126
+ if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
127
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
128
+ else:
129
+ self.layer_norm = None
130
+
131
+ self.project_out_dim = (
132
+ Linear(embed_dim, self.output_embed_dim, bias=False)
133
+ if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
134
+ else None
135
+ )
136
+
137
+ self.adaptive_softmax = None
138
+ self.output_projection = output_projection
139
+ if self.output_projection is None:
140
+ self.build_output_projection(cfg, dictionary, embed_tokens)
141
+
142
+ def build_output_projection(self, cfg, dictionary, embed_tokens):
143
+ if cfg.adaptive_softmax_cutoff is not None:
144
+ self.adaptive_softmax = AdaptiveSoftmax(
145
+ len(dictionary),
146
+ self.output_embed_dim,
147
+ utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
148
+ dropout=cfg.adaptive_softmax_dropout,
149
+ adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
150
+ factor=cfg.adaptive_softmax_factor,
151
+ tie_proj=cfg.tie_adaptive_proj,
152
+ )
153
+ elif self.share_input_output_embed:
154
+ self.output_projection = nn.Linear(
155
+ self.embed_tokens.weight.shape[1],
156
+ self.embed_tokens.weight.shape[0],
157
+ bias=False,
158
+ )
159
+ self.output_projection.weight = self.embed_tokens.weight
160
+ else:
161
+ self.output_projection = nn.Linear(
162
+ self.output_embed_dim, len(dictionary), bias=False
163
+ )
164
+ nn.init.normal_(
165
+ self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5
166
+ )
167
+ num_base_layers = cfg.base_layers
168
+ for i in range(num_base_layers):
169
+ self.layers.insert(
170
+ ((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
171
+ BaseLayer(cfg),
172
+ )
173
+
174
+ def build_decoder_layer(self, cfg, no_encoder_attn=False):
175
+ layer = transformer_layer.TransformerDecoderLayerBase(cfg, no_encoder_attn)
176
+ checkpoint = cfg.checkpoint_activations
177
+ if checkpoint:
178
+ offload_to_cpu = cfg.offload_activations
179
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
180
+ # if we are checkpointing, enforce that FSDP always wraps the
181
+ # checkpointed layer, regardless of layer size
182
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
183
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
184
+ return layer
185
+
186
+ def forward(
187
+ self,
188
+ prev_output_tokens,
189
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
190
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
191
+ features_only: bool = False,
192
+ full_context_alignment: bool = False,
193
+ alignment_layer: Optional[int] = None,
194
+ alignment_heads: Optional[int] = None,
195
+ src_lengths: Optional[Any] = None,
196
+ return_all_hiddens: bool = False,
197
+ ):
198
+ """
199
+ Args:
200
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
201
+ `(batch, tgt_len)`, for teacher forcing
202
+ encoder_out (optional): output from the encoder, used for
203
+ encoder-side attention, should be of size T x B x C
204
+ incremental_state (dict): dictionary used for storing state during
205
+ :ref:`Incremental decoding`
206
+ features_only (bool, optional): only return features without
207
+ applying output layer (default: False).
208
+ full_context_alignment (bool, optional): don't apply
209
+ auto-regressive mask to self-attention (default: False).
210
+
211
+ Returns:
212
+ tuple:
213
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
214
+ - a dictionary with any model-specific outputs
215
+ """
216
+
217
+ x, extra = self.extract_features(
218
+ prev_output_tokens,
219
+ encoder_out=encoder_out,
220
+ incremental_state=incremental_state,
221
+ full_context_alignment=full_context_alignment,
222
+ alignment_layer=alignment_layer,
223
+ alignment_heads=alignment_heads,
224
+ )
225
+
226
+ if not features_only:
227
+ x = self.output_layer(x)
228
+ return x, extra
229
+
230
+ def extract_features(
231
+ self,
232
+ prev_output_tokens,
233
+ encoder_out: Optional[Dict[str, List[Tensor]]],
234
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
235
+ full_context_alignment: bool = False,
236
+ alignment_layer: Optional[int] = None,
237
+ alignment_heads: Optional[int] = None,
238
+ ):
239
+ return self.extract_features_scriptable(
240
+ prev_output_tokens,
241
+ encoder_out,
242
+ incremental_state,
243
+ full_context_alignment,
244
+ alignment_layer,
245
+ alignment_heads,
246
+ )
247
+
248
+ """
249
+ A scriptable subclass of this class has an extract_features method and calls
250
+ super().extract_features, but super() is not supported in torchscript. A copy of
251
+ this function is made to be used in the subclass instead.
252
+ """
253
+
254
+ def extract_features_scriptable(
255
+ self,
256
+ prev_output_tokens,
257
+ encoder_out: Optional[Dict[str, List[Tensor]]],
258
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
259
+ full_context_alignment: bool = False,
260
+ alignment_layer: Optional[int] = None,
261
+ alignment_heads: Optional[int] = None,
262
+ ):
263
+ """
264
+ Similar to *forward* but only return features.
265
+
266
+ Includes several features from "Jointly Learning to Align and
267
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
268
+
269
+ Args:
270
+ full_context_alignment (bool, optional): don't apply
271
+ auto-regressive mask to self-attention (default: False).
272
+ alignment_layer (int, optional): return mean alignment over
273
+ heads at this layer (default: last layer).
274
+ alignment_heads (int, optional): only average alignment over
275
+ this many heads (default: all heads).
276
+
277
+ Returns:
278
+ tuple:
279
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
280
+ - a dictionary with any model-specific outputs
281
+ """
282
+ bs, slen = prev_output_tokens.size()
283
+ if alignment_layer is None:
284
+ alignment_layer = self.num_layers - 1
285
+
286
+ enc: Optional[Tensor] = None
287
+ padding_mask: Optional[Tensor] = None
288
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
289
+ enc = encoder_out["encoder_out"][0]
290
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
291
+ padding_mask = encoder_out["encoder_padding_mask"][0]
292
+
293
+ # embed positions
294
+ positions = None
295
+ if self.embed_positions is not None:
296
+ positions = self.embed_positions(
297
+ prev_output_tokens, incremental_state=incremental_state
298
+ )
299
+
300
+ if incremental_state is not None:
301
+ prev_output_tokens = prev_output_tokens[:, -1:]
302
+ if positions is not None:
303
+ positions = positions[:, -1:]
304
+
305
+ # Prevent torchscript exporting issue for dynamic quant embedding
306
+ prev_output_tokens = prev_output_tokens.contiguous()
307
+ # embed tokens and positions
308
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
309
+
310
+ if self.quant_noise is not None:
311
+ x = self.quant_noise(x)
312
+
313
+ if self.project_in_dim is not None:
314
+ x = self.project_in_dim(x)
315
+
316
+ if positions is not None:
317
+ x += positions
318
+
319
+ if self.layernorm_embedding is not None:
320
+ x = self.layernorm_embedding(x)
321
+
322
+ x = self.dropout_module(x)
323
+
324
+ # B x T x C -> T x B x C
325
+ x = x.transpose(0, 1)
326
+
327
+ self_attn_padding_mask: Optional[Tensor] = None
328
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
329
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
330
+
331
+ # decoder layers
332
+ attn: Optional[Tensor] = None
333
+ inner_states: List[Optional[Tensor]] = [x]
334
+ for idx, layer in enumerate(self.layers):
335
+ if incremental_state is None and not full_context_alignment:
336
+ self_attn_mask = self.buffered_future_mask(x)
337
+ else:
338
+ self_attn_mask = None
339
+
340
+ x, layer_attn, _ = layer(
341
+ x,
342
+ enc,
343
+ padding_mask,
344
+ incremental_state,
345
+ self_attn_mask=self_attn_mask,
346
+ self_attn_padding_mask=self_attn_padding_mask,
347
+ need_attn=bool((idx == alignment_layer)),
348
+ need_head_weights=bool((idx == alignment_layer)),
349
+ )
350
+ inner_states.append(x)
351
+ if layer_attn is not None and idx == alignment_layer:
352
+ attn = layer_attn.float().to(x)
353
+
354
+ if attn is not None:
355
+ if alignment_heads is not None:
356
+ attn = attn[:alignment_heads]
357
+
358
+ # average probabilities over heads
359
+ attn = attn.mean(dim=0)
360
+
361
+ if self.layer_norm is not None:
362
+ x = self.layer_norm(x)
363
+
364
+ # T x B x C -> B x T x C
365
+ x = x.transpose(0, 1)
366
+
367
+ if self.project_out_dim is not None:
368
+ x = self.project_out_dim(x)
369
+
370
+ return x, {"attn": [attn], "inner_states": inner_states}
371
+
372
+ def output_layer(self, features):
373
+ """Project features to the vocabulary size."""
374
+ if self.adaptive_softmax is None:
375
+ # project back to size of vocabulary
376
+ return self.output_projection(features)
377
+ else:
378
+ return features
379
+
380
+ def max_positions(self):
381
+ """Maximum output length supported by the decoder."""
382
+ if self.embed_positions is None:
383
+ return self.max_target_positions
384
+ return min(self.max_target_positions, self.embed_positions.max_positions)
385
+
386
+ def buffered_future_mask(self, tensor):
387
+ dim = tensor.size(0)
388
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
389
+ if (
390
+ self._future_mask.size(0) == 0
391
+ or (not self._future_mask.device == tensor.device)
392
+ or self._future_mask.size(0) < dim
393
+ ):
394
+ self._future_mask = torch.triu(
395
+ utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
396
+ )
397
+ self._future_mask = self._future_mask.to(tensor)
398
+ return self._future_mask[:dim, :dim]
399
+
400
+ def upgrade_state_dict_named(self, state_dict, name):
401
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
402
+ if f"{name}.output_projection.weight" not in state_dict:
403
+ if self.share_input_output_embed:
404
+ embed_out_key = f"{name}.embed_tokens.weight"
405
+ else:
406
+ embed_out_key = f"{name}.embed_out"
407
+ if embed_out_key in state_dict:
408
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
409
+ embed_out_key
410
+ ]
411
+ if not self.share_input_output_embed:
412
+ del state_dict[embed_out_key]
413
+
414
+ for i in range(self.num_layers):
415
+ # update layer norms
416
+ layer_norm_map = {
417
+ "0": "self_attn_layer_norm",
418
+ "1": "encoder_attn_layer_norm",
419
+ "2": "final_layer_norm",
420
+ }
421
+ for old, new in layer_norm_map.items():
422
+ for m in ("weight", "bias"):
423
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
424
+ if k in state_dict:
425
+ state_dict[
426
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
427
+ ] = state_dict[k]
428
+ del state_dict[k]
429
+
430
+ version_key = "{}.version".format(name)
431
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
432
+ # earlier checkpoints did not normalize after the stack of layers
433
+ self.layer_norm = None
434
+ self.normalize = False
435
+ state_dict[version_key] = torch.Tensor([1])
436
+
437
+ return state_dict
438
+
439
+
440
+ def Linear(in_features, out_features, bias=True):
441
+ m = nn.Linear(in_features, out_features, bias)
442
+ nn.init.xavier_uniform_(m.weight)
443
+ if bias:
444
+ nn.init.constant_(m.bias, 0.0)
445
+ return m
446
+
447
+
448
+ class TransformerDecoder(TransformerDecoderBase):
449
+ def __init__(
450
+ self,
451
+ args,
452
+ dictionary,
453
+ embed_tokens,
454
+ no_encoder_attn=False,
455
+ output_projection=None,
456
+ ):
457
+ self.args = args
458
+ super().__init__(
459
+ TransformerConfig.from_namespace(args),
460
+ dictionary,
461
+ embed_tokens,
462
+ no_encoder_attn=no_encoder_attn,
463
+ output_projection=output_projection,
464
+ )
465
+
466
+ def build_output_projection(self, args, dictionary, embed_tokens):
467
+ super().build_output_projection(
468
+ TransformerConfig.from_namespace(args), dictionary, embed_tokens
469
+ )
470
+
471
+ def build_decoder_layer(self, args, no_encoder_attn=False):
472
+ return super().build_decoder_layer(
473
+ TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
474
+ )
fairseq/fairseq/models/transformer/transformer_decoder_aug.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+
12
+ from fairseq import utils
13
+ from fairseq.distributed import fsdp_wrap
14
+ from fairseq.models.transformer import TransformerConfig
15
+ from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase
16
+ from fairseq.modules import (
17
+ LayerDropModuleList,
18
+ SinusoidalPositionalEmbedding,
19
+ transformer_layer_aug,
20
+ )
21
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
22
+
23
+
24
+ class AugTransformerDecoderBase(TransformerDecoderBase):
25
+ """
26
+ Transformer decoder augmented with an additional cross-attention. Each layer
27
+ is a :class:`AugTransformerDecoderLayerBase`.
28
+
29
+ Args:
30
+ cfg (argparse.Namespace): parsed command-line arguments
31
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
32
+ embed_tokens (torch.nn.Embedding): output embedding
33
+ encoder_attn_merge_type (str, optional): the way to combine outputs from
34
+ two cross-attention modules. If "sequential" is set, two cross-attention
35
+ modules are stacked sequentially. If "parallel" is set, they are processed
36
+ in parallel and combined before feeding it to FFN (default: sequential).
37
+ dropnet_ratio (float, optional): a probability to drop each cross-attention
38
+ module during training (default: 0.0).
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ cfg,
44
+ dictionary,
45
+ embed_tokens,
46
+ output_projection=None,
47
+ encoder_attn_merge_type="sequential",
48
+ dropnet_ratio=0.0,
49
+ ):
50
+ super().__init__(
51
+ cfg,
52
+ dictionary,
53
+ embed_tokens,
54
+ no_encoder_attn=False,
55
+ output_projection=output_projection,
56
+ )
57
+ # assert cfg.cross_self_attention
58
+ self.cross_self_attention = cfg.cross_self_attention
59
+
60
+ if self.decoder_layerdrop > 0.0:
61
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
62
+ else:
63
+ self.layers = nn.ModuleList([])
64
+ self.layers.extend(
65
+ [
66
+ self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio)
67
+ for _ in range(cfg.decoder.layers)
68
+ ]
69
+ )
70
+
71
+ def build_decoder_layer(
72
+ self,
73
+ cfg,
74
+ encoder_attn_merge_type="sequential",
75
+ dropnet_ratio=0,
76
+ ):
77
+ layer = transformer_layer_aug.AugTransformerDecoderLayerBase(
78
+ cfg,
79
+ no_encoder_attn=False,
80
+ encoder_attn_merge_type=encoder_attn_merge_type,
81
+ dropnet_ratio=dropnet_ratio,
82
+ )
83
+ checkpoint = cfg.checkpoint_activations
84
+ if checkpoint:
85
+ offload_to_cpu = cfg.offload_activations
86
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
87
+ # if we are checkpointing, enforce that FSDP always wraps the
88
+ # checkpointed layer, regardless of layer size
89
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
90
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
91
+ return layer
92
+
93
+ def forward(
94
+ self,
95
+ prev_output_tokens,
96
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
97
+ encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None,
98
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
99
+ features_only: bool = False,
100
+ full_context_alignment: bool = False,
101
+ alignment_layer: Optional[int] = None,
102
+ alignment_heads: Optional[int] = None,
103
+ src_lengths: Optional[Any] = None,
104
+ return_all_hiddens: bool = False,
105
+ ):
106
+ """
107
+ Args:
108
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
109
+ `(batch, tgt_len)`, for teacher forcing
110
+ encoder_out (optional): output from the encoder, used for
111
+ encoder-side attention, should be of size T x B x C
112
+ incremental_state (dict): dictionary used for storing state during
113
+ :ref:`Incremental decoding`
114
+ features_only (bool, optional): only return features without
115
+ applying output layer (default: False).
116
+ full_context_alignment (bool, optional): don't apply
117
+ auto-regressive mask to self-attention (default: False).
118
+
119
+ Returns:
120
+ tuple:
121
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
122
+ - a dictionary with any model-specific outputs
123
+ """
124
+
125
+ x, extra = self.extract_features(
126
+ prev_output_tokens,
127
+ encoder_out=encoder_out,
128
+ encoder_out_aug=encoder_out_aug,
129
+ incremental_state=incremental_state,
130
+ full_context_alignment=full_context_alignment,
131
+ alignment_layer=alignment_layer,
132
+ alignment_heads=alignment_heads,
133
+ )
134
+
135
+ if not features_only:
136
+ x = self.output_layer(x)
137
+ return x, extra
138
+
139
+ def extract_features(
140
+ self,
141
+ prev_output_tokens,
142
+ encoder_out: Optional[Dict[str, List[Tensor]]],
143
+ encoder_out_aug: Optional[Dict[str, List[Tensor]]],
144
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
145
+ full_context_alignment: bool = False,
146
+ alignment_layer: Optional[int] = None,
147
+ alignment_heads: Optional[int] = None,
148
+ ):
149
+ return self.extract_features_scriptable(
150
+ prev_output_tokens,
151
+ encoder_out,
152
+ encoder_out_aug,
153
+ incremental_state,
154
+ full_context_alignment,
155
+ alignment_layer,
156
+ alignment_heads,
157
+ )
158
+
159
+ """
160
+ A scriptable subclass of this class has an extract_features method and calls
161
+ super().extract_features, but super() is not supported in torchscript. A copy of
162
+ this function is made to be used in the subclass instead.
163
+ """
164
+
165
+ def extract_features_scriptable(
166
+ self,
167
+ prev_output_tokens,
168
+ encoder_out: Optional[Dict[str, List[Tensor]]],
169
+ encoder_out_aug: Optional[Dict[str, List[Tensor]]],
170
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
171
+ full_context_alignment: bool = False,
172
+ alignment_layer: Optional[int] = None,
173
+ alignment_heads: Optional[int] = None,
174
+ ):
175
+ """
176
+ Similar to *forward* but only return features.
177
+
178
+ Includes several features from "Jointly Learning to Align and
179
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
180
+
181
+ Args:
182
+ full_context_alignment (bool, optional): don't apply
183
+ auto-regressive mask to self-attention (default: False).
184
+ alignment_layer (int, optional): return mean alignment over
185
+ heads at this layer (default: last layer).
186
+ alignment_heads (int, optional): only average alignment over
187
+ this many heads (default: all heads).
188
+
189
+ Returns:
190
+ tuple:
191
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
192
+ - a dictionary with any model-specific outputs
193
+ """
194
+ bs, slen = prev_output_tokens.size()
195
+ if alignment_layer is None:
196
+ alignment_layer = self.num_layers - 1
197
+
198
+ enc: Optional[Tensor] = None
199
+ padding_mask: Optional[Tensor] = None
200
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
201
+ enc = encoder_out["encoder_out"][0]
202
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
203
+ padding_mask = encoder_out["encoder_padding_mask"][0]
204
+
205
+ enc_aug: Optional[Tensor] = None
206
+ padding_mask_aug: Optional[Tensor] = None
207
+ if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0:
208
+ enc_aug = encoder_out_aug["encoder_out"][0]
209
+ if (
210
+ encoder_out_aug is not None
211
+ and len(encoder_out_aug["encoder_padding_mask"]) > 0
212
+ ):
213
+ padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0]
214
+
215
+ # embed positions
216
+ positions = None
217
+ if self.embed_positions is not None:
218
+ positions = self.embed_positions(
219
+ prev_output_tokens, incremental_state=incremental_state
220
+ )
221
+
222
+ if incremental_state is not None:
223
+ prev_output_tokens = prev_output_tokens[:, -1:]
224
+ if positions is not None:
225
+ positions = positions[:, -1:]
226
+
227
+ # Prevent torchscript exporting issue for dynamic quant embedding
228
+ prev_output_tokens = prev_output_tokens.contiguous()
229
+ # embed tokens and positions
230
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
231
+
232
+ if self.quant_noise is not None:
233
+ x = self.quant_noise(x)
234
+
235
+ if self.project_in_dim is not None:
236
+ x = self.project_in_dim(x)
237
+
238
+ if positions is not None:
239
+ x += positions
240
+
241
+ if self.layernorm_embedding is not None:
242
+ x = self.layernorm_embedding(x)
243
+
244
+ x = self.dropout_module(x)
245
+
246
+ # B x T x C -> T x B x C
247
+ x = x.transpose(0, 1)
248
+
249
+ self_attn_padding_mask: Optional[Tensor] = None
250
+ if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
251
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
252
+
253
+ # decoder layers
254
+ attn: Optional[Tensor] = None
255
+ attn_aug: Optional[Tensor] = None
256
+ inner_states: List[Optional[Tensor]] = [x]
257
+ for idx, layer in enumerate(self.layers):
258
+ if incremental_state is None and not full_context_alignment:
259
+ self_attn_mask = self.buffered_future_mask(x)
260
+ else:
261
+ self_attn_mask = None
262
+
263
+ x, layer_attn, layer_attn_aug, _ = layer(
264
+ x,
265
+ enc,
266
+ padding_mask,
267
+ enc_aug,
268
+ padding_mask_aug,
269
+ incremental_state,
270
+ self_attn_mask=self_attn_mask,
271
+ self_attn_padding_mask=self_attn_padding_mask,
272
+ need_attn=bool((idx == alignment_layer)),
273
+ need_head_weights=bool((idx == alignment_layer)),
274
+ )
275
+ inner_states.append(x)
276
+ if layer_attn is not None and idx == alignment_layer:
277
+ attn = layer_attn.float().to(x)
278
+ if layer_attn_aug is not None and idx == alignment_layer:
279
+ attn_aug = layer_attn_aug.float().to(x)
280
+
281
+ if attn is not None:
282
+ if alignment_heads is not None:
283
+ attn = attn[:alignment_heads]
284
+
285
+ # average probabilities over heads
286
+ attn = attn.mean(dim=0)
287
+
288
+ if attn_aug is not None:
289
+ if alignment_heads is not None:
290
+ attn_aug = attn_aug[:alignment_heads]
291
+
292
+ # average probabilities over heads
293
+ attn_aug = attn_aug.mean(dim=0)
294
+
295
+ if self.layer_norm is not None:
296
+ x = self.layer_norm(x)
297
+
298
+ # T x B x C -> B x T x C
299
+ x = x.transpose(0, 1)
300
+
301
+ if self.project_out_dim is not None:
302
+ x = self.project_out_dim(x)
303
+
304
+ return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states}
305
+
306
+ def upgrade_state_dict_named(self, state_dict, name):
307
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
308
+ if f"{name}.output_projection.weight" not in state_dict:
309
+ if self.share_input_output_embed:
310
+ embed_out_key = f"{name}.embed_tokens.weight"
311
+ else:
312
+ embed_out_key = f"{name}.embed_out"
313
+ if embed_out_key in state_dict:
314
+ state_dict[f"{name}.output_projection.weight"] = state_dict[
315
+ embed_out_key
316
+ ]
317
+ if not self.share_input_output_embed:
318
+ del state_dict[embed_out_key]
319
+
320
+ for i in range(self.num_layers):
321
+ # update layer norms
322
+ layer_norm_map = {
323
+ "0": "self_attn_layer_norm",
324
+ "1": "encoder_attn_layer_norm",
325
+ "2": "encoder_attn_layer_norm2",
326
+ "3": "final_layer_norm",
327
+ }
328
+ for old, new in layer_norm_map.items():
329
+ for m in ("weight", "bias"):
330
+ k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
331
+ if k in state_dict:
332
+ state_dict[
333
+ "{}.layers.{}.{}.{}".format(name, i, new, m)
334
+ ] = state_dict[k]
335
+ del state_dict[k]
336
+
337
+ version_key = "{}.version".format(name)
338
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
339
+ # earlier checkpoints did not normalize after the stack of layers
340
+ self.layer_norm = None
341
+ self.normalize = False
342
+ state_dict[version_key] = torch.Tensor([1])
343
+
344
+ return state_dict
345
+
346
+
347
+ class AugTransformerDecoder(AugTransformerDecoderBase):
348
+ def __init__(
349
+ self,
350
+ args,
351
+ dictionary,
352
+ embed_tokens,
353
+ output_projection=None,
354
+ ):
355
+ self.args = args
356
+ super().__init__(
357
+ TransformerConfig.from_namespace(args),
358
+ dictionary,
359
+ embed_tokens,
360
+ no_encoder_attn=False,
361
+ output_projection=output_projection,
362
+ encoder_attn_merge_type=getattr(
363
+ args, "synthesizer_augmented_cross_attention_merge_type", "sequential"
364
+ ),
365
+ dropnet_ratio=getattr(args, "dropnet_ratio", 0),
366
+ )
367
+
368
+ def build_output_projection(self, args, dictionary, embed_tokens):
369
+ super().build_output_projection(
370
+ TransformerConfig.from_namespace(args), dictionary, embed_tokens
371
+ )
372
+
373
+ def build_decoder_layer(
374
+ self,
375
+ args,
376
+ encoder_attn_merge_type="sequential",
377
+ dropnet_ratio=0,
378
+ ):
379
+ return super().build_decoder_layer(
380
+ TransformerConfig.from_namespace(args),
381
+ no_encoder_attn=False,
382
+ encoder_attn_merge_type=encoder_attn_merge_type,
383
+ dropnet_ratio=dropnet_ratio,
384
+ )
fairseq/fairseq/models/transformer/transformer_encoder.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from typing import Dict, List, Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+ from fairseq import utils
14
+ from fairseq.distributed import fsdp_wrap
15
+ from fairseq.models import FairseqEncoder
16
+ from fairseq.models.transformer import TransformerConfig
17
+ from fairseq.modules import (
18
+ FairseqDropout,
19
+ LayerDropModuleList,
20
+ LayerNorm,
21
+ PositionalEmbedding,
22
+ SinusoidalPositionalEmbedding,
23
+ transformer_layer,
24
+ )
25
+ from fairseq.modules.checkpoint_activations import checkpoint_wrapper
26
+ from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
27
+
28
+
29
+ # rewrite name for backward compatibility in `make_generation_fast_`
30
+ def module_name_fordropout(module_name: str) -> str:
31
+ if module_name == "TransformerEncoderBase":
32
+ return "TransformerEncoder"
33
+ else:
34
+ return module_name
35
+
36
+
37
+ class TransformerEncoderBase(FairseqEncoder):
38
+ """
39
+ Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
40
+ is a :class:`TransformerEncoderLayer`.
41
+
42
+ Args:
43
+ args (argparse.Namespace): parsed command-line arguments
44
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
45
+ embed_tokens (torch.nn.Embedding): input embedding
46
+ """
47
+
48
+ def __init__(self, cfg, dictionary, embed_tokens, return_fc=False):
49
+ self.cfg = cfg
50
+ super().__init__(dictionary)
51
+ self.register_buffer("version", torch.Tensor([3]))
52
+
53
+ self.dropout_module = FairseqDropout(
54
+ cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
55
+ )
56
+ self.encoder_layerdrop = cfg.encoder.layerdrop
57
+ self.return_fc = return_fc
58
+
59
+ embed_dim = embed_tokens.embedding_dim
60
+ self.padding_idx = embed_tokens.padding_idx
61
+ self.max_source_positions = cfg.max_source_positions
62
+
63
+ self.embed_tokens = embed_tokens
64
+
65
+ self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
66
+
67
+ self.embed_positions = (
68
+ PositionalEmbedding(
69
+ cfg.max_source_positions,
70
+ embed_dim,
71
+ self.padding_idx,
72
+ learned=cfg.encoder.learned_pos,
73
+ )
74
+ if not cfg.no_token_positional_embeddings
75
+ else None
76
+ )
77
+ if cfg.layernorm_embedding:
78
+ self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
79
+ else:
80
+ self.layernorm_embedding = None
81
+
82
+ if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
83
+ self.quant_noise = apply_quant_noise_(
84
+ nn.Linear(embed_dim, embed_dim, bias=False),
85
+ cfg.quant_noise.pq,
86
+ cfg.quant_noise.pq_block_size,
87
+ )
88
+ else:
89
+ self.quant_noise = None
90
+
91
+ if self.encoder_layerdrop > 0.0:
92
+ self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
93
+ else:
94
+ self.layers = nn.ModuleList([])
95
+ self.layers.extend(
96
+ [self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)]
97
+ )
98
+ self.num_layers = len(self.layers)
99
+
100
+ if cfg.encoder.normalize_before:
101
+ self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
102
+ else:
103
+ self.layer_norm = None
104
+
105
+ def build_encoder_layer(self, cfg):
106
+ layer = transformer_layer.TransformerEncoderLayerBase(
107
+ cfg, return_fc=self.return_fc
108
+ )
109
+ checkpoint = cfg.checkpoint_activations
110
+ if checkpoint:
111
+ offload_to_cpu = cfg.offload_activations
112
+ layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
113
+ # if we are checkpointing, enforce that FSDP always wraps the
114
+ # checkpointed layer, regardless of layer size
115
+ min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
116
+ layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
117
+ return layer
118
+
119
+ def forward_embedding(
120
+ self, src_tokens, token_embedding: Optional[torch.Tensor] = None
121
+ ):
122
+ # embed tokens and positions
123
+ if token_embedding is None:
124
+ token_embedding = self.embed_tokens(src_tokens)
125
+ x = embed = self.embed_scale * token_embedding
126
+ if self.embed_positions is not None:
127
+ x = embed + self.embed_positions(src_tokens)
128
+ if self.layernorm_embedding is not None:
129
+ x = self.layernorm_embedding(x)
130
+ x = self.dropout_module(x)
131
+ if self.quant_noise is not None:
132
+ x = self.quant_noise(x)
133
+ return x, embed
134
+
135
+ def forward(
136
+ self,
137
+ src_tokens,
138
+ src_lengths: Optional[torch.Tensor] = None,
139
+ return_all_hiddens: bool = False,
140
+ token_embeddings: Optional[torch.Tensor] = None,
141
+ ):
142
+ """
143
+ Args:
144
+ src_tokens (LongTensor): tokens in the source language of shape
145
+ `(batch, src_len)`
146
+ src_lengths (torch.LongTensor): lengths of each source sentence of
147
+ shape `(batch)`
148
+ return_all_hiddens (bool, optional): also return all of the
149
+ intermediate hidden states (default: False).
150
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
151
+ default `None` will recompute embeddings
152
+
153
+ Returns:
154
+ dict:
155
+ - **encoder_out** (Tensor): the last encoder layer's output of
156
+ shape `(src_len, batch, embed_dim)`
157
+ - **encoder_padding_mask** (ByteTensor): the positions of
158
+ padding elements of shape `(batch, src_len)`
159
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
160
+ of shape `(batch, src_len, embed_dim)`
161
+ - **encoder_states** (List[Tensor]): all intermediate
162
+ hidden states of shape `(src_len, batch, embed_dim)`.
163
+ Only populated if *return_all_hiddens* is True.
164
+ """
165
+ return self.forward_scriptable(
166
+ src_tokens, src_lengths, return_all_hiddens, token_embeddings
167
+ )
168
+
169
+ # TorchScript doesn't support super() method so that the scriptable Subclass
170
+ # can't access the base class model in Torchscript.
171
+ # Current workaround is to add a helper function with different name and
172
+ # call the helper function from scriptable Subclass.
173
+ def forward_scriptable(
174
+ self,
175
+ src_tokens,
176
+ src_lengths: Optional[torch.Tensor] = None,
177
+ return_all_hiddens: bool = False,
178
+ token_embeddings: Optional[torch.Tensor] = None,
179
+ ):
180
+ """
181
+ Args:
182
+ src_tokens (LongTensor): tokens in the source language of shape
183
+ `(batch, src_len)`
184
+ src_lengths (torch.LongTensor): lengths of each source sentence of
185
+ shape `(batch)`
186
+ return_all_hiddens (bool, optional): also return all of the
187
+ intermediate hidden states (default: False).
188
+ token_embeddings (torch.Tensor, optional): precomputed embeddings
189
+ default `None` will recompute embeddings
190
+
191
+ Returns:
192
+ dict:
193
+ - **encoder_out** (Tensor): the last encoder layer's output of
194
+ shape `(src_len, batch, embed_dim)`
195
+ - **encoder_padding_mask** (ByteTensor): the positions of
196
+ padding elements of shape `(batch, src_len)`
197
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
198
+ of shape `(batch, src_len, embed_dim)`
199
+ - **encoder_states** (List[Tensor]): all intermediate
200
+ hidden states of shape `(src_len, batch, embed_dim)`.
201
+ Only populated if *return_all_hiddens* is True.
202
+ """
203
+ # compute padding mask
204
+ encoder_padding_mask = src_tokens.eq(self.padding_idx)
205
+ has_pads = (
206
+ torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any()
207
+ )
208
+ # Torchscript doesn't handle bool Tensor correctly, so we need to work around.
209
+ if torch.jit.is_scripting():
210
+ has_pads = torch.tensor(1) if has_pads else torch.tensor(0)
211
+
212
+ x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
213
+
214
+ # account for padding while computing the representation
215
+ x = x * (
216
+ 1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x)
217
+ )
218
+
219
+ # B x T x C -> T x B x C
220
+ x = x.transpose(0, 1)
221
+
222
+ encoder_states = []
223
+ fc_results = []
224
+
225
+ if return_all_hiddens:
226
+ encoder_states.append(x)
227
+
228
+ # encoder layers
229
+ for layer in self.layers:
230
+ lr = layer(
231
+ x, encoder_padding_mask=encoder_padding_mask if has_pads else None
232
+ )
233
+
234
+ if isinstance(lr, tuple) and len(lr) == 2:
235
+ x, fc_result = lr
236
+ else:
237
+ x = lr
238
+ fc_result = None
239
+
240
+ if return_all_hiddens and not torch.jit.is_scripting():
241
+ assert encoder_states is not None
242
+ encoder_states.append(x)
243
+ fc_results.append(fc_result)
244
+
245
+ if self.layer_norm is not None:
246
+ x = self.layer_norm(x)
247
+
248
+ # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
249
+ # `forward` so we use a dictionary instead.
250
+ # TorchScript does not support mixed values so the values are all lists.
251
+ # The empty list is equivalent to None.
252
+ src_lengths = (
253
+ src_tokens.ne(self.padding_idx)
254
+ .sum(dim=1, dtype=torch.int32)
255
+ .reshape(-1, 1)
256
+ .contiguous()
257
+ )
258
+ return {
259
+ "encoder_out": [x], # T x B x C
260
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
261
+ "encoder_embedding": [encoder_embedding], # B x T x C
262
+ "encoder_states": encoder_states, # List[T x B x C]
263
+ "fc_results": fc_results, # List[T x B x C]
264
+ "src_tokens": [],
265
+ "src_lengths": [src_lengths],
266
+ }
267
+
268
+ @torch.jit.export
269
+ def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
270
+ """
271
+ Reorder encoder output according to *new_order*.
272
+
273
+ Args:
274
+ encoder_out: output from the ``forward()`` method
275
+ new_order (LongTensor): desired order
276
+
277
+ Returns:
278
+ *encoder_out* rearranged according to *new_order*
279
+ """
280
+ if len(encoder_out["encoder_out"]) == 0:
281
+ new_encoder_out = []
282
+ else:
283
+ new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
284
+ if len(encoder_out["encoder_padding_mask"]) == 0:
285
+ new_encoder_padding_mask = []
286
+ else:
287
+ new_encoder_padding_mask = [
288
+ encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
289
+ ]
290
+ if len(encoder_out["encoder_embedding"]) == 0:
291
+ new_encoder_embedding = []
292
+ else:
293
+ new_encoder_embedding = [
294
+ encoder_out["encoder_embedding"][0].index_select(0, new_order)
295
+ ]
296
+
297
+ if len(encoder_out["src_tokens"]) == 0:
298
+ src_tokens = []
299
+ else:
300
+ src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
301
+
302
+ if len(encoder_out["src_lengths"]) == 0:
303
+ src_lengths = []
304
+ else:
305
+ src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)]
306
+
307
+ encoder_states = encoder_out["encoder_states"]
308
+ if len(encoder_states) > 0:
309
+ for idx, state in enumerate(encoder_states):
310
+ encoder_states[idx] = state.index_select(1, new_order)
311
+
312
+ return {
313
+ "encoder_out": new_encoder_out, # T x B x C
314
+ "encoder_padding_mask": new_encoder_padding_mask, # B x T
315
+ "encoder_embedding": new_encoder_embedding, # B x T x C
316
+ "encoder_states": encoder_states, # List[T x B x C]
317
+ "src_tokens": src_tokens, # B x T
318
+ "src_lengths": src_lengths, # B x 1
319
+ }
320
+
321
+ @torch.jit.export
322
+ def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
323
+ """Dummy re-order function for beamable enc-dec attention"""
324
+ return encoder_out
325
+
326
+ def max_positions(self):
327
+ """Maximum input length supported by the encoder."""
328
+ if self.embed_positions is None:
329
+ return self.max_source_positions
330
+ return min(self.max_source_positions, self.embed_positions.max_positions)
331
+
332
+ def upgrade_state_dict_named(self, state_dict, name):
333
+ """Upgrade a (possibly old) state dict for new versions of fairseq."""
334
+ for i in range(self.num_layers):
335
+ # update layer norms
336
+ self.layers[i].upgrade_state_dict_named(
337
+ state_dict, "{}.layers.{}".format(name, i)
338
+ )
339
+
340
+ version_key = "{}.version".format(name)
341
+ if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
342
+ # earlier checkpoints did not normalize after the stack of layers
343
+ self.layer_norm = None
344
+ self.normalize = False
345
+ state_dict[version_key] = torch.Tensor([1])
346
+ return state_dict
347
+
348
+
349
+ class TransformerEncoder(TransformerEncoderBase):
350
+ def __init__(self, args, dictionary, embed_tokens, return_fc=False):
351
+ self.args = args
352
+ super().__init__(
353
+ TransformerConfig.from_namespace(args),
354
+ dictionary,
355
+ embed_tokens,
356
+ return_fc=return_fc,
357
+ )
358
+
359
+ def build_encoder_layer(self, args):
360
+ return super().build_encoder_layer(
361
+ TransformerConfig.from_namespace(args),
362
+ )
fairseq/fairseq/models/transformer/transformer_legacy.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.dataclass.utils import gen_parser_from_dataclass
7
+ from fairseq.models import (
8
+ register_model,
9
+ register_model_architecture,
10
+ )
11
+ from fairseq.models.transformer.transformer_config import (
12
+ TransformerConfig,
13
+ DEFAULT_MAX_SOURCE_POSITIONS,
14
+ DEFAULT_MAX_TARGET_POSITIONS,
15
+ DEFAULT_MIN_PARAMS_TO_WRAP,
16
+ )
17
+ from fairseq.models.transformer.transformer_base import (
18
+ TransformerModelBase,
19
+ )
20
+
21
+
22
+ @register_model("transformer")
23
+ class TransformerModel(TransformerModelBase):
24
+ """
25
+ This is the legacy implementation of the transformer model that
26
+ uses argparse for configuration.
27
+ """
28
+
29
+ @classmethod
30
+ def hub_models(cls):
31
+ # fmt: off
32
+
33
+ def moses_subword(path):
34
+ return {
35
+ 'path': path,
36
+ 'tokenizer': 'moses',
37
+ 'bpe': 'subword_nmt',
38
+ }
39
+
40
+ def moses_fastbpe(path):
41
+ return {
42
+ 'path': path,
43
+ 'tokenizer': 'moses',
44
+ 'bpe': 'fastbpe',
45
+ }
46
+
47
+ def spm(path):
48
+ return {
49
+ 'path': path,
50
+ 'bpe': 'sentencepiece',
51
+ 'tokenizer': 'space',
52
+ }
53
+
54
+ return {
55
+ 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'),
56
+ 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2',
57
+ 'transformer.wmt18.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz'),
58
+ 'transformer.wmt19.en-de': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz'),
59
+ 'transformer.wmt19.en-ru': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz'),
60
+ 'transformer.wmt19.de-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz'),
61
+ 'transformer.wmt19.ru-en': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz'),
62
+ 'transformer.wmt19.en-de.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz'),
63
+ 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'),
64
+ 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'),
65
+ 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'),
66
+ 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'),
67
+ 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'),
68
+ 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'),
69
+ 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'),
70
+ 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'),
71
+ 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'),
72
+ 'transformer.flores101.mm100.615M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_615M.tar.gz'),
73
+ 'transformer.flores101.mm100.175M': spm('https://dl.fbaipublicfiles.com/flores101/pretrained_models/flores101_mm100_175M.tar.gz'),
74
+ }
75
+ # fmt: on
76
+
77
+ def __init__(self, args, encoder, decoder):
78
+ cfg = TransformerConfig.from_namespace(args)
79
+ super().__init__(cfg, encoder, decoder)
80
+ self.args = args
81
+
82
+ @classmethod
83
+ def add_args(cls, parser):
84
+ """Add model-specific arguments to the parser."""
85
+ # we want to build the args recursively in this case.
86
+ # do not set defaults so that settings defaults from various architectures still works
87
+ gen_parser_from_dataclass(
88
+ parser, TransformerConfig(), delete_default=True, with_prefix=""
89
+ )
90
+
91
+ @classmethod
92
+ def build_model(cls, args, task):
93
+ """Build a new model instance."""
94
+
95
+ # make sure all arguments are present in older models
96
+ base_architecture(args)
97
+
98
+ if args.encoder_layers_to_keep:
99
+ args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
100
+ if args.decoder_layers_to_keep:
101
+ args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
102
+
103
+ if getattr(args, "max_source_positions", None) is None:
104
+ args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
105
+ if getattr(args, "max_target_positions", None) is None:
106
+ args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
107
+
108
+ src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
109
+
110
+ if args.share_all_embeddings:
111
+ if src_dict != tgt_dict:
112
+ raise ValueError("--share-all-embeddings requires a joined dictionary")
113
+ if args.encoder_embed_dim != args.decoder_embed_dim:
114
+ raise ValueError(
115
+ "--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
116
+ )
117
+ if args.decoder_embed_path and (
118
+ args.decoder_embed_path != args.encoder_embed_path
119
+ ):
120
+ raise ValueError(
121
+ "--share-all-embeddings not compatible with --decoder-embed-path"
122
+ )
123
+ args.share_decoder_input_output_embed = True
124
+
125
+ if getattr(args, "offload_activations", False):
126
+ args.checkpoint_activations = True # offloading implies checkpointing
127
+
128
+ if not args.share_all_embeddings:
129
+ args.min_params_to_wrap = getattr(
130
+ args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
131
+ )
132
+ cfg = TransformerConfig.from_namespace(args)
133
+ return super().build_model(cfg, task)
134
+
135
+ @classmethod
136
+ def build_embedding(cls, args, dictionary, embed_dim, path=None):
137
+ return super().build_embedding(
138
+ TransformerConfig.from_namespace(args), dictionary, embed_dim, path
139
+ )
140
+
141
+ @classmethod
142
+ def build_encoder(cls, args, src_dict, embed_tokens):
143
+ return super().build_encoder(
144
+ TransformerConfig.from_namespace(args), src_dict, embed_tokens
145
+ )
146
+
147
+ @classmethod
148
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
149
+ return super().build_decoder(
150
+ TransformerConfig.from_namespace(args), tgt_dict, embed_tokens
151
+ )
152
+
153
+
154
+ # architectures
155
+
156
+
157
+ @register_model_architecture("transformer", "transformer_tiny")
158
+ def tiny_architecture(args):
159
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 64)
160
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 64)
161
+ args.encoder_layers = getattr(args, "encoder_layers", 2)
162
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
163
+ args.decoder_layers = getattr(args, "decoder_layers", 2)
164
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
165
+ return base_architecture(args)
166
+
167
+
168
+ @register_model_architecture("transformer", "transformer")
169
+ def base_architecture(args):
170
+ args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
171
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
172
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
173
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
174
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
175
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
176
+ args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
177
+
178
+ args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
179
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
180
+ args.decoder_ffn_embed_dim = getattr(
181
+ args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
182
+ )
183
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
184
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
185
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
186
+ args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
187
+ args.attention_dropout = getattr(args, "attention_dropout", 0.0)
188
+ args.activation_dropout = getattr(args, "activation_dropout", 0.0)
189
+ args.activation_fn = getattr(args, "activation_fn", "relu")
190
+ args.dropout = getattr(args, "dropout", 0.1)
191
+ args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
192
+ args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
193
+ args.share_decoder_input_output_embed = getattr(
194
+ args, "share_decoder_input_output_embed", False
195
+ )
196
+ args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
197
+ args.merge_src_tgt_embed = getattr(args, "merge_src_tgt_embed", False)
198
+ args.no_token_positional_embeddings = getattr(
199
+ args, "no_token_positional_embeddings", False
200
+ )
201
+ args.adaptive_input = getattr(args, "adaptive_input", False)
202
+ args.no_cross_attention = getattr(args, "no_cross_attention", False)
203
+ args.cross_self_attention = getattr(args, "cross_self_attention", False)
204
+
205
+ args.decoder_output_dim = getattr(
206
+ args, "decoder_output_dim", args.decoder_embed_dim
207
+ )
208
+ args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
209
+
210
+ args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
211
+ args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
212
+ args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
213
+ args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
214
+ args.offload_activations = getattr(args, "offload_activations", False)
215
+ if args.offload_activations:
216
+ args.checkpoint_activations = True
217
+ args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
218
+ args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
219
+ args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
220
+ args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
221
+ args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
222
+ args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
223
+ args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)
224
+
225
+
226
+ @register_model_architecture("transformer", "transformer_iwslt_de_en")
227
+ def transformer_iwslt_de_en(args):
228
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
229
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
230
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
231
+ args.encoder_layers = getattr(args, "encoder_layers", 6)
232
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
233
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
234
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
235
+ args.decoder_layers = getattr(args, "decoder_layers", 6)
236
+ base_architecture(args)
237
+
238
+
239
+ @register_model_architecture("transformer", "transformer_wmt_en_de")
240
+ def transformer_wmt_en_de(args):
241
+ base_architecture(args)
242
+
243
+
244
+ # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
245
+ @register_model_architecture("transformer", "transformer_vaswani_wmt_en_de_big")
246
+ def transformer_vaswani_wmt_en_de_big(args):
247
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
248
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
249
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
250
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
251
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
252
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
253
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
254
+ args.dropout = getattr(args, "dropout", 0.3)
255
+ base_architecture(args)
256
+
257
+
258
+ @register_model_architecture("transformer", "transformer_vaswani_wmt_en_fr_big")
259
+ def transformer_vaswani_wmt_en_fr_big(args):
260
+ args.dropout = getattr(args, "dropout", 0.1)
261
+ transformer_vaswani_wmt_en_de_big(args)
262
+
263
+
264
+ @register_model_architecture("transformer", "transformer_wmt_en_de_big")
265
+ def transformer_wmt_en_de_big(args):
266
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
267
+ transformer_vaswani_wmt_en_de_big(args)
268
+
269
+
270
+ # default parameters used in tensor2tensor implementation
271
+ @register_model_architecture("transformer", "transformer_wmt_en_de_big_t2t")
272
+ def transformer_wmt_en_de_big_t2t(args):
273
+ args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
274
+ args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
275
+ args.attention_dropout = getattr(args, "attention_dropout", 0.1)
276
+ args.activation_dropout = getattr(args, "activation_dropout", 0.1)
277
+ transformer_vaswani_wmt_en_de_big(args)
fairseq/fairseq/models/wav2vec/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .wav2vec import * # noqa
7
+ from .wav2vec2 import * # noqa
8
+ from .wav2vec2_asr import * # noqa
9
+ from .wav2vec2_laser import * # noqa
10
+ from .wav2vec2_classification import * # noqa
fairseq/fairseq/models/wav2vec/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (345 Bytes). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/utils.cpython-310.pyc ADDED
Binary file (670 Bytes). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2.cpython-310.pyc ADDED
Binary file (32.6 kB). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_asr.cpython-310.pyc ADDED
Binary file (23.9 kB). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_classification.cpython-310.pyc ADDED
Binary file (9.41 kB). View file
 
fairseq/fairseq/models/wav2vec/__pycache__/wav2vec2_laser.cpython-310.pyc ADDED
Binary file (1.6 kB). View file