PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
878264b
·
verified ·
1 Parent(s): d28af7f

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml +31 -0
  3. fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml +49 -0
  4. fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml +49 -0
  5. fairseq/examples/MMPT/projects/task/coin.yaml +25 -0
  6. fairseq/examples/MMPT/projects/task/coin_videoclip.yaml +7 -0
  7. fairseq/examples/MMPT/projects/task/test.yaml +13 -0
  8. fairseq/examples/MMPT/projects/task/test_vtt.yaml +19 -0
  9. fairseq/examples/MMPT/projects/task/test_youcook.yaml +22 -0
  10. fairseq/examples/MMPT/projects/task/test_youcookcap.yaml +23 -0
  11. fairseq/examples/MMPT/projects/task/vtt.yaml +25 -0
  12. fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml +12 -0
  13. fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml +10 -0
  14. fairseq/examples/MMPT/projects/task/youcook.yaml +25 -0
  15. fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml +9 -0
  16. fairseq/examples/MMPT/projects/task/youcookcap.yaml +23 -0
  17. fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml +5 -0
  18. fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py +106 -0
  19. fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py +157 -0
  20. fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh +8 -0
  21. fairseq/examples/MMPT/scripts/video_feature_extractor/model.py +58 -0
  22. fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py +89 -0
  23. fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py +57 -0
  24. fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py +29 -0
  25. fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py +64 -0
  26. fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py +242 -0
  27. fairseq/examples/MMPT/videoclip.png +3 -0
  28. fairseq/examples/MMPT/vlm.png +3 -0
  29. fairseq/examples/adaptive_span/README.md +90 -0
  30. fairseq/examples/adaptive_span/__init__.py +19 -0
  31. fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
  32. fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
  33. fairseq/examples/adaptive_span/adaptive_span_loss.py +107 -0
  34. fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
  35. fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
  36. fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +285 -0
  37. fairseq/examples/attention_head_selection/README.md +161 -0
  38. fairseq/examples/attention_head_selection/src/__init__.py +0 -0
  39. fairseq/examples/attention_head_selection/src/data/__init__.py +0 -0
  40. fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py +242 -0
  41. fairseq/examples/attention_head_selection/src/loss/__init__.py +0 -0
  42. fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py +27 -0
  43. fairseq/examples/attention_head_selection/src/models/__init__.py +0 -0
  44. fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py +170 -0
  45. fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py +215 -0
  46. fairseq/examples/attention_head_selection/src/modules/__init__.py +0 -0
  47. fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py +81 -0
  48. fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py +92 -0
  49. fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py +355 -0
  50. fairseq/examples/attention_head_selection/src/modules/multihead_functional.py +278 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fairseq/examples/MMPT/vlm.png filter=lfs diff=lfs merge=lfs -text
37
+ fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ slurm_config: big
2
+ task_type: local_predict
3
+ dataset:
4
+ split: test
5
+ video_processor: VideoProcessor
6
+ aligner: DiDeMoAligner
7
+ bert_name: bert-base-uncased
8
+ meta_processor: DiDeMoMetaProcessor
9
+ test_path: data/didemo/test_data.json
10
+ vfeat_dir: data/feat/feat_didemo_s3d
11
+ text_processor: DiDeMoTextProcessor
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ dataset:
17
+ batch_size: 256
18
+ valid_subset: test
19
+ num_workers: 2
20
+ common_eval:
21
+ path: runs/retri/videoclip/checkpoint_best.pt
22
+ model:
23
+ model_cls: MMFusionSeparate
24
+ mm_encoder_cls: null
25
+ video_encoder_cls: MMBertForEncoder
26
+ text_encoder_cls: BertModel
27
+ num_hidden_video_layers: 6
28
+ eval:
29
+ save_path: runs/retri/videoclip/didemo_zs/eval
30
+ metric: DiDeMoMetric
31
+ predictor: DiDeMoPredictor
fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: VideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: MSRVTTMetaProcessor
5
+ train_path: data/msrvtt/MSRVTT_train.csv
6
+ dup: 20
7
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
8
+ vfeat_dir: data/feat/feat_vtt_s3d
9
+ text_processor: MSRVTTTextProcessor
10
+ json_path: data/msrvtt/MSRVTT_data.json
11
+ aligner: DSAligner
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 128
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 5
35
+ checkpoint:
36
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/retri/videoclip/vttqa
41
+ task_type: sweep_small
42
+ model:
43
+ model_cls: MMFusionSeparate
44
+ mm_encoder_cls: null
45
+ video_encoder_cls: MMBertForEncoder
46
+ text_encoder_cls: BertModel
47
+ num_hidden_video_layers: 6
48
+ loss:
49
+ loss_cls: V2TContraLoss
fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ video_processor: YoucookVideoProcessor
3
+ bert_name: bert-base-uncased
4
+ meta_processor: YoucookMetaProcessor
5
+ train_path: data/youcook/youcook_train.pkl
6
+ val_path: data/youcook/youcook_val.pkl
7
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
8
+ use_annotation_text: true
9
+ vfeat_dir: data/feat/feat_youcook_s3d
10
+ text_processor: TextProcessor
11
+ aligner: DSAligner
12
+ num_iso_layer: 12
13
+ max_video_len: 32
14
+ max_len: 96
15
+ fairseq:
16
+ common:
17
+ tensorboard_logdir: run
18
+ log_interval: 1000
19
+ fp16: true
20
+ dataset:
21
+ num_workers: 4
22
+ batch_size: 128
23
+ optimization:
24
+ lr:
25
+ - 5.0e-05
26
+ clip_norm: 2.0
27
+ optimizer: adam
28
+ adam_betas: (0.9, 0.98)
29
+ lr_scheduler: polynomial_decay
30
+ total_num_update: 1000000
31
+ warmup_updates: 122
32
+ weight_decay: 0.0
33
+ ddp_backend: no_c10d
34
+ max_epoch: 10
35
+ checkpoint:
36
+ restore_file: runs/retri/videoclip/checkpoint_best.pt
37
+ reset_optimizer: true
38
+ reset_dataloader: true
39
+ reset_meters: true
40
+ save_dir: runs/retri/videoclip/youcook
41
+ task_type: sweep_small
42
+ model:
43
+ model_cls: MMFusionSeparate
44
+ mm_encoder_cls: null
45
+ video_encoder_cls: MMBertForEncoder
46
+ text_encoder_cls: BertModel
47
+ num_hidden_video_layers: 6
48
+ loss:
49
+ loss_cls: T2VContraLoss
fairseq/examples/MMPT/projects/task/coin.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/ft.yaml
2
+ task_type: sweep_big
3
+ dataset:
4
+ meta_processor: COINActionSegmentationMetaProcessor
5
+ train_path: data/coin/COIN.json
6
+ val_path: data/coin/COIN.json
7
+ vfeat_dir: data/feat/feat_coin_s3d
8
+ video_processor: VideoProcessor
9
+ text_processor: COINActionSegmentationTextProcessor
10
+ aligner: COINActionSegmentationAligner
11
+ num_iso_layer: 12
12
+ sliding_window: 8
13
+ sliding_window_size: 32
14
+ model:
15
+ model_cls: MMFusionActionSegmentation
16
+ mm_encoder_cls: MMBertForTokenClassification
17
+ loss:
18
+ loss_cls: CrossEntropy
19
+ fairseq:
20
+ dataset:
21
+ batch_size: 1
22
+ optimization:
23
+ max_epoch: 8
24
+ checkpoint:
25
+ save_dir: runs/task/coin
fairseq/examples/MMPT/projects/task/coin_videoclip.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ includes: projects/task/coin.yaml
2
+ model:
3
+ model_cls: MMFusionSeparateActionSegmentation
4
+ mm_encoder_cls:
5
+ video_encoder_cls: MMBertForTokenClassification
6
+ text_encoder_cls: BertModel # dummy, not used.
7
+ num_hidden_video_layers: 6
fairseq/examples/MMPT/projects/task/test.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this yaml cannot be run alone: implement a test_${dataset}.yaml
2
+ slurm_config: big
3
+ task_type: local_predict
4
+ dataset:
5
+ split: test
6
+ video_processor: VideoProcessor
7
+ aligner: DSAligner
8
+ bert_name: bert-base-uncased
9
+ fairseq:
10
+ dataset:
11
+ batch_size: 256
12
+ valid_subset: test
13
+ num_workers: 2
fairseq/examples/MMPT/projects/task/test_vtt.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/test.yaml
2
+ dataset:
3
+ meta_processor: MSRVTTMetaProcessor
4
+ test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
5
+ video_processor: VideoProcessor
6
+ vfeat_dir: data/feat/feat_vtt_s3d
7
+ text_processor: MSRVTTTextProcessor
8
+ num_iso_layer: 12
9
+ model:
10
+ model_cls: MMFusionJoint
11
+ mm_encoder_cls: MMBertForJoint
12
+ eval:
13
+ save_path: runs/task/vtt/eval
14
+ fairseq:
15
+ # read code and find what is the checkpoint arg.
16
+ common_eval:
17
+ path: runs/task/vtt/checkpoint_last.pt
18
+ metric: RetrievalMetric
19
+ predictor: RetrievalPredictor
fairseq/examples/MMPT/projects/task/test_youcook.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/test.yaml
2
+ dataset:
3
+ meta_processor: YoucookMetaProcessor
4
+ test_path: data/youcook/youcook_val.pkl
5
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
6
+ use_annotation_text: True
7
+ video_processor: YoucookVideoProcessor
8
+ vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
9
+ text_processor: TextProcessor
10
+ aligner: DSAligner
11
+ num_iso_layer: 12
12
+ model:
13
+ model_cls: MMFusionJoint
14
+ mm_encoder_cls: MMBertForJoint
15
+ eval:
16
+ save_path: runs/task/youcook/eval
17
+ fairseq:
18
+ # read code and find what is the checkpoint arg.
19
+ common_eval:
20
+ path: runs/task/youcook/checkpoint_last.pt
21
+ metric: RetrievalMetric
22
+ predictor: RetrievalPredictor
fairseq/examples/MMPT/projects/task/test_youcookcap.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/test.yaml
2
+ dataset:
3
+ meta_processor: YoucookNLGMetaProcessor
4
+ test_path: data/youcook/val_list.txt
5
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
6
+ video_processor: YoucookVideoProcessor
7
+ vfeat_dir: data/feat/feat_youcook_s3d
8
+ text_processor: NLGTextProcessor
9
+ aligner: DSNLGAligner
10
+ model:
11
+ model_cls: MMFusionNLG
12
+ mm_encoder_cls: MMBertForNLG
13
+ max_decode_length: 24
14
+ eval:
15
+ save_path: runs/task/youcookcap/eval
16
+ fairseq:
17
+ # read code and find what is the checkpoint arg.
18
+ common_eval:
19
+ path: runs/task/youcookcap/checkpoint_best.pt
20
+ metric: NLGMetric
21
+ predictor: NLGPredictor
22
+ gen_param:
23
+ num_beams: 5
fairseq/examples/MMPT/projects/task/vtt.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/ft.yaml
2
+ dataset:
3
+ meta_processor: MSRVTTMetaProcessor
4
+ train_path: data/msrvtt/MSRVTT_train.csv
5
+ jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
6
+ full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
7
+ dup: 20
8
+ val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
9
+ vfeat_dir: data/feat/feat_vtt_s3d
10
+ text_processor: MSRVTTTextProcessor
11
+ json_path: data/msrvtt/MSRVTT_data.json
12
+ aligner: DSAligner
13
+ num_iso_layer: 12
14
+ model:
15
+ model_cls: MMFusionJoint
16
+ mm_encoder_cls: MMBertForJoint
17
+ loss:
18
+ loss_cls: T2VContraLoss
19
+ fairseq:
20
+ dataset:
21
+ batch_size: 256
22
+ optimization:
23
+ max_epoch: 10
24
+ checkpoint:
25
+ save_dir: runs/task/vtt
fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/vtt.yaml
2
+ model:
3
+ model_cls: MMFusionSeparate
4
+ mm_encoder_cls:
5
+ video_encoder_cls: MMBertForEncoder
6
+ text_encoder_cls: BertModel
7
+ num_hidden_video_layers: 6
8
+ fairseq:
9
+ dataset:
10
+ batch_size: 224
11
+ # model_cls: MMFusionShare
12
+ # mm_encoder_cls: MMBertForEncoder
fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/vttqa.yaml
2
+ model:
3
+ model_cls: MMFusionSeparate
4
+ mm_encoder_cls:
5
+ video_encoder_cls: MMBertForEncoder
6
+ text_encoder_cls: BertModel
7
+ num_hidden_video_layers: 6
8
+
9
+ # model_cls: MMFusionShare
10
+ # mm_encoder_cls: MMBertForEncoder
fairseq/examples/MMPT/projects/task/youcook.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/ft.yaml
2
+ dataset:
3
+ meta_processor: YoucookMetaProcessor
4
+ train_path: data/youcook/youcook_train.pkl
5
+ val_path: data/youcook/youcook_val.pkl
6
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
7
+ use_annotation_text: True
8
+ video_processor: YoucookVideoProcessor
9
+ vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
10
+ text_processor: TextProcessor
11
+ aligner: DSAligner
12
+ num_iso_layer: 12
13
+ model:
14
+ model_cls: MMFusionJoint
15
+ mm_encoder_cls: MMBertForJoint
16
+ loss:
17
+ loss_cls: T2VContraLoss
18
+ fairseq:
19
+ dataset:
20
+ batch_size: 128
21
+ optimization:
22
+ max_epoch: 10
23
+ checkpoint:
24
+ save_dir: runs/task/youcook
25
+
fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ includes: projects/task/youcook.yaml
2
+ model:
3
+ model_cls: MMFusionSeparate
4
+ mm_encoder_cls:
5
+ video_encoder_cls: MMBertForEncoder
6
+ text_encoder_cls: BertModel
7
+ num_hidden_video_layers: 6
8
+ # model_cls: MMFusionShare
9
+ # mm_encoder_cls: MMBertForEncoder
fairseq/examples/MMPT/projects/task/youcookcap.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # finetuning for youcook captioning.
2
+ includes: projects/task/ft.yaml
3
+ dataset:
4
+ meta_processor: YoucookNLGMetaProcessor
5
+ train_path: data/youcook/train_list.txt
6
+ val_path: data/youcook/val_list.txt
7
+ trainval_annotation: data/youcook/youcookii_annotations_trainval.json
8
+ video_processor: YoucookVideoProcessor
9
+ vfeat_dir: data/feat/feat_youcook_s3d
10
+ text_processor: NLGTextProcessor
11
+ aligner: DSNLGAligner
12
+ model:
13
+ model_cls: MMFusionNLG
14
+ mm_encoder_cls: MMBertForNLG
15
+ loss:
16
+ loss_cls: NLGLoss
17
+ fairseq:
18
+ dataset:
19
+ batch_size: 128
20
+ optimization:
21
+ max_epoch: 10
22
+ checkpoint:
23
+ save_dir: runs/task/youcookcap
fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ dataset:
2
+ bert_name: bert-base-uncased
3
+ caption_pkl_path: data/how2/raw_caption_dedup.pkl
4
+ use_fast: true
5
+ target_dir: data/feat/feat_how2_s3d_shard_small
fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import pickle
7
+ import os
8
+ import argparse
9
+ import numpy as np
10
+
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from mmpt.processors import PKLJSONStrTextProcessor
13
+ from mmpt.utils import ShardedTensor, recursive_config
14
+
15
+
16
+ class TokenizerDataset(Dataset):
17
+ def __init__(self, config):
18
+ self.text_processor = PKLJSONStrTextProcessor(config)
19
+ self.video_ids = list(self.text_processor.data.keys())
20
+
21
+ def __getitem__(self, idx):
22
+ video_id = self.video_ids[idx]
23
+ return video_id, self.text_processor(video_id)
24
+
25
+ def __len__(self):
26
+ return len(self.video_ids)
27
+
28
+
29
+ def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32):
30
+ startends = []
31
+ caps_ids = []
32
+ for video_id in video_ids:
33
+ caption = captions[video_id]
34
+ startend = []
35
+ cap_ids = []
36
+ for start, end, cap in zip(
37
+ caption["start"], caption["end"], caption["cap"]):
38
+ startend.append(np.array([start, end]).astype("float32"))
39
+ cap_id = np.full((max_cap_len,), -1, dtype=np.int32)
40
+ cap = cap[:max_cap_len]
41
+ cap_id[:len(cap)] = cap
42
+ cap_ids.append(cap_id)
43
+ startends.append(np.stack(startend))
44
+ caps_ids.append(np.stack(cap_ids))
45
+
46
+ startends = ShardedTensor.from_list(startends)
47
+ target_path = os.path.join(
48
+ target_dir,
49
+ prefix + split + "_" + str(shard_idx)
50
+ )
51
+ print("save to", target_path)
52
+ startends.save(target_path + ".startends")
53
+ caps_ids = ShardedTensor.from_list(caps_ids)
54
+ caps_ids.save(target_path + ".caps_ids")
55
+
56
+
57
+ def sharding(config, out_file):
58
+ with open(out_file, "rb") as fr:
59
+ captions = pickle.load(fr)
60
+ target_dir = config.target_dir
61
+ prefix = os.path.basename(
62
+ os.path.splitext(config.caption_pkl_path)[0]
63
+ ) + "." + config.bert_name + "."
64
+ for split in ["train", "val"]:
65
+ target_path = os.path.join(target_dir, split + "_meta")
66
+ with open(target_path + ".pkl", "rb") as fr:
67
+ meta = pickle.load(fr)
68
+ print("load meta", target_path, len(meta))
69
+ for shard_id in meta:
70
+ numpify(
71
+ shard_id, meta[shard_id], captions,
72
+ target_dir, split, prefix
73
+ )
74
+
75
+
76
+ def tokenize(config, out_file):
77
+ def collator(samples):
78
+ return samples
79
+ dataset = TokenizerDataset(config)
80
+ data = {}
81
+ for idx, batch in enumerate(
82
+ DataLoader(dataset, collate_fn=collator, num_workers=16)):
83
+ for video_id, caption in batch:
84
+ data[video_id] = caption
85
+ if idx % 5000 == 0:
86
+ print(idx)
87
+ with open(out_file, "wb") as fw:
88
+ pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL)
89
+
90
+
91
+ def main(args):
92
+ config = recursive_config(args.config).dataset
93
+
94
+ out_file = os.path.splitext(config.caption_pkl_path)[0] \
95
+ + "." + config.bert_name + ".pkl"
96
+ if not os.path.isfile(out_file):
97
+ tokenize(config, out_file)
98
+ sharding(config, out_file)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ parser = argparse.ArgumentParser(
103
+ description="pretokenize (raw_)caption.json into pkl.")
104
+ parser.add_argument('config', type=str)
105
+ args = parser.parse_args()
106
+ main(args)
fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Howto100M authors.
2
+ # Copyright (c) Facebook, Inc. All Rights Reserved
3
+
4
+ import torch as th
5
+ import torch.nn.functional as F
6
+ import math
7
+ import numpy as np
8
+ import argparse
9
+
10
+ from torch.utils.data import DataLoader
11
+ from model import get_model
12
+ from preprocessing import Preprocessing
13
+ from random_sequence_shuffler import RandomSequenceSampler
14
+
15
+ from tqdm import tqdm
16
+ from pathbuilder import PathBuilder
17
+ from videoreader import VideoLoader
18
+
19
+
20
+ parser = argparse.ArgumentParser(description='Easy video feature extractor')
21
+
22
+ parser.add_argument('--vdir', type=str)
23
+ parser.add_argument('--fdir', type=str)
24
+ parser.add_argument('--hflip', type=int, default=0)
25
+
26
+ parser.add_argument('--batch_size', type=int, default=64,
27
+ help='batch size')
28
+ parser.add_argument('--type', type=str, default='2d',
29
+ help='CNN type')
30
+ parser.add_argument('--half_precision', type=int, default=0,
31
+ help='output half precision float')
32
+ parser.add_argument('--num_decoding_thread', type=int, default=4,
33
+ help='Num parallel thread for video decoding')
34
+ parser.add_argument('--l2_normalize', type=int, default=1,
35
+ help='l2 normalize feature')
36
+ parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth',
37
+ help='Resnext model path')
38
+ parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth',
39
+ help='vmz model path')
40
+
41
+ args = parser.parse_args()
42
+
43
+
44
+ # TODO: refactor all args into config. (current code is from different people.)
45
+ CONFIGS = {
46
+ "2d": {
47
+ "fps": 1,
48
+ "size": 224,
49
+ "centercrop": False,
50
+ "shards": 0,
51
+ },
52
+ "3d": {
53
+ "fps": 24,
54
+ "size": 112,
55
+ "centercrop": True,
56
+ "shards": 0,
57
+ },
58
+ "s3d": {
59
+ "fps": 30,
60
+ "size": 224,
61
+ "centercrop": True,
62
+ "shards": 0,
63
+ },
64
+ "vmz": {
65
+ "fps": 24,
66
+ "size": 112,
67
+ "centercrop": True,
68
+ "shards": 0,
69
+ },
70
+ "vae": {
71
+ "fps": 2,
72
+ "size": 256,
73
+ "centercrop": True,
74
+ "shards": 100,
75
+ }
76
+ }
77
+
78
+ config = CONFIGS[args.type]
79
+
80
+
81
+ video_dirs = args.vdir
82
+ feature_dir = args.fdir
83
+
84
+ video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"])
85
+
86
+ dataset = VideoLoader(
87
+ video_dict=video_dict,
88
+ framerate=config["fps"],
89
+ size=config["size"],
90
+ centercrop=config["centercrop"],
91
+ hflip=args.hflip
92
+ )
93
+ n_dataset = len(dataset)
94
+ sampler = RandomSequenceSampler(n_dataset, 10)
95
+ loader = DataLoader(
96
+ dataset,
97
+ batch_size=1,
98
+ shuffle=False,
99
+ num_workers=args.num_decoding_thread,
100
+ sampler=sampler if n_dataset > 10 else None,
101
+ )
102
+ preprocess = Preprocessing(args.type)
103
+ model = get_model(args)
104
+
105
+ with th.no_grad():
106
+ for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True):
107
+ input_file = data['input'][0]
108
+ output_file = data['output'][0]
109
+ if len(data['video'].shape) > 3:
110
+ video = data['video'].squeeze()
111
+ if len(video.shape) == 4:
112
+ video = preprocess(video)
113
+ n_chunk = len(video)
114
+ if args.type == 'vmz':
115
+ n_chunk = math.ceil(n_chunk/float(3))
116
+ features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
117
+ elif args.type == 's3d':
118
+ features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
119
+ elif args.type == "vae":
120
+ features = th.cuda.LongTensor(n_chunk, 1024).fill_(0)
121
+ else:
122
+ features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
123
+ n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
124
+ for i in range(n_iter):
125
+ factor = 1
126
+ if args.type == 'vmz':
127
+ factor = 3
128
+ min_ind = factor * i * args.batch_size
129
+ max_ind = factor * (i + 1) * args.batch_size
130
+ video_batch = video[min_ind:max_ind:factor].cuda()
131
+ if args.type == '2d':
132
+ batch_features = model(video_batch) # (51, 487), (51, 512)
133
+ elif args.type == 's3d':
134
+ batch_features = model(video_batch)
135
+ batch_features = batch_features['video_embedding']
136
+ elif args.type == "vae":
137
+ # image_code.
138
+ batch_features = model(video_batch)
139
+ else:
140
+ batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512)
141
+ if args.l2_normalize:
142
+ batch_features = F.normalize(batch_features, dim=1)
143
+ features[i*args.batch_size:(i+1)*args.batch_size] = batch_features
144
+ features = features.cpu().numpy()
145
+ if args.half_precision:
146
+ if args.type == "vae":
147
+ features = features.astype(np.int16)
148
+ else:
149
+ features = features.astype('float16')
150
+ else:
151
+ if args.type == "vae":
152
+ features = features.astype(np.int32)
153
+ else:
154
+ features = features.astype('float32')
155
+ np.save(output_file, features)
156
+ else:
157
+ print('Video {} error.'.format(input_file))
fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+ python scripts/video_feature_extractor/extract.py \
5
+ --vdir <path_to_video_folder> \
6
+ --fdir data/feat/feat_how2_s3d \
7
+ --type=s3d --num_decoding_thread=4 \
8
+ --batch_size 32 --half_precision 1
fairseq/examples/MMPT/scripts/video_feature_extractor/model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Howto100M authors and Facebook, Inc. All Rights Reserved
2
+
3
+ import torch as th
4
+
5
+ from torch import nn
6
+
7
+
8
+ class GlobalAvgPool(nn.Module):
9
+ def __init__(self):
10
+ super(GlobalAvgPool, self).__init__()
11
+
12
+ def forward(self, x):
13
+ return th.mean(x, dim=[-2, -1])
14
+
15
+
16
+ def get_model(args):
17
+ assert args.type in ['2d', '3d', 'vmz', 's3d', 'vae']
18
+ if args.type == '2d':
19
+ print('Loading 2D-ResNet-152 ...')
20
+ import torchvision.models as models
21
+ model = models.resnet152(pretrained=True)
22
+ model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
23
+ model = model.cuda()
24
+ elif args.type == 'vmz':
25
+ print('Loading VMZ ...')
26
+ from vmz34 import r2plus1d_34
27
+ model = r2plus1d_34(pretrained_path=args.vmz_model_path, pretrained_num_classes=487)
28
+ model = model.cuda()
29
+ elif args.type == 's3d':
30
+ # we use one copy of s3d instead of dup another one for feature extraction.
31
+ from mmpt.processors.models.s3dg import S3D
32
+ model = S3D('pretrained_models/s3d_dict.npy', 512)
33
+ model.load_state_dict(th.load('pretrained_models/s3d_howto100m.pth'))
34
+ model = model.cuda()
35
+
36
+ elif args.type == '3d':
37
+ print('Loading 3D-ResneXt-101 ...')
38
+ from videocnn.models import resnext
39
+ model = resnext.resnet101(
40
+ num_classes=400,
41
+ shortcut_type='B',
42
+ cardinality=32,
43
+ sample_size=112,
44
+ sample_duration=16,
45
+ last_fc=False)
46
+ model = model.cuda()
47
+ model_data = th.load(args.resnext101_model_path)
48
+ model.load_state_dict(model_data)
49
+ elif args.type == 'vae':
50
+ from openaivae import OpenAIParallelDiscreteVAE
51
+ model = OpenAIParallelDiscreteVAE()
52
+ model = model.cuda()
53
+ else:
54
+ raise ValueError("model not supported yet.")
55
+
56
+ model.eval()
57
+ print('loaded')
58
+ return model
fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import os
6
+ import urllib.parse
7
+ import json
8
+ import pandas as pd
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ # TODO: extending to other datasets.
14
+ supported_formats = {}
15
+
16
+
17
+ class PathBuilder(object):
18
+ @classmethod
19
+ def build(cls, video_dirs, feature_dir, ext, shards=0, split=None):
20
+ meta_fn = os.path.join(feature_dir, "meta_plan.json")
21
+ os.makedirs(feature_dir, exist_ok=True)
22
+ if os.path.isfile(meta_fn):
23
+ with open(meta_fn) as fr:
24
+ meta = json.load(fr)
25
+ return meta
26
+ print("searching videos...")
27
+
28
+ video_id_to_path = {}
29
+ for video_dir in video_dirs.split(","):
30
+ # TODO: add supports of recursive listdir.
31
+ if video_dir in supported_formats:
32
+ supported_formats[video_dir].load(video_dir, video_id_to_path)
33
+ else:
34
+ for idx, fn in enumerate(tqdm(os.listdir(video_dir))):
35
+ video_fn = os.path.join(video_dir, fn)
36
+ if os.path.isfile(video_fn):
37
+ video_id = os.path.splitext(fn)[0]
38
+ video_id_to_path[video_id] = video_fn
39
+ elif os.path.isdir(video_fn):
40
+ # shards of folders.
41
+ shard_dir = video_fn
42
+ for idx, fn in enumerate(os.listdir(shard_dir)):
43
+ video_fn = os.path.join(shard_dir, fn)
44
+ if os.path.isfile(video_fn):
45
+ video_id = os.path.splitext(fn)[0]
46
+ video_id_to_path[video_id] = video_fn
47
+
48
+ video_path, feature_path = [], []
49
+ valid_ext = set()
50
+ for idx, video_id in enumerate(video_id_to_path):
51
+ video_path.append(video_id_to_path[video_id])
52
+ if ext is None:
53
+ # use original file ext for format compatibility.
54
+ video_id_to_path[video_id]
55
+ path = urllib.parse.urlparse(video_id_to_path[video_id]).path
56
+ ext = os.path.splitext(path)[1]
57
+ if ext not in valid_ext:
58
+ valid_ext.add(ext)
59
+ print("adding", ext)
60
+ if shards:
61
+ shard_id = str(idx % shards)
62
+ feature_fn = os.path.join(
63
+ feature_dir, shard_id, video_id + ext)
64
+ else:
65
+ feature_fn = os.path.join(
66
+ feature_dir, video_id + ext)
67
+ feature_path.append(feature_fn)
68
+
69
+ print("targeting", len(feature_path), "videos")
70
+ meta = {
71
+ "video_path": video_path, "feature_path": feature_path}
72
+ with open(meta_fn, "w") as fw:
73
+ json.dump(meta, fw)
74
+
75
+ if split is not None:
76
+ splits = split.split("/")
77
+ assert len(splits) == 2
78
+ cur, total = int(splits[0]), int(splits[1])
79
+ assert cur < total
80
+ import math
81
+ chunk = math.ceil(len(meta["video_path"]) / total)
82
+ start = cur * chunk
83
+ end = (cur + 1) * chunk
84
+ meta = {
85
+ "video_path": meta["video_path"][start:end],
86
+ "feature_path": meta["feature_path"][start:end]
87
+ }
88
+
89
+ return meta
fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Howto100m authors.
2
+ # Copyright (c) Facebook, Inc. All Rights Reserved
3
+
4
+ import torch as th
5
+
6
+ class Normalize(object):
7
+
8
+ def __init__(self, mean, std):
9
+ self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
10
+ self.std = th.FloatTensor(std).view(1, 3, 1, 1)
11
+
12
+ def __call__(self, tensor):
13
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
14
+ return tensor
15
+
16
+ class Preprocessing(object):
17
+
18
+ def __init__(self, type):
19
+ self.type = type
20
+ if type == '2d':
21
+ self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
22
+ elif type == '3d':
23
+ self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0])
24
+ elif type == 'vmz':
25
+ self.norm = Normalize(mean=[110.201, 100.64, 95.997], std=[58.1489, 56.4701, 55.3324])
26
+
27
+ def _zero_pad(self, tensor, size):
28
+ n = size - len(tensor) % size
29
+ if n == size:
30
+ return tensor
31
+ else:
32
+ z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3])
33
+ return th.cat((tensor, z), 0)
34
+
35
+ def __call__(self, tensor):
36
+ if self.type == '2d':
37
+ tensor = tensor / 255.0
38
+ tensor = self.norm(tensor)
39
+ elif self.type == 'vmz':
40
+ #tensor = self._zero_pad(tensor, 8)
41
+ tensor = self._zero_pad(tensor, 10)
42
+ tensor = self.norm(tensor)
43
+ #tensor = tensor.view(-1, 8, 3, 112, 112)
44
+ tensor = tensor.view(-1, 10, 3, 112, 112)
45
+ tensor = tensor.transpose(1, 2)
46
+ elif self.type == '3d':
47
+ tensor = self._zero_pad(tensor, 16)
48
+ tensor = self.norm(tensor)
49
+ tensor = tensor.view(-1, 16, 3, 112, 112)
50
+ tensor = tensor.transpose(1, 2)
51
+ elif self.type == 's3d':
52
+ tensor = tensor / 255.0
53
+ tensor = self._zero_pad(tensor, 30)
54
+ tensor = tensor.view(-1, 30, 3, 224, 224) # N x 30 x 3 x H x W
55
+ tensor = tensor.transpose(1, 2) # N x 3 x 30 x H x W
56
+ # for vae do nothing
57
+ return tensor
fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. All Rights Reserved
2
+
3
+ import numpy as np
4
+
5
+ from torch.utils.data.sampler import Sampler
6
+
7
+
8
+ class RandomSequenceSampler(Sampler):
9
+
10
+ def __init__(self, n_sample, seq_len):
11
+ self.n_sample = n_sample
12
+ self.seq_len = seq_len
13
+
14
+ def _pad_ind(self, ind):
15
+ zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
16
+ ind = np.concatenate((ind, zeros))
17
+ return ind
18
+
19
+ def __iter__(self):
20
+ idx = np.arange(self.n_sample)
21
+ if self.n_sample % self.seq_len != 0:
22
+ idx = self._pad_ind(idx)
23
+ idx = np.reshape(idx, (-1, self.seq_len))
24
+ np.random.shuffle(idx)
25
+ idx = np.reshape(idx, (-1))
26
+ return iter(idx.astype(int))
27
+
28
+ def __len__(self):
29
+ return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)
fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import numpy as np
6
+ import os
7
+ import pickle
8
+
9
+ from mmpt.utils import ShardedTensor
10
+
11
+
12
+ class Shard(object):
13
+ def __init__(
14
+ self,
15
+ vfeat_dir,
16
+ tfeat_dir,
17
+ target_dir,
18
+ file_paths,
19
+ shard_size=4096
20
+ ):
21
+ self.vfeat_dir = vfeat_dir
22
+ self.tfeat_dir = tfeat_dir
23
+ self.target_dir = target_dir
24
+ self.video_ids = {}
25
+ for split, file_path in zip(["train", "val"], file_paths):
26
+ with open(file_path) as fr:
27
+ self.video_ids[split] = [
28
+ line.strip() for line in fr.readlines()]
29
+ self.shard_size = shard_size
30
+
31
+ def __call__(self, split="train"):
32
+ for split in ["train", "val"]:
33
+ meta = {}
34
+ for shard_idx, shard_offset in enumerate(
35
+ range(0, len(self.video_ids[split]), self.shard_size)
36
+ ):
37
+ print(shard_idx)
38
+ meta_shard = []
39
+ video_shard = []
40
+ for video_id in self.video_ids[split][shard_offset:shard_offset+self.shard_size]:
41
+ meta_shard.append(video_id)
42
+ npy_file = os.path.join(self.vfeat_dir, video_id + ".npy")
43
+ video_shard.append(np.load(npy_file))
44
+
45
+ meta[shard_idx] = meta_shard
46
+ video_shard = ShardedTensor.from_list(video_shard)
47
+ target_path = os.path.join(
48
+ self.target_dir, split + "_" + str(shard_idx))
49
+ video_shard.save(target_path)
50
+
51
+ target_path = os.path.join(self.target_dir, split + "_meta")
52
+ with open(target_path + ".pkl", "wb") as fw:
53
+ pickle.dump(meta, fw, pickle.HIGHEST_PROTOCOL)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ shard = Shard(
58
+ "data/feat/feat_how2_s3d",
59
+ "data/how2/raw_caption_dedup.bert-base-uncased",
60
+ "data/feat/feat_how2_s3d_shard_small",
61
+ ["data/how2/how2_s3d_train.lst", "data/how2/how2_s3d_val.lst"]
62
+ )
63
+
64
+ shard()
fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Howto100M authors.
2
+ # Copyright (c) Facebook, Inc. All Rights Reserved
3
+
4
+ import torch as th
5
+ import pandas as pd
6
+ import os
7
+ import numpy as np
8
+ import ffmpeg
9
+ import random
10
+
11
+ from torch.utils.data import Dataset
12
+
13
+
14
+ class VideoLoader(Dataset):
15
+ """modified from how2's video_feature_extractor."""
16
+ def __init__(
17
+ self,
18
+ csv=None,
19
+ video_dict=None,
20
+ framerate=1,
21
+ size=112,
22
+ centercrop=False,
23
+ hflip=False,
24
+ **kwargs
25
+ ):
26
+ if csv is None and video_dict is None:
27
+ raise ValueError("csv and video_dict cannot be both None.")
28
+ if csv is not None:
29
+ self.csv = pd.read_csv(csv)
30
+ if video_dict is not None:
31
+ self.csv = pd.DataFrame.from_dict(video_dict)
32
+
33
+ self.centercrop = centercrop
34
+ self.size = size
35
+ self.framerate = framerate
36
+ self.hflip = hflip
37
+
38
+ def __len__(self):
39
+ return len(self.csv)
40
+
41
+ def _get_video_dim(self, video_path):
42
+ probe = ffmpeg.probe(video_path)
43
+ video_stream = next((stream for stream in probe['streams']
44
+ if stream['codec_type'] == 'video'), None)
45
+ width = int(video_stream['width'])
46
+ height = int(video_stream['height'])
47
+ return height, width
48
+
49
+ def _get_video_info(self, video_path):
50
+ probe = ffmpeg.probe(video_path)
51
+ video_stream = next((stream for stream in probe['streams']
52
+ if stream['codec_type'] == 'video'), None)
53
+ return video_stream
54
+
55
+ def _get_output_dim(self, h, w):
56
+ if isinstance(self.size, tuple) and len(self.size) == 2:
57
+ return self.size
58
+ elif h >= w:
59
+ return int(h * self.size / w), self.size
60
+ else:
61
+ return self.size, int(w * self.size / h)
62
+
63
+ def __getitem__(self, idx):
64
+ video_path = self.csv['video_path'].values[idx]
65
+ output_file = self.csv['feature_path'].values[idx]
66
+ return self._decode(output_file, video_path)
67
+
68
+ def _decode(self, output_file, video_path):
69
+ if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
70
+ try:
71
+ h, w = self._get_video_dim(video_path)
72
+ except Exception:
73
+ print('ffprobe failed at: {}'.format(video_path))
74
+ return {'video': th.zeros(1), 'input': video_path,
75
+ 'output': output_file}
76
+ try:
77
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
78
+ height, width = self._get_output_dim(h, w)
79
+
80
+ cmd = (
81
+ ffmpeg
82
+ .input(video_path)
83
+ .filter('fps', fps=self.framerate)
84
+ .filter('scale', width, height)
85
+ )
86
+ if self.hflip:
87
+ cmd = cmd.filter('hflip')
88
+
89
+ if self.centercrop:
90
+ x = int((width - self.size) / 2.0)
91
+ y = int((height - self.size) / 2.0)
92
+ cmd = cmd.crop(x, y, self.size, self.size)
93
+ video = self._run(cmd, output_file)
94
+ except Exception:
95
+ video = th.zeros(1)
96
+ else:
97
+ video = th.zeros(1)
98
+
99
+ return {'video': video, 'input': video_path, 'output': output_file}
100
+
101
+ def _run(self, cmd, output_file):
102
+ out, _ = (
103
+ cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
104
+ .run(capture_stdout=True, quiet=True)
105
+ )
106
+ if self.centercrop and isinstance(self.size, int):
107
+ height, width = self.size, self.size
108
+ video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
109
+ video = th.from_numpy(video.astype('float32'))
110
+ return video.permute(0, 3, 1, 2)
111
+
112
+
113
+ class VideoVerifier(VideoLoader):
114
+ def __getitem__(self, idx):
115
+ video_path = self.csv['video_path'].values[idx]
116
+ try:
117
+ return self._get_video_info(video_path)
118
+ except Exception:
119
+ # print('ffprobe failed at: {}'.format(video_path))
120
+ return None
121
+
122
+
123
+ class VideoCompressor(VideoLoader):
124
+ def __init__(
125
+ self,
126
+ csv=None,
127
+ video_dict=None,
128
+ framerate=1,
129
+ size=112,
130
+ centercrop=False,
131
+ hflip=False,
132
+ crf=32,
133
+ **kwargs
134
+ ):
135
+ super().__init__(
136
+ csv,
137
+ video_dict,
138
+ framerate,
139
+ size,
140
+ centercrop,
141
+ hflip
142
+ )
143
+ self.crf = crf
144
+
145
+ def _run(self, cmd, output_file):
146
+ out, _ = (
147
+ cmd.output(filename=output_file, crf=self.crf)
148
+ .run(quiet=True)
149
+ )
150
+ video = None
151
+ return video
152
+
153
+
154
+ class VideoDownloader(VideoCompressor):
155
+ """download"""
156
+ def __getitem__(self, idx):
157
+ video_path = self.csv['video_path'].values[idx]
158
+ output_file = self.csv['feature_path'].values[idx]
159
+ if not(os.path.isfile(output_file)):
160
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
161
+ cmd = "wget -O" + output_file + " " + video_path
162
+ # import subprocess
163
+ # subprocess.check_output(
164
+ # cmd,
165
+ # stderr=subprocess.STDOUT, shell=True)
166
+ os.system(cmd)
167
+ return {'video': None, 'input': video_path, 'output': output_file}
168
+
169
+
170
+ class AvKeyframeVideoCompressor(VideoLoader):
171
+ """extract keyframes from a video and save it as jpg.
172
+ TODO: consider to merge with `CodecProcessor`.
173
+ """
174
+ def __init__(
175
+ self,
176
+ csv=None,
177
+ video_dict=None,
178
+ framerate=1,
179
+ size=112,
180
+ centercrop=False,
181
+ max_num_frames=5,
182
+ **kwargs
183
+ ):
184
+ super().__init__(csv, video_dict, framerate, size, centercrop)
185
+ self.max_num_frames = max_num_frames
186
+
187
+ def _get_video_dim(self, video_fn):
188
+ """decord cannot probe the size of a video, we use pyav instead."""
189
+ import av
190
+ with av.open(video_fn) as container:
191
+ height = container.streams.video[0].codec_context.height
192
+ width = container.streams.video[0].codec_context.width
193
+ return height, width
194
+
195
+ def _get_output_dim(self, height, width):
196
+ """
197
+ keep the shorter side be `self.size`, strech the other.
198
+ """
199
+ if height >= width:
200
+ return int(height * self.size / width), self.size
201
+ else:
202
+ return self.size, int(width * self.size / height)
203
+
204
+ def __getitem__(self, idx):
205
+ import av
206
+ video_path = self.csv['video_path'].values[idx]
207
+ output_file = self.csv['feature_path'].values[idx]
208
+ if not(os.path.isdir(output_file)) and os.path.isfile(video_path):
209
+ try:
210
+ h, w = self._get_video_dim(video_path)
211
+ except Exception:
212
+ print('probe failed at: {}'.format(video_path))
213
+ return {'video': th.zeros(1), 'input': video_path,
214
+ 'output': output_file}
215
+
216
+ try:
217
+ height, width = self._get_output_dim(h, w)
218
+
219
+ # new for av.
220
+ with av.open(video_path) as container:
221
+ container.streams.video[0].thread_type = "AUTO"
222
+ container.streams.video[0].codec_context.height = height
223
+ container.streams.video[0].codec_context.width = width
224
+ if self.framerate == 0: # keyframe.
225
+ container.streams.video[0].codec_context.skip_frame = 'NONKEY'
226
+ frames = []
227
+ for frame in container.decode(video=0):
228
+ frames.append(frame)
229
+ frames = random.sample(frames, self.max_num_frames)
230
+
231
+ os.makedirs(output_file, exist_ok=True)
232
+ for frame in frames:
233
+ frame.to_image().save(
234
+ os.path.join(
235
+ output_file,
236
+ "%04d.jpg" % frame.index))
237
+ except Exception:
238
+ print('extract failed at: {}'.format(video_path))
239
+ return {'video': th.zeros(1), 'input': video_path,
240
+ 'output': output_file}
241
+ video = th.zeros(1)
242
+ return {'video': video, 'input': video_path, 'output': output_file}
fairseq/examples/MMPT/videoclip.png ADDED

Git LFS Details

  • SHA256: 1d54fe18d1259ade9332e78fdb74f834fdfbdb0b0486517e6a7cd48956b30663
  • Pointer size: 131 Bytes
  • Size of remote file: 386 kB
fairseq/examples/MMPT/vlm.png ADDED

Git LFS Details

  • SHA256: 722852ed6258ac9f7ffd3e3913fa1a370702c4d989ef6d881847432d59ade4e5
  • Pointer size: 131 Bytes
  • Size of remote file: 418 kB
fairseq/examples/adaptive_span/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adaptive Span
2
+
3
+ Adaptive Span is a novel self-attention mechanism that can learn its optimal
4
+ attention span. This allows us to extend significantly the maximum context size
5
+ used in Transformer, while maintaining control over their memory footprint
6
+ and computational time. It uses the Truncated BPTT technique for training,
7
+ as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
8
+
9
+ Adaptive Span was introduced by paper:
10
+ [Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
11
+ which achieved state-of-the-art language modeling results at the time of publication.
12
+
13
+ We manage to reproduce their result in fairseq and keep most of the
14
+ [original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
15
+ You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
16
+
17
+ ##### 0. Setup
18
+
19
+ First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
20
+ from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
21
+ You can download the dataset, and then run:
22
+ ```bash
23
+ fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
24
+ --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
25
+ --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
26
+ ```
27
+
28
+ ##### 1. Train a Adaptive Span model on Enwik8
29
+
30
+ We will train a 12-layer Adaptive Span model following the [hyperparameters
31
+ used in the original
32
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
33
+
34
+ The following command assumes 4 GPUs, so that the total batch size is 64
35
+ sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
36
+ ```bash
37
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
38
+ --user-dir examples/adaptive_span \
39
+ --data ~/data/enwik8/data-bin/ \
40
+ --fp16 --fp16-no-flatten-grads --max-update 600000 \
41
+ --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
42
+ --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
43
+ --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
44
+ --validate-interval-updates 1000 \
45
+ --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
46
+ --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
47
+ --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
48
+ ```
49
+ This should land around 1.05 on validation, 1.03 on test. You can lower the
50
+ --aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
51
+ improvement to the transformerXL baseline here.
52
+ If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
53
+ and simulate training on 4 GPUs.
54
+ You can also reproduce the transformerXL result on enwik8 using this code base.
55
+ It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
56
+ You can try by
57
+ ```bash
58
+ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
59
+ --user-dir examples/truncated_bptt \
60
+ ~/data/enwik8/data-bin/ \
61
+ --task truncated_bptt_lm --fp16 --max-update 400000 \
62
+ --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
63
+ --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
64
+ --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
65
+ --lr-scheduler cosine --warmup-updates 0 \
66
+ --lr 0.0 --lr 0.00025 --batch-size 15 \
67
+ --update-freq 1 --seed 2 --log-format json --log-interval 25 \
68
+ --fp16
69
+ ```
70
+
71
+ ##### 2. Evaluate
72
+ For Adaptive Span:
73
+ ```bash
74
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
75
+ --user-dir examples/adaptive_span \
76
+ --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
77
+ ```
78
+ For Transformer-XL evaluation:
79
+ ```bash
80
+ fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
81
+ --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
82
+ --tokens-per-sample 80 \
83
+ --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
84
+ --gen-subset valid
85
+ ```
86
+
87
+ *Note:* During training the model saw 512 tokens of context
88
+ (``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
89
+ settings from [the original
90
+ paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
fairseq/examples/adaptive_span/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ # automatically import any Python files in the current directory
10
+ cur_dir = os.path.dirname(__file__)
11
+ for file in os.listdir(cur_dir):
12
+ path = os.path.join(cur_dir, file)
13
+ if (
14
+ not file.startswith("_")
15
+ and not file.startswith(".")
16
+ and (file.endswith(".py") or os.path.isdir(path))
17
+ ):
18
+ mod_name = file[: file.find(".py")] if file.endswith(".py") else file
19
+ module = importlib.import_module(__name__ + "." + mod_name)
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.optim import Adagrad
7
+
8
+ from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
9
+
10
+
11
+ @register_optimizer("adagrad_with_grad_clip")
12
+ class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
13
+ def __init__(self, args, params):
14
+ super().__init__(args)
15
+ self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
16
+
17
+ @staticmethod
18
+ def add_args(parser):
19
+ """Add optimizer-specific arguments to the parser."""
20
+ # fmt: off
21
+ parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
22
+ help='weight decay')
23
+ parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
24
+ help='internal grad clip')
25
+ # fmt: on
26
+
27
+ @property
28
+ def optimizer_config(self):
29
+ """
30
+ Return a kwarg dictionary that will be used to override optimizer
31
+ args stored in checkpoints. This allows us to load a checkpoint and
32
+ resume training using a different set of optimizer args, e.g., with a
33
+ different learning rate.
34
+ """
35
+ return {
36
+ "lr": self.args.lr[0],
37
+ "weight_decay": self.args.weight_decay,
38
+ "grad_clip": self.args.adagrad_clip,
39
+ }
40
+
41
+ @property
42
+ def supports_flat_params(self):
43
+ return False
44
+
45
+
46
+ def _clip_grad(clr, grad, group_grad_clip):
47
+ if group_grad_clip > 0:
48
+ norm = grad.norm(2).item()
49
+ if norm > group_grad_clip:
50
+ clr *= group_grad_clip / (norm + 1e-10)
51
+ return clr
52
+
53
+
54
+ class AdagradWithGradClip(Adagrad):
55
+ """Adagrad algorithm with custom gradient clipping"""
56
+
57
+ def __init__(
58
+ self,
59
+ params,
60
+ lr=1e-2,
61
+ lr_decay=0,
62
+ weight_decay=0,
63
+ initial_accumulator_value=0,
64
+ grad_clip=0,
65
+ ):
66
+ Adagrad.__init__(
67
+ self,
68
+ params,
69
+ lr=lr,
70
+ lr_decay=lr_decay,
71
+ weight_decay=weight_decay,
72
+ initial_accumulator_value=initial_accumulator_value,
73
+ )
74
+ self.defaults["grad_clip"] = grad_clip
75
+ self.param_groups[0].setdefault("grad_clip", grad_clip)
76
+
77
+ def step(self, closure=None):
78
+ loss = None
79
+ if closure is not None:
80
+ loss = closure()
81
+
82
+ for group in self.param_groups:
83
+ for p in group["params"]:
84
+ if p.grad is None:
85
+ continue
86
+
87
+ grad = p.grad.data
88
+ state = self.state[p]
89
+
90
+ state["step"] += 1
91
+
92
+ if group["weight_decay"] != 0:
93
+ if p.grad.data.is_sparse:
94
+ raise RuntimeError(
95
+ "weight_decay option is "
96
+ "not compatible with sparse "
97
+ "gradients"
98
+ )
99
+ grad = grad.add(group["weight_decay"], p.data)
100
+
101
+ clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
102
+
103
+ # clip
104
+ clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
105
+
106
+ if grad.is_sparse:
107
+ # the update is non-linear so indices must be unique
108
+ grad = grad.coalesce()
109
+ grad_indices = grad._indices()
110
+ grad_values = grad._values()
111
+ size = grad.size()
112
+
113
+ def make_sparse(values):
114
+ constructor = grad.new
115
+ if grad_indices.dim() == 0 or values.dim() == 0:
116
+ return constructor().resize_as_(grad)
117
+ return constructor(grad_indices, values, size)
118
+
119
+ state["sum"].add_(make_sparse(grad_values.pow(2)))
120
+ std = state["sum"]._sparse_mask(grad)
121
+ std_values = std._values().sqrt_().add_(1e-10)
122
+ p.data.add_(-clr, make_sparse(grad_values / std_values))
123
+ else:
124
+ state["sum"].addcmul_(1, grad, grad)
125
+ std = state["sum"].sqrt().add_(1e-10)
126
+ p.data.addcdiv_(-clr, grad, std)
127
+
128
+ return loss
fairseq/examples/adaptive_span/adaptive_span_attention.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class AdaptiveMask(nn.Module):
13
+ """Soft masking function for adaptive size.
14
+ It masks out the last K values of an input. The masking value
15
+ goes from 1 to 0 gradually, so K can be learned with
16
+ back-propagation.
17
+ Args:
18
+ max_size: maximum size (i.e. input dimension)
19
+ ramp_size: size of the ramp going from 0 to 1
20
+ init_val: initial size proportion not to be masked out
21
+ shape: learn multiple sizes independent of each other
22
+ """
23
+
24
+ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
25
+ nn.Module.__init__(self)
26
+ self._max_size = max_size
27
+ self._ramp_size = ramp_size
28
+ self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
29
+ mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
30
+ self.register_buffer("mask_template", mask_template)
31
+
32
+ def forward(self, x):
33
+ mask = self.mask_template.float() + self.current_val.float() * self._max_size
34
+ mask = mask / self._ramp_size + 1
35
+ mask = mask.clamp(0, 1)
36
+ if x.size(-1) < self._max_size:
37
+ # the input could have been trimmed beforehand to save computation
38
+ mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
39
+ x = (x * mask).type_as(x)
40
+ return x
41
+
42
+ def get_current_max_size(self, include_ramp=True):
43
+ current_size = math.ceil(self.current_val.max().item() * self._max_size)
44
+ if include_ramp:
45
+ current_size += self._ramp_size
46
+ current_size = max(0, min(self._max_size, current_size))
47
+ return current_size
48
+
49
+ def get_current_avg_size(self, include_ramp=True):
50
+ current_size = math.ceil(
51
+ self.current_val.float().mean().item() * self._max_size
52
+ )
53
+ if include_ramp:
54
+ current_size += self._ramp_size
55
+ current_size = max(0, min(self._max_size, current_size))
56
+ return current_size
57
+
58
+ def clamp_param(self):
59
+ """this need to be called after each update"""
60
+ self.current_val.data.clamp_(0, 1)
61
+
62
+
63
+ class AdaptiveSpan(nn.Module):
64
+ """Adaptive attention span for Transformerself.
65
+ This module learns an attention span length from data for each
66
+ self-attention head.
67
+ Args:
68
+ attn_span: maximum attention span
69
+ adapt_span_loss: loss coefficient for the span length
70
+ adapt_span_ramp: length of the masking ramp
71
+ adapt_span_init: initial size ratio
72
+ adapt_span_cache: adapt cache size to reduce memory usage
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ attn_span,
78
+ adapt_span_ramp,
79
+ adapt_span_init,
80
+ n_head,
81
+ adapt_span_layer,
82
+ **kargs
83
+ ):
84
+ nn.Module.__init__(self)
85
+ self._max_span = attn_span
86
+ self._n_head = n_head
87
+ self._adapt_span_layer = adapt_span_layer
88
+ if self._adapt_span_layer:
89
+ self._mask = AdaptiveMask(
90
+ max_size=self._max_span,
91
+ ramp_size=adapt_span_ramp,
92
+ init_val=adapt_span_init,
93
+ )
94
+ else:
95
+ self._mask = AdaptiveMask(
96
+ max_size=self._max_span,
97
+ ramp_size=adapt_span_ramp,
98
+ init_val=adapt_span_init,
99
+ shape=(n_head, 1, 1),
100
+ )
101
+
102
+ def forward(self, attn, normalize=True):
103
+ """mask attention with the right span"""
104
+ # batch and head dimensions are merged together, so separate them first
105
+ self.clamp_param()
106
+ if self._adapt_span_layer:
107
+ attn = self._mask(attn)
108
+ else:
109
+ B = attn.size(0) # batch size
110
+ M = attn.size(1) # block size
111
+ attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
112
+ attn = self._mask(attn)
113
+ attn = attn.view(B, M, -1)
114
+ return attn
115
+
116
+ def get_trim_len(self):
117
+ """how much of memory can be trimmed to reduce computation"""
118
+ L = self._max_span
119
+ trim_len = min(L - 1, L - self._mask.get_current_max_size())
120
+ # too fine granularity might be bad for the memory management
121
+ trim_len = math.floor(trim_len / 64) * 64
122
+ return trim_len
123
+
124
+ def trim_memory(self, query, key, value, key_pe):
125
+ """trim out unnecessary memory beforehand to reduce computation"""
126
+ trim_len = self.get_trim_len()
127
+ cache_size = key.size(1) - query.size(1)
128
+ trim_len_cache = trim_len - (self._max_span - cache_size)
129
+ if trim_len_cache > 0:
130
+ key = key[:, trim_len_cache:, :]
131
+ value = value[:, trim_len_cache:, :]
132
+ elif trim_len_cache < 0:
133
+ # cache is too short! this happens when validation resumes
134
+ # after a lot of updates.
135
+ key = F.pad(key, [0, 0, -trim_len_cache, 0])
136
+ value = F.pad(value, [0, 0, -trim_len_cache, 0])
137
+ if trim_len > 0:
138
+ if key_pe is not None:
139
+ key_pe = key_pe[:, :, trim_len:]
140
+ return key, value, key_pe
141
+
142
+ def get_cache_size(self):
143
+ """determine how long the cache should be"""
144
+ trim_len = self.get_trim_len()
145
+ # give a buffer of 64 steps since a span might increase
146
+ # in future updates
147
+ return min(self._max_span, self._max_span - trim_len + 64)
148
+
149
+ def get_loss(self):
150
+ """a loss term for regularizing the span length"""
151
+ return self._max_span * self._mask.current_val.float().mean()
152
+
153
+ def get_current_max_span(self):
154
+ return self._mask.get_current_max_size()
155
+
156
+ def get_current_avg_span(self):
157
+ return self._mask.get_current_avg_size()
158
+
159
+ def clamp_param(self):
160
+ self._mask.clamp_param()
fairseq/examples/adaptive_span/adaptive_span_loss.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+
9
+ import torch.nn.functional as F
10
+ from fairseq import utils
11
+ from fairseq.logging import metrics
12
+ from fairseq.criterions import register_criterion
13
+ from fairseq.criterions.cross_entropy import CrossEntropyCriterion
14
+ from fairseq.dataclass import FairseqDataclass
15
+ from omegaconf import II
16
+
17
+
18
+ @dataclass
19
+ class AdaptiveSpanCriterionConfig(FairseqDataclass):
20
+ sentence_avg: bool = II("optimization.sentence_avg")
21
+
22
+
23
+ @register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
24
+ class AdaptiveSpanCriterion(CrossEntropyCriterion):
25
+ def __init__(self, task, sentence_avg):
26
+ super().__init__(task, sentence_avg)
27
+
28
+ def forward(self, model, sample, reduce=True):
29
+ """Compute the loss for the given sample.
30
+
31
+ Returns a tuple with three elements:
32
+ 1) the loss here is summed, different from the adaptive span code
33
+ 2) the sample size, which is used as the denominator for the gradient
34
+ 3) logging outputs to display while training
35
+ """
36
+ net_output = model(**sample["net_input"])
37
+ loss, aux_loss, avg_span, max_span = self.compute_loss(
38
+ model, net_output, sample, reduce=reduce
39
+ )
40
+ sample_size = (
41
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
42
+ )
43
+ loss /= sample_size
44
+ total_loss = loss + aux_loss
45
+ sample_size = 1
46
+
47
+ logging_output = {
48
+ "loss": loss.data,
49
+ "ntokens": sample["ntokens"],
50
+ "nsentences": sample["target"].size(0),
51
+ "sample_size": sample_size,
52
+ "total_loss": total_loss.data,
53
+ "avg_span": avg_span * sample_size,
54
+ "max_span": max_span * sample_size,
55
+ }
56
+ return total_loss, sample_size, logging_output
57
+
58
+ def compute_loss(self, model, net_output, sample, reduce=True):
59
+ loss, _ = super().compute_loss(model, net_output, sample, reduce)
60
+ aux_loss = model.get_aux_loss()
61
+ avg_span = model.get_current_avg_span()
62
+ max_span = model.get_current_max_span()
63
+ return loss, aux_loss, avg_span, max_span
64
+
65
+ @staticmethod
66
+ def reduce_metrics(logging_outputs) -> None:
67
+ """Aggregate logging outputs from data parallel training."""
68
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
69
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
70
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
71
+ total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
72
+ avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
73
+ max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
74
+
75
+ # we divide by log(2) to convert the loss from base e to base 2
76
+ metrics.log_scalar(
77
+ "loss", loss_sum / sample_size / math.log(2), sample_size, round=3
78
+ )
79
+ metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
80
+ metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
81
+ # total loss contains the L1 norm on adaptive-span
82
+ metrics.log_scalar(
83
+ "total_loss",
84
+ total_loss_sum / sample_size / math.log(2),
85
+ sample_size,
86
+ round=3,
87
+ )
88
+ if sample_size != ntokens:
89
+ metrics.log_scalar(
90
+ "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
91
+ )
92
+ metrics.log_derived(
93
+ "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
94
+ )
95
+ else:
96
+ metrics.log_derived(
97
+ "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
98
+ )
99
+
100
+ @staticmethod
101
+ def logging_outputs_can_be_summed() -> bool:
102
+ """
103
+ Whether the logging outputs returned by `forward` can be summed
104
+ across workers prior to calling `reduce_metrics`. Setting this
105
+ to True will improves distributed training speed.
106
+ """
107
+ return True
fairseq/examples/adaptive_span/adaptive_span_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from fairseq.modules.layer_norm import LayerNorm
14
+
15
+ from .adaptive_span_attention import AdaptiveSpan
16
+
17
+ # Size notations:
18
+ # B = batch_size, H = d_model, M = block_size, L = attn_span
19
+
20
+
21
+ def _skew(X, pad_value):
22
+ """shift every row 1 step to right"""
23
+ # X = B x M x L
24
+ B, M, L = X.size()
25
+ X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
26
+ X = X.view(B, -1) # B x ML+MM+M
27
+ X = X[:, :-M] # B x ML+MM
28
+ X = X.view(B, M, M + L) # B x M x L+M
29
+ return X
30
+
31
+
32
+ def _unskew(X):
33
+ """reverse _skew operation"""
34
+ # X = B x M x L+M
35
+ B, M, L = X.size()
36
+ L -= M
37
+ X = X.view(B, -1) # B x ML+MM
38
+ X = F.pad(X, (0, M)) # B x ML+MM+M
39
+ X = X.view(B, M, M + L + 1) # B x M x L+M+1
40
+ X = X[:, :, :L] # B x M x L
41
+ return X
42
+
43
+
44
+ class SeqAttention(nn.Module):
45
+ """Sequential self-attention layer.
46
+ Each token will attend to its previous fixed number of steps.
47
+ Note that attention doesn't include the current step itself.
48
+ """
49
+
50
+ def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
51
+ nn.Module.__init__(self)
52
+ self.dropout = nn.Dropout(dropout)
53
+ self.d_model = d_model # size of a single head
54
+ self.attn_span = attn_span
55
+ self.adaptive_span = AdaptiveSpan(
56
+ attn_span=attn_span,
57
+ n_head=n_head,
58
+ adapt_span_layer=adapt_span_layer,
59
+ **kargs
60
+ )
61
+
62
+ def forward(self, query, key, value, key_pe):
63
+ # query size = B x M x H
64
+ # key, value sizes = B x (M+L) x H
65
+
66
+ key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
67
+
68
+ # compute attention from context
69
+ # B x M (dest) x (M+L) (src)
70
+ attn_cont = torch.matmul(query, key.transpose(-1, -2))
71
+ attn_cont = _unskew(attn_cont) # B x M x L
72
+
73
+ # compute the effect of position embedding
74
+ attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
75
+ attn = attn_cont + attn_pos
76
+
77
+ attn = attn / math.sqrt(self.d_model) # B x M X L_pos
78
+
79
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
80
+
81
+ # trim attention lengths according to the learned span
82
+ attn = self.adaptive_span(attn)
83
+
84
+ attn = self.dropout(attn) # B x M X L_pos
85
+
86
+ attn_cont = _skew(attn, 0) # B x M X (L+M)
87
+ out = torch.matmul(attn_cont, value) # B x M x H
88
+ return out
89
+
90
+ def get_cache_size(self):
91
+ return self.adaptive_span.get_cache_size()
92
+
93
+
94
+ class MultiHeadSeqAttention(nn.Module):
95
+ def __init__(self, d_model, n_head, **kargs):
96
+ nn.Module.__init__(self)
97
+ assert d_model % n_head == 0
98
+ self.n_head = n_head
99
+ self.head_dim = d_model // n_head
100
+ self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
101
+ self.proj_query = nn.Linear(d_model, d_model, bias=False)
102
+ nn.init.xavier_normal_(self.proj_query.weight)
103
+ self.proj_out = nn.Linear(d_model, d_model, bias=False)
104
+ nn.init.xavier_normal_(self.proj_out.weight)
105
+ self.proj_val = nn.Linear(d_model, d_model, bias=False)
106
+ nn.init.xavier_normal_(self.proj_val.weight)
107
+ self.proj_key = nn.Linear(d_model, d_model, bias=False)
108
+ nn.init.xavier_normal_(self.proj_key.weight)
109
+
110
+ def head_reshape(self, x):
111
+ K = self.n_head
112
+ D = self.head_dim
113
+ x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
114
+ x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
115
+ x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
116
+ return x
117
+
118
+ def forward(self, query, key, value, key_pe):
119
+ B = query.size(0)
120
+ K = self.n_head
121
+ D = self.head_dim
122
+ M = query.size(1)
123
+
124
+ query = self.proj_query(query)
125
+ query = self.head_reshape(query)
126
+ value = self.proj_val(value)
127
+ value = self.head_reshape(value)
128
+ key = self.proj_key(key)
129
+ key = self.head_reshape(key)
130
+
131
+ out = self.attn(query, key, value, key_pe) # B_K x M x D
132
+ out = out.view(B, K, M, D) # B x K x M x D
133
+ out = out.transpose(1, 2).contiguous() # B x M x K x D
134
+ out = out.view(B, M, -1) # B x M x K_D
135
+ out = self.proj_out(out)
136
+ return out
137
+
138
+
139
+ class FeedForwardLayer(nn.Module):
140
+ def __init__(self, d_model, d_inner, dropout, **kargs):
141
+ nn.Module.__init__(self)
142
+ self.fc1 = nn.Linear(d_model, d_inner)
143
+ self.fc2 = nn.Linear(d_inner, d_model)
144
+ nn.init.xavier_uniform_(self.fc1.weight)
145
+ nn.init.xavier_uniform_(self.fc2.weight)
146
+ self.dropout = nn.Dropout(dropout)
147
+
148
+ def forward(self, h):
149
+ h1 = F.relu(self.fc1(h))
150
+ h1 = self.dropout(h1)
151
+ h2 = self.fc2(h1)
152
+ return h2
153
+
154
+
155
+ class TransformerSeqLayer(nn.Module):
156
+ def __init__(self, d_model, **kargs):
157
+ nn.Module.__init__(self)
158
+ self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
159
+ self.norm1 = LayerNorm(d_model)
160
+ self.ff = FeedForwardLayer(d_model=d_model, **kargs)
161
+ self.norm2 = LayerNorm(d_model)
162
+
163
+ def forward(self, h, h_cache, key_pe):
164
+ # h = B x M x H
165
+ # h_cache = B x L x H
166
+ h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
167
+ attn_out = self.attn(h, h_all, h_all, key_pe)
168
+ h = self.norm1(h + attn_out) # B x M x H
169
+ if self.ff is not None:
170
+ ff_out = self.ff(h)
171
+ out = self.norm2(h + ff_out) # B x M x H
172
+ else:
173
+ out = h
174
+ return out
175
+
176
+ def get_cache_size(self):
177
+ return self.attn.attn.get_cache_size()
178
+
179
+
180
+ class TransformerSeq(nn.Module):
181
+ def __init__(
182
+ self,
183
+ vocab_size,
184
+ d_model,
185
+ n_head,
186
+ n_layer,
187
+ attn_span,
188
+ emb_dropout,
189
+ aux_loss_scaler,
190
+ adapt_span_layer,
191
+ **kargs
192
+ ):
193
+ nn.Module.__init__(self)
194
+ # token embeddings
195
+ self.in_emb = nn.Embedding(vocab_size, d_model)
196
+ nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
197
+ self.out_emb = nn.Linear(d_model, vocab_size)
198
+ self.aux_loss_scaler = aux_loss_scaler
199
+ if emb_dropout > 0:
200
+ self.emb_dropout = nn.Dropout(emb_dropout)
201
+ else:
202
+ self.emb_dropout = None
203
+ # position embeddings
204
+ self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
205
+
206
+ self.layers = nn.ModuleList()
207
+ self.layers.extend(
208
+ TransformerSeqLayer(
209
+ d_model=d_model,
210
+ n_head=n_head,
211
+ attn_span=attn_span,
212
+ adapt_span_layer=adapt_span_layer,
213
+ **kargs
214
+ )
215
+ for _ in range(n_layer)
216
+ )
217
+
218
+ def forward(self, x, h_cache, target=None):
219
+ # x size = B x M
220
+ block_size = x.size(1)
221
+ h = self.in_emb(x) # B x M x H
222
+ if self.emb_dropout is not None:
223
+ h = self.emb_dropout(h)
224
+
225
+ h_cache_next = []
226
+ for l, layer in enumerate(self.layers):
227
+ cache_size = layer.attn.attn.get_cache_size()
228
+ if cache_size > block_size:
229
+ h_cache_next_l = torch.cat(
230
+ [h_cache[l][:, -cache_size + block_size :, :], h], dim=1
231
+ ).detach()
232
+ else:
233
+ h_cache_next_l = h[:, -cache_size:, :].detach()
234
+ h_cache_next.append(h_cache_next_l)
235
+ h = layer(h, h_cache[l], self.key_pe) # B x M x H
236
+
237
+ if self.emb_dropout is not None:
238
+ h = self.emb_dropout(h)
239
+
240
+ out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
241
+ dummy_loss = None
242
+
243
+ return out, h_cache_next, dummy_loss
244
+
245
+ def get_aux_loss(self):
246
+ loss = 0.0
247
+ for layer in self.layers:
248
+ loss += layer.attn.attn.adaptive_span.get_loss()
249
+ return self.aux_loss_scaler * loss
250
+
251
+ def get_current_max_span(self):
252
+ max_span = 0.0
253
+ for layer in self.layers:
254
+ max_span = max(
255
+ max_span, layer.attn.attn.adaptive_span.get_current_max_span()
256
+ )
257
+ return max_span
258
+
259
+ def get_current_avg_span(self):
260
+ avg_span = 0.0
261
+ for layer in self.layers:
262
+ avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
263
+ return avg_span / len(self.layers)
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ from fairseq.dataclass import FairseqDataclass
12
+ from fairseq.models import (
13
+ FairseqIncrementalDecoder,
14
+ FairseqLanguageModel,
15
+ register_model,
16
+ )
17
+ from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class AdaptiveSpanSmallConfig(FairseqDataclass):
25
+ # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
26
+ vocab_size: int = 50
27
+ d_model: int = 256
28
+ n_head: int = 4
29
+ d_inner: int = 1024
30
+ n_layer: int = 8
31
+ attn_span: int = 1024
32
+ dropout: float = 0.0
33
+ emb_dropout: float = 0.0
34
+ adapt_span_ramp: int = 32
35
+ adapt_span_init: float = 0.0
36
+ aux_loss_scaler: float = 0.000002
37
+ adapt_span_layer: bool = False
38
+
39
+
40
+ @register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
41
+ class AdaptiveSpanTransformer(FairseqLanguageModel):
42
+ @classmethod
43
+ def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
44
+ return cls(AdaptiveSpanDecoder(cfg, task))
45
+
46
+ def get_aux_loss(self):
47
+ return self.decoder.get_aux_loss()
48
+
49
+ def get_current_max_span(self):
50
+ return self.decoder.get_current_max_span()
51
+
52
+ def get_current_avg_span(self):
53
+ return self.decoder.get_current_avg_span()
54
+
55
+
56
+ class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
57
+ def __init__(self, cfg, task):
58
+
59
+ super().__init__(task.target_dictionary)
60
+
61
+ self.config = cfg
62
+ config = AdaptiveSpanSmallConfig(
63
+ vocab_size=len(task.target_dictionary),
64
+ d_model=cfg.d_model,
65
+ n_head=cfg.n_head,
66
+ d_inner=cfg.d_inner,
67
+ n_layer=cfg.n_layer,
68
+ attn_span=cfg.attn_span,
69
+ dropout=cfg.dropout,
70
+ emb_dropout=cfg.emb_dropout,
71
+ adapt_span_ramp=cfg.adapt_span_ramp,
72
+ adapt_span_init=cfg.adapt_span_init,
73
+ aux_loss_scaler=cfg.aux_loss_scaler,
74
+ adapt_span_layer=cfg.adapt_span_layer,
75
+ )
76
+ logger.info(config)
77
+ self.model = AdaptiveSpanTransformerModel(**config.__dict__)
78
+
79
+ self._mems = None
80
+
81
+ def forward(
82
+ self,
83
+ src_tokens,
84
+ incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
85
+ encoder_out=None,
86
+ ):
87
+ bsz = src_tokens.size(0)
88
+ if incremental_state is not None: # used during inference
89
+ mems = self.get_incremental_state("mems")
90
+ src_tokens = src_tokens[:, -1:] # only keep the most recent token
91
+ else:
92
+ mems = self._mems
93
+
94
+ if mems is None:
95
+ # first time init
96
+ mems = self.init_hid_cache(bsz)
97
+ output = self.model(x=src_tokens, h_cache=mems,)
98
+ if incremental_state is not None:
99
+ self.set_incremental_state(incremental_state, "mems", output[1])
100
+ else:
101
+ self._mems = output[1]
102
+ return (output[0],)
103
+
104
+ def max_positions(self):
105
+ return self.config.attn_span
106
+
107
+ def init_hid_cache(self, batch_sz):
108
+ hid = []
109
+ for layer in self.model.layers:
110
+ param = next(self.model.parameters())
111
+ h = torch.zeros(
112
+ batch_sz,
113
+ layer.get_cache_size(),
114
+ self.config.d_model,
115
+ dtype=param.dtype,
116
+ device=param.device,
117
+ )
118
+ hid.append(h)
119
+ return hid
120
+
121
+ def get_aux_loss(self):
122
+ return self.model.get_aux_loss()
123
+
124
+ def get_current_max_span(self):
125
+ return self.model.get_current_max_span()
126
+
127
+ def get_current_avg_span(self):
128
+ return self.model.get_current_avg_span()
129
+
130
+ def reorder_incremental_state(
131
+ self,
132
+ incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
133
+ new_order: torch.Tensor,
134
+ ):
135
+ """Reorder incremental state.
136
+
137
+ This will be called when the order of the input has changed from the
138
+ previous time step. A typical use case is beam search, where the input
139
+ order changes between time steps based on the selection of beams.
140
+ """
141
+ raise NotImplementedError("This is required for generation/beam search")
142
+ # mems = self.get_incremental_state(incremental_state, "mems")
143
+ # if mems is not None:
144
+ # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
145
+ # self.set_incremental_state(incremental_state, "mems", new_mems)
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from typing import List, Optional, Tuple
10
+
11
+ import torch
12
+ from fairseq import utils
13
+ from fairseq.data import (
14
+ Dictionary,
15
+ TokenBlockDataset,
16
+ data_utils,
17
+ iterators,
18
+ )
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.distributed import utils as dist_utils
21
+ from fairseq.tasks import FairseqTask, register_task
22
+ from omegaconf import II
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class TruncatedBPTTLMConfig(FairseqDataclass):
30
+ data: str = field(default="???", metadata={"help": "path to data directory"})
31
+ tokens_per_sample: int = field(
32
+ default=1024, metadata={"help": "max number of tokens per sequence"},
33
+ )
34
+ batch_size: int = II("dataset.batch_size")
35
+ # Some models use *max_target_positions* to know how many positional
36
+ # embeddings to learn. We use II(...) to make it default to
37
+ # *tokens_per_sample*, but in principle there could be more positional
38
+ # embeddings than tokens in a single batch. This may also be irrelevant for
39
+ # custom model implementations.
40
+ max_target_positions: int = II("task.tokens_per_sample")
41
+ # these will be populated automatically if not provided
42
+ data_parallel_rank: Optional[int] = None
43
+ data_parallel_size: Optional[int] = None
44
+
45
+
46
+ @register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
47
+ class TruncatedBPTTLMTask(FairseqTask):
48
+ def __init__(self, cfg: TruncatedBPTTLMConfig):
49
+ super().__init__(cfg)
50
+
51
+ if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
52
+ if torch.distributed.is_initialized():
53
+ cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
54
+ cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
55
+ else:
56
+ cfg.data_parallel_rank = 0
57
+ cfg.data_parallel_size = 1
58
+
59
+ # load the dictionary
60
+ paths = utils.split_paths(cfg.data)
61
+ assert len(paths) > 0
62
+ self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
63
+ logger.info("dictionary: {} types".format(len(self.dictionary)))
64
+
65
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
66
+ """Load a given dataset split (e.g., train, valid, test)"""
67
+
68
+ # support sharded datasets
69
+ paths = utils.split_paths(self.cfg.data)
70
+ assert len(paths) > 0
71
+ data_path = paths[(epoch - 1) % len(paths)]
72
+ split_path = os.path.join(data_path, split)
73
+
74
+ # each element of *data* will be a tensorized line from the original
75
+ # text dataset, similar to ``open(split_path).readlines()``
76
+ data = data_utils.load_indexed_dataset(
77
+ split_path, self.dictionary, combine=combine
78
+ )
79
+ if data is None:
80
+ raise FileNotFoundError(
81
+ "Dataset not found: {} ({})".format(split, split_path)
82
+ )
83
+
84
+ # this is similar to ``data.view(-1).split(tokens_per_sample)``
85
+ data = TokenBlockDataset(
86
+ data,
87
+ data.sizes,
88
+ block_size=self.cfg.tokens_per_sample,
89
+ pad=None, # unused
90
+ eos=None, # unused
91
+ break_mode="none",
92
+ )
93
+
94
+ self.datasets[split] = TruncatedBPTTDataset(
95
+ data=data,
96
+ bsz_per_shard=self.cfg.batch_size,
97
+ shard_id=self.cfg.data_parallel_rank,
98
+ num_shards=self.cfg.data_parallel_size,
99
+ )
100
+
101
+ def dataset(self, split):
102
+ return self.datasets[split]
103
+
104
+ def get_batch_iterator(
105
+ self,
106
+ dataset,
107
+ num_workers=0,
108
+ epoch=1,
109
+ data_buffer_size=0,
110
+ skip_remainder_batch=False,
111
+ **kwargs
112
+ ):
113
+ return iterators.EpochBatchIterator(
114
+ dataset=dataset,
115
+ collate_fn=self._collate_fn,
116
+ num_workers=num_workers,
117
+ epoch=epoch,
118
+ buffer_size=data_buffer_size,
119
+ # we don't use the batching functionality from EpochBatchIterator;
120
+ # instead every item in *dataset* is a whole batch
121
+ batch_sampler=[[i] for i in range(len(dataset))],
122
+ disable_shuffling=True,
123
+ skip_remainder_batch=skip_remainder_batch,
124
+ )
125
+
126
+ def _collate_fn(self, items: List[List[torch.Tensor]]):
127
+ # we don't use fairseq's batching functionality, so we expect a single
128
+ # Tensor of type List[torch.Tensor]
129
+ assert len(items) == 1
130
+
131
+ # item will have shape B x T (the last batch may have length < T)
132
+ id, item = items[0]
133
+ item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
134
+ B, T = item.size()
135
+
136
+ # shift item one position over and append a padding token for the target
137
+ target = torch.nn.functional.pad(
138
+ item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
139
+ )
140
+
141
+ # fairseq expects batches to have the following structure
142
+ return {
143
+ "id": torch.tensor([id] * item.size(0)),
144
+ "net_input": {"src_tokens": item,},
145
+ "target": target,
146
+ "nsentences": item.size(0),
147
+ "ntokens": item.numel(),
148
+ }
149
+
150
+ def build_dataset_for_inference(
151
+ self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
152
+ ) -> torch.utils.data.Dataset:
153
+ eos = self.source_dictionary.eos()
154
+ dataset = TokenBlockDataset(
155
+ src_tokens,
156
+ src_lengths,
157
+ block_size=None, # ignored for "eos" break mode
158
+ pad=self.source_dictionary.pad(),
159
+ eos=eos,
160
+ break_mode="eos",
161
+ )
162
+
163
+ class Dataset(torch.utils.data.Dataset):
164
+ def __getitem__(self, i):
165
+ item = dataset[i]
166
+ if item[-1] == eos:
167
+ # remove eos to support generating with a prefix
168
+ item = item[:-1]
169
+ return (i, [item])
170
+
171
+ def __len__(self):
172
+ return len(dataset)
173
+
174
+ return Dataset()
175
+
176
+ def inference_step(
177
+ self, generator, models, sample, prefix_tokens=None, constraints=None
178
+ ):
179
+ with torch.no_grad():
180
+ if constraints is not None:
181
+ raise NotImplementedError
182
+
183
+ # SequenceGenerator doesn't use *src_tokens* directly, we need to
184
+ # pass the *prefix_tokens* argument instead.
185
+ if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
186
+ prefix_tokens = sample["net_input"]["src_tokens"]
187
+
188
+ # begin generation with the end-of-sentence token
189
+ bos_token = self.source_dictionary.eos()
190
+
191
+ return generator.generate(
192
+ models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
193
+ )
194
+
195
+ def eval_lm_dataloader(
196
+ self,
197
+ dataset,
198
+ max_tokens: Optional[int] = 36000,
199
+ batch_size: Optional[int] = None,
200
+ max_positions: Optional[int] = None,
201
+ num_shards: int = 1,
202
+ shard_id: int = 0,
203
+ num_workers: int = 1,
204
+ data_buffer_size: int = 10,
205
+ context_window: int = 0,
206
+ ):
207
+ if context_window > 0:
208
+ raise NotImplementedError(
209
+ "Transformer-XL doesn't need --context-window, try "
210
+ "--model-overrides '{\"mem_len\":42}' instead "
211
+ )
212
+ return self.get_batch_iterator(
213
+ dataset=dataset,
214
+ max_tokens=max_tokens,
215
+ max_sentences=batch_size,
216
+ max_positions=max_positions,
217
+ ignore_invalid_inputs=True,
218
+ num_shards=num_shards,
219
+ shard_id=shard_id,
220
+ num_workers=num_workers,
221
+ data_buffer_size=data_buffer_size,
222
+ ).next_epoch_itr(shuffle=False)
223
+
224
+ @property
225
+ def source_dictionary(self):
226
+ return self.dictionary
227
+
228
+ @property
229
+ def target_dictionary(self):
230
+ return self.dictionary
231
+
232
+
233
+ class TruncatedBPTTDataset(torch.utils.data.Dataset):
234
+ def __init__(
235
+ self,
236
+ data: List[torch.Tensor], # ordered list of items
237
+ bsz_per_shard, # number of items processed per GPUs per forward
238
+ shard_id, # current GPU ID
239
+ num_shards, # number of GPUs
240
+ ):
241
+ super().__init__()
242
+ self.data = data
243
+
244
+ def batchify(data, bsz):
245
+ # Work out how cleanly we can divide the dataset into bsz parts.
246
+ nbatch = data.size(0) // bsz
247
+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
248
+ data = data.narrow(0, 0, nbatch * bsz)
249
+ # Evenly divide the data across the bsz batches.
250
+ data = data.view(bsz, -1).contiguous()
251
+ return data
252
+
253
+ # total number of sequences processed by all GPUs in each forward pass
254
+ global_batch_size = bsz_per_shard * num_shards
255
+
256
+ """
257
+ With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
258
+ *indices* might look like:
259
+
260
+ indices = [[0, 1],
261
+ [2, 3],
262
+ [4, 5],
263
+ [6, 7],
264
+ [8, 9],
265
+ [10, 11]]
266
+
267
+ The size of the TruncatedBPTTDataset instance will be 2,
268
+ and shard 1 will see items:
269
+
270
+ [(0, [data[4], data[6]]),
271
+ (1, [data[5], data[7]])]
272
+ """
273
+ indices = batchify(torch.arange(len(data)), global_batch_size)
274
+ assert indices.size(0) == global_batch_size
275
+
276
+ self.my_indices = indices[
277
+ shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
278
+ ]
279
+ assert self.my_indices.size(0) == bsz_per_shard
280
+
281
+ def __len__(self):
282
+ return self.my_indices.size(1)
283
+
284
+ def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
285
+ return (i, [self.data[idx] for idx in self.my_indices[:, i]])
fairseq/examples/attention_head_selection/README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)
2
+
3
+ [https://arxiv.org/pdf/2106.10840.pdf](https://arxiv.org/pdf/2106.10840.pdf)
4
+
5
+ ## Introduction
6
+
7
+ We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.
8
+
9
+ Below is an example of training multilingual/multi-domain speech recognition models.
10
+
11
+ ## Data Preparation
12
+ Prepare mTEDx data as in [mTEDx example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/mtedx_example.md) and CoVoST data as in [CoVoST example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/covost_example.md). Similarly prepare EuroParl data.
13
+
14
+
15
+ ## Training a multilingual ASR model with attention head selection
16
+
17
+ ```bash
18
+ data_dir=<path to mtedx data>
19
+ train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
20
+ valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
21
+ strateg=<subset or group>
22
+
23
+ fairseq-train ${data_dir} \
24
+ --user-dir examples/attention_head_selection/src \
25
+ --train-subset "${train_subset}" \
26
+ --valid-subset "${valid_subset}" \
27
+ --config-yaml 'config_asr.yaml' \
28
+ --arch 'head_selection_s2t_transformer_s' \
29
+ --task 'speech_to_text_head_selection' \
30
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
31
+ --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
32
+ --lr 5e-4 \
33
+ --clip-norm 10.0 \
34
+ --seed 1 \
35
+ --max-epoch 400 \
36
+ --max-tokens 32000 \
37
+ --ignore-prefix-size 1 \
38
+ --dropout 0.3 \
39
+ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
40
+ --skip-invalid-size-inputs-valid-test \
41
+ --encoder-attn-head-select \
42
+ --total-encoder-attention-heads 8 \
43
+ --decoder-self-attn-head-select \
44
+ --total-decoder-attention-heads 8 \
45
+ --attn-head-select-strategy ${strategy} \
46
+ --task-type lang \
47
+ ```
48
+
49
+ ## Training a multi-domain ASR model with attention head selection
50
+
51
+ ```bash
52
+ data_dir=<path to multi-domain data>
53
+ train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
54
+ valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
55
+ strateg=<subset or group>
56
+
57
+ fairseq-train ${data_dir} \
58
+ --user-dir examples/attention_head_selection/src \
59
+ --train-subset "${train_subset}" \
60
+ --valid-subset "${valid_subset}" \
61
+ --config-yaml 'config_asr.yaml' \
62
+ --arch head_selection_s2t_transformer_s \
63
+ --task speech_to_text_head_selection \
64
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
65
+ --lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
66
+ --lr 5e-4 \
67
+ --clip-norm 10.0 \
68
+ --seed 1 \
69
+ --max-epoch 400 \
70
+ --max-tokens 32000 \
71
+ --ignore-prefix-size 1 \
72
+ --dropout 0.3 \
73
+ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
74
+ --skip-invalid-size-inputs-valid-test \
75
+ --encoder-attn-head-select \
76
+ --total-encoder-attention-heads 8 \
77
+ --decoder-self-attn-head-select \
78
+ --total-decoder-attention-heads 8 \
79
+ --attn-head-select-strategy ${strategy} \
80
+ --task-type domain
81
+ ```
82
+
83
+ ## Inference in multilingual setting
84
+
85
+ ```bash
86
+ MODEL_DIR=<checkpoint directory>
87
+ data_dir=<path to mtedx data>
88
+ gen_subset=<data to test, e.g., test_ar_ar_tedx>
89
+ train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
90
+ last_n=10
91
+ CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
92
+ CHECKPOINT="_avg"
93
+ RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
94
+ if [ ! -d $RESULTS ]; then
95
+ mkdir -p $RESULTS
96
+ fi;
97
+
98
+ python scripts/average_checkpoints.py \
99
+ --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
100
+ --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
101
+
102
+ fairseq-generate ${data_dir} \
103
+ --user-dir examples/attention_head_selection/src \
104
+ --arch 'head_selection_s2t_transformer_s' \
105
+ --task 'speech_to_text_head_selection' \
106
+ --train-subset ${train_subset} \
107
+ --gen-subset ${gen_subset} \
108
+ --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
109
+ --config-yaml 'config_asr.yaml' \
110
+ --prefix-size 1 \
111
+ --max-tokens 40000 --beam 5 \
112
+ --skip-invalid-size-inputs-valid-test \
113
+ --results-path ${RESULTS} \
114
+ --scoring wer --wer-tokenizer 13a \
115
+ --wer-lowercase --wer-remove-punct --remove-bpe
116
+ ```
117
+
118
+ ## Inference in multi-domain setting
119
+
120
+ ```bash
121
+ MODEL_DIR=<checkpoint directory>
122
+ data_dir=<path to multi-domain data>
123
+ gen_subset=<data to test, e.g., test_pt_pt_cv>
124
+ train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
125
+ last_n=10
126
+ CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
127
+ CHECKPOINT="_avg"
128
+ RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
129
+ if [ ! -d $RESULTS ]; then
130
+ mkdir -p $RESULTS
131
+ fi;
132
+
133
+ python scripts/average_checkpoints.py \
134
+ --inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
135
+ --output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
136
+
137
+ fairseq-generate ${data_dir} \
138
+ --user-dir examples/attention_head_selection/src \
139
+ --arch 'head_selection_s2t_transformer_s' \
140
+ --task 'speech_to_text_head_selection' \
141
+ --train-subset ${train_subset} \
142
+ --gen-subset ${gen_subset} \
143
+ --path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
144
+ --config-yaml 'config_asr.yaml' \
145
+ --prefix-size 1 \
146
+ --max-tokens 40000 --beam 5 \
147
+ --skip-invalid-size-inputs-valid-test \
148
+ --results-path ${RESULTS} \
149
+ --scoring wer --wer-tokenizer 13a \
150
+ --wer-lowercase --wer-remove-punct --remove-bpe
151
+ ```
152
+
153
+ ## Citation
154
+ ```bibtex
155
+ @article{gong2021pay,
156
+ title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
157
+ author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
158
+ journal={arXiv preprint arXiv:2106.10840},
159
+ year={2021}
160
+ }
161
+ '''
fairseq/examples/attention_head_selection/src/__init__.py ADDED
File without changes
fairseq/examples/attention_head_selection/src/data/__init__.py ADDED
File without changes
fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional
9
+ from dataclasses import dataclass
10
+
11
+ import torch
12
+ from fairseq.data import (
13
+ ConcatDataset,
14
+ Dictionary,
15
+ FairseqDataset,
16
+ ResamplingDataset
17
+ )
18
+ from fairseq.data.audio.data_cfg import S2TDataConfig
19
+ from fairseq.data.audio.speech_to_text_dataset import (
20
+ SpeechToTextDatasetItem,
21
+ SpeechToTextDataset,
22
+ SpeechToTextDatasetCreator
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class SpeechToTextDatasetItemWithDomain(SpeechToTextDatasetItem):
30
+ src_lang_id: Optional[torch.Tensor] = None
31
+ tgt_lang_id: Optional[torch.Tensor] = None
32
+ domain_id: Optional[torch.Tensor] = None
33
+
34
+
35
+ class SpeechToTextDatasetWithDomain(SpeechToTextDataset):
36
+
37
+ def __init__(
38
+ self,
39
+ split: str,
40
+ is_train_split: bool,
41
+ cfg: S2TDataConfig,
42
+ audio_paths: List[str],
43
+ n_frames: List[int],
44
+ src_texts: Optional[List[str]] = None,
45
+ tgt_texts: Optional[List[str]] = None,
46
+ speakers: Optional[List[str]] = None,
47
+ src_langs: Optional[List[str]] = None,
48
+ tgt_langs: Optional[List[str]] = None,
49
+ ids: Optional[List[str]] = None,
50
+ tgt_dict: Optional[Dictionary] = None,
51
+ pre_tokenizer=None,
52
+ bpe_tokenizer=None,
53
+ n_frames_per_step=1,
54
+ speaker_to_id=None,
55
+ src_lang_ids: Optional[List[int]] = None,
56
+ tgt_lang_ids: Optional[List[int]] = None,
57
+ domain_ids: Optional[List[int]] = None
58
+ ):
59
+ super().__init__(
60
+ split, is_train_split, cfg, audio_paths, n_frames,
61
+ src_texts, tgt_texts, speakers, src_langs, tgt_langs,
62
+ ids, tgt_dict, pre_tokenizer, bpe_tokenizer,
63
+ n_frames_per_step, speaker_to_id
64
+ )
65
+ assert src_lang_ids is None or len(src_lang_ids) == self.n_samples
66
+ assert tgt_lang_ids is None or len(tgt_lang_ids) == self.n_samples
67
+ assert domain_ids is None or len(domain_ids) == self.n_samples
68
+
69
+ self.src_lang_ids = src_lang_ids
70
+ self.tgt_lang_ids = tgt_lang_ids
71
+ self.domain_ids = domain_ids
72
+
73
+ def __getitem__(self, index: int) -> SpeechToTextDatasetItemWithDomain:
74
+ item = super().__getitem__(index)
75
+ src_lang_id = self.src_lang_ids[index]
76
+ tgt_lang_id = self.tgt_lang_ids[index]
77
+ domain_id = self.domain_ids[index]
78
+ return SpeechToTextDatasetItemWithDomain(
79
+ index=item.index, source=item.source,
80
+ target=item.target, speaker_id=item.speaker_id,
81
+ src_lang_id=src_lang_id,
82
+ tgt_lang_id=tgt_lang_id,
83
+ domain_id=domain_id
84
+ )
85
+
86
+ def collater(
87
+ self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
88
+ ) -> Dict:
89
+ if len(samples) == 0:
90
+ return {}
91
+ out = super().collater(samples, return_order=True)
92
+ order = out["order"]
93
+ src_lang_ids = torch.tensor([x.src_lang_id for x in samples], dtype=torch.long).index_select(0, order)
94
+ tgt_lang_ids = torch.tensor([x.tgt_lang_id for x in samples], dtype=torch.long).index_select(0, order)
95
+ domain_ids = torch.tensor([x.domain_id for x in samples], dtype=torch.long).index_select(0, order)
96
+
97
+ out["src_lang_ids"] = src_lang_ids
98
+ out["tgt_lang_ids"] = tgt_lang_ids
99
+ out["domain_ids"] = domain_ids
100
+ if not return_order:
101
+ del out["order"]
102
+ return out
103
+
104
+
105
+ class SpeechToTextDatasetCreatorWithDomain(SpeechToTextDatasetCreator):
106
+ KEY_SRC_LANG_ID, KEY_TGT_LANG_ID = "src_lang_id", "tgt_lang_id"
107
+ KEY_DOMAIN_ID = "domain_id"
108
+ # default values
109
+ DEFAULT_SRC_LANG_ID, DEFAULT_TGT_LANG_ID, DEFAULT_DOMAIN_ID = 0, 0, 0
110
+
111
+ @classmethod
112
+ def _from_list(
113
+ cls,
114
+ split_name: str,
115
+ is_train_split,
116
+ samples: List[Dict],
117
+ cfg: S2TDataConfig,
118
+ tgt_dict,
119
+ pre_tokenizer,
120
+ bpe_tokenizer,
121
+ n_frames_per_step,
122
+ speaker_to_id
123
+ ) -> SpeechToTextDatasetWithDomain:
124
+ audio_root = Path(cfg.audio_root)
125
+ ids = [s[cls.KEY_ID] for s in samples]
126
+ audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
127
+ n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
128
+ tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
129
+ src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
130
+ speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
131
+ src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
132
+ tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
133
+ src_lang_ids = [s.get(cls.KEY_SRC_LANG_ID, cls.DEFAULT_SRC_LANG_ID) for s in samples]
134
+ tgt_lang_ids = [s.get(cls.KEY_TGT_LANG_ID, cls.DEFAULT_TGT_LANG_ID) for s in samples]
135
+ domain_ids = [s.get(cls.KEY_DOMAIN_ID, cls.DEFAULT_DOMAIN_ID) for s in samples]
136
+ return SpeechToTextDatasetWithDomain(
137
+ split_name,
138
+ is_train_split,
139
+ cfg,
140
+ audio_paths,
141
+ n_frames,
142
+ src_texts=src_texts,
143
+ tgt_texts=tgt_texts,
144
+ speakers=speakers,
145
+ src_langs=src_langs,
146
+ tgt_langs=tgt_langs,
147
+ ids=ids,
148
+ tgt_dict=tgt_dict,
149
+ pre_tokenizer=pre_tokenizer,
150
+ bpe_tokenizer=bpe_tokenizer,
151
+ n_frames_per_step=n_frames_per_step,
152
+ speaker_to_id=speaker_to_id,
153
+ src_lang_ids=src_lang_ids,
154
+ tgt_lang_ids=tgt_lang_ids,
155
+ domain_ids=domain_ids
156
+ )
157
+
158
+ @classmethod
159
+ def _load_samples_from_tsv(
160
+ cls,
161
+ root: str,
162
+ split: str,
163
+ src_lang_map,
164
+ tgt_lang_map,
165
+ domain_map
166
+ ):
167
+ # metadata from split
168
+ _, src_lang, tgt_lang, domain = split.split("_")
169
+ src_lang_id = src_lang_map[src_lang]
170
+ tgt_lang_id = tgt_lang_map[tgt_lang]
171
+ domain_id = domain_map[domain]
172
+
173
+ samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
174
+ for s in samples:
175
+ s.update({
176
+ cls.KEY_SRC_LANG_ID: src_lang_id,
177
+ cls.KEY_TGT_LANG_ID: tgt_lang_id,
178
+ cls.KEY_DOMAIN_ID: domain_id
179
+ })
180
+ return samples
181
+
182
+ @classmethod
183
+ def _from_tsv(
184
+ cls,
185
+ root: str,
186
+ cfg: S2TDataConfig,
187
+ split: str,
188
+ tgt_dict,
189
+ is_train_split: bool,
190
+ pre_tokenizer,
191
+ bpe_tokenizer,
192
+ n_frames_per_step,
193
+ speaker_to_id,
194
+ src_lang_map: Dict[str, int],
195
+ tgt_lang_map: Dict[str, int],
196
+ domain_map: Dict[str, int]
197
+ ) -> SpeechToTextDatasetItemWithDomain:
198
+ samples = cls._load_samples_from_tsv(
199
+ root, split, src_lang_map,
200
+ tgt_lang_map, domain_map
201
+ )
202
+ return cls._from_list(
203
+ split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
204
+ bpe_tokenizer, n_frames_per_step, speaker_to_id
205
+ )
206
+
207
+ @classmethod
208
+ def from_tsv(
209
+ cls,
210
+ root: str,
211
+ cfg: S2TDataConfig,
212
+ splits: str,
213
+ tgt_dict,
214
+ pre_tokenizer,
215
+ bpe_tokenizer,
216
+ is_train_split: bool,
217
+ epoch: int,
218
+ seed: int,
219
+ src_lang_map: Dict[str, int],
220
+ tgt_lang_map: Dict[str, int],
221
+ domain_map: Dict[str, int],
222
+ n_frames_per_step: int = 1,
223
+ speaker_to_id=None
224
+ ) -> SpeechToTextDatasetWithDomain:
225
+ datasets = [
226
+ cls._from_tsv(
227
+ root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, src_lang_map, tgt_lang_map, domain_map
228
+ )
229
+ for split in splits.split(",")
230
+ ]
231
+
232
+ if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
233
+ # temperature-based sampling
234
+ size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
235
+ datasets = [
236
+ ResamplingDataset(
237
+ d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
238
+ )
239
+ for r, d in zip(size_ratios, datasets)
240
+ ]
241
+
242
+ return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
fairseq/examples/attention_head_selection/src/loss/__init__.py ADDED
File without changes
fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch.nn.modules.loss import _Loss
10
+
11
+
12
+ class HeadSelectionLoss(_Loss):
13
+
14
+ def __init__(self, args):
15
+ super().__init__()
16
+ self.args = args
17
+ self.kl_weight = getattr(args, "kl_weight", 0.0)
18
+
19
+ def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7):
20
+ """
21
+ head_scores: (num_tasks, num_layers, num_heads)
22
+ sample_sizes: (num_tasks, )
23
+ """
24
+ kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1)
25
+ kl_loss /= (torch.numel(head_samples) / head_samples.size(0))
26
+ kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes)
27
+ return kl_loss
fairseq/examples/attention_head_selection/src/models/__init__.py ADDED
File without changes
fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from typing import Dict, List, Optional
8
+ from pathlib import Path
9
+ import torch.nn as nn
10
+ from torch import Tensor
11
+ from fairseq import checkpoint_utils
12
+
13
+ from fairseq.models import register_model, register_model_architecture
14
+ from fairseq.utils import safe_hasattr
15
+ from fairseq.models.speech_to_text.s2t_transformer import (
16
+ S2TTransformerModel,
17
+ S2TTransformerEncoder,
18
+ TransformerDecoderScriptable
19
+ )
20
+ from fairseq.models.speech_to_text.s2t_transformer import base_architecture as s2t_base_architecture
21
+
22
+ from ..modules.attn_head_selector import AttnHeadSelector
23
+ from ..modules.head_selection_transformer_layer import HeadSelectionTransformerEncoderLayer
24
+ from .head_selection_transformer import HeadSelectionTransformerDecoder
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @register_model("head_selection_s2t_transformer")
31
+ class HeadSelectionS2TTransformerModel(S2TTransformerModel):
32
+ """
33
+ Head selection implemented in S2TTransformer
34
+ """
35
+ def __init__(self, encoder, decoder):
36
+ super().__init__(encoder, decoder)
37
+
38
+ @staticmethod
39
+ def add_args(parser):
40
+ S2TTransformerModel.add_args(parser)
41
+ # encoder head selection
42
+ parser.add_argument(
43
+ "--encoder-attn-head-select",
44
+ action="store_true",
45
+ default=False,
46
+ help="encoder head selection"
47
+ )
48
+ parser.add_argument(
49
+ "--total-encoder-attention-heads",
50
+ type=int,
51
+ help="total number of encoder attention heads"
52
+ )
53
+ # decoder self attention selection
54
+ parser.add_argument(
55
+ "--decoder-self-attn-head-select",
56
+ action="store_true",
57
+ default=False,
58
+ help="decoder self-attention head selection"
59
+ )
60
+ # decoder-encoder attention selection
61
+ parser.add_argument(
62
+ "--dec-enc-attn-head-select",
63
+ action="store_true",
64
+ default=False,
65
+ help="decoder-encoder attention head selection"
66
+ )
67
+ parser.add_argument(
68
+ "--total-decoder-attention-heads",
69
+ type=int,
70
+ help="total number of decoder attention heads"
71
+ )
72
+ # selection strategy
73
+ parser.add_argument(
74
+ "--attn-head-select-strategy",
75
+ type=str,
76
+ help="attention head selection strategy, subset or group"
77
+ )
78
+
79
+ @classmethod
80
+ def build_encoder(cls, args):
81
+ if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
82
+ encoder = HeadSelectionS2TTransformerEncoder(args)
83
+ else:
84
+ encoder = S2TTransformerEncoder(args)
85
+ pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
86
+ if pretraining_path is not None:
87
+ if not Path(pretraining_path).exists():
88
+ logger.warning(
89
+ f"skipped pretraining because {pretraining_path} does not exist"
90
+ )
91
+ else:
92
+ encoder = checkpoint_utils.load_pretrained_component_from_model(
93
+ component=encoder, checkpoint=pretraining_path
94
+ )
95
+ logger.info(f"loaded pretrained encoder from: {pretraining_path}")
96
+ return encoder
97
+
98
+ @classmethod
99
+ def build_decoder(cls, args, task, embed_tokens):
100
+ if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
101
+ return HeadSelectionTransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
102
+ else:
103
+ return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
104
+
105
+
106
+ class HeadSelectionS2TTransformerEncoder(S2TTransformerEncoder):
107
+
108
+ def __init__(self, args):
109
+ super().__init__(args)
110
+ self.attn_head_selector = AttnHeadSelector(
111
+ args.encoder_tasks,
112
+ args.encoder_layers,
113
+ args.total_encoder_attention_heads,
114
+ args.encoder_attention_heads,
115
+ args.attn_head_select_strategy,
116
+ )
117
+ self.task_ids = None
118
+ self.transformer_layers = nn.ModuleList([
119
+ HeadSelectionTransformerEncoderLayer(args, layer_idx, attn_head_selector=self.attn_head_selector) for layer_idx in range(args.encoder_layers)
120
+ ])
121
+
122
+ def set_task_ids(self, task_ids):
123
+ self.task_ids = task_ids
124
+
125
+ def _forward(self, src_tokens, src_lengths, return_all_hiddens=False):
126
+ self.attn_head_selector.head_select(self.task_ids)
127
+ return super()._forward(src_tokens, src_lengths, return_all_hiddens)
128
+
129
+
130
+ class HeadSelectionTransformerDecoderScriptable(HeadSelectionTransformerDecoder):
131
+ def extract_features(
132
+ self,
133
+ prev_output_tokens,
134
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
135
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
136
+ full_context_alignment: bool = False,
137
+ alignment_layer: Optional[int] = None,
138
+ alignment_heads: Optional[int] = None,
139
+ ):
140
+ # call scriptable method from parent class
141
+ x, _ = self.extract_features_scriptable(
142
+ prev_output_tokens,
143
+ encoder_out,
144
+ incremental_state,
145
+ full_context_alignment,
146
+ alignment_layer,
147
+ alignment_heads,
148
+ )
149
+ return x, None
150
+
151
+
152
+ @register_model_architecture(model_name="head_selection_s2t_transformer", arch_name="head_selection_s2t_transformer")
153
+ def base_architecture(args):
154
+ s2t_base_architecture(args)
155
+ args.encoder_attn_head_select = getattr(args, "encoder_attn_head_select", False)
156
+ args.decoder_self_attn_head_select = getattr(args, "decoder_self_attn_head_select", False)
157
+ args.dec_enc_attn_head_select = getattr(args, "dec_enc_attn_head_select", False)
158
+ args.total_encoder_attention_heads = getattr(args, "total_encoder_attention_heads", 8)
159
+ args.total_decoder_attention_heads = getattr(args, "total_decoder_attention_heads", 8)
160
+ args.attn_head_select_strategy = getattr(args, "attn_head_select_strategy", "group")
161
+
162
+
163
+ @register_model_architecture("head_selection_s2t_transformer", "head_selection_s2t_transformer_s")
164
+ def head_selection_s2t_transformer_s(args):
165
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
166
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
167
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
168
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
169
+ args.dropout = getattr(args, "dropout", 0.1)
170
+ base_architecture(args)
fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, List, Dict, Optional
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+
11
+ from fairseq.utils import safe_hasattr
12
+ from fairseq.models.transformer import (
13
+ TransformerModel,
14
+ TransformerEncoder,
15
+ TransformerDecoder
16
+ )
17
+
18
+ from ..modules.attn_head_selector import AttnHeadSelector
19
+ from ..modules.head_selection_transformer_layer import (
20
+ HeadSelectionTransformerEncoderLayer,
21
+ HeadSelectionTransformerDecoderLayer
22
+ )
23
+
24
+
25
+ class HeadSelectionTransformerModel(TransformerModel):
26
+ def __init__(self, args, encoder, decoder):
27
+ super().__init__(args, encoder, decoder)
28
+
29
+ @staticmethod
30
+ def add_args(parser):
31
+ TransformerModel.add_args(parser)
32
+ # encoder head selection
33
+ parser.add_argument(
34
+ "--encoder-attn-head-select",
35
+ action="store_true",
36
+ default=False,
37
+ help="encoder head selection"
38
+ )
39
+ parser.add_argument(
40
+ "--total-encoder-attention-heads",
41
+ type=int,
42
+ help="total number of encoder attention heads"
43
+ )
44
+ # decoder self attention
45
+ parser.add_argument(
46
+ "--decoder-self-attn-head-select",
47
+ action="store_true",
48
+ default=False,
49
+ help="decoder self-attention head selection"
50
+ )
51
+ # decoder-encoder attention
52
+ parser.add_argument(
53
+ "--dec-enc-attn-head-select",
54
+ action="store_true",
55
+ default=False,
56
+ help="decoder-encoder attention head selection"
57
+ )
58
+ parser.add_argument(
59
+ "--total-decoder-attention-heads",
60
+ type=int,
61
+ help="total number of decoder attention heads"
62
+ )
63
+ # selection strategy
64
+ parser.add_argument(
65
+ "--attn-head-select-strategy",
66
+ type=str,
67
+ help="attention head selection strategy, subset or group"
68
+ )
69
+
70
+ @classmethod
71
+ def build_encoder(cls, args, src_dict, embed_tokens):
72
+ if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
73
+ return HeadSelectionTransformerEncoder(
74
+ args, src_dict, embed_tokens
75
+ )
76
+ else:
77
+ return TransformerEncoder(args, src_dict, embed_tokens)
78
+
79
+ @classmethod
80
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
81
+ if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
82
+ return HeadSelectionTransformerDecoder(
83
+ args, tgt_dict, embed_tokens
84
+ )
85
+ else:
86
+ return TransformerDecoder(args, tgt_dict, embed_tokens)
87
+
88
+
89
+ class HeadSelectionTransformerEncoder(TransformerEncoder):
90
+
91
+ def __init__(self, args, dictionary, embed_tokens):
92
+ self.num_tasks = args.encoder_tasks
93
+ self.num_layers = args.encoder_layers
94
+ self.total_num_heads = args.total_encoder_attention_heads
95
+ self.num_heads = args.encoder_attention_heads
96
+ self.select_strategy = args.attn_head_select_strategy
97
+
98
+ super().__init__(args, dictionary, embed_tokens)
99
+ self.attn_head_selector = AttnHeadSelector(
100
+ self.num_tasks,
101
+ self.num_layers,
102
+ self.total_num_heads,
103
+ self.num_heads,
104
+ self.select_strategy
105
+ )
106
+ self.task_ids = None
107
+ self.layers = nn.ModuleList(
108
+ [self.build_encoder_layer(args, i) for i in range(args.encoder_layers)]
109
+ )
110
+
111
+ def set_task_ids(self, task_ids):
112
+ self.task_ids = task_ids
113
+
114
+ def build_encoder_layer(self, args, layer_idx=None):
115
+ return HeadSelectionTransformerEncoderLayer(
116
+ args,
117
+ layer_idx,
118
+ attn_head_selector=self.attn_head_selector
119
+ )
120
+
121
+ def forward(
122
+ self,
123
+ src_tokens,
124
+ src_lengths: Optional[torch.Tensor] = None,
125
+ return_all_hiddens: bool = False,
126
+ token_embeddings: Optional[torch.Tensor] = None,
127
+ ):
128
+ self.attn_head_selector.head_select(self.task_ids)
129
+ return super().forward(src_tokens, src_lengths, return_all_hiddens, token_embeddings)
130
+
131
+
132
+ class HeadSelectionTransformerDecoder(TransformerDecoder):
133
+
134
+ def __init__(
135
+ self,
136
+ args,
137
+ dictionary,
138
+ embed_tokens,
139
+ no_encoder_attn=False,
140
+ output_projection=None,
141
+ ):
142
+ self.num_tasks = args.decoder_tasks
143
+ self.num_layers = args.decoder_layers
144
+ self.total_num_heads = args.total_decoder_attention_heads
145
+ self.num_heads = args.decoder_attention_heads
146
+ self.select_strategy = args.attn_head_select_strategy
147
+ super().__init__(
148
+ args, dictionary, embed_tokens,
149
+ no_encoder_attn=no_encoder_attn,
150
+ output_projection=output_projection
151
+ )
152
+ self.self_attn_head_selector = None
153
+ self.enc_attn_head_selector = None
154
+ if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select:
155
+ self.self_attn_head_selector = AttnHeadSelector(
156
+ self.num_tasks,
157
+ self.num_layers,
158
+ self.total_num_heads,
159
+ self.num_heads,
160
+ self.select_strategy
161
+ )
162
+ if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select:
163
+ self.enc_attn_head_selector = AttnHeadSelector(
164
+ self.num_tasks,
165
+ self.num_layers,
166
+ self.total_num_heads,
167
+ self.num_heads,
168
+ self.select_strategy
169
+ )
170
+ self.task_ids = None
171
+ self.layers = nn.ModuleList(
172
+ [
173
+ self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers)
174
+ ]
175
+ )
176
+
177
+ def set_task_ids(self, task_ids):
178
+ self.task_ids = task_ids
179
+
180
+ def build_head_selection_decoder_layer(self, args, no_encoder_attn=False, layer_idx=None):
181
+ return HeadSelectionTransformerDecoderLayer(
182
+ args,
183
+ layer_idx,
184
+ self.self_attn_head_selector,
185
+ self.enc_attn_head_selector,
186
+ no_encoder_attn=no_encoder_attn
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ prev_output_tokens,
192
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
193
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
194
+ features_only: bool = False,
195
+ full_context_alignment: bool = False,
196
+ alignment_layer: Optional[int] = None,
197
+ alignment_heads: Optional[int] = None,
198
+ src_lengths: Optional[Any] = None,
199
+ return_all_hiddens: bool = False,
200
+ ):
201
+ if self.self_attn_head_selector is not None:
202
+ self.self_attn_head_selector.head_select(self.task_ids)
203
+ if self.enc_attn_head_selector is not None:
204
+ self.enc_attn_head_selector.head_select(self.task_ids)
205
+ return super().forward(
206
+ prev_output_tokens=prev_output_tokens,
207
+ encoder_out=encoder_out,
208
+ incremental_state=incremental_state,
209
+ features_only=features_only,
210
+ full_context_alignment=full_context_alignment,
211
+ alignment_layer=alignment_layer,
212
+ alignment_heads=alignment_heads,
213
+ src_lengths=src_lengths,
214
+ return_all_hiddens=return_all_hiddens
215
+ )
fairseq/examples/attention_head_selection/src/modules/__init__.py ADDED
File without changes
fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the MIT license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+
9
+ class AttnHeadSelector(nn.Module):
10
+ """
11
+ Latent variable modeling of attention head selection
12
+ """
13
+ def __init__(
14
+ self, num_tasks, num_layers,
15
+ total_num_heads, num_heads,
16
+ select_strategy="group",
17
+ head_select_temp=5.0
18
+ ):
19
+ super(AttnHeadSelector, self).__init__()
20
+ self.num_tasks = num_tasks
21
+ self.num_layers = num_layers
22
+ self.total_num_heads = total_num_heads
23
+ self.num_heads = num_heads
24
+ self.select_strategy = select_strategy
25
+ self.temp = head_select_temp
26
+
27
+ self.head_logits = torch.nn.Parameter(
28
+ torch.Tensor(self.num_tasks, self.num_layers, total_num_heads),
29
+ requires_grad=True
30
+ )
31
+ nn.init.uniform_(
32
+ self.head_logits, a=math.log(0.01),
33
+ b=math.log(1.0)
34
+ )
35
+
36
+ def gumbel_sample(self, logits, tau=1.0):
37
+ gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
38
+ gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
39
+ gumbels1 = (logits + gumbels1 - gumbels2) / tau
40
+ y_soft = gumbels1.sigmoid()
41
+ return y_soft
42
+
43
+ def subset_select(self, y_soft, topk, dim=-1):
44
+ top_values, top_inds = torch.topk(y_soft, k=topk, dim=dim)
45
+ top_ret = 1.0 - top_values.detach() + top_values
46
+ return top_inds.detach(), top_ret
47
+
48
+ def group_selet(self, y_soft, topk, dim=-1):
49
+ # top_values: (num_tasks, num_layers, topk)
50
+ top_values, top_inds = torch.max(
51
+ y_soft.view(self.num_tasks, self.num_layers, -1, topk), dim=2
52
+ )
53
+ top_inds = top_inds * topk + torch.arange(topk, device=top_inds.device).unsqueeze(0).unsqueeze(1)
54
+ top_ret = 1.0 - top_values.detach() + top_values
55
+ return top_inds.detach(), top_ret
56
+
57
+ def head_select(self, task_ids=None):
58
+ # gumbel_sample
59
+ self.head_samples = self.gumbel_sample(self.head_logits, tau=self.temp)
60
+ # head select
61
+ if self.select_strategy == "subset":
62
+ self.subset_heads, self.subset_weights = self.subset_select(
63
+ self.head_samples,
64
+ topk=self.num_heads,
65
+ )
66
+ elif self.select_strategy == "group":
67
+ self.subset_heads, self.subset_weights = self.group_selet(
68
+ self.head_samples,
69
+ topk=self.num_heads,
70
+ )
71
+ else:
72
+ raise ValueError("{} is not supported".format(self.select_strategy))
73
+
74
+ self.batch_subset = self.subset_heads[task_ids, :, :]
75
+ self.batch_weights = self.subset_weights[task_ids, :, :]
76
+
77
+ def forward(self, layer_idx):
78
+ assert layer_idx is not None
79
+ batch_subset = self.batch_subset[:, layer_idx, :]
80
+ batch_weights = self.batch_weights[:, layer_idx, :]
81
+ return batch_subset, batch_weights
fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.utils import safe_getattr
7
+ from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
8
+ from ..modules.multihead_attention_selection import MultiheadAttentionSelection
9
+
10
+
11
+ class HeadSelectionTransformerEncoderLayer(TransformerEncoderLayer):
12
+
13
+ def __init__(self, args, layer_idx, attn_head_selector=None):
14
+ super().__init__(args)
15
+ self.layer_idx = layer_idx
16
+ self.self_attn = self.build_self_attention_selection(
17
+ self.embed_dim, args, attn_head_selector
18
+ )
19
+
20
+ def build_self_attention_selection(self, embed_dim, args, attn_head_selector=None):
21
+ return MultiheadAttentionSelection(
22
+ embed_dim,
23
+ args.total_encoder_attention_heads,
24
+ args.encoder_attention_heads,
25
+ dropout=args.attention_dropout,
26
+ self_attention=True,
27
+ q_noise=self.quant_noise,
28
+ qn_block_size=self.quant_noise_block_size,
29
+ layer_idx=self.layer_idx,
30
+ attn_head_selector=attn_head_selector
31
+ )
32
+
33
+
34
+ class HeadSelectionTransformerDecoderLayer(TransformerDecoderLayer):
35
+
36
+ def __init__(
37
+ self,
38
+ args,
39
+ layer_idx,
40
+ self_attn_head_selector=None,
41
+ enc_attn_head_selector=None,
42
+ no_encoder_attn=False,
43
+ add_bias_kv=False,
44
+ add_zero_attn=False,
45
+ ):
46
+ self.layer_idx = layer_idx
47
+ super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
48
+ if self_attn_head_selector is not None:
49
+ self.self_attn = self.build_self_attention_selection(
50
+ self.embed_dim, args,
51
+ self_attn_head_selector=self_attn_head_selector,
52
+ add_bias_kv=add_bias_kv,
53
+ add_zero_attn=add_zero_attn
54
+ )
55
+ if enc_attn_head_selector is not None:
56
+ self.encoder_attn = self.build_encoder_attention_selection(
57
+ self.embed_dim, args,
58
+ enc_attn_head_selector=enc_attn_head_selector
59
+ )
60
+
61
+ def build_self_attention_selection(
62
+ self, embed_dim, args, self_attn_head_selector=None,
63
+ add_bias_kv=False, add_zero_attn=False
64
+ ):
65
+ return MultiheadAttentionSelection(
66
+ embed_dim,
67
+ args.total_decoder_attention_heads,
68
+ args.decoder_attention_heads,
69
+ dropout=args.attention_dropout,
70
+ add_bias_kv=add_bias_kv,
71
+ add_zero_attn=add_zero_attn,
72
+ self_attention=not safe_getattr(args, "cross_self_attention"),
73
+ q_noise=self.quant_noise,
74
+ qn_block_size=self.quant_noise_block_size,
75
+ layer_idx=self.layer_idx,
76
+ attn_head_selector=self_attn_head_selector,
77
+ )
78
+
79
+ def build_encoder_attention_selection(self, embed_dim, args, enc_attn_head_selector=None):
80
+ return MultiheadAttentionSelection(
81
+ embed_dim,
82
+ args.total_decoder_attention_heads,
83
+ args.decoder_attention_heads,
84
+ kdim=args.encoder_embed_dim,
85
+ vdim=args.encoder_embed_dim,
86
+ dropout=args.attention_dropout,
87
+ encoder_decoder_attention=True,
88
+ q_noise=self.quant_noise,
89
+ qn_block_size=self.quant_noise_block_size,
90
+ layer_idx=self.layer_idx,
91
+ attn_head_selector=enc_attn_head_selector,
92
+ )
fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Dict, Optional, Tuple
7
+ import torch
8
+ from fairseq import utils
9
+ from fairseq.modules.quant_noise import quant_noise
10
+ from torch import Tensor, nn
11
+ from torch.nn import Parameter
12
+
13
+ from fairseq.modules.multihead_attention import MultiheadAttention
14
+ from ..modules.multihead_functional import multi_head_attention_forward
15
+
16
+
17
+ class MultiheadAttentionSelection(MultiheadAttention):
18
+
19
+ def __init__(
20
+ self,
21
+ embed_dim,
22
+ total_num_heads,
23
+ num_heads,
24
+ kdim=None,
25
+ vdim=None,
26
+ dropout=0.0,
27
+ bias=True,
28
+ add_bias_kv=False,
29
+ add_zero_attn=False,
30
+ self_attention=False,
31
+ encoder_decoder_attention=False,
32
+ q_noise=0.0,
33
+ qn_block_size=8,
34
+ layer_idx=0,
35
+ attn_head_selector=None
36
+ ):
37
+ super().__init__(
38
+ embed_dim,
39
+ num_heads,
40
+ kdim=kdim,
41
+ vdim=vdim,
42
+ dropout=dropout,
43
+ bias=bias,
44
+ add_bias_kv=add_bias_kv,
45
+ add_zero_attn=add_zero_attn,
46
+ self_attention=self_attention,
47
+ encoder_decoder_attention=encoder_decoder_attention,
48
+ q_noise=q_noise,
49
+ qn_block_size=qn_block_size,
50
+ )
51
+ self.layer_idx = layer_idx
52
+ self.attn_head_selector = attn_head_selector
53
+ self.total_num_heads = total_num_heads
54
+ self.total_embed_dim = self.head_dim * total_num_heads
55
+ self.k_proj = quant_noise(
56
+ nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
57
+ )
58
+ self.v_proj = quant_noise(
59
+ nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
60
+ )
61
+ self.q_proj = quant_noise(
62
+ nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
63
+ )
64
+ if add_bias_kv:
65
+ self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
66
+ self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
67
+ else:
68
+ self.bias_k = self.bias_v = None
69
+ self.reset_parameters()
70
+
71
+ def forward(
72
+ self,
73
+ query,
74
+ key: Optional[Tensor],
75
+ value: Optional[Tensor],
76
+ key_padding_mask: Optional[Tensor] = None,
77
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
78
+ need_weights: bool = True,
79
+ static_kv: bool = False,
80
+ attn_mask: Optional[Tensor] = None,
81
+ before_softmax: bool = False,
82
+ need_head_weights: bool = False,
83
+ # subset_heads: Optional[Tensor] = None,
84
+ # subset_weights: Optional[Tensor] = None
85
+ ) -> Tuple[Tensor, Optional[Tensor]]:
86
+ if need_head_weights:
87
+ need_weights = True
88
+
89
+ is_tpu = query.device.type == "xla"
90
+
91
+ subset_heads, subset_weights = self.attn_head_selector(self.layer_idx)
92
+
93
+ tgt_len, bsz, embed_dim = query.size()
94
+ src_len = tgt_len
95
+ assert list(query.size()) == [tgt_len, bsz, self.embed_dim]
96
+ if key is not None:
97
+ src_len, key_bsz, _ = key.size()
98
+ if not torch.jit.is_scripting():
99
+ assert key_bsz == bsz
100
+ assert value is not None
101
+ assert src_len, bsz == value.shape[:2]
102
+
103
+ if (
104
+ not self.onnx_trace
105
+ and not is_tpu # don't use PyTorch version on TPUs
106
+ and incremental_state is None
107
+ and not static_kv
108
+ # A workaround for quantization to work. Otherwise JIT compilation
109
+ # treats bias in linear module as method.
110
+ and not torch.jit.is_scripting()
111
+ ):
112
+ assert key is not None and value is not None
113
+ return multi_head_attention_forward(
114
+ query,
115
+ key,
116
+ value,
117
+ self.embed_dim,
118
+ self.total_num_heads,
119
+ self.num_heads,
120
+ torch.empty([0]),
121
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
122
+ self.bias_k,
123
+ self.bias_v,
124
+ self.add_zero_attn,
125
+ self.dropout_module.p,
126
+ self.out_proj.weight,
127
+ self.out_proj.bias,
128
+ self.training or self.dropout_module.apply_during_inference,
129
+ key_padding_mask,
130
+ need_weights,
131
+ attn_mask,
132
+ use_separate_proj_weight=True,
133
+ q_proj_weight=self.q_proj.weight,
134
+ k_proj_weight=self.k_proj.weight,
135
+ v_proj_weight=self.v_proj.weight,
136
+ subset_heads=subset_heads,
137
+ subset_weights=subset_weights
138
+ )
139
+
140
+ if incremental_state is not None:
141
+ saved_state = self._get_input_buffer(incremental_state)
142
+ if saved_state is not None and "prev_key" in saved_state:
143
+ # previous time steps are cached - no need to recompute
144
+ # key and value if they are static
145
+ if static_kv:
146
+ assert self.encoder_decoder_attention and not self.self_attention
147
+ key = value = None
148
+ else:
149
+ saved_state = None
150
+
151
+ if self.self_attention:
152
+ q = self.q_proj(query)
153
+ k = self.k_proj(query)
154
+ v = self.v_proj(query)
155
+ elif self.encoder_decoder_attention:
156
+ # encoder-decoder attention
157
+ q = self.q_proj(query)
158
+ if key is None:
159
+ assert value is None
160
+ k = v = None
161
+ else:
162
+ k = self.k_proj(key)
163
+ v = self.v_proj(key)
164
+
165
+ else:
166
+ assert key is not None and value is not None
167
+ q = self.q_proj(query)
168
+ k = self.k_proj(key)
169
+ v = self.v_proj(value)
170
+ q *= self.scaling
171
+
172
+ if self.bias_k is not None:
173
+ assert self.bias_v is not None
174
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
175
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
176
+ if attn_mask is not None:
177
+ attn_mask = torch.cat(
178
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
179
+ )
180
+ if key_padding_mask is not None:
181
+ key_padding_mask = torch.cat(
182
+ [
183
+ key_padding_mask,
184
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
185
+ ],
186
+ dim=1,
187
+ )
188
+
189
+ q = (
190
+ q.contiguous()
191
+ .view(tgt_len, bsz * self.total_num_heads, self.head_dim)
192
+ .transpose(0, 1)
193
+ )
194
+ if k is not None:
195
+ k = (
196
+ k.contiguous()
197
+ .view(-1, bsz * self.total_num_heads, self.head_dim)
198
+ .transpose(0, 1)
199
+ )
200
+ if v is not None:
201
+ v = (
202
+ v.contiguous()
203
+ .view(-1, bsz * self.total_num_heads, self.head_dim)
204
+ .transpose(0, 1)
205
+ )
206
+
207
+ if saved_state is not None:
208
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
209
+ if "prev_key" in saved_state:
210
+ _prev_key = saved_state["prev_key"]
211
+ assert _prev_key is not None
212
+ prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim)
213
+ if static_kv:
214
+ k = prev_key
215
+ else:
216
+ assert k is not None
217
+ k = torch.cat([prev_key, k], dim=1)
218
+ src_len = k.size(1)
219
+ if "prev_value" in saved_state:
220
+ _prev_value = saved_state["prev_value"]
221
+ assert _prev_value is not None
222
+ prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim)
223
+ if static_kv:
224
+ v = prev_value
225
+ else:
226
+ assert v is not None
227
+ v = torch.cat([prev_value, v], dim=1)
228
+ prev_key_padding_mask: Optional[Tensor] = None
229
+ if "prev_key_padding_mask" in saved_state:
230
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
231
+ assert k is not None and v is not None
232
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
233
+ key_padding_mask=key_padding_mask,
234
+ prev_key_padding_mask=prev_key_padding_mask,
235
+ batch_size=bsz,
236
+ src_len=k.size(1),
237
+ static_kv=static_kv,
238
+ )
239
+
240
+ saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim)
241
+ saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim)
242
+ saved_state["prev_key_padding_mask"] = key_padding_mask
243
+ # In this branch incremental_state is never None
244
+ assert incremental_state is not None
245
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
246
+ assert k is not None
247
+ assert k.size(1) == src_len
248
+
249
+ # This is part of a workaround to get around fork/join parallelism
250
+ # not supporting Optional types.
251
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
252
+ key_padding_mask = None
253
+
254
+ if key_padding_mask is not None:
255
+ assert key_padding_mask.size(0) == bsz
256
+ assert key_padding_mask.size(1) == src_len
257
+
258
+ if self.add_zero_attn:
259
+ assert v is not None
260
+ src_len += 1
261
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
262
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
263
+ if attn_mask is not None:
264
+ attn_mask = torch.cat(
265
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
266
+ )
267
+ if key_padding_mask is not None:
268
+ key_padding_mask = torch.cat(
269
+ [
270
+ key_padding_mask,
271
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
272
+ key_padding_mask
273
+ ),
274
+ ],
275
+ dim=1,
276
+ )
277
+
278
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
279
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
280
+
281
+ assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len]
282
+
283
+ if attn_mask is not None:
284
+ attn_mask = attn_mask.unsqueeze(0)
285
+ if self.onnx_trace:
286
+ attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
287
+ attn_weights += attn_mask
288
+
289
+ if key_padding_mask is not None:
290
+ # don't attend to padding symbols
291
+ attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len)
292
+ if not is_tpu:
293
+ attn_weights = attn_weights.masked_fill(
294
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
295
+ float("-inf"),
296
+ )
297
+ else:
298
+ attn_weights = attn_weights.transpose(0, 2)
299
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
300
+ attn_weights = attn_weights.transpose(0, 2)
301
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
302
+
303
+ if before_softmax:
304
+ return attn_weights, v
305
+
306
+ attn_weights_float = utils.softmax(
307
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
308
+ )
309
+ attn_weights = attn_weights_float.type_as(attn_weights)
310
+ attn_probs = self.dropout_module(attn_weights)
311
+
312
+ assert v is not None
313
+
314
+ # evaluation
315
+ if subset_heads is not None and subset_heads.numel() == 1:
316
+ subset_heads = subset_heads.repeat(bsz)
317
+ subset_weights = subset_weights.repeat(bsz)
318
+
319
+ if subset_heads is None:
320
+ attn = torch.bmm(attn_probs, v)
321
+ else:
322
+ # training with head selection
323
+ mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim)
324
+ attn = torch.stack(
325
+ [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
326
+ )
327
+ attn = attn * subset_weights.unsqueeze(2).unsqueeze(3)
328
+ attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
329
+
330
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
331
+ if self.onnx_trace and attn.size(1) == 1:
332
+ # when ONNX tracing a single decoder step (sequence length == 1)
333
+ # the transpose is a no-op copy before view, thus unnecessary
334
+ attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
335
+ else:
336
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
337
+ attn = self.out_proj(attn)
338
+ attn_weights: Optional[Tensor] = None
339
+ if need_weights:
340
+ if subset_heads is None:
341
+ attn_weights = attn_weights_float.view(
342
+ bsz, self.num_heads, tgt_len, src_len
343
+ ).transpose(1, 0)
344
+ else:
345
+ mixed_attn_weights = attn_weights_float.view(
346
+ bsz, self.total_num_heads, tgt_len, src_len
347
+ )
348
+ attn_weights = torch.stack(
349
+ [mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
350
+ ).transpose(1, 0)
351
+ if not need_head_weights:
352
+ # average attention weights over heads
353
+ attn_weights = attn_weights.mean(dim=0)
354
+
355
+ return attn, attn_weights
fairseq/examples/attention_head_selection/src/modules/multihead_functional.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional, Tuple
7
+ import torch
8
+ from torch import Tensor
9
+ from torch.nn.functional import (
10
+ linear, softmax, dropout, pad,
11
+ has_torch_function,
12
+ handle_torch_function,
13
+ _in_projection_packed,
14
+ )
15
+ import math
16
+ import warnings
17
+
18
+
19
+ def _scaled_dot_product_attention(
20
+ q: Tensor,
21
+ k: Tensor,
22
+ v: Tensor,
23
+ attn_mask: Optional[Tensor] = None,
24
+ dropout_p: float = 0.0,
25
+ bsz: int = 1,
26
+ subset_heads: Optional[Tensor] = None,
27
+ subset_weights: Optional[Tensor] = None,
28
+ ) -> Tuple[Tensor, Tensor]:
29
+ B, Nt, E = q.shape
30
+ q = q / math.sqrt(E)
31
+ # B: bsz * total_num_heads
32
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
33
+ attn = torch.bmm(q, k.transpose(-2, -1))
34
+ if attn_mask is not None:
35
+ attn += attn_mask
36
+ attn = softmax(attn, dim=-1)
37
+ if dropout_p > 0.0:
38
+ attn = dropout(attn, p=dropout_p)
39
+ if subset_heads is None:
40
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
41
+ output = torch.bmm(attn, v)
42
+ else:
43
+ mixed_output = torch.bmm(attn, v).contiguous().view(bsz, -1, Nt, E)
44
+ output = torch.stack(
45
+ [mixed_output[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))],
46
+ dim=1
47
+ )
48
+ output = output * subset_weights.unsqueeze(2).unsqueeze(3)
49
+ output = output.contiguous().view(-1, Nt, E)
50
+ if subset_heads is not None:
51
+ _, Nt, Ns = attn.size()
52
+ mixed_attn = attn.view(bsz, -1, Nt, Ns)
53
+ attn = torch.stack(
54
+ [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
55
+ )
56
+ return output, attn
57
+
58
+
59
+ def _in_projection(
60
+ q: Tensor,
61
+ k: Tensor,
62
+ v: Tensor,
63
+ w_q: Tensor,
64
+ w_k: Tensor,
65
+ w_v: Tensor,
66
+ b_q: Optional[Tensor] = None,
67
+ b_k: Optional[Tensor] = None,
68
+ b_v: Optional[Tensor] = None,
69
+ ) -> Tuple[Tensor, Tensor, Tensor]:
70
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
71
+
72
+
73
+ def multi_head_attention_forward(
74
+ query: Tensor,
75
+ key: Tensor,
76
+ value: Tensor,
77
+ embed_dim_to_check: int,
78
+ total_num_heads: int,
79
+ num_heads: int,
80
+ in_proj_weight: Tensor,
81
+ in_proj_bias: Optional[Tensor],
82
+ bias_k: Optional[Tensor],
83
+ bias_v: Optional[Tensor],
84
+ add_zero_attn: bool,
85
+ dropout_p: float,
86
+ out_proj_weight: Tensor,
87
+ out_proj_bias: Optional[Tensor],
88
+ training: bool = True,
89
+ key_padding_mask: Optional[Tensor] = None,
90
+ need_weights: bool = True,
91
+ attn_mask: Optional[Tensor] = None,
92
+ use_separate_proj_weight: bool = False,
93
+ q_proj_weight: Optional[Tensor] = None,
94
+ k_proj_weight: Optional[Tensor] = None,
95
+ v_proj_weight: Optional[Tensor] = None,
96
+ static_k: Optional[Tensor] = None,
97
+ static_v: Optional[Tensor] = None,
98
+ subset_heads: Optional[Tensor] = None,
99
+ subset_weights: Optional[Tensor] = None,
100
+ ):
101
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
102
+ if has_torch_function(tens_ops):
103
+ return handle_torch_function(
104
+ multi_head_attention_forward,
105
+ tens_ops,
106
+ query,
107
+ key,
108
+ value,
109
+ embed_dim_to_check,
110
+ total_num_heads,
111
+ num_heads,
112
+ in_proj_weight,
113
+ in_proj_bias,
114
+ bias_k,
115
+ bias_v,
116
+ add_zero_attn,
117
+ dropout_p,
118
+ out_proj_weight,
119
+ out_proj_bias,
120
+ training=training,
121
+ key_padding_mask=key_padding_mask,
122
+ need_weights=need_weights,
123
+ attn_mask=attn_mask,
124
+ use_separate_proj_weight=use_separate_proj_weight,
125
+ q_proj_weight=q_proj_weight,
126
+ k_proj_weight=k_proj_weight,
127
+ v_proj_weight=v_proj_weight,
128
+ static_k=static_k,
129
+ static_v=static_v,
130
+ subset_heads=subset_heads,
131
+ subset_weights=subset_weights
132
+ )
133
+
134
+ # set up shape vars
135
+ tgt_len, bsz, embed_dim = query.shape
136
+ src_len, _, _ = key.shape
137
+ assert embed_dim == embed_dim_to_check, \
138
+ f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
139
+ if isinstance(embed_dim, torch.Tensor):
140
+ # embed_dim can be a tensor when JIT tracing
141
+ head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
142
+ else:
143
+ head_dim = embed_dim // num_heads
144
+ assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
145
+ if use_separate_proj_weight:
146
+ # allow MHA to have different embedding dimensions when separate projection weights are used
147
+ assert key.shape[:2] == value.shape[:2], \
148
+ f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
149
+ else:
150
+ assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
151
+
152
+ #
153
+ # compute in-projection
154
+ #
155
+ if not use_separate_proj_weight:
156
+ q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
157
+ else:
158
+ assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
159
+ assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
160
+ assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
161
+ if in_proj_bias is None:
162
+ b_q = b_k = b_v = None
163
+ else:
164
+ b_q, b_k, b_v = in_proj_bias.chunk(3)
165
+ q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
166
+
167
+ # prep attention mask
168
+ if attn_mask is not None:
169
+ if attn_mask.dtype == torch.uint8:
170
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
171
+ attn_mask = attn_mask.to(torch.bool)
172
+ else:
173
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
174
+ f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
175
+ # ensure attn_mask's dim is 3
176
+ if attn_mask.dim() == 2:
177
+ correct_2d_size = (tgt_len, src_len)
178
+ if attn_mask.shape != correct_2d_size:
179
+ raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
180
+ attn_mask = attn_mask.unsqueeze(0)
181
+ elif attn_mask.dim() == 3:
182
+ correct_3d_size = (bsz * total_num_heads, tgt_len, src_len)
183
+ if attn_mask.shape != correct_3d_size:
184
+ raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
185
+ else:
186
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
187
+
188
+ # prep key padding mask
189
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
190
+ warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
191
+ key_padding_mask = key_padding_mask.to(torch.bool)
192
+
193
+ # add bias along batch dimension (currently second)
194
+ if bias_k is not None and bias_v is not None:
195
+ assert static_k is None, "bias cannot be added to static key."
196
+ assert static_v is None, "bias cannot be added to static value."
197
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
198
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
199
+ if attn_mask is not None:
200
+ attn_mask = pad(attn_mask, (0, 1))
201
+ if key_padding_mask is not None:
202
+ key_padding_mask = pad(key_padding_mask, (0, 1))
203
+ else:
204
+ assert bias_k is None
205
+ assert bias_v is None
206
+
207
+ #
208
+ # reshape q, k, v for multihead attention and make em batch first
209
+ #
210
+ q = q.contiguous().view(tgt_len, bsz * total_num_heads, head_dim).transpose(0, 1)
211
+ if static_k is None:
212
+ k = k.contiguous().view(k.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
213
+ else:
214
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
215
+ assert static_k.size(0) == bsz * total_num_heads, \
216
+ f"expecting static_k.size(0) of {bsz * total_num_heads}, but got {static_k.size(0)}"
217
+ assert static_k.size(2) == head_dim, \
218
+ f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
219
+ k = static_k
220
+ if static_v is None:
221
+ v = v.contiguous().view(v.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
222
+ else:
223
+ # TODO finish disentangling control flow so we don't do in-projections when statics are passed
224
+ assert static_v.size(0) == bsz * total_num_heads, \
225
+ f"expecting static_v.size(0) of {bsz * total_num_heads}, but got {static_v.size(0)}"
226
+ assert static_v.size(2) == head_dim, \
227
+ f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
228
+ v = static_v
229
+
230
+ # add zero attention along batch dimension (now first)
231
+ if add_zero_attn:
232
+ zero_attn_shape = (bsz * total_num_heads, 1, head_dim)
233
+ k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
234
+ v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
235
+ if attn_mask is not None:
236
+ attn_mask = pad(attn_mask, (0, 1))
237
+ if key_padding_mask is not None:
238
+ key_padding_mask = pad(key_padding_mask, (0, 1))
239
+
240
+ # update source sequence length after adjustments
241
+ src_len = k.size(1)
242
+
243
+ # merge key padding and attention masks
244
+ if key_padding_mask is not None:
245
+ assert key_padding_mask.shape == (bsz, src_len), \
246
+ f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
247
+ key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
248
+ expand(-1, total_num_heads, -1, -1).reshape(bsz * total_num_heads, 1, src_len)
249
+ if attn_mask is None:
250
+ attn_mask = key_padding_mask
251
+ elif attn_mask.dtype == torch.bool:
252
+ attn_mask = attn_mask.logical_or(key_padding_mask)
253
+ else:
254
+ attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
255
+
256
+ # convert mask to float
257
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
258
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
259
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
260
+ attn_mask = new_attn_mask
261
+
262
+ # adjust dropout probability
263
+ if not training:
264
+ dropout_p = 0.0
265
+
266
+ #
267
+ # (deep breath) calculate attention and out projection
268
+ #
269
+ attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, bsz, subset_heads, subset_weights)
270
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
271
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
272
+
273
+ if need_weights:
274
+ # average attention weights over heads
275
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
276
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
277
+ else:
278
+ return attn_output, None