Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml +31 -0
- fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml +49 -0
- fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml +49 -0
- fairseq/examples/MMPT/projects/task/coin.yaml +25 -0
- fairseq/examples/MMPT/projects/task/coin_videoclip.yaml +7 -0
- fairseq/examples/MMPT/projects/task/test.yaml +13 -0
- fairseq/examples/MMPT/projects/task/test_vtt.yaml +19 -0
- fairseq/examples/MMPT/projects/task/test_youcook.yaml +22 -0
- fairseq/examples/MMPT/projects/task/test_youcookcap.yaml +23 -0
- fairseq/examples/MMPT/projects/task/vtt.yaml +25 -0
- fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml +12 -0
- fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml +10 -0
- fairseq/examples/MMPT/projects/task/youcook.yaml +25 -0
- fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml +9 -0
- fairseq/examples/MMPT/projects/task/youcookcap.yaml +23 -0
- fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml +5 -0
- fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py +106 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py +157 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh +8 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/model.py +58 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py +89 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py +57 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py +29 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py +64 -0
- fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py +242 -0
- fairseq/examples/MMPT/videoclip.png +3 -0
- fairseq/examples/MMPT/vlm.png +3 -0
- fairseq/examples/adaptive_span/README.md +90 -0
- fairseq/examples/adaptive_span/__init__.py +19 -0
- fairseq/examples/adaptive_span/adagrad_with_grad_clip.py +128 -0
- fairseq/examples/adaptive_span/adaptive_span_attention.py +160 -0
- fairseq/examples/adaptive_span/adaptive_span_loss.py +107 -0
- fairseq/examples/adaptive_span/adaptive_span_model.py +263 -0
- fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py +145 -0
- fairseq/examples/adaptive_span/truncated_bptt_lm_task.py +285 -0
- fairseq/examples/attention_head_selection/README.md +161 -0
- fairseq/examples/attention_head_selection/src/__init__.py +0 -0
- fairseq/examples/attention_head_selection/src/data/__init__.py +0 -0
- fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py +242 -0
- fairseq/examples/attention_head_selection/src/loss/__init__.py +0 -0
- fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py +27 -0
- fairseq/examples/attention_head_selection/src/models/__init__.py +0 -0
- fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py +170 -0
- fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py +215 -0
- fairseq/examples/attention_head_selection/src/modules/__init__.py +0 -0
- fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py +81 -0
- fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py +92 -0
- fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py +355 -0
- fairseq/examples/attention_head_selection/src/modules/multihead_functional.py +278 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fairseq/examples/MMPT/vlm.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
fairseq/examples/MMPT/videoclip.png filter=lfs diff=lfs merge=lfs -text
|
fairseq/examples/MMPT/projects/retri/videoclip/test_didemo_zs.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
slurm_config: big
|
2 |
+
task_type: local_predict
|
3 |
+
dataset:
|
4 |
+
split: test
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
aligner: DiDeMoAligner
|
7 |
+
bert_name: bert-base-uncased
|
8 |
+
meta_processor: DiDeMoMetaProcessor
|
9 |
+
test_path: data/didemo/test_data.json
|
10 |
+
vfeat_dir: data/feat/feat_didemo_s3d
|
11 |
+
text_processor: DiDeMoTextProcessor
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
dataset:
|
17 |
+
batch_size: 256
|
18 |
+
valid_subset: test
|
19 |
+
num_workers: 2
|
20 |
+
common_eval:
|
21 |
+
path: runs/retri/videoclip/checkpoint_best.pt
|
22 |
+
model:
|
23 |
+
model_cls: MMFusionSeparate
|
24 |
+
mm_encoder_cls: null
|
25 |
+
video_encoder_cls: MMBertForEncoder
|
26 |
+
text_encoder_cls: BertModel
|
27 |
+
num_hidden_video_layers: 6
|
28 |
+
eval:
|
29 |
+
save_path: runs/retri/videoclip/didemo_zs/eval
|
30 |
+
metric: DiDeMoMetric
|
31 |
+
predictor: DiDeMoPredictor
|
fairseq/examples/MMPT/projects/retri/videoclip/vttqa_videoclip.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: VideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: MSRVTTMetaProcessor
|
5 |
+
train_path: data/msrvtt/MSRVTT_train.csv
|
6 |
+
dup: 20
|
7 |
+
val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
8 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
9 |
+
text_processor: MSRVTTTextProcessor
|
10 |
+
json_path: data/msrvtt/MSRVTT_data.json
|
11 |
+
aligner: DSAligner
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 128
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 5
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/retri/videoclip/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/retri/videoclip/vttqa
|
41 |
+
task_type: sweep_small
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionSeparate
|
44 |
+
mm_encoder_cls: null
|
45 |
+
video_encoder_cls: MMBertForEncoder
|
46 |
+
text_encoder_cls: BertModel
|
47 |
+
num_hidden_video_layers: 6
|
48 |
+
loss:
|
49 |
+
loss_cls: V2TContraLoss
|
fairseq/examples/MMPT/projects/retri/videoclip/youcook_videoclip.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
video_processor: YoucookVideoProcessor
|
3 |
+
bert_name: bert-base-uncased
|
4 |
+
meta_processor: YoucookMetaProcessor
|
5 |
+
train_path: data/youcook/youcook_train.pkl
|
6 |
+
val_path: data/youcook/youcook_val.pkl
|
7 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
8 |
+
use_annotation_text: true
|
9 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
10 |
+
text_processor: TextProcessor
|
11 |
+
aligner: DSAligner
|
12 |
+
num_iso_layer: 12
|
13 |
+
max_video_len: 32
|
14 |
+
max_len: 96
|
15 |
+
fairseq:
|
16 |
+
common:
|
17 |
+
tensorboard_logdir: run
|
18 |
+
log_interval: 1000
|
19 |
+
fp16: true
|
20 |
+
dataset:
|
21 |
+
num_workers: 4
|
22 |
+
batch_size: 128
|
23 |
+
optimization:
|
24 |
+
lr:
|
25 |
+
- 5.0e-05
|
26 |
+
clip_norm: 2.0
|
27 |
+
optimizer: adam
|
28 |
+
adam_betas: (0.9, 0.98)
|
29 |
+
lr_scheduler: polynomial_decay
|
30 |
+
total_num_update: 1000000
|
31 |
+
warmup_updates: 122
|
32 |
+
weight_decay: 0.0
|
33 |
+
ddp_backend: no_c10d
|
34 |
+
max_epoch: 10
|
35 |
+
checkpoint:
|
36 |
+
restore_file: runs/retri/videoclip/checkpoint_best.pt
|
37 |
+
reset_optimizer: true
|
38 |
+
reset_dataloader: true
|
39 |
+
reset_meters: true
|
40 |
+
save_dir: runs/retri/videoclip/youcook
|
41 |
+
task_type: sweep_small
|
42 |
+
model:
|
43 |
+
model_cls: MMFusionSeparate
|
44 |
+
mm_encoder_cls: null
|
45 |
+
video_encoder_cls: MMBertForEncoder
|
46 |
+
text_encoder_cls: BertModel
|
47 |
+
num_hidden_video_layers: 6
|
48 |
+
loss:
|
49 |
+
loss_cls: T2VContraLoss
|
fairseq/examples/MMPT/projects/task/coin.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/ft.yaml
|
2 |
+
task_type: sweep_big
|
3 |
+
dataset:
|
4 |
+
meta_processor: COINActionSegmentationMetaProcessor
|
5 |
+
train_path: data/coin/COIN.json
|
6 |
+
val_path: data/coin/COIN.json
|
7 |
+
vfeat_dir: data/feat/feat_coin_s3d
|
8 |
+
video_processor: VideoProcessor
|
9 |
+
text_processor: COINActionSegmentationTextProcessor
|
10 |
+
aligner: COINActionSegmentationAligner
|
11 |
+
num_iso_layer: 12
|
12 |
+
sliding_window: 8
|
13 |
+
sliding_window_size: 32
|
14 |
+
model:
|
15 |
+
model_cls: MMFusionActionSegmentation
|
16 |
+
mm_encoder_cls: MMBertForTokenClassification
|
17 |
+
loss:
|
18 |
+
loss_cls: CrossEntropy
|
19 |
+
fairseq:
|
20 |
+
dataset:
|
21 |
+
batch_size: 1
|
22 |
+
optimization:
|
23 |
+
max_epoch: 8
|
24 |
+
checkpoint:
|
25 |
+
save_dir: runs/task/coin
|
fairseq/examples/MMPT/projects/task/coin_videoclip.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/coin.yaml
|
2 |
+
model:
|
3 |
+
model_cls: MMFusionSeparateActionSegmentation
|
4 |
+
mm_encoder_cls:
|
5 |
+
video_encoder_cls: MMBertForTokenClassification
|
6 |
+
text_encoder_cls: BertModel # dummy, not used.
|
7 |
+
num_hidden_video_layers: 6
|
fairseq/examples/MMPT/projects/task/test.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this yaml cannot be run alone: implement a test_${dataset}.yaml
|
2 |
+
slurm_config: big
|
3 |
+
task_type: local_predict
|
4 |
+
dataset:
|
5 |
+
split: test
|
6 |
+
video_processor: VideoProcessor
|
7 |
+
aligner: DSAligner
|
8 |
+
bert_name: bert-base-uncased
|
9 |
+
fairseq:
|
10 |
+
dataset:
|
11 |
+
batch_size: 256
|
12 |
+
valid_subset: test
|
13 |
+
num_workers: 2
|
fairseq/examples/MMPT/projects/task/test_vtt.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/test.yaml
|
2 |
+
dataset:
|
3 |
+
meta_processor: MSRVTTMetaProcessor
|
4 |
+
test_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
5 |
+
video_processor: VideoProcessor
|
6 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
7 |
+
text_processor: MSRVTTTextProcessor
|
8 |
+
num_iso_layer: 12
|
9 |
+
model:
|
10 |
+
model_cls: MMFusionJoint
|
11 |
+
mm_encoder_cls: MMBertForJoint
|
12 |
+
eval:
|
13 |
+
save_path: runs/task/vtt/eval
|
14 |
+
fairseq:
|
15 |
+
# read code and find what is the checkpoint arg.
|
16 |
+
common_eval:
|
17 |
+
path: runs/task/vtt/checkpoint_last.pt
|
18 |
+
metric: RetrievalMetric
|
19 |
+
predictor: RetrievalPredictor
|
fairseq/examples/MMPT/projects/task/test_youcook.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/test.yaml
|
2 |
+
dataset:
|
3 |
+
meta_processor: YoucookMetaProcessor
|
4 |
+
test_path: data/youcook/youcook_val.pkl
|
5 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
6 |
+
use_annotation_text: True
|
7 |
+
video_processor: YoucookVideoProcessor
|
8 |
+
vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
|
9 |
+
text_processor: TextProcessor
|
10 |
+
aligner: DSAligner
|
11 |
+
num_iso_layer: 12
|
12 |
+
model:
|
13 |
+
model_cls: MMFusionJoint
|
14 |
+
mm_encoder_cls: MMBertForJoint
|
15 |
+
eval:
|
16 |
+
save_path: runs/task/youcook/eval
|
17 |
+
fairseq:
|
18 |
+
# read code and find what is the checkpoint arg.
|
19 |
+
common_eval:
|
20 |
+
path: runs/task/youcook/checkpoint_last.pt
|
21 |
+
metric: RetrievalMetric
|
22 |
+
predictor: RetrievalPredictor
|
fairseq/examples/MMPT/projects/task/test_youcookcap.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/test.yaml
|
2 |
+
dataset:
|
3 |
+
meta_processor: YoucookNLGMetaProcessor
|
4 |
+
test_path: data/youcook/val_list.txt
|
5 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
6 |
+
video_processor: YoucookVideoProcessor
|
7 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
8 |
+
text_processor: NLGTextProcessor
|
9 |
+
aligner: DSNLGAligner
|
10 |
+
model:
|
11 |
+
model_cls: MMFusionNLG
|
12 |
+
mm_encoder_cls: MMBertForNLG
|
13 |
+
max_decode_length: 24
|
14 |
+
eval:
|
15 |
+
save_path: runs/task/youcookcap/eval
|
16 |
+
fairseq:
|
17 |
+
# read code and find what is the checkpoint arg.
|
18 |
+
common_eval:
|
19 |
+
path: runs/task/youcookcap/checkpoint_best.pt
|
20 |
+
metric: NLGMetric
|
21 |
+
predictor: NLGPredictor
|
22 |
+
gen_param:
|
23 |
+
num_beams: 5
|
fairseq/examples/MMPT/projects/task/vtt.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/ft.yaml
|
2 |
+
dataset:
|
3 |
+
meta_processor: MSRVTTMetaProcessor
|
4 |
+
train_path: data/msrvtt/MSRVTT_train.csv
|
5 |
+
jsfusion_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
6 |
+
full_test_path: data/msrvtt/MSRVTT_FULL_test.csv
|
7 |
+
dup: 20
|
8 |
+
val_path: data/msrvtt/MSRVTT_JSFUSION_test.csv
|
9 |
+
vfeat_dir: data/feat/feat_vtt_s3d
|
10 |
+
text_processor: MSRVTTTextProcessor
|
11 |
+
json_path: data/msrvtt/MSRVTT_data.json
|
12 |
+
aligner: DSAligner
|
13 |
+
num_iso_layer: 12
|
14 |
+
model:
|
15 |
+
model_cls: MMFusionJoint
|
16 |
+
mm_encoder_cls: MMBertForJoint
|
17 |
+
loss:
|
18 |
+
loss_cls: T2VContraLoss
|
19 |
+
fairseq:
|
20 |
+
dataset:
|
21 |
+
batch_size: 256
|
22 |
+
optimization:
|
23 |
+
max_epoch: 10
|
24 |
+
checkpoint:
|
25 |
+
save_dir: runs/task/vtt
|
fairseq/examples/MMPT/projects/task/vtt_videoclip.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/vtt.yaml
|
2 |
+
model:
|
3 |
+
model_cls: MMFusionSeparate
|
4 |
+
mm_encoder_cls:
|
5 |
+
video_encoder_cls: MMBertForEncoder
|
6 |
+
text_encoder_cls: BertModel
|
7 |
+
num_hidden_video_layers: 6
|
8 |
+
fairseq:
|
9 |
+
dataset:
|
10 |
+
batch_size: 224
|
11 |
+
# model_cls: MMFusionShare
|
12 |
+
# mm_encoder_cls: MMBertForEncoder
|
fairseq/examples/MMPT/projects/task/vttqa_videoclip.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/vttqa.yaml
|
2 |
+
model:
|
3 |
+
model_cls: MMFusionSeparate
|
4 |
+
mm_encoder_cls:
|
5 |
+
video_encoder_cls: MMBertForEncoder
|
6 |
+
text_encoder_cls: BertModel
|
7 |
+
num_hidden_video_layers: 6
|
8 |
+
|
9 |
+
# model_cls: MMFusionShare
|
10 |
+
# mm_encoder_cls: MMBertForEncoder
|
fairseq/examples/MMPT/projects/task/youcook.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/ft.yaml
|
2 |
+
dataset:
|
3 |
+
meta_processor: YoucookMetaProcessor
|
4 |
+
train_path: data/youcook/youcook_train.pkl
|
5 |
+
val_path: data/youcook/youcook_val.pkl
|
6 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
7 |
+
use_annotation_text: True
|
8 |
+
video_processor: YoucookVideoProcessor
|
9 |
+
vfeat_dir: data/feat/feat_youcook_s3d # /checkpoint/huxu/feat/youcook_vmz # /checkpoint/prarora/berniehuang/feat_youcook_vmz
|
10 |
+
text_processor: TextProcessor
|
11 |
+
aligner: DSAligner
|
12 |
+
num_iso_layer: 12
|
13 |
+
model:
|
14 |
+
model_cls: MMFusionJoint
|
15 |
+
mm_encoder_cls: MMBertForJoint
|
16 |
+
loss:
|
17 |
+
loss_cls: T2VContraLoss
|
18 |
+
fairseq:
|
19 |
+
dataset:
|
20 |
+
batch_size: 128
|
21 |
+
optimization:
|
22 |
+
max_epoch: 10
|
23 |
+
checkpoint:
|
24 |
+
save_dir: runs/task/youcook
|
25 |
+
|
fairseq/examples/MMPT/projects/task/youcook_videoclip.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
includes: projects/task/youcook.yaml
|
2 |
+
model:
|
3 |
+
model_cls: MMFusionSeparate
|
4 |
+
mm_encoder_cls:
|
5 |
+
video_encoder_cls: MMBertForEncoder
|
6 |
+
text_encoder_cls: BertModel
|
7 |
+
num_hidden_video_layers: 6
|
8 |
+
# model_cls: MMFusionShare
|
9 |
+
# mm_encoder_cls: MMBertForEncoder
|
fairseq/examples/MMPT/projects/task/youcookcap.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# finetuning for youcook captioning.
|
2 |
+
includes: projects/task/ft.yaml
|
3 |
+
dataset:
|
4 |
+
meta_processor: YoucookNLGMetaProcessor
|
5 |
+
train_path: data/youcook/train_list.txt
|
6 |
+
val_path: data/youcook/val_list.txt
|
7 |
+
trainval_annotation: data/youcook/youcookii_annotations_trainval.json
|
8 |
+
video_processor: YoucookVideoProcessor
|
9 |
+
vfeat_dir: data/feat/feat_youcook_s3d
|
10 |
+
text_processor: NLGTextProcessor
|
11 |
+
aligner: DSNLGAligner
|
12 |
+
model:
|
13 |
+
model_cls: MMFusionNLG
|
14 |
+
mm_encoder_cls: MMBertForNLG
|
15 |
+
loss:
|
16 |
+
loss_cls: NLGLoss
|
17 |
+
fairseq:
|
18 |
+
dataset:
|
19 |
+
batch_size: 128
|
20 |
+
optimization:
|
21 |
+
max_epoch: 10
|
22 |
+
checkpoint:
|
23 |
+
save_dir: runs/task/youcookcap
|
fairseq/examples/MMPT/scripts/text_token_extractor/configs/bert-base-uncased.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
bert_name: bert-base-uncased
|
3 |
+
caption_pkl_path: data/how2/raw_caption_dedup.pkl
|
4 |
+
use_fast: true
|
5 |
+
target_dir: data/feat/feat_how2_s3d_shard_small
|
fairseq/examples/MMPT/scripts/text_token_extractor/pretokenization.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import pickle
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
from mmpt.processors import PKLJSONStrTextProcessor
|
13 |
+
from mmpt.utils import ShardedTensor, recursive_config
|
14 |
+
|
15 |
+
|
16 |
+
class TokenizerDataset(Dataset):
|
17 |
+
def __init__(self, config):
|
18 |
+
self.text_processor = PKLJSONStrTextProcessor(config)
|
19 |
+
self.video_ids = list(self.text_processor.data.keys())
|
20 |
+
|
21 |
+
def __getitem__(self, idx):
|
22 |
+
video_id = self.video_ids[idx]
|
23 |
+
return video_id, self.text_processor(video_id)
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.video_ids)
|
27 |
+
|
28 |
+
|
29 |
+
def numpify(shard_idx, video_ids, captions, target_dir, split, prefix, max_cap_len=32):
|
30 |
+
startends = []
|
31 |
+
caps_ids = []
|
32 |
+
for video_id in video_ids:
|
33 |
+
caption = captions[video_id]
|
34 |
+
startend = []
|
35 |
+
cap_ids = []
|
36 |
+
for start, end, cap in zip(
|
37 |
+
caption["start"], caption["end"], caption["cap"]):
|
38 |
+
startend.append(np.array([start, end]).astype("float32"))
|
39 |
+
cap_id = np.full((max_cap_len,), -1, dtype=np.int32)
|
40 |
+
cap = cap[:max_cap_len]
|
41 |
+
cap_id[:len(cap)] = cap
|
42 |
+
cap_ids.append(cap_id)
|
43 |
+
startends.append(np.stack(startend))
|
44 |
+
caps_ids.append(np.stack(cap_ids))
|
45 |
+
|
46 |
+
startends = ShardedTensor.from_list(startends)
|
47 |
+
target_path = os.path.join(
|
48 |
+
target_dir,
|
49 |
+
prefix + split + "_" + str(shard_idx)
|
50 |
+
)
|
51 |
+
print("save to", target_path)
|
52 |
+
startends.save(target_path + ".startends")
|
53 |
+
caps_ids = ShardedTensor.from_list(caps_ids)
|
54 |
+
caps_ids.save(target_path + ".caps_ids")
|
55 |
+
|
56 |
+
|
57 |
+
def sharding(config, out_file):
|
58 |
+
with open(out_file, "rb") as fr:
|
59 |
+
captions = pickle.load(fr)
|
60 |
+
target_dir = config.target_dir
|
61 |
+
prefix = os.path.basename(
|
62 |
+
os.path.splitext(config.caption_pkl_path)[0]
|
63 |
+
) + "." + config.bert_name + "."
|
64 |
+
for split in ["train", "val"]:
|
65 |
+
target_path = os.path.join(target_dir, split + "_meta")
|
66 |
+
with open(target_path + ".pkl", "rb") as fr:
|
67 |
+
meta = pickle.load(fr)
|
68 |
+
print("load meta", target_path, len(meta))
|
69 |
+
for shard_id in meta:
|
70 |
+
numpify(
|
71 |
+
shard_id, meta[shard_id], captions,
|
72 |
+
target_dir, split, prefix
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def tokenize(config, out_file):
|
77 |
+
def collator(samples):
|
78 |
+
return samples
|
79 |
+
dataset = TokenizerDataset(config)
|
80 |
+
data = {}
|
81 |
+
for idx, batch in enumerate(
|
82 |
+
DataLoader(dataset, collate_fn=collator, num_workers=16)):
|
83 |
+
for video_id, caption in batch:
|
84 |
+
data[video_id] = caption
|
85 |
+
if idx % 5000 == 0:
|
86 |
+
print(idx)
|
87 |
+
with open(out_file, "wb") as fw:
|
88 |
+
pickle.dump(data, fw, pickle.HIGHEST_PROTOCOL)
|
89 |
+
|
90 |
+
|
91 |
+
def main(args):
|
92 |
+
config = recursive_config(args.config).dataset
|
93 |
+
|
94 |
+
out_file = os.path.splitext(config.caption_pkl_path)[0] \
|
95 |
+
+ "." + config.bert_name + ".pkl"
|
96 |
+
if not os.path.isfile(out_file):
|
97 |
+
tokenize(config, out_file)
|
98 |
+
sharding(config, out_file)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
parser = argparse.ArgumentParser(
|
103 |
+
description="pretokenize (raw_)caption.json into pkl.")
|
104 |
+
parser.add_argument('config', type=str)
|
105 |
+
args = parser.parse_args()
|
106 |
+
main(args)
|
fairseq/examples/MMPT/scripts/video_feature_extractor/extract.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Howto100M authors.
|
2 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import torch as th
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
from model import get_model
|
12 |
+
from preprocessing import Preprocessing
|
13 |
+
from random_sequence_shuffler import RandomSequenceSampler
|
14 |
+
|
15 |
+
from tqdm import tqdm
|
16 |
+
from pathbuilder import PathBuilder
|
17 |
+
from videoreader import VideoLoader
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser(description='Easy video feature extractor')
|
21 |
+
|
22 |
+
parser.add_argument('--vdir', type=str)
|
23 |
+
parser.add_argument('--fdir', type=str)
|
24 |
+
parser.add_argument('--hflip', type=int, default=0)
|
25 |
+
|
26 |
+
parser.add_argument('--batch_size', type=int, default=64,
|
27 |
+
help='batch size')
|
28 |
+
parser.add_argument('--type', type=str, default='2d',
|
29 |
+
help='CNN type')
|
30 |
+
parser.add_argument('--half_precision', type=int, default=0,
|
31 |
+
help='output half precision float')
|
32 |
+
parser.add_argument('--num_decoding_thread', type=int, default=4,
|
33 |
+
help='Num parallel thread for video decoding')
|
34 |
+
parser.add_argument('--l2_normalize', type=int, default=1,
|
35 |
+
help='l2 normalize feature')
|
36 |
+
parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth',
|
37 |
+
help='Resnext model path')
|
38 |
+
parser.add_argument('--vmz_model_path', type=str, default='model/r2plus1d_34_clip8_ig65m_from_scratch-9bae36ae.pth',
|
39 |
+
help='vmz model path')
|
40 |
+
|
41 |
+
args = parser.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
# TODO: refactor all args into config. (current code is from different people.)
|
45 |
+
CONFIGS = {
|
46 |
+
"2d": {
|
47 |
+
"fps": 1,
|
48 |
+
"size": 224,
|
49 |
+
"centercrop": False,
|
50 |
+
"shards": 0,
|
51 |
+
},
|
52 |
+
"3d": {
|
53 |
+
"fps": 24,
|
54 |
+
"size": 112,
|
55 |
+
"centercrop": True,
|
56 |
+
"shards": 0,
|
57 |
+
},
|
58 |
+
"s3d": {
|
59 |
+
"fps": 30,
|
60 |
+
"size": 224,
|
61 |
+
"centercrop": True,
|
62 |
+
"shards": 0,
|
63 |
+
},
|
64 |
+
"vmz": {
|
65 |
+
"fps": 24,
|
66 |
+
"size": 112,
|
67 |
+
"centercrop": True,
|
68 |
+
"shards": 0,
|
69 |
+
},
|
70 |
+
"vae": {
|
71 |
+
"fps": 2,
|
72 |
+
"size": 256,
|
73 |
+
"centercrop": True,
|
74 |
+
"shards": 100,
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
config = CONFIGS[args.type]
|
79 |
+
|
80 |
+
|
81 |
+
video_dirs = args.vdir
|
82 |
+
feature_dir = args.fdir
|
83 |
+
|
84 |
+
video_dict = PathBuilder.build(video_dirs, feature_dir, ".npy", config["shards"])
|
85 |
+
|
86 |
+
dataset = VideoLoader(
|
87 |
+
video_dict=video_dict,
|
88 |
+
framerate=config["fps"],
|
89 |
+
size=config["size"],
|
90 |
+
centercrop=config["centercrop"],
|
91 |
+
hflip=args.hflip
|
92 |
+
)
|
93 |
+
n_dataset = len(dataset)
|
94 |
+
sampler = RandomSequenceSampler(n_dataset, 10)
|
95 |
+
loader = DataLoader(
|
96 |
+
dataset,
|
97 |
+
batch_size=1,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=args.num_decoding_thread,
|
100 |
+
sampler=sampler if n_dataset > 10 else None,
|
101 |
+
)
|
102 |
+
preprocess = Preprocessing(args.type)
|
103 |
+
model = get_model(args)
|
104 |
+
|
105 |
+
with th.no_grad():
|
106 |
+
for k, data in tqdm(enumerate(loader), total=loader.__len__(), ascii=True):
|
107 |
+
input_file = data['input'][0]
|
108 |
+
output_file = data['output'][0]
|
109 |
+
if len(data['video'].shape) > 3:
|
110 |
+
video = data['video'].squeeze()
|
111 |
+
if len(video.shape) == 4:
|
112 |
+
video = preprocess(video)
|
113 |
+
n_chunk = len(video)
|
114 |
+
if args.type == 'vmz':
|
115 |
+
n_chunk = math.ceil(n_chunk/float(3))
|
116 |
+
features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
|
117 |
+
elif args.type == 's3d':
|
118 |
+
features = th.cuda.FloatTensor(n_chunk, 512).fill_(0)
|
119 |
+
elif args.type == "vae":
|
120 |
+
features = th.cuda.LongTensor(n_chunk, 1024).fill_(0)
|
121 |
+
else:
|
122 |
+
features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
|
123 |
+
n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
|
124 |
+
for i in range(n_iter):
|
125 |
+
factor = 1
|
126 |
+
if args.type == 'vmz':
|
127 |
+
factor = 3
|
128 |
+
min_ind = factor * i * args.batch_size
|
129 |
+
max_ind = factor * (i + 1) * args.batch_size
|
130 |
+
video_batch = video[min_ind:max_ind:factor].cuda()
|
131 |
+
if args.type == '2d':
|
132 |
+
batch_features = model(video_batch) # (51, 487), (51, 512)
|
133 |
+
elif args.type == 's3d':
|
134 |
+
batch_features = model(video_batch)
|
135 |
+
batch_features = batch_features['video_embedding']
|
136 |
+
elif args.type == "vae":
|
137 |
+
# image_code.
|
138 |
+
batch_features = model(video_batch)
|
139 |
+
else:
|
140 |
+
batch_pred, batch_features = model(video_batch) # (51, 487), (51, 512)
|
141 |
+
if args.l2_normalize:
|
142 |
+
batch_features = F.normalize(batch_features, dim=1)
|
143 |
+
features[i*args.batch_size:(i+1)*args.batch_size] = batch_features
|
144 |
+
features = features.cpu().numpy()
|
145 |
+
if args.half_precision:
|
146 |
+
if args.type == "vae":
|
147 |
+
features = features.astype(np.int16)
|
148 |
+
else:
|
149 |
+
features = features.astype('float16')
|
150 |
+
else:
|
151 |
+
if args.type == "vae":
|
152 |
+
features = features.astype(np.int32)
|
153 |
+
else:
|
154 |
+
features = features.astype('float32')
|
155 |
+
np.save(output_file, features)
|
156 |
+
else:
|
157 |
+
print('Video {} error.'.format(input_file))
|
fairseq/examples/MMPT/scripts/video_feature_extractor/how2/s3d.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
|
4 |
+
python scripts/video_feature_extractor/extract.py \
|
5 |
+
--vdir <path_to_video_folder> \
|
6 |
+
--fdir data/feat/feat_how2_s3d \
|
7 |
+
--type=s3d --num_decoding_thread=4 \
|
8 |
+
--batch_size 32 --half_precision 1
|
fairseq/examples/MMPT/scripts/video_feature_extractor/model.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Howto100M authors and Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import torch as th
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class GlobalAvgPool(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super(GlobalAvgPool, self).__init__()
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return th.mean(x, dim=[-2, -1])
|
14 |
+
|
15 |
+
|
16 |
+
def get_model(args):
|
17 |
+
assert args.type in ['2d', '3d', 'vmz', 's3d', 'vae']
|
18 |
+
if args.type == '2d':
|
19 |
+
print('Loading 2D-ResNet-152 ...')
|
20 |
+
import torchvision.models as models
|
21 |
+
model = models.resnet152(pretrained=True)
|
22 |
+
model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
|
23 |
+
model = model.cuda()
|
24 |
+
elif args.type == 'vmz':
|
25 |
+
print('Loading VMZ ...')
|
26 |
+
from vmz34 import r2plus1d_34
|
27 |
+
model = r2plus1d_34(pretrained_path=args.vmz_model_path, pretrained_num_classes=487)
|
28 |
+
model = model.cuda()
|
29 |
+
elif args.type == 's3d':
|
30 |
+
# we use one copy of s3d instead of dup another one for feature extraction.
|
31 |
+
from mmpt.processors.models.s3dg import S3D
|
32 |
+
model = S3D('pretrained_models/s3d_dict.npy', 512)
|
33 |
+
model.load_state_dict(th.load('pretrained_models/s3d_howto100m.pth'))
|
34 |
+
model = model.cuda()
|
35 |
+
|
36 |
+
elif args.type == '3d':
|
37 |
+
print('Loading 3D-ResneXt-101 ...')
|
38 |
+
from videocnn.models import resnext
|
39 |
+
model = resnext.resnet101(
|
40 |
+
num_classes=400,
|
41 |
+
shortcut_type='B',
|
42 |
+
cardinality=32,
|
43 |
+
sample_size=112,
|
44 |
+
sample_duration=16,
|
45 |
+
last_fc=False)
|
46 |
+
model = model.cuda()
|
47 |
+
model_data = th.load(args.resnext101_model_path)
|
48 |
+
model.load_state_dict(model_data)
|
49 |
+
elif args.type == 'vae':
|
50 |
+
from openaivae import OpenAIParallelDiscreteVAE
|
51 |
+
model = OpenAIParallelDiscreteVAE()
|
52 |
+
model = model.cuda()
|
53 |
+
else:
|
54 |
+
raise ValueError("model not supported yet.")
|
55 |
+
|
56 |
+
model.eval()
|
57 |
+
print('loaded')
|
58 |
+
return model
|
fairseq/examples/MMPT/scripts/video_feature_extractor/pathbuilder.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import os
|
6 |
+
import urllib.parse
|
7 |
+
import json
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
# TODO: extending to other datasets.
|
14 |
+
supported_formats = {}
|
15 |
+
|
16 |
+
|
17 |
+
class PathBuilder(object):
|
18 |
+
@classmethod
|
19 |
+
def build(cls, video_dirs, feature_dir, ext, shards=0, split=None):
|
20 |
+
meta_fn = os.path.join(feature_dir, "meta_plan.json")
|
21 |
+
os.makedirs(feature_dir, exist_ok=True)
|
22 |
+
if os.path.isfile(meta_fn):
|
23 |
+
with open(meta_fn) as fr:
|
24 |
+
meta = json.load(fr)
|
25 |
+
return meta
|
26 |
+
print("searching videos...")
|
27 |
+
|
28 |
+
video_id_to_path = {}
|
29 |
+
for video_dir in video_dirs.split(","):
|
30 |
+
# TODO: add supports of recursive listdir.
|
31 |
+
if video_dir in supported_formats:
|
32 |
+
supported_formats[video_dir].load(video_dir, video_id_to_path)
|
33 |
+
else:
|
34 |
+
for idx, fn in enumerate(tqdm(os.listdir(video_dir))):
|
35 |
+
video_fn = os.path.join(video_dir, fn)
|
36 |
+
if os.path.isfile(video_fn):
|
37 |
+
video_id = os.path.splitext(fn)[0]
|
38 |
+
video_id_to_path[video_id] = video_fn
|
39 |
+
elif os.path.isdir(video_fn):
|
40 |
+
# shards of folders.
|
41 |
+
shard_dir = video_fn
|
42 |
+
for idx, fn in enumerate(os.listdir(shard_dir)):
|
43 |
+
video_fn = os.path.join(shard_dir, fn)
|
44 |
+
if os.path.isfile(video_fn):
|
45 |
+
video_id = os.path.splitext(fn)[0]
|
46 |
+
video_id_to_path[video_id] = video_fn
|
47 |
+
|
48 |
+
video_path, feature_path = [], []
|
49 |
+
valid_ext = set()
|
50 |
+
for idx, video_id in enumerate(video_id_to_path):
|
51 |
+
video_path.append(video_id_to_path[video_id])
|
52 |
+
if ext is None:
|
53 |
+
# use original file ext for format compatibility.
|
54 |
+
video_id_to_path[video_id]
|
55 |
+
path = urllib.parse.urlparse(video_id_to_path[video_id]).path
|
56 |
+
ext = os.path.splitext(path)[1]
|
57 |
+
if ext not in valid_ext:
|
58 |
+
valid_ext.add(ext)
|
59 |
+
print("adding", ext)
|
60 |
+
if shards:
|
61 |
+
shard_id = str(idx % shards)
|
62 |
+
feature_fn = os.path.join(
|
63 |
+
feature_dir, shard_id, video_id + ext)
|
64 |
+
else:
|
65 |
+
feature_fn = os.path.join(
|
66 |
+
feature_dir, video_id + ext)
|
67 |
+
feature_path.append(feature_fn)
|
68 |
+
|
69 |
+
print("targeting", len(feature_path), "videos")
|
70 |
+
meta = {
|
71 |
+
"video_path": video_path, "feature_path": feature_path}
|
72 |
+
with open(meta_fn, "w") as fw:
|
73 |
+
json.dump(meta, fw)
|
74 |
+
|
75 |
+
if split is not None:
|
76 |
+
splits = split.split("/")
|
77 |
+
assert len(splits) == 2
|
78 |
+
cur, total = int(splits[0]), int(splits[1])
|
79 |
+
assert cur < total
|
80 |
+
import math
|
81 |
+
chunk = math.ceil(len(meta["video_path"]) / total)
|
82 |
+
start = cur * chunk
|
83 |
+
end = (cur + 1) * chunk
|
84 |
+
meta = {
|
85 |
+
"video_path": meta["video_path"][start:end],
|
86 |
+
"feature_path": meta["feature_path"][start:end]
|
87 |
+
}
|
88 |
+
|
89 |
+
return meta
|
fairseq/examples/MMPT/scripts/video_feature_extractor/preprocessing.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Howto100m authors.
|
2 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import torch as th
|
5 |
+
|
6 |
+
class Normalize(object):
|
7 |
+
|
8 |
+
def __init__(self, mean, std):
|
9 |
+
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
|
10 |
+
self.std = th.FloatTensor(std).view(1, 3, 1, 1)
|
11 |
+
|
12 |
+
def __call__(self, tensor):
|
13 |
+
tensor = (tensor - self.mean) / (self.std + 1e-8)
|
14 |
+
return tensor
|
15 |
+
|
16 |
+
class Preprocessing(object):
|
17 |
+
|
18 |
+
def __init__(self, type):
|
19 |
+
self.type = type
|
20 |
+
if type == '2d':
|
21 |
+
self.norm = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
22 |
+
elif type == '3d':
|
23 |
+
self.norm = Normalize(mean=[110.6, 103.2, 96.3], std=[1.0, 1.0, 1.0])
|
24 |
+
elif type == 'vmz':
|
25 |
+
self.norm = Normalize(mean=[110.201, 100.64, 95.997], std=[58.1489, 56.4701, 55.3324])
|
26 |
+
|
27 |
+
def _zero_pad(self, tensor, size):
|
28 |
+
n = size - len(tensor) % size
|
29 |
+
if n == size:
|
30 |
+
return tensor
|
31 |
+
else:
|
32 |
+
z = th.zeros(n, tensor.shape[1], tensor.shape[2], tensor.shape[3])
|
33 |
+
return th.cat((tensor, z), 0)
|
34 |
+
|
35 |
+
def __call__(self, tensor):
|
36 |
+
if self.type == '2d':
|
37 |
+
tensor = tensor / 255.0
|
38 |
+
tensor = self.norm(tensor)
|
39 |
+
elif self.type == 'vmz':
|
40 |
+
#tensor = self._zero_pad(tensor, 8)
|
41 |
+
tensor = self._zero_pad(tensor, 10)
|
42 |
+
tensor = self.norm(tensor)
|
43 |
+
#tensor = tensor.view(-1, 8, 3, 112, 112)
|
44 |
+
tensor = tensor.view(-1, 10, 3, 112, 112)
|
45 |
+
tensor = tensor.transpose(1, 2)
|
46 |
+
elif self.type == '3d':
|
47 |
+
tensor = self._zero_pad(tensor, 16)
|
48 |
+
tensor = self.norm(tensor)
|
49 |
+
tensor = tensor.view(-1, 16, 3, 112, 112)
|
50 |
+
tensor = tensor.transpose(1, 2)
|
51 |
+
elif self.type == 's3d':
|
52 |
+
tensor = tensor / 255.0
|
53 |
+
tensor = self._zero_pad(tensor, 30)
|
54 |
+
tensor = tensor.view(-1, 30, 3, 224, 224) # N x 30 x 3 x H x W
|
55 |
+
tensor = tensor.transpose(1, 2) # N x 3 x 30 x H x W
|
56 |
+
# for vae do nothing
|
57 |
+
return tensor
|
fairseq/examples/MMPT/scripts/video_feature_extractor/random_sequence_shuffler.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from torch.utils.data.sampler import Sampler
|
6 |
+
|
7 |
+
|
8 |
+
class RandomSequenceSampler(Sampler):
|
9 |
+
|
10 |
+
def __init__(self, n_sample, seq_len):
|
11 |
+
self.n_sample = n_sample
|
12 |
+
self.seq_len = seq_len
|
13 |
+
|
14 |
+
def _pad_ind(self, ind):
|
15 |
+
zeros = np.zeros(self.seq_len - self.n_sample % self.seq_len)
|
16 |
+
ind = np.concatenate((ind, zeros))
|
17 |
+
return ind
|
18 |
+
|
19 |
+
def __iter__(self):
|
20 |
+
idx = np.arange(self.n_sample)
|
21 |
+
if self.n_sample % self.seq_len != 0:
|
22 |
+
idx = self._pad_ind(idx)
|
23 |
+
idx = np.reshape(idx, (-1, self.seq_len))
|
24 |
+
np.random.shuffle(idx)
|
25 |
+
idx = np.reshape(idx, (-1))
|
26 |
+
return iter(idx.astype(int))
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return self.n_sample + (self.seq_len - self.n_sample % self.seq_len)
|
fairseq/examples/MMPT/scripts/video_feature_extractor/shard_feature.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
from mmpt.utils import ShardedTensor
|
10 |
+
|
11 |
+
|
12 |
+
class Shard(object):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
vfeat_dir,
|
16 |
+
tfeat_dir,
|
17 |
+
target_dir,
|
18 |
+
file_paths,
|
19 |
+
shard_size=4096
|
20 |
+
):
|
21 |
+
self.vfeat_dir = vfeat_dir
|
22 |
+
self.tfeat_dir = tfeat_dir
|
23 |
+
self.target_dir = target_dir
|
24 |
+
self.video_ids = {}
|
25 |
+
for split, file_path in zip(["train", "val"], file_paths):
|
26 |
+
with open(file_path) as fr:
|
27 |
+
self.video_ids[split] = [
|
28 |
+
line.strip() for line in fr.readlines()]
|
29 |
+
self.shard_size = shard_size
|
30 |
+
|
31 |
+
def __call__(self, split="train"):
|
32 |
+
for split in ["train", "val"]:
|
33 |
+
meta = {}
|
34 |
+
for shard_idx, shard_offset in enumerate(
|
35 |
+
range(0, len(self.video_ids[split]), self.shard_size)
|
36 |
+
):
|
37 |
+
print(shard_idx)
|
38 |
+
meta_shard = []
|
39 |
+
video_shard = []
|
40 |
+
for video_id in self.video_ids[split][shard_offset:shard_offset+self.shard_size]:
|
41 |
+
meta_shard.append(video_id)
|
42 |
+
npy_file = os.path.join(self.vfeat_dir, video_id + ".npy")
|
43 |
+
video_shard.append(np.load(npy_file))
|
44 |
+
|
45 |
+
meta[shard_idx] = meta_shard
|
46 |
+
video_shard = ShardedTensor.from_list(video_shard)
|
47 |
+
target_path = os.path.join(
|
48 |
+
self.target_dir, split + "_" + str(shard_idx))
|
49 |
+
video_shard.save(target_path)
|
50 |
+
|
51 |
+
target_path = os.path.join(self.target_dir, split + "_meta")
|
52 |
+
with open(target_path + ".pkl", "wb") as fw:
|
53 |
+
pickle.dump(meta, fw, pickle.HIGHEST_PROTOCOL)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
shard = Shard(
|
58 |
+
"data/feat/feat_how2_s3d",
|
59 |
+
"data/how2/raw_caption_dedup.bert-base-uncased",
|
60 |
+
"data/feat/feat_how2_s3d_shard_small",
|
61 |
+
["data/how2/how2_s3d_train.lst", "data/how2/how2_s3d_val.lst"]
|
62 |
+
)
|
63 |
+
|
64 |
+
shard()
|
fairseq/examples/MMPT/scripts/video_feature_extractor/videoreader.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright Howto100M authors.
|
2 |
+
# Copyright (c) Facebook, Inc. All Rights Reserved
|
3 |
+
|
4 |
+
import torch as th
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import ffmpeg
|
9 |
+
import random
|
10 |
+
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
|
14 |
+
class VideoLoader(Dataset):
|
15 |
+
"""modified from how2's video_feature_extractor."""
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
csv=None,
|
19 |
+
video_dict=None,
|
20 |
+
framerate=1,
|
21 |
+
size=112,
|
22 |
+
centercrop=False,
|
23 |
+
hflip=False,
|
24 |
+
**kwargs
|
25 |
+
):
|
26 |
+
if csv is None and video_dict is None:
|
27 |
+
raise ValueError("csv and video_dict cannot be both None.")
|
28 |
+
if csv is not None:
|
29 |
+
self.csv = pd.read_csv(csv)
|
30 |
+
if video_dict is not None:
|
31 |
+
self.csv = pd.DataFrame.from_dict(video_dict)
|
32 |
+
|
33 |
+
self.centercrop = centercrop
|
34 |
+
self.size = size
|
35 |
+
self.framerate = framerate
|
36 |
+
self.hflip = hflip
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.csv)
|
40 |
+
|
41 |
+
def _get_video_dim(self, video_path):
|
42 |
+
probe = ffmpeg.probe(video_path)
|
43 |
+
video_stream = next((stream for stream in probe['streams']
|
44 |
+
if stream['codec_type'] == 'video'), None)
|
45 |
+
width = int(video_stream['width'])
|
46 |
+
height = int(video_stream['height'])
|
47 |
+
return height, width
|
48 |
+
|
49 |
+
def _get_video_info(self, video_path):
|
50 |
+
probe = ffmpeg.probe(video_path)
|
51 |
+
video_stream = next((stream for stream in probe['streams']
|
52 |
+
if stream['codec_type'] == 'video'), None)
|
53 |
+
return video_stream
|
54 |
+
|
55 |
+
def _get_output_dim(self, h, w):
|
56 |
+
if isinstance(self.size, tuple) and len(self.size) == 2:
|
57 |
+
return self.size
|
58 |
+
elif h >= w:
|
59 |
+
return int(h * self.size / w), self.size
|
60 |
+
else:
|
61 |
+
return self.size, int(w * self.size / h)
|
62 |
+
|
63 |
+
def __getitem__(self, idx):
|
64 |
+
video_path = self.csv['video_path'].values[idx]
|
65 |
+
output_file = self.csv['feature_path'].values[idx]
|
66 |
+
return self._decode(output_file, video_path)
|
67 |
+
|
68 |
+
def _decode(self, output_file, video_path):
|
69 |
+
if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
|
70 |
+
try:
|
71 |
+
h, w = self._get_video_dim(video_path)
|
72 |
+
except Exception:
|
73 |
+
print('ffprobe failed at: {}'.format(video_path))
|
74 |
+
return {'video': th.zeros(1), 'input': video_path,
|
75 |
+
'output': output_file}
|
76 |
+
try:
|
77 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
78 |
+
height, width = self._get_output_dim(h, w)
|
79 |
+
|
80 |
+
cmd = (
|
81 |
+
ffmpeg
|
82 |
+
.input(video_path)
|
83 |
+
.filter('fps', fps=self.framerate)
|
84 |
+
.filter('scale', width, height)
|
85 |
+
)
|
86 |
+
if self.hflip:
|
87 |
+
cmd = cmd.filter('hflip')
|
88 |
+
|
89 |
+
if self.centercrop:
|
90 |
+
x = int((width - self.size) / 2.0)
|
91 |
+
y = int((height - self.size) / 2.0)
|
92 |
+
cmd = cmd.crop(x, y, self.size, self.size)
|
93 |
+
video = self._run(cmd, output_file)
|
94 |
+
except Exception:
|
95 |
+
video = th.zeros(1)
|
96 |
+
else:
|
97 |
+
video = th.zeros(1)
|
98 |
+
|
99 |
+
return {'video': video, 'input': video_path, 'output': output_file}
|
100 |
+
|
101 |
+
def _run(self, cmd, output_file):
|
102 |
+
out, _ = (
|
103 |
+
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
|
104 |
+
.run(capture_stdout=True, quiet=True)
|
105 |
+
)
|
106 |
+
if self.centercrop and isinstance(self.size, int):
|
107 |
+
height, width = self.size, self.size
|
108 |
+
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
|
109 |
+
video = th.from_numpy(video.astype('float32'))
|
110 |
+
return video.permute(0, 3, 1, 2)
|
111 |
+
|
112 |
+
|
113 |
+
class VideoVerifier(VideoLoader):
|
114 |
+
def __getitem__(self, idx):
|
115 |
+
video_path = self.csv['video_path'].values[idx]
|
116 |
+
try:
|
117 |
+
return self._get_video_info(video_path)
|
118 |
+
except Exception:
|
119 |
+
# print('ffprobe failed at: {}'.format(video_path))
|
120 |
+
return None
|
121 |
+
|
122 |
+
|
123 |
+
class VideoCompressor(VideoLoader):
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
csv=None,
|
127 |
+
video_dict=None,
|
128 |
+
framerate=1,
|
129 |
+
size=112,
|
130 |
+
centercrop=False,
|
131 |
+
hflip=False,
|
132 |
+
crf=32,
|
133 |
+
**kwargs
|
134 |
+
):
|
135 |
+
super().__init__(
|
136 |
+
csv,
|
137 |
+
video_dict,
|
138 |
+
framerate,
|
139 |
+
size,
|
140 |
+
centercrop,
|
141 |
+
hflip
|
142 |
+
)
|
143 |
+
self.crf = crf
|
144 |
+
|
145 |
+
def _run(self, cmd, output_file):
|
146 |
+
out, _ = (
|
147 |
+
cmd.output(filename=output_file, crf=self.crf)
|
148 |
+
.run(quiet=True)
|
149 |
+
)
|
150 |
+
video = None
|
151 |
+
return video
|
152 |
+
|
153 |
+
|
154 |
+
class VideoDownloader(VideoCompressor):
|
155 |
+
"""download"""
|
156 |
+
def __getitem__(self, idx):
|
157 |
+
video_path = self.csv['video_path'].values[idx]
|
158 |
+
output_file = self.csv['feature_path'].values[idx]
|
159 |
+
if not(os.path.isfile(output_file)):
|
160 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
161 |
+
cmd = "wget -O" + output_file + " " + video_path
|
162 |
+
# import subprocess
|
163 |
+
# subprocess.check_output(
|
164 |
+
# cmd,
|
165 |
+
# stderr=subprocess.STDOUT, shell=True)
|
166 |
+
os.system(cmd)
|
167 |
+
return {'video': None, 'input': video_path, 'output': output_file}
|
168 |
+
|
169 |
+
|
170 |
+
class AvKeyframeVideoCompressor(VideoLoader):
|
171 |
+
"""extract keyframes from a video and save it as jpg.
|
172 |
+
TODO: consider to merge with `CodecProcessor`.
|
173 |
+
"""
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
csv=None,
|
177 |
+
video_dict=None,
|
178 |
+
framerate=1,
|
179 |
+
size=112,
|
180 |
+
centercrop=False,
|
181 |
+
max_num_frames=5,
|
182 |
+
**kwargs
|
183 |
+
):
|
184 |
+
super().__init__(csv, video_dict, framerate, size, centercrop)
|
185 |
+
self.max_num_frames = max_num_frames
|
186 |
+
|
187 |
+
def _get_video_dim(self, video_fn):
|
188 |
+
"""decord cannot probe the size of a video, we use pyav instead."""
|
189 |
+
import av
|
190 |
+
with av.open(video_fn) as container:
|
191 |
+
height = container.streams.video[0].codec_context.height
|
192 |
+
width = container.streams.video[0].codec_context.width
|
193 |
+
return height, width
|
194 |
+
|
195 |
+
def _get_output_dim(self, height, width):
|
196 |
+
"""
|
197 |
+
keep the shorter side be `self.size`, strech the other.
|
198 |
+
"""
|
199 |
+
if height >= width:
|
200 |
+
return int(height * self.size / width), self.size
|
201 |
+
else:
|
202 |
+
return self.size, int(width * self.size / height)
|
203 |
+
|
204 |
+
def __getitem__(self, idx):
|
205 |
+
import av
|
206 |
+
video_path = self.csv['video_path'].values[idx]
|
207 |
+
output_file = self.csv['feature_path'].values[idx]
|
208 |
+
if not(os.path.isdir(output_file)) and os.path.isfile(video_path):
|
209 |
+
try:
|
210 |
+
h, w = self._get_video_dim(video_path)
|
211 |
+
except Exception:
|
212 |
+
print('probe failed at: {}'.format(video_path))
|
213 |
+
return {'video': th.zeros(1), 'input': video_path,
|
214 |
+
'output': output_file}
|
215 |
+
|
216 |
+
try:
|
217 |
+
height, width = self._get_output_dim(h, w)
|
218 |
+
|
219 |
+
# new for av.
|
220 |
+
with av.open(video_path) as container:
|
221 |
+
container.streams.video[0].thread_type = "AUTO"
|
222 |
+
container.streams.video[0].codec_context.height = height
|
223 |
+
container.streams.video[0].codec_context.width = width
|
224 |
+
if self.framerate == 0: # keyframe.
|
225 |
+
container.streams.video[0].codec_context.skip_frame = 'NONKEY'
|
226 |
+
frames = []
|
227 |
+
for frame in container.decode(video=0):
|
228 |
+
frames.append(frame)
|
229 |
+
frames = random.sample(frames, self.max_num_frames)
|
230 |
+
|
231 |
+
os.makedirs(output_file, exist_ok=True)
|
232 |
+
for frame in frames:
|
233 |
+
frame.to_image().save(
|
234 |
+
os.path.join(
|
235 |
+
output_file,
|
236 |
+
"%04d.jpg" % frame.index))
|
237 |
+
except Exception:
|
238 |
+
print('extract failed at: {}'.format(video_path))
|
239 |
+
return {'video': th.zeros(1), 'input': video_path,
|
240 |
+
'output': output_file}
|
241 |
+
video = th.zeros(1)
|
242 |
+
return {'video': video, 'input': video_path, 'output': output_file}
|
fairseq/examples/MMPT/videoclip.png
ADDED
![]() |
Git LFS Details
|
fairseq/examples/MMPT/vlm.png
ADDED
![]() |
Git LFS Details
|
fairseq/examples/adaptive_span/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adaptive Span
|
2 |
+
|
3 |
+
Adaptive Span is a novel self-attention mechanism that can learn its optimal
|
4 |
+
attention span. This allows us to extend significantly the maximum context size
|
5 |
+
used in Transformer, while maintaining control over their memory footprint
|
6 |
+
and computational time. It uses the Truncated BPTT technique for training,
|
7 |
+
as in [transformerXL](https://github.com/pytorch/fairseq/blob/main/examples/truncated_bptt/README.md).
|
8 |
+
|
9 |
+
Adaptive Span was introduced by paper:
|
10 |
+
[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799),
|
11 |
+
which achieved state-of-the-art language modeling results at the time of publication.
|
12 |
+
|
13 |
+
We manage to reproduce their result in fairseq and keep most of the
|
14 |
+
[original implementation](https://github.com/facebookresearch/adaptive-span) untouched.
|
15 |
+
You can refer to the their sweep file as well if any combination of hyperparameter is not clear.
|
16 |
+
|
17 |
+
##### 0. Setup
|
18 |
+
|
19 |
+
First you need to process the Enwik8 dataset, we use the pre-tokenized dataset
|
20 |
+
from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh).
|
21 |
+
You can download the dataset, and then run:
|
22 |
+
```bash
|
23 |
+
fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \
|
24 |
+
--validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \
|
25 |
+
--destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20
|
26 |
+
```
|
27 |
+
|
28 |
+
##### 1. Train a Adaptive Span model on Enwik8
|
29 |
+
|
30 |
+
We will train a 12-layer Adaptive Span model following the [hyperparameters
|
31 |
+
used in the original
|
32 |
+
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
|
33 |
+
|
34 |
+
The following command assumes 4 GPUs, so that the total batch size is 64
|
35 |
+
sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs:
|
36 |
+
```bash
|
37 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
|
38 |
+
--user-dir examples/adaptive_span \
|
39 |
+
--data ~/data/enwik8/data-bin/ \
|
40 |
+
--fp16 --fp16-no-flatten-grads --max-update 600000 \
|
41 |
+
--task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \
|
42 |
+
--n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \
|
43 |
+
--attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \
|
44 |
+
--validate-interval-updates 1000 \
|
45 |
+
--lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \
|
46 |
+
--lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \
|
47 |
+
--seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07
|
48 |
+
```
|
49 |
+
This should land around 1.05 on validation, 1.03 on test. You can lower the
|
50 |
+
--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc
|
51 |
+
improvement to the transformerXL baseline here.
|
52 |
+
If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients
|
53 |
+
and simulate training on 4 GPUs.
|
54 |
+
You can also reproduce the transformerXL result on enwik8 using this code base.
|
55 |
+
It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh).
|
56 |
+
You can try by
|
57 |
+
```bash
|
58 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
|
59 |
+
--user-dir examples/truncated_bptt \
|
60 |
+
~/data/enwik8/data-bin/ \
|
61 |
+
--task truncated_bptt_lm --fp16 --max-update 400000 \
|
62 |
+
--tokens-per-sample 512 --arch transformer_xl --n-layer 12 \
|
63 |
+
--d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \
|
64 |
+
--dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \
|
65 |
+
--lr-scheduler cosine --warmup-updates 0 \
|
66 |
+
--lr 0.0 --lr 0.00025 --batch-size 15 \
|
67 |
+
--update-freq 1 --seed 2 --log-format json --log-interval 25 \
|
68 |
+
--fp16
|
69 |
+
```
|
70 |
+
|
71 |
+
##### 2. Evaluate
|
72 |
+
For Adaptive Span:
|
73 |
+
```bash
|
74 |
+
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
|
75 |
+
--user-dir examples/adaptive_span \
|
76 |
+
--task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test
|
77 |
+
```
|
78 |
+
For Transformer-XL evaluation:
|
79 |
+
```bash
|
80 |
+
fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \
|
81 |
+
--user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \
|
82 |
+
--tokens-per-sample 80 \
|
83 |
+
--model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \
|
84 |
+
--gen-subset valid
|
85 |
+
```
|
86 |
+
|
87 |
+
*Note:* During training the model saw 512 tokens of context
|
88 |
+
(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation
|
89 |
+
settings from [the original
|
90 |
+
paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh).
|
fairseq/examples/adaptive_span/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import importlib
|
7 |
+
import os
|
8 |
+
|
9 |
+
# automatically import any Python files in the current directory
|
10 |
+
cur_dir = os.path.dirname(__file__)
|
11 |
+
for file in os.listdir(cur_dir):
|
12 |
+
path = os.path.join(cur_dir, file)
|
13 |
+
if (
|
14 |
+
not file.startswith("_")
|
15 |
+
and not file.startswith(".")
|
16 |
+
and (file.endswith(".py") or os.path.isdir(path))
|
17 |
+
):
|
18 |
+
mod_name = file[: file.find(".py")] if file.endswith(".py") else file
|
19 |
+
module = importlib.import_module(__name__ + "." + mod_name)
|
fairseq/examples/adaptive_span/adagrad_with_grad_clip.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from torch.optim import Adagrad
|
7 |
+
|
8 |
+
from fairseq.optim import LegacyFairseqOptimizer, register_optimizer
|
9 |
+
|
10 |
+
|
11 |
+
@register_optimizer("adagrad_with_grad_clip")
|
12 |
+
class FairseqAdagradWithGradClip(LegacyFairseqOptimizer):
|
13 |
+
def __init__(self, args, params):
|
14 |
+
super().__init__(args)
|
15 |
+
self._optimizer = AdagradWithGradClip(params, **self.optimizer_config)
|
16 |
+
|
17 |
+
@staticmethod
|
18 |
+
def add_args(parser):
|
19 |
+
"""Add optimizer-specific arguments to the parser."""
|
20 |
+
# fmt: off
|
21 |
+
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
|
22 |
+
help='weight decay')
|
23 |
+
parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D',
|
24 |
+
help='internal grad clip')
|
25 |
+
# fmt: on
|
26 |
+
|
27 |
+
@property
|
28 |
+
def optimizer_config(self):
|
29 |
+
"""
|
30 |
+
Return a kwarg dictionary that will be used to override optimizer
|
31 |
+
args stored in checkpoints. This allows us to load a checkpoint and
|
32 |
+
resume training using a different set of optimizer args, e.g., with a
|
33 |
+
different learning rate.
|
34 |
+
"""
|
35 |
+
return {
|
36 |
+
"lr": self.args.lr[0],
|
37 |
+
"weight_decay": self.args.weight_decay,
|
38 |
+
"grad_clip": self.args.adagrad_clip,
|
39 |
+
}
|
40 |
+
|
41 |
+
@property
|
42 |
+
def supports_flat_params(self):
|
43 |
+
return False
|
44 |
+
|
45 |
+
|
46 |
+
def _clip_grad(clr, grad, group_grad_clip):
|
47 |
+
if group_grad_clip > 0:
|
48 |
+
norm = grad.norm(2).item()
|
49 |
+
if norm > group_grad_clip:
|
50 |
+
clr *= group_grad_clip / (norm + 1e-10)
|
51 |
+
return clr
|
52 |
+
|
53 |
+
|
54 |
+
class AdagradWithGradClip(Adagrad):
|
55 |
+
"""Adagrad algorithm with custom gradient clipping"""
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
params,
|
60 |
+
lr=1e-2,
|
61 |
+
lr_decay=0,
|
62 |
+
weight_decay=0,
|
63 |
+
initial_accumulator_value=0,
|
64 |
+
grad_clip=0,
|
65 |
+
):
|
66 |
+
Adagrad.__init__(
|
67 |
+
self,
|
68 |
+
params,
|
69 |
+
lr=lr,
|
70 |
+
lr_decay=lr_decay,
|
71 |
+
weight_decay=weight_decay,
|
72 |
+
initial_accumulator_value=initial_accumulator_value,
|
73 |
+
)
|
74 |
+
self.defaults["grad_clip"] = grad_clip
|
75 |
+
self.param_groups[0].setdefault("grad_clip", grad_clip)
|
76 |
+
|
77 |
+
def step(self, closure=None):
|
78 |
+
loss = None
|
79 |
+
if closure is not None:
|
80 |
+
loss = closure()
|
81 |
+
|
82 |
+
for group in self.param_groups:
|
83 |
+
for p in group["params"]:
|
84 |
+
if p.grad is None:
|
85 |
+
continue
|
86 |
+
|
87 |
+
grad = p.grad.data
|
88 |
+
state = self.state[p]
|
89 |
+
|
90 |
+
state["step"] += 1
|
91 |
+
|
92 |
+
if group["weight_decay"] != 0:
|
93 |
+
if p.grad.data.is_sparse:
|
94 |
+
raise RuntimeError(
|
95 |
+
"weight_decay option is "
|
96 |
+
"not compatible with sparse "
|
97 |
+
"gradients"
|
98 |
+
)
|
99 |
+
grad = grad.add(group["weight_decay"], p.data)
|
100 |
+
|
101 |
+
clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"])
|
102 |
+
|
103 |
+
# clip
|
104 |
+
clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"])
|
105 |
+
|
106 |
+
if grad.is_sparse:
|
107 |
+
# the update is non-linear so indices must be unique
|
108 |
+
grad = grad.coalesce()
|
109 |
+
grad_indices = grad._indices()
|
110 |
+
grad_values = grad._values()
|
111 |
+
size = grad.size()
|
112 |
+
|
113 |
+
def make_sparse(values):
|
114 |
+
constructor = grad.new
|
115 |
+
if grad_indices.dim() == 0 or values.dim() == 0:
|
116 |
+
return constructor().resize_as_(grad)
|
117 |
+
return constructor(grad_indices, values, size)
|
118 |
+
|
119 |
+
state["sum"].add_(make_sparse(grad_values.pow(2)))
|
120 |
+
std = state["sum"]._sparse_mask(grad)
|
121 |
+
std_values = std._values().sqrt_().add_(1e-10)
|
122 |
+
p.data.add_(-clr, make_sparse(grad_values / std_values))
|
123 |
+
else:
|
124 |
+
state["sum"].addcmul_(1, grad, grad)
|
125 |
+
std = state["sum"].sqrt().add_(1e-10)
|
126 |
+
p.data.addcdiv_(-clr, grad, std)
|
127 |
+
|
128 |
+
return loss
|
fairseq/examples/adaptive_span/adaptive_span_attention.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class AdaptiveMask(nn.Module):
|
13 |
+
"""Soft masking function for adaptive size.
|
14 |
+
It masks out the last K values of an input. The masking value
|
15 |
+
goes from 1 to 0 gradually, so K can be learned with
|
16 |
+
back-propagation.
|
17 |
+
Args:
|
18 |
+
max_size: maximum size (i.e. input dimension)
|
19 |
+
ramp_size: size of the ramp going from 0 to 1
|
20 |
+
init_val: initial size proportion not to be masked out
|
21 |
+
shape: learn multiple sizes independent of each other
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
|
25 |
+
nn.Module.__init__(self)
|
26 |
+
self._max_size = max_size
|
27 |
+
self._ramp_size = ramp_size
|
28 |
+
self.current_val = nn.Parameter(torch.zeros(*shape) + init_val)
|
29 |
+
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)
|
30 |
+
self.register_buffer("mask_template", mask_template)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
mask = self.mask_template.float() + self.current_val.float() * self._max_size
|
34 |
+
mask = mask / self._ramp_size + 1
|
35 |
+
mask = mask.clamp(0, 1)
|
36 |
+
if x.size(-1) < self._max_size:
|
37 |
+
# the input could have been trimmed beforehand to save computation
|
38 |
+
mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1))
|
39 |
+
x = (x * mask).type_as(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
def get_current_max_size(self, include_ramp=True):
|
43 |
+
current_size = math.ceil(self.current_val.max().item() * self._max_size)
|
44 |
+
if include_ramp:
|
45 |
+
current_size += self._ramp_size
|
46 |
+
current_size = max(0, min(self._max_size, current_size))
|
47 |
+
return current_size
|
48 |
+
|
49 |
+
def get_current_avg_size(self, include_ramp=True):
|
50 |
+
current_size = math.ceil(
|
51 |
+
self.current_val.float().mean().item() * self._max_size
|
52 |
+
)
|
53 |
+
if include_ramp:
|
54 |
+
current_size += self._ramp_size
|
55 |
+
current_size = max(0, min(self._max_size, current_size))
|
56 |
+
return current_size
|
57 |
+
|
58 |
+
def clamp_param(self):
|
59 |
+
"""this need to be called after each update"""
|
60 |
+
self.current_val.data.clamp_(0, 1)
|
61 |
+
|
62 |
+
|
63 |
+
class AdaptiveSpan(nn.Module):
|
64 |
+
"""Adaptive attention span for Transformerself.
|
65 |
+
This module learns an attention span length from data for each
|
66 |
+
self-attention head.
|
67 |
+
Args:
|
68 |
+
attn_span: maximum attention span
|
69 |
+
adapt_span_loss: loss coefficient for the span length
|
70 |
+
adapt_span_ramp: length of the masking ramp
|
71 |
+
adapt_span_init: initial size ratio
|
72 |
+
adapt_span_cache: adapt cache size to reduce memory usage
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
attn_span,
|
78 |
+
adapt_span_ramp,
|
79 |
+
adapt_span_init,
|
80 |
+
n_head,
|
81 |
+
adapt_span_layer,
|
82 |
+
**kargs
|
83 |
+
):
|
84 |
+
nn.Module.__init__(self)
|
85 |
+
self._max_span = attn_span
|
86 |
+
self._n_head = n_head
|
87 |
+
self._adapt_span_layer = adapt_span_layer
|
88 |
+
if self._adapt_span_layer:
|
89 |
+
self._mask = AdaptiveMask(
|
90 |
+
max_size=self._max_span,
|
91 |
+
ramp_size=adapt_span_ramp,
|
92 |
+
init_val=adapt_span_init,
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
self._mask = AdaptiveMask(
|
96 |
+
max_size=self._max_span,
|
97 |
+
ramp_size=adapt_span_ramp,
|
98 |
+
init_val=adapt_span_init,
|
99 |
+
shape=(n_head, 1, 1),
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, attn, normalize=True):
|
103 |
+
"""mask attention with the right span"""
|
104 |
+
# batch and head dimensions are merged together, so separate them first
|
105 |
+
self.clamp_param()
|
106 |
+
if self._adapt_span_layer:
|
107 |
+
attn = self._mask(attn)
|
108 |
+
else:
|
109 |
+
B = attn.size(0) # batch size
|
110 |
+
M = attn.size(1) # block size
|
111 |
+
attn = attn.reshape(B // self._n_head, self._n_head, M, -1)
|
112 |
+
attn = self._mask(attn)
|
113 |
+
attn = attn.view(B, M, -1)
|
114 |
+
return attn
|
115 |
+
|
116 |
+
def get_trim_len(self):
|
117 |
+
"""how much of memory can be trimmed to reduce computation"""
|
118 |
+
L = self._max_span
|
119 |
+
trim_len = min(L - 1, L - self._mask.get_current_max_size())
|
120 |
+
# too fine granularity might be bad for the memory management
|
121 |
+
trim_len = math.floor(trim_len / 64) * 64
|
122 |
+
return trim_len
|
123 |
+
|
124 |
+
def trim_memory(self, query, key, value, key_pe):
|
125 |
+
"""trim out unnecessary memory beforehand to reduce computation"""
|
126 |
+
trim_len = self.get_trim_len()
|
127 |
+
cache_size = key.size(1) - query.size(1)
|
128 |
+
trim_len_cache = trim_len - (self._max_span - cache_size)
|
129 |
+
if trim_len_cache > 0:
|
130 |
+
key = key[:, trim_len_cache:, :]
|
131 |
+
value = value[:, trim_len_cache:, :]
|
132 |
+
elif trim_len_cache < 0:
|
133 |
+
# cache is too short! this happens when validation resumes
|
134 |
+
# after a lot of updates.
|
135 |
+
key = F.pad(key, [0, 0, -trim_len_cache, 0])
|
136 |
+
value = F.pad(value, [0, 0, -trim_len_cache, 0])
|
137 |
+
if trim_len > 0:
|
138 |
+
if key_pe is not None:
|
139 |
+
key_pe = key_pe[:, :, trim_len:]
|
140 |
+
return key, value, key_pe
|
141 |
+
|
142 |
+
def get_cache_size(self):
|
143 |
+
"""determine how long the cache should be"""
|
144 |
+
trim_len = self.get_trim_len()
|
145 |
+
# give a buffer of 64 steps since a span might increase
|
146 |
+
# in future updates
|
147 |
+
return min(self._max_span, self._max_span - trim_len + 64)
|
148 |
+
|
149 |
+
def get_loss(self):
|
150 |
+
"""a loss term for regularizing the span length"""
|
151 |
+
return self._max_span * self._mask.current_val.float().mean()
|
152 |
+
|
153 |
+
def get_current_max_span(self):
|
154 |
+
return self._mask.get_current_max_size()
|
155 |
+
|
156 |
+
def get_current_avg_span(self):
|
157 |
+
return self._mask.get_current_avg_size()
|
158 |
+
|
159 |
+
def clamp_param(self):
|
160 |
+
self._mask.clamp_param()
|
fairseq/examples/adaptive_span/adaptive_span_loss.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fairseq import utils
|
11 |
+
from fairseq.logging import metrics
|
12 |
+
from fairseq.criterions import register_criterion
|
13 |
+
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
|
14 |
+
from fairseq.dataclass import FairseqDataclass
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class AdaptiveSpanCriterionConfig(FairseqDataclass):
|
20 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
21 |
+
|
22 |
+
|
23 |
+
@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig)
|
24 |
+
class AdaptiveSpanCriterion(CrossEntropyCriterion):
|
25 |
+
def __init__(self, task, sentence_avg):
|
26 |
+
super().__init__(task, sentence_avg)
|
27 |
+
|
28 |
+
def forward(self, model, sample, reduce=True):
|
29 |
+
"""Compute the loss for the given sample.
|
30 |
+
|
31 |
+
Returns a tuple with three elements:
|
32 |
+
1) the loss here is summed, different from the adaptive span code
|
33 |
+
2) the sample size, which is used as the denominator for the gradient
|
34 |
+
3) logging outputs to display while training
|
35 |
+
"""
|
36 |
+
net_output = model(**sample["net_input"])
|
37 |
+
loss, aux_loss, avg_span, max_span = self.compute_loss(
|
38 |
+
model, net_output, sample, reduce=reduce
|
39 |
+
)
|
40 |
+
sample_size = (
|
41 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
42 |
+
)
|
43 |
+
loss /= sample_size
|
44 |
+
total_loss = loss + aux_loss
|
45 |
+
sample_size = 1
|
46 |
+
|
47 |
+
logging_output = {
|
48 |
+
"loss": loss.data,
|
49 |
+
"ntokens": sample["ntokens"],
|
50 |
+
"nsentences": sample["target"].size(0),
|
51 |
+
"sample_size": sample_size,
|
52 |
+
"total_loss": total_loss.data,
|
53 |
+
"avg_span": avg_span * sample_size,
|
54 |
+
"max_span": max_span * sample_size,
|
55 |
+
}
|
56 |
+
return total_loss, sample_size, logging_output
|
57 |
+
|
58 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
59 |
+
loss, _ = super().compute_loss(model, net_output, sample, reduce)
|
60 |
+
aux_loss = model.get_aux_loss()
|
61 |
+
avg_span = model.get_current_avg_span()
|
62 |
+
max_span = model.get_current_max_span()
|
63 |
+
return loss, aux_loss, avg_span, max_span
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def reduce_metrics(logging_outputs) -> None:
|
67 |
+
"""Aggregate logging outputs from data parallel training."""
|
68 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
69 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
70 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
71 |
+
total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs)
|
72 |
+
avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs)
|
73 |
+
max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs)
|
74 |
+
|
75 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
76 |
+
metrics.log_scalar(
|
77 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
78 |
+
)
|
79 |
+
metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3)
|
80 |
+
metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3)
|
81 |
+
# total loss contains the L1 norm on adaptive-span
|
82 |
+
metrics.log_scalar(
|
83 |
+
"total_loss",
|
84 |
+
total_loss_sum / sample_size / math.log(2),
|
85 |
+
sample_size,
|
86 |
+
round=3,
|
87 |
+
)
|
88 |
+
if sample_size != ntokens:
|
89 |
+
metrics.log_scalar(
|
90 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
91 |
+
)
|
92 |
+
metrics.log_derived(
|
93 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
metrics.log_derived(
|
97 |
+
"ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
|
98 |
+
)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def logging_outputs_can_be_summed() -> bool:
|
102 |
+
"""
|
103 |
+
Whether the logging outputs returned by `forward` can be summed
|
104 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
105 |
+
to True will improves distributed training speed.
|
106 |
+
"""
|
107 |
+
return True
|
fairseq/examples/adaptive_span/adaptive_span_model.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
from fairseq.modules.layer_norm import LayerNorm
|
14 |
+
|
15 |
+
from .adaptive_span_attention import AdaptiveSpan
|
16 |
+
|
17 |
+
# Size notations:
|
18 |
+
# B = batch_size, H = d_model, M = block_size, L = attn_span
|
19 |
+
|
20 |
+
|
21 |
+
def _skew(X, pad_value):
|
22 |
+
"""shift every row 1 step to right"""
|
23 |
+
# X = B x M x L
|
24 |
+
B, M, L = X.size()
|
25 |
+
X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1)
|
26 |
+
X = X.view(B, -1) # B x ML+MM+M
|
27 |
+
X = X[:, :-M] # B x ML+MM
|
28 |
+
X = X.view(B, M, M + L) # B x M x L+M
|
29 |
+
return X
|
30 |
+
|
31 |
+
|
32 |
+
def _unskew(X):
|
33 |
+
"""reverse _skew operation"""
|
34 |
+
# X = B x M x L+M
|
35 |
+
B, M, L = X.size()
|
36 |
+
L -= M
|
37 |
+
X = X.view(B, -1) # B x ML+MM
|
38 |
+
X = F.pad(X, (0, M)) # B x ML+MM+M
|
39 |
+
X = X.view(B, M, M + L + 1) # B x M x L+M+1
|
40 |
+
X = X[:, :, :L] # B x M x L
|
41 |
+
return X
|
42 |
+
|
43 |
+
|
44 |
+
class SeqAttention(nn.Module):
|
45 |
+
"""Sequential self-attention layer.
|
46 |
+
Each token will attend to its previous fixed number of steps.
|
47 |
+
Note that attention doesn't include the current step itself.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs):
|
51 |
+
nn.Module.__init__(self)
|
52 |
+
self.dropout = nn.Dropout(dropout)
|
53 |
+
self.d_model = d_model # size of a single head
|
54 |
+
self.attn_span = attn_span
|
55 |
+
self.adaptive_span = AdaptiveSpan(
|
56 |
+
attn_span=attn_span,
|
57 |
+
n_head=n_head,
|
58 |
+
adapt_span_layer=adapt_span_layer,
|
59 |
+
**kargs
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, query, key, value, key_pe):
|
63 |
+
# query size = B x M x H
|
64 |
+
# key, value sizes = B x (M+L) x H
|
65 |
+
|
66 |
+
key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe)
|
67 |
+
|
68 |
+
# compute attention from context
|
69 |
+
# B x M (dest) x (M+L) (src)
|
70 |
+
attn_cont = torch.matmul(query, key.transpose(-1, -2))
|
71 |
+
attn_cont = _unskew(attn_cont) # B x M x L
|
72 |
+
|
73 |
+
# compute the effect of position embedding
|
74 |
+
attn_pos = torch.matmul(query, key_pe) # B x M x L_pos
|
75 |
+
attn = attn_cont + attn_pos
|
76 |
+
|
77 |
+
attn = attn / math.sqrt(self.d_model) # B x M X L_pos
|
78 |
+
|
79 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
80 |
+
|
81 |
+
# trim attention lengths according to the learned span
|
82 |
+
attn = self.adaptive_span(attn)
|
83 |
+
|
84 |
+
attn = self.dropout(attn) # B x M X L_pos
|
85 |
+
|
86 |
+
attn_cont = _skew(attn, 0) # B x M X (L+M)
|
87 |
+
out = torch.matmul(attn_cont, value) # B x M x H
|
88 |
+
return out
|
89 |
+
|
90 |
+
def get_cache_size(self):
|
91 |
+
return self.adaptive_span.get_cache_size()
|
92 |
+
|
93 |
+
|
94 |
+
class MultiHeadSeqAttention(nn.Module):
|
95 |
+
def __init__(self, d_model, n_head, **kargs):
|
96 |
+
nn.Module.__init__(self)
|
97 |
+
assert d_model % n_head == 0
|
98 |
+
self.n_head = n_head
|
99 |
+
self.head_dim = d_model // n_head
|
100 |
+
self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs)
|
101 |
+
self.proj_query = nn.Linear(d_model, d_model, bias=False)
|
102 |
+
nn.init.xavier_normal_(self.proj_query.weight)
|
103 |
+
self.proj_out = nn.Linear(d_model, d_model, bias=False)
|
104 |
+
nn.init.xavier_normal_(self.proj_out.weight)
|
105 |
+
self.proj_val = nn.Linear(d_model, d_model, bias=False)
|
106 |
+
nn.init.xavier_normal_(self.proj_val.weight)
|
107 |
+
self.proj_key = nn.Linear(d_model, d_model, bias=False)
|
108 |
+
nn.init.xavier_normal_(self.proj_key.weight)
|
109 |
+
|
110 |
+
def head_reshape(self, x):
|
111 |
+
K = self.n_head
|
112 |
+
D = self.head_dim
|
113 |
+
x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D
|
114 |
+
x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D
|
115 |
+
x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D
|
116 |
+
return x
|
117 |
+
|
118 |
+
def forward(self, query, key, value, key_pe):
|
119 |
+
B = query.size(0)
|
120 |
+
K = self.n_head
|
121 |
+
D = self.head_dim
|
122 |
+
M = query.size(1)
|
123 |
+
|
124 |
+
query = self.proj_query(query)
|
125 |
+
query = self.head_reshape(query)
|
126 |
+
value = self.proj_val(value)
|
127 |
+
value = self.head_reshape(value)
|
128 |
+
key = self.proj_key(key)
|
129 |
+
key = self.head_reshape(key)
|
130 |
+
|
131 |
+
out = self.attn(query, key, value, key_pe) # B_K x M x D
|
132 |
+
out = out.view(B, K, M, D) # B x K x M x D
|
133 |
+
out = out.transpose(1, 2).contiguous() # B x M x K x D
|
134 |
+
out = out.view(B, M, -1) # B x M x K_D
|
135 |
+
out = self.proj_out(out)
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
class FeedForwardLayer(nn.Module):
|
140 |
+
def __init__(self, d_model, d_inner, dropout, **kargs):
|
141 |
+
nn.Module.__init__(self)
|
142 |
+
self.fc1 = nn.Linear(d_model, d_inner)
|
143 |
+
self.fc2 = nn.Linear(d_inner, d_model)
|
144 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
145 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
146 |
+
self.dropout = nn.Dropout(dropout)
|
147 |
+
|
148 |
+
def forward(self, h):
|
149 |
+
h1 = F.relu(self.fc1(h))
|
150 |
+
h1 = self.dropout(h1)
|
151 |
+
h2 = self.fc2(h1)
|
152 |
+
return h2
|
153 |
+
|
154 |
+
|
155 |
+
class TransformerSeqLayer(nn.Module):
|
156 |
+
def __init__(self, d_model, **kargs):
|
157 |
+
nn.Module.__init__(self)
|
158 |
+
self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs)
|
159 |
+
self.norm1 = LayerNorm(d_model)
|
160 |
+
self.ff = FeedForwardLayer(d_model=d_model, **kargs)
|
161 |
+
self.norm2 = LayerNorm(d_model)
|
162 |
+
|
163 |
+
def forward(self, h, h_cache, key_pe):
|
164 |
+
# h = B x M x H
|
165 |
+
# h_cache = B x L x H
|
166 |
+
h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H
|
167 |
+
attn_out = self.attn(h, h_all, h_all, key_pe)
|
168 |
+
h = self.norm1(h + attn_out) # B x M x H
|
169 |
+
if self.ff is not None:
|
170 |
+
ff_out = self.ff(h)
|
171 |
+
out = self.norm2(h + ff_out) # B x M x H
|
172 |
+
else:
|
173 |
+
out = h
|
174 |
+
return out
|
175 |
+
|
176 |
+
def get_cache_size(self):
|
177 |
+
return self.attn.attn.get_cache_size()
|
178 |
+
|
179 |
+
|
180 |
+
class TransformerSeq(nn.Module):
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
vocab_size,
|
184 |
+
d_model,
|
185 |
+
n_head,
|
186 |
+
n_layer,
|
187 |
+
attn_span,
|
188 |
+
emb_dropout,
|
189 |
+
aux_loss_scaler,
|
190 |
+
adapt_span_layer,
|
191 |
+
**kargs
|
192 |
+
):
|
193 |
+
nn.Module.__init__(self)
|
194 |
+
# token embeddings
|
195 |
+
self.in_emb = nn.Embedding(vocab_size, d_model)
|
196 |
+
nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5)
|
197 |
+
self.out_emb = nn.Linear(d_model, vocab_size)
|
198 |
+
self.aux_loss_scaler = aux_loss_scaler
|
199 |
+
if emb_dropout > 0:
|
200 |
+
self.emb_dropout = nn.Dropout(emb_dropout)
|
201 |
+
else:
|
202 |
+
self.emb_dropout = None
|
203 |
+
# position embeddings
|
204 |
+
self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span))
|
205 |
+
|
206 |
+
self.layers = nn.ModuleList()
|
207 |
+
self.layers.extend(
|
208 |
+
TransformerSeqLayer(
|
209 |
+
d_model=d_model,
|
210 |
+
n_head=n_head,
|
211 |
+
attn_span=attn_span,
|
212 |
+
adapt_span_layer=adapt_span_layer,
|
213 |
+
**kargs
|
214 |
+
)
|
215 |
+
for _ in range(n_layer)
|
216 |
+
)
|
217 |
+
|
218 |
+
def forward(self, x, h_cache, target=None):
|
219 |
+
# x size = B x M
|
220 |
+
block_size = x.size(1)
|
221 |
+
h = self.in_emb(x) # B x M x H
|
222 |
+
if self.emb_dropout is not None:
|
223 |
+
h = self.emb_dropout(h)
|
224 |
+
|
225 |
+
h_cache_next = []
|
226 |
+
for l, layer in enumerate(self.layers):
|
227 |
+
cache_size = layer.attn.attn.get_cache_size()
|
228 |
+
if cache_size > block_size:
|
229 |
+
h_cache_next_l = torch.cat(
|
230 |
+
[h_cache[l][:, -cache_size + block_size :, :], h], dim=1
|
231 |
+
).detach()
|
232 |
+
else:
|
233 |
+
h_cache_next_l = h[:, -cache_size:, :].detach()
|
234 |
+
h_cache_next.append(h_cache_next_l)
|
235 |
+
h = layer(h, h_cache[l], self.key_pe) # B x M x H
|
236 |
+
|
237 |
+
if self.emb_dropout is not None:
|
238 |
+
h = self.emb_dropout(h)
|
239 |
+
|
240 |
+
out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h)
|
241 |
+
dummy_loss = None
|
242 |
+
|
243 |
+
return out, h_cache_next, dummy_loss
|
244 |
+
|
245 |
+
def get_aux_loss(self):
|
246 |
+
loss = 0.0
|
247 |
+
for layer in self.layers:
|
248 |
+
loss += layer.attn.attn.adaptive_span.get_loss()
|
249 |
+
return self.aux_loss_scaler * loss
|
250 |
+
|
251 |
+
def get_current_max_span(self):
|
252 |
+
max_span = 0.0
|
253 |
+
for layer in self.layers:
|
254 |
+
max_span = max(
|
255 |
+
max_span, layer.attn.attn.adaptive_span.get_current_max_span()
|
256 |
+
)
|
257 |
+
return max_span
|
258 |
+
|
259 |
+
def get_current_avg_span(self):
|
260 |
+
avg_span = 0.0
|
261 |
+
for layer in self.layers:
|
262 |
+
avg_span += layer.attn.attn.adaptive_span.get_current_avg_span()
|
263 |
+
return avg_span / len(self.layers)
|
fairseq/examples/adaptive_span/adaptive_span_model_wrapper.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq.dataclass import FairseqDataclass
|
12 |
+
from fairseq.models import (
|
13 |
+
FairseqIncrementalDecoder,
|
14 |
+
FairseqLanguageModel,
|
15 |
+
register_model,
|
16 |
+
)
|
17 |
+
from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel
|
18 |
+
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class AdaptiveSpanSmallConfig(FairseqDataclass):
|
25 |
+
# defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh
|
26 |
+
vocab_size: int = 50
|
27 |
+
d_model: int = 256
|
28 |
+
n_head: int = 4
|
29 |
+
d_inner: int = 1024
|
30 |
+
n_layer: int = 8
|
31 |
+
attn_span: int = 1024
|
32 |
+
dropout: float = 0.0
|
33 |
+
emb_dropout: float = 0.0
|
34 |
+
adapt_span_ramp: int = 32
|
35 |
+
adapt_span_init: float = 0.0
|
36 |
+
aux_loss_scaler: float = 0.000002
|
37 |
+
adapt_span_layer: bool = False
|
38 |
+
|
39 |
+
|
40 |
+
@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig)
|
41 |
+
class AdaptiveSpanTransformer(FairseqLanguageModel):
|
42 |
+
@classmethod
|
43 |
+
def build_model(cls, cfg: AdaptiveSpanSmallConfig, task):
|
44 |
+
return cls(AdaptiveSpanDecoder(cfg, task))
|
45 |
+
|
46 |
+
def get_aux_loss(self):
|
47 |
+
return self.decoder.get_aux_loss()
|
48 |
+
|
49 |
+
def get_current_max_span(self):
|
50 |
+
return self.decoder.get_current_max_span()
|
51 |
+
|
52 |
+
def get_current_avg_span(self):
|
53 |
+
return self.decoder.get_current_avg_span()
|
54 |
+
|
55 |
+
|
56 |
+
class AdaptiveSpanDecoder(FairseqIncrementalDecoder):
|
57 |
+
def __init__(self, cfg, task):
|
58 |
+
|
59 |
+
super().__init__(task.target_dictionary)
|
60 |
+
|
61 |
+
self.config = cfg
|
62 |
+
config = AdaptiveSpanSmallConfig(
|
63 |
+
vocab_size=len(task.target_dictionary),
|
64 |
+
d_model=cfg.d_model,
|
65 |
+
n_head=cfg.n_head,
|
66 |
+
d_inner=cfg.d_inner,
|
67 |
+
n_layer=cfg.n_layer,
|
68 |
+
attn_span=cfg.attn_span,
|
69 |
+
dropout=cfg.dropout,
|
70 |
+
emb_dropout=cfg.emb_dropout,
|
71 |
+
adapt_span_ramp=cfg.adapt_span_ramp,
|
72 |
+
adapt_span_init=cfg.adapt_span_init,
|
73 |
+
aux_loss_scaler=cfg.aux_loss_scaler,
|
74 |
+
adapt_span_layer=cfg.adapt_span_layer,
|
75 |
+
)
|
76 |
+
logger.info(config)
|
77 |
+
self.model = AdaptiveSpanTransformerModel(**config.__dict__)
|
78 |
+
|
79 |
+
self._mems = None
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self,
|
83 |
+
src_tokens,
|
84 |
+
incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None,
|
85 |
+
encoder_out=None,
|
86 |
+
):
|
87 |
+
bsz = src_tokens.size(0)
|
88 |
+
if incremental_state is not None: # used during inference
|
89 |
+
mems = self.get_incremental_state("mems")
|
90 |
+
src_tokens = src_tokens[:, -1:] # only keep the most recent token
|
91 |
+
else:
|
92 |
+
mems = self._mems
|
93 |
+
|
94 |
+
if mems is None:
|
95 |
+
# first time init
|
96 |
+
mems = self.init_hid_cache(bsz)
|
97 |
+
output = self.model(x=src_tokens, h_cache=mems,)
|
98 |
+
if incremental_state is not None:
|
99 |
+
self.set_incremental_state(incremental_state, "mems", output[1])
|
100 |
+
else:
|
101 |
+
self._mems = output[1]
|
102 |
+
return (output[0],)
|
103 |
+
|
104 |
+
def max_positions(self):
|
105 |
+
return self.config.attn_span
|
106 |
+
|
107 |
+
def init_hid_cache(self, batch_sz):
|
108 |
+
hid = []
|
109 |
+
for layer in self.model.layers:
|
110 |
+
param = next(self.model.parameters())
|
111 |
+
h = torch.zeros(
|
112 |
+
batch_sz,
|
113 |
+
layer.get_cache_size(),
|
114 |
+
self.config.d_model,
|
115 |
+
dtype=param.dtype,
|
116 |
+
device=param.device,
|
117 |
+
)
|
118 |
+
hid.append(h)
|
119 |
+
return hid
|
120 |
+
|
121 |
+
def get_aux_loss(self):
|
122 |
+
return self.model.get_aux_loss()
|
123 |
+
|
124 |
+
def get_current_max_span(self):
|
125 |
+
return self.model.get_current_max_span()
|
126 |
+
|
127 |
+
def get_current_avg_span(self):
|
128 |
+
return self.model.get_current_avg_span()
|
129 |
+
|
130 |
+
def reorder_incremental_state(
|
131 |
+
self,
|
132 |
+
incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]],
|
133 |
+
new_order: torch.Tensor,
|
134 |
+
):
|
135 |
+
"""Reorder incremental state.
|
136 |
+
|
137 |
+
This will be called when the order of the input has changed from the
|
138 |
+
previous time step. A typical use case is beam search, where the input
|
139 |
+
order changes between time steps based on the selection of beams.
|
140 |
+
"""
|
141 |
+
raise NotImplementedError("This is required for generation/beam search")
|
142 |
+
# mems = self.get_incremental_state(incremental_state, "mems")
|
143 |
+
# if mems is not None:
|
144 |
+
# new_mems = [mems_i.index_select(1, new_order) for mems_i in mems]
|
145 |
+
# self.set_incremental_state(incremental_state, "mems", new_mems)
|
fairseq/examples/adaptive_span/truncated_bptt_lm_task.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import List, Optional, Tuple
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.data import (
|
14 |
+
Dictionary,
|
15 |
+
TokenBlockDataset,
|
16 |
+
data_utils,
|
17 |
+
iterators,
|
18 |
+
)
|
19 |
+
from fairseq.dataclass import FairseqDataclass
|
20 |
+
from fairseq.distributed import utils as dist_utils
|
21 |
+
from fairseq.tasks import FairseqTask, register_task
|
22 |
+
from omegaconf import II
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class TruncatedBPTTLMConfig(FairseqDataclass):
|
30 |
+
data: str = field(default="???", metadata={"help": "path to data directory"})
|
31 |
+
tokens_per_sample: int = field(
|
32 |
+
default=1024, metadata={"help": "max number of tokens per sequence"},
|
33 |
+
)
|
34 |
+
batch_size: int = II("dataset.batch_size")
|
35 |
+
# Some models use *max_target_positions* to know how many positional
|
36 |
+
# embeddings to learn. We use II(...) to make it default to
|
37 |
+
# *tokens_per_sample*, but in principle there could be more positional
|
38 |
+
# embeddings than tokens in a single batch. This may also be irrelevant for
|
39 |
+
# custom model implementations.
|
40 |
+
max_target_positions: int = II("task.tokens_per_sample")
|
41 |
+
# these will be populated automatically if not provided
|
42 |
+
data_parallel_rank: Optional[int] = None
|
43 |
+
data_parallel_size: Optional[int] = None
|
44 |
+
|
45 |
+
|
46 |
+
@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig)
|
47 |
+
class TruncatedBPTTLMTask(FairseqTask):
|
48 |
+
def __init__(self, cfg: TruncatedBPTTLMConfig):
|
49 |
+
super().__init__(cfg)
|
50 |
+
|
51 |
+
if cfg.data_parallel_rank is None or cfg.data_parallel_size is None:
|
52 |
+
if torch.distributed.is_initialized():
|
53 |
+
cfg.data_parallel_rank = dist_utils.get_data_parallel_rank()
|
54 |
+
cfg.data_parallel_size = dist_utils.get_data_parallel_world_size()
|
55 |
+
else:
|
56 |
+
cfg.data_parallel_rank = 0
|
57 |
+
cfg.data_parallel_size = 1
|
58 |
+
|
59 |
+
# load the dictionary
|
60 |
+
paths = utils.split_paths(cfg.data)
|
61 |
+
assert len(paths) > 0
|
62 |
+
self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt"))
|
63 |
+
logger.info("dictionary: {} types".format(len(self.dictionary)))
|
64 |
+
|
65 |
+
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
66 |
+
"""Load a given dataset split (e.g., train, valid, test)"""
|
67 |
+
|
68 |
+
# support sharded datasets
|
69 |
+
paths = utils.split_paths(self.cfg.data)
|
70 |
+
assert len(paths) > 0
|
71 |
+
data_path = paths[(epoch - 1) % len(paths)]
|
72 |
+
split_path = os.path.join(data_path, split)
|
73 |
+
|
74 |
+
# each element of *data* will be a tensorized line from the original
|
75 |
+
# text dataset, similar to ``open(split_path).readlines()``
|
76 |
+
data = data_utils.load_indexed_dataset(
|
77 |
+
split_path, self.dictionary, combine=combine
|
78 |
+
)
|
79 |
+
if data is None:
|
80 |
+
raise FileNotFoundError(
|
81 |
+
"Dataset not found: {} ({})".format(split, split_path)
|
82 |
+
)
|
83 |
+
|
84 |
+
# this is similar to ``data.view(-1).split(tokens_per_sample)``
|
85 |
+
data = TokenBlockDataset(
|
86 |
+
data,
|
87 |
+
data.sizes,
|
88 |
+
block_size=self.cfg.tokens_per_sample,
|
89 |
+
pad=None, # unused
|
90 |
+
eos=None, # unused
|
91 |
+
break_mode="none",
|
92 |
+
)
|
93 |
+
|
94 |
+
self.datasets[split] = TruncatedBPTTDataset(
|
95 |
+
data=data,
|
96 |
+
bsz_per_shard=self.cfg.batch_size,
|
97 |
+
shard_id=self.cfg.data_parallel_rank,
|
98 |
+
num_shards=self.cfg.data_parallel_size,
|
99 |
+
)
|
100 |
+
|
101 |
+
def dataset(self, split):
|
102 |
+
return self.datasets[split]
|
103 |
+
|
104 |
+
def get_batch_iterator(
|
105 |
+
self,
|
106 |
+
dataset,
|
107 |
+
num_workers=0,
|
108 |
+
epoch=1,
|
109 |
+
data_buffer_size=0,
|
110 |
+
skip_remainder_batch=False,
|
111 |
+
**kwargs
|
112 |
+
):
|
113 |
+
return iterators.EpochBatchIterator(
|
114 |
+
dataset=dataset,
|
115 |
+
collate_fn=self._collate_fn,
|
116 |
+
num_workers=num_workers,
|
117 |
+
epoch=epoch,
|
118 |
+
buffer_size=data_buffer_size,
|
119 |
+
# we don't use the batching functionality from EpochBatchIterator;
|
120 |
+
# instead every item in *dataset* is a whole batch
|
121 |
+
batch_sampler=[[i] for i in range(len(dataset))],
|
122 |
+
disable_shuffling=True,
|
123 |
+
skip_remainder_batch=skip_remainder_batch,
|
124 |
+
)
|
125 |
+
|
126 |
+
def _collate_fn(self, items: List[List[torch.Tensor]]):
|
127 |
+
# we don't use fairseq's batching functionality, so we expect a single
|
128 |
+
# Tensor of type List[torch.Tensor]
|
129 |
+
assert len(items) == 1
|
130 |
+
|
131 |
+
# item will have shape B x T (the last batch may have length < T)
|
132 |
+
id, item = items[0]
|
133 |
+
item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad())
|
134 |
+
B, T = item.size()
|
135 |
+
|
136 |
+
# shift item one position over and append a padding token for the target
|
137 |
+
target = torch.nn.functional.pad(
|
138 |
+
item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad()
|
139 |
+
)
|
140 |
+
|
141 |
+
# fairseq expects batches to have the following structure
|
142 |
+
return {
|
143 |
+
"id": torch.tensor([id] * item.size(0)),
|
144 |
+
"net_input": {"src_tokens": item,},
|
145 |
+
"target": target,
|
146 |
+
"nsentences": item.size(0),
|
147 |
+
"ntokens": item.numel(),
|
148 |
+
}
|
149 |
+
|
150 |
+
def build_dataset_for_inference(
|
151 |
+
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
|
152 |
+
) -> torch.utils.data.Dataset:
|
153 |
+
eos = self.source_dictionary.eos()
|
154 |
+
dataset = TokenBlockDataset(
|
155 |
+
src_tokens,
|
156 |
+
src_lengths,
|
157 |
+
block_size=None, # ignored for "eos" break mode
|
158 |
+
pad=self.source_dictionary.pad(),
|
159 |
+
eos=eos,
|
160 |
+
break_mode="eos",
|
161 |
+
)
|
162 |
+
|
163 |
+
class Dataset(torch.utils.data.Dataset):
|
164 |
+
def __getitem__(self, i):
|
165 |
+
item = dataset[i]
|
166 |
+
if item[-1] == eos:
|
167 |
+
# remove eos to support generating with a prefix
|
168 |
+
item = item[:-1]
|
169 |
+
return (i, [item])
|
170 |
+
|
171 |
+
def __len__(self):
|
172 |
+
return len(dataset)
|
173 |
+
|
174 |
+
return Dataset()
|
175 |
+
|
176 |
+
def inference_step(
|
177 |
+
self, generator, models, sample, prefix_tokens=None, constraints=None
|
178 |
+
):
|
179 |
+
with torch.no_grad():
|
180 |
+
if constraints is not None:
|
181 |
+
raise NotImplementedError
|
182 |
+
|
183 |
+
# SequenceGenerator doesn't use *src_tokens* directly, we need to
|
184 |
+
# pass the *prefix_tokens* argument instead.
|
185 |
+
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
|
186 |
+
prefix_tokens = sample["net_input"]["src_tokens"]
|
187 |
+
|
188 |
+
# begin generation with the end-of-sentence token
|
189 |
+
bos_token = self.source_dictionary.eos()
|
190 |
+
|
191 |
+
return generator.generate(
|
192 |
+
models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token
|
193 |
+
)
|
194 |
+
|
195 |
+
def eval_lm_dataloader(
|
196 |
+
self,
|
197 |
+
dataset,
|
198 |
+
max_tokens: Optional[int] = 36000,
|
199 |
+
batch_size: Optional[int] = None,
|
200 |
+
max_positions: Optional[int] = None,
|
201 |
+
num_shards: int = 1,
|
202 |
+
shard_id: int = 0,
|
203 |
+
num_workers: int = 1,
|
204 |
+
data_buffer_size: int = 10,
|
205 |
+
context_window: int = 0,
|
206 |
+
):
|
207 |
+
if context_window > 0:
|
208 |
+
raise NotImplementedError(
|
209 |
+
"Transformer-XL doesn't need --context-window, try "
|
210 |
+
"--model-overrides '{\"mem_len\":42}' instead "
|
211 |
+
)
|
212 |
+
return self.get_batch_iterator(
|
213 |
+
dataset=dataset,
|
214 |
+
max_tokens=max_tokens,
|
215 |
+
max_sentences=batch_size,
|
216 |
+
max_positions=max_positions,
|
217 |
+
ignore_invalid_inputs=True,
|
218 |
+
num_shards=num_shards,
|
219 |
+
shard_id=shard_id,
|
220 |
+
num_workers=num_workers,
|
221 |
+
data_buffer_size=data_buffer_size,
|
222 |
+
).next_epoch_itr(shuffle=False)
|
223 |
+
|
224 |
+
@property
|
225 |
+
def source_dictionary(self):
|
226 |
+
return self.dictionary
|
227 |
+
|
228 |
+
@property
|
229 |
+
def target_dictionary(self):
|
230 |
+
return self.dictionary
|
231 |
+
|
232 |
+
|
233 |
+
class TruncatedBPTTDataset(torch.utils.data.Dataset):
|
234 |
+
def __init__(
|
235 |
+
self,
|
236 |
+
data: List[torch.Tensor], # ordered list of items
|
237 |
+
bsz_per_shard, # number of items processed per GPUs per forward
|
238 |
+
shard_id, # current GPU ID
|
239 |
+
num_shards, # number of GPUs
|
240 |
+
):
|
241 |
+
super().__init__()
|
242 |
+
self.data = data
|
243 |
+
|
244 |
+
def batchify(data, bsz):
|
245 |
+
# Work out how cleanly we can divide the dataset into bsz parts.
|
246 |
+
nbatch = data.size(0) // bsz
|
247 |
+
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
248 |
+
data = data.narrow(0, 0, nbatch * bsz)
|
249 |
+
# Evenly divide the data across the bsz batches.
|
250 |
+
data = data.view(bsz, -1).contiguous()
|
251 |
+
return data
|
252 |
+
|
253 |
+
# total number of sequences processed by all GPUs in each forward pass
|
254 |
+
global_batch_size = bsz_per_shard * num_shards
|
255 |
+
|
256 |
+
"""
|
257 |
+
With a 16 item dataset, bsz_per_shard=2 and num_shards=3,
|
258 |
+
*indices* might look like:
|
259 |
+
|
260 |
+
indices = [[0, 1],
|
261 |
+
[2, 3],
|
262 |
+
[4, 5],
|
263 |
+
[6, 7],
|
264 |
+
[8, 9],
|
265 |
+
[10, 11]]
|
266 |
+
|
267 |
+
The size of the TruncatedBPTTDataset instance will be 2,
|
268 |
+
and shard 1 will see items:
|
269 |
+
|
270 |
+
[(0, [data[4], data[6]]),
|
271 |
+
(1, [data[5], data[7]])]
|
272 |
+
"""
|
273 |
+
indices = batchify(torch.arange(len(data)), global_batch_size)
|
274 |
+
assert indices.size(0) == global_batch_size
|
275 |
+
|
276 |
+
self.my_indices = indices[
|
277 |
+
shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard
|
278 |
+
]
|
279 |
+
assert self.my_indices.size(0) == bsz_per_shard
|
280 |
+
|
281 |
+
def __len__(self):
|
282 |
+
return self.my_indices.size(1)
|
283 |
+
|
284 |
+
def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]:
|
285 |
+
return (i, [self.data[idx] for idx in self.my_indices[:, i]])
|
fairseq/examples/attention_head_selection/README.md
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)
|
2 |
+
|
3 |
+
[https://arxiv.org/pdf/2106.10840.pdf](https://arxiv.org/pdf/2106.10840.pdf)
|
4 |
+
|
5 |
+
## Introduction
|
6 |
+
|
7 |
+
We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.
|
8 |
+
|
9 |
+
Below is an example of training multilingual/multi-domain speech recognition models.
|
10 |
+
|
11 |
+
## Data Preparation
|
12 |
+
Prepare mTEDx data as in [mTEDx example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/mtedx_example.md) and CoVoST data as in [CoVoST example](https://github.com/fairinternal/fairseq-py/blob/0d9c5851e6fac40f9e366b3633ccd615c2901788/examples/speech_to_text/docs/covost_example.md). Similarly prepare EuroParl data.
|
13 |
+
|
14 |
+
|
15 |
+
## Training a multilingual ASR model with attention head selection
|
16 |
+
|
17 |
+
```bash
|
18 |
+
data_dir=<path to mtedx data>
|
19 |
+
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
|
20 |
+
valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
|
21 |
+
strateg=<subset or group>
|
22 |
+
|
23 |
+
fairseq-train ${data_dir} \
|
24 |
+
--user-dir examples/attention_head_selection/src \
|
25 |
+
--train-subset "${train_subset}" \
|
26 |
+
--valid-subset "${valid_subset}" \
|
27 |
+
--config-yaml 'config_asr.yaml' \
|
28 |
+
--arch 'head_selection_s2t_transformer_s' \
|
29 |
+
--task 'speech_to_text_head_selection' \
|
30 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
31 |
+
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
|
32 |
+
--lr 5e-4 \
|
33 |
+
--clip-norm 10.0 \
|
34 |
+
--seed 1 \
|
35 |
+
--max-epoch 400 \
|
36 |
+
--max-tokens 32000 \
|
37 |
+
--ignore-prefix-size 1 \
|
38 |
+
--dropout 0.3 \
|
39 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
40 |
+
--skip-invalid-size-inputs-valid-test \
|
41 |
+
--encoder-attn-head-select \
|
42 |
+
--total-encoder-attention-heads 8 \
|
43 |
+
--decoder-self-attn-head-select \
|
44 |
+
--total-decoder-attention-heads 8 \
|
45 |
+
--attn-head-select-strategy ${strategy} \
|
46 |
+
--task-type lang \
|
47 |
+
```
|
48 |
+
|
49 |
+
## Training a multi-domain ASR model with attention head selection
|
50 |
+
|
51 |
+
```bash
|
52 |
+
data_dir=<path to multi-domain data>
|
53 |
+
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
|
54 |
+
valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
|
55 |
+
strateg=<subset or group>
|
56 |
+
|
57 |
+
fairseq-train ${data_dir} \
|
58 |
+
--user-dir examples/attention_head_selection/src \
|
59 |
+
--train-subset "${train_subset}" \
|
60 |
+
--valid-subset "${valid_subset}" \
|
61 |
+
--config-yaml 'config_asr.yaml' \
|
62 |
+
--arch head_selection_s2t_transformer_s \
|
63 |
+
--task speech_to_text_head_selection \
|
64 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
65 |
+
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
|
66 |
+
--lr 5e-4 \
|
67 |
+
--clip-norm 10.0 \
|
68 |
+
--seed 1 \
|
69 |
+
--max-epoch 400 \
|
70 |
+
--max-tokens 32000 \
|
71 |
+
--ignore-prefix-size 1 \
|
72 |
+
--dropout 0.3 \
|
73 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
74 |
+
--skip-invalid-size-inputs-valid-test \
|
75 |
+
--encoder-attn-head-select \
|
76 |
+
--total-encoder-attention-heads 8 \
|
77 |
+
--decoder-self-attn-head-select \
|
78 |
+
--total-decoder-attention-heads 8 \
|
79 |
+
--attn-head-select-strategy ${strategy} \
|
80 |
+
--task-type domain
|
81 |
+
```
|
82 |
+
|
83 |
+
## Inference in multilingual setting
|
84 |
+
|
85 |
+
```bash
|
86 |
+
MODEL_DIR=<checkpoint directory>
|
87 |
+
data_dir=<path to mtedx data>
|
88 |
+
gen_subset=<data to test, e.g., test_ar_ar_tedx>
|
89 |
+
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
|
90 |
+
last_n=10
|
91 |
+
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
|
92 |
+
CHECKPOINT="_avg"
|
93 |
+
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
|
94 |
+
if [ ! -d $RESULTS ]; then
|
95 |
+
mkdir -p $RESULTS
|
96 |
+
fi;
|
97 |
+
|
98 |
+
python scripts/average_checkpoints.py \
|
99 |
+
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
|
100 |
+
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
|
101 |
+
|
102 |
+
fairseq-generate ${data_dir} \
|
103 |
+
--user-dir examples/attention_head_selection/src \
|
104 |
+
--arch 'head_selection_s2t_transformer_s' \
|
105 |
+
--task 'speech_to_text_head_selection' \
|
106 |
+
--train-subset ${train_subset} \
|
107 |
+
--gen-subset ${gen_subset} \
|
108 |
+
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
|
109 |
+
--config-yaml 'config_asr.yaml' \
|
110 |
+
--prefix-size 1 \
|
111 |
+
--max-tokens 40000 --beam 5 \
|
112 |
+
--skip-invalid-size-inputs-valid-test \
|
113 |
+
--results-path ${RESULTS} \
|
114 |
+
--scoring wer --wer-tokenizer 13a \
|
115 |
+
--wer-lowercase --wer-remove-punct --remove-bpe
|
116 |
+
```
|
117 |
+
|
118 |
+
## Inference in multi-domain setting
|
119 |
+
|
120 |
+
```bash
|
121 |
+
MODEL_DIR=<checkpoint directory>
|
122 |
+
data_dir=<path to multi-domain data>
|
123 |
+
gen_subset=<data to test, e.g., test_pt_pt_cv>
|
124 |
+
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
|
125 |
+
last_n=10
|
126 |
+
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
|
127 |
+
CHECKPOINT="_avg"
|
128 |
+
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
|
129 |
+
if [ ! -d $RESULTS ]; then
|
130 |
+
mkdir -p $RESULTS
|
131 |
+
fi;
|
132 |
+
|
133 |
+
python scripts/average_checkpoints.py \
|
134 |
+
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
|
135 |
+
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
|
136 |
+
|
137 |
+
fairseq-generate ${data_dir} \
|
138 |
+
--user-dir examples/attention_head_selection/src \
|
139 |
+
--arch 'head_selection_s2t_transformer_s' \
|
140 |
+
--task 'speech_to_text_head_selection' \
|
141 |
+
--train-subset ${train_subset} \
|
142 |
+
--gen-subset ${gen_subset} \
|
143 |
+
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
|
144 |
+
--config-yaml 'config_asr.yaml' \
|
145 |
+
--prefix-size 1 \
|
146 |
+
--max-tokens 40000 --beam 5 \
|
147 |
+
--skip-invalid-size-inputs-valid-test \
|
148 |
+
--results-path ${RESULTS} \
|
149 |
+
--scoring wer --wer-tokenizer 13a \
|
150 |
+
--wer-lowercase --wer-remove-punct --remove-bpe
|
151 |
+
```
|
152 |
+
|
153 |
+
## Citation
|
154 |
+
```bibtex
|
155 |
+
@article{gong2021pay,
|
156 |
+
title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
|
157 |
+
author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
|
158 |
+
journal={arXiv preprint arXiv:2106.10840},
|
159 |
+
year={2021}
|
160 |
+
}
|
161 |
+
'''
|
fairseq/examples/attention_head_selection/src/__init__.py
ADDED
File without changes
|
fairseq/examples/attention_head_selection/src/data/__init__.py
ADDED
File without changes
|
fairseq/examples/attention_head_selection/src/data/speech_to_text_dataset_with_domain.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from fairseq.data import (
|
13 |
+
ConcatDataset,
|
14 |
+
Dictionary,
|
15 |
+
FairseqDataset,
|
16 |
+
ResamplingDataset
|
17 |
+
)
|
18 |
+
from fairseq.data.audio.data_cfg import S2TDataConfig
|
19 |
+
from fairseq.data.audio.speech_to_text_dataset import (
|
20 |
+
SpeechToTextDatasetItem,
|
21 |
+
SpeechToTextDataset,
|
22 |
+
SpeechToTextDatasetCreator
|
23 |
+
)
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class SpeechToTextDatasetItemWithDomain(SpeechToTextDatasetItem):
|
30 |
+
src_lang_id: Optional[torch.Tensor] = None
|
31 |
+
tgt_lang_id: Optional[torch.Tensor] = None
|
32 |
+
domain_id: Optional[torch.Tensor] = None
|
33 |
+
|
34 |
+
|
35 |
+
class SpeechToTextDatasetWithDomain(SpeechToTextDataset):
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
split: str,
|
40 |
+
is_train_split: bool,
|
41 |
+
cfg: S2TDataConfig,
|
42 |
+
audio_paths: List[str],
|
43 |
+
n_frames: List[int],
|
44 |
+
src_texts: Optional[List[str]] = None,
|
45 |
+
tgt_texts: Optional[List[str]] = None,
|
46 |
+
speakers: Optional[List[str]] = None,
|
47 |
+
src_langs: Optional[List[str]] = None,
|
48 |
+
tgt_langs: Optional[List[str]] = None,
|
49 |
+
ids: Optional[List[str]] = None,
|
50 |
+
tgt_dict: Optional[Dictionary] = None,
|
51 |
+
pre_tokenizer=None,
|
52 |
+
bpe_tokenizer=None,
|
53 |
+
n_frames_per_step=1,
|
54 |
+
speaker_to_id=None,
|
55 |
+
src_lang_ids: Optional[List[int]] = None,
|
56 |
+
tgt_lang_ids: Optional[List[int]] = None,
|
57 |
+
domain_ids: Optional[List[int]] = None
|
58 |
+
):
|
59 |
+
super().__init__(
|
60 |
+
split, is_train_split, cfg, audio_paths, n_frames,
|
61 |
+
src_texts, tgt_texts, speakers, src_langs, tgt_langs,
|
62 |
+
ids, tgt_dict, pre_tokenizer, bpe_tokenizer,
|
63 |
+
n_frames_per_step, speaker_to_id
|
64 |
+
)
|
65 |
+
assert src_lang_ids is None or len(src_lang_ids) == self.n_samples
|
66 |
+
assert tgt_lang_ids is None or len(tgt_lang_ids) == self.n_samples
|
67 |
+
assert domain_ids is None or len(domain_ids) == self.n_samples
|
68 |
+
|
69 |
+
self.src_lang_ids = src_lang_ids
|
70 |
+
self.tgt_lang_ids = tgt_lang_ids
|
71 |
+
self.domain_ids = domain_ids
|
72 |
+
|
73 |
+
def __getitem__(self, index: int) -> SpeechToTextDatasetItemWithDomain:
|
74 |
+
item = super().__getitem__(index)
|
75 |
+
src_lang_id = self.src_lang_ids[index]
|
76 |
+
tgt_lang_id = self.tgt_lang_ids[index]
|
77 |
+
domain_id = self.domain_ids[index]
|
78 |
+
return SpeechToTextDatasetItemWithDomain(
|
79 |
+
index=item.index, source=item.source,
|
80 |
+
target=item.target, speaker_id=item.speaker_id,
|
81 |
+
src_lang_id=src_lang_id,
|
82 |
+
tgt_lang_id=tgt_lang_id,
|
83 |
+
domain_id=domain_id
|
84 |
+
)
|
85 |
+
|
86 |
+
def collater(
|
87 |
+
self, samples: List[SpeechToTextDatasetItem], return_order: bool = False
|
88 |
+
) -> Dict:
|
89 |
+
if len(samples) == 0:
|
90 |
+
return {}
|
91 |
+
out = super().collater(samples, return_order=True)
|
92 |
+
order = out["order"]
|
93 |
+
src_lang_ids = torch.tensor([x.src_lang_id for x in samples], dtype=torch.long).index_select(0, order)
|
94 |
+
tgt_lang_ids = torch.tensor([x.tgt_lang_id for x in samples], dtype=torch.long).index_select(0, order)
|
95 |
+
domain_ids = torch.tensor([x.domain_id for x in samples], dtype=torch.long).index_select(0, order)
|
96 |
+
|
97 |
+
out["src_lang_ids"] = src_lang_ids
|
98 |
+
out["tgt_lang_ids"] = tgt_lang_ids
|
99 |
+
out["domain_ids"] = domain_ids
|
100 |
+
if not return_order:
|
101 |
+
del out["order"]
|
102 |
+
return out
|
103 |
+
|
104 |
+
|
105 |
+
class SpeechToTextDatasetCreatorWithDomain(SpeechToTextDatasetCreator):
|
106 |
+
KEY_SRC_LANG_ID, KEY_TGT_LANG_ID = "src_lang_id", "tgt_lang_id"
|
107 |
+
KEY_DOMAIN_ID = "domain_id"
|
108 |
+
# default values
|
109 |
+
DEFAULT_SRC_LANG_ID, DEFAULT_TGT_LANG_ID, DEFAULT_DOMAIN_ID = 0, 0, 0
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def _from_list(
|
113 |
+
cls,
|
114 |
+
split_name: str,
|
115 |
+
is_train_split,
|
116 |
+
samples: List[Dict],
|
117 |
+
cfg: S2TDataConfig,
|
118 |
+
tgt_dict,
|
119 |
+
pre_tokenizer,
|
120 |
+
bpe_tokenizer,
|
121 |
+
n_frames_per_step,
|
122 |
+
speaker_to_id
|
123 |
+
) -> SpeechToTextDatasetWithDomain:
|
124 |
+
audio_root = Path(cfg.audio_root)
|
125 |
+
ids = [s[cls.KEY_ID] for s in samples]
|
126 |
+
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples]
|
127 |
+
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples]
|
128 |
+
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples]
|
129 |
+
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples]
|
130 |
+
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples]
|
131 |
+
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples]
|
132 |
+
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples]
|
133 |
+
src_lang_ids = [s.get(cls.KEY_SRC_LANG_ID, cls.DEFAULT_SRC_LANG_ID) for s in samples]
|
134 |
+
tgt_lang_ids = [s.get(cls.KEY_TGT_LANG_ID, cls.DEFAULT_TGT_LANG_ID) for s in samples]
|
135 |
+
domain_ids = [s.get(cls.KEY_DOMAIN_ID, cls.DEFAULT_DOMAIN_ID) for s in samples]
|
136 |
+
return SpeechToTextDatasetWithDomain(
|
137 |
+
split_name,
|
138 |
+
is_train_split,
|
139 |
+
cfg,
|
140 |
+
audio_paths,
|
141 |
+
n_frames,
|
142 |
+
src_texts=src_texts,
|
143 |
+
tgt_texts=tgt_texts,
|
144 |
+
speakers=speakers,
|
145 |
+
src_langs=src_langs,
|
146 |
+
tgt_langs=tgt_langs,
|
147 |
+
ids=ids,
|
148 |
+
tgt_dict=tgt_dict,
|
149 |
+
pre_tokenizer=pre_tokenizer,
|
150 |
+
bpe_tokenizer=bpe_tokenizer,
|
151 |
+
n_frames_per_step=n_frames_per_step,
|
152 |
+
speaker_to_id=speaker_to_id,
|
153 |
+
src_lang_ids=src_lang_ids,
|
154 |
+
tgt_lang_ids=tgt_lang_ids,
|
155 |
+
domain_ids=domain_ids
|
156 |
+
)
|
157 |
+
|
158 |
+
@classmethod
|
159 |
+
def _load_samples_from_tsv(
|
160 |
+
cls,
|
161 |
+
root: str,
|
162 |
+
split: str,
|
163 |
+
src_lang_map,
|
164 |
+
tgt_lang_map,
|
165 |
+
domain_map
|
166 |
+
):
|
167 |
+
# metadata from split
|
168 |
+
_, src_lang, tgt_lang, domain = split.split("_")
|
169 |
+
src_lang_id = src_lang_map[src_lang]
|
170 |
+
tgt_lang_id = tgt_lang_map[tgt_lang]
|
171 |
+
domain_id = domain_map[domain]
|
172 |
+
|
173 |
+
samples = SpeechToTextDatasetCreator._load_samples_from_tsv(root, split)
|
174 |
+
for s in samples:
|
175 |
+
s.update({
|
176 |
+
cls.KEY_SRC_LANG_ID: src_lang_id,
|
177 |
+
cls.KEY_TGT_LANG_ID: tgt_lang_id,
|
178 |
+
cls.KEY_DOMAIN_ID: domain_id
|
179 |
+
})
|
180 |
+
return samples
|
181 |
+
|
182 |
+
@classmethod
|
183 |
+
def _from_tsv(
|
184 |
+
cls,
|
185 |
+
root: str,
|
186 |
+
cfg: S2TDataConfig,
|
187 |
+
split: str,
|
188 |
+
tgt_dict,
|
189 |
+
is_train_split: bool,
|
190 |
+
pre_tokenizer,
|
191 |
+
bpe_tokenizer,
|
192 |
+
n_frames_per_step,
|
193 |
+
speaker_to_id,
|
194 |
+
src_lang_map: Dict[str, int],
|
195 |
+
tgt_lang_map: Dict[str, int],
|
196 |
+
domain_map: Dict[str, int]
|
197 |
+
) -> SpeechToTextDatasetItemWithDomain:
|
198 |
+
samples = cls._load_samples_from_tsv(
|
199 |
+
root, split, src_lang_map,
|
200 |
+
tgt_lang_map, domain_map
|
201 |
+
)
|
202 |
+
return cls._from_list(
|
203 |
+
split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer,
|
204 |
+
bpe_tokenizer, n_frames_per_step, speaker_to_id
|
205 |
+
)
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def from_tsv(
|
209 |
+
cls,
|
210 |
+
root: str,
|
211 |
+
cfg: S2TDataConfig,
|
212 |
+
splits: str,
|
213 |
+
tgt_dict,
|
214 |
+
pre_tokenizer,
|
215 |
+
bpe_tokenizer,
|
216 |
+
is_train_split: bool,
|
217 |
+
epoch: int,
|
218 |
+
seed: int,
|
219 |
+
src_lang_map: Dict[str, int],
|
220 |
+
tgt_lang_map: Dict[str, int],
|
221 |
+
domain_map: Dict[str, int],
|
222 |
+
n_frames_per_step: int = 1,
|
223 |
+
speaker_to_id=None
|
224 |
+
) -> SpeechToTextDatasetWithDomain:
|
225 |
+
datasets = [
|
226 |
+
cls._from_tsv(
|
227 |
+
root, cfg, split, tgt_dict, is_train_split, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, src_lang_map, tgt_lang_map, domain_map
|
228 |
+
)
|
229 |
+
for split in splits.split(",")
|
230 |
+
]
|
231 |
+
|
232 |
+
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0:
|
233 |
+
# temperature-based sampling
|
234 |
+
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha)
|
235 |
+
datasets = [
|
236 |
+
ResamplingDataset(
|
237 |
+
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
238 |
+
)
|
239 |
+
for r, d in zip(size_ratios, datasets)
|
240 |
+
]
|
241 |
+
|
242 |
+
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
fairseq/examples/attention_head_selection/src/loss/__init__.py
ADDED
File without changes
|
fairseq/examples/attention_head_selection/src/loss/attention_head_selection.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.nn.modules.loss import _Loss
|
10 |
+
|
11 |
+
|
12 |
+
class HeadSelectionLoss(_Loss):
|
13 |
+
|
14 |
+
def __init__(self, args):
|
15 |
+
super().__init__()
|
16 |
+
self.args = args
|
17 |
+
self.kl_weight = getattr(args, "kl_weight", 0.0)
|
18 |
+
|
19 |
+
def forward(self, head_samples, sample_sizes, prior=0.5, eps=1e-7):
|
20 |
+
"""
|
21 |
+
head_scores: (num_tasks, num_layers, num_heads)
|
22 |
+
sample_sizes: (num_tasks, )
|
23 |
+
"""
|
24 |
+
kl_loss = (head_samples * (torch.log(head_samples + eps) - math.log(prior))).sum(-1).sum(-1)
|
25 |
+
kl_loss /= (torch.numel(head_samples) / head_samples.size(0))
|
26 |
+
kl_loss = self.kl_weight * torch.matmul(kl_loss, sample_sizes)
|
27 |
+
return kl_loss
|
fairseq/examples/attention_head_selection/src/models/__init__.py
ADDED
File without changes
|
fairseq/examples/attention_head_selection/src/models/head_selection_s2t_transformer.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
from typing import Dict, List, Optional
|
8 |
+
from pathlib import Path
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch import Tensor
|
11 |
+
from fairseq import checkpoint_utils
|
12 |
+
|
13 |
+
from fairseq.models import register_model, register_model_architecture
|
14 |
+
from fairseq.utils import safe_hasattr
|
15 |
+
from fairseq.models.speech_to_text.s2t_transformer import (
|
16 |
+
S2TTransformerModel,
|
17 |
+
S2TTransformerEncoder,
|
18 |
+
TransformerDecoderScriptable
|
19 |
+
)
|
20 |
+
from fairseq.models.speech_to_text.s2t_transformer import base_architecture as s2t_base_architecture
|
21 |
+
|
22 |
+
from ..modules.attn_head_selector import AttnHeadSelector
|
23 |
+
from ..modules.head_selection_transformer_layer import HeadSelectionTransformerEncoderLayer
|
24 |
+
from .head_selection_transformer import HeadSelectionTransformerDecoder
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
@register_model("head_selection_s2t_transformer")
|
31 |
+
class HeadSelectionS2TTransformerModel(S2TTransformerModel):
|
32 |
+
"""
|
33 |
+
Head selection implemented in S2TTransformer
|
34 |
+
"""
|
35 |
+
def __init__(self, encoder, decoder):
|
36 |
+
super().__init__(encoder, decoder)
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def add_args(parser):
|
40 |
+
S2TTransformerModel.add_args(parser)
|
41 |
+
# encoder head selection
|
42 |
+
parser.add_argument(
|
43 |
+
"--encoder-attn-head-select",
|
44 |
+
action="store_true",
|
45 |
+
default=False,
|
46 |
+
help="encoder head selection"
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--total-encoder-attention-heads",
|
50 |
+
type=int,
|
51 |
+
help="total number of encoder attention heads"
|
52 |
+
)
|
53 |
+
# decoder self attention selection
|
54 |
+
parser.add_argument(
|
55 |
+
"--decoder-self-attn-head-select",
|
56 |
+
action="store_true",
|
57 |
+
default=False,
|
58 |
+
help="decoder self-attention head selection"
|
59 |
+
)
|
60 |
+
# decoder-encoder attention selection
|
61 |
+
parser.add_argument(
|
62 |
+
"--dec-enc-attn-head-select",
|
63 |
+
action="store_true",
|
64 |
+
default=False,
|
65 |
+
help="decoder-encoder attention head selection"
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--total-decoder-attention-heads",
|
69 |
+
type=int,
|
70 |
+
help="total number of decoder attention heads"
|
71 |
+
)
|
72 |
+
# selection strategy
|
73 |
+
parser.add_argument(
|
74 |
+
"--attn-head-select-strategy",
|
75 |
+
type=str,
|
76 |
+
help="attention head selection strategy, subset or group"
|
77 |
+
)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def build_encoder(cls, args):
|
81 |
+
if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
|
82 |
+
encoder = HeadSelectionS2TTransformerEncoder(args)
|
83 |
+
else:
|
84 |
+
encoder = S2TTransformerEncoder(args)
|
85 |
+
pretraining_path = getattr(args, "load_pretrained_encoder_from", None)
|
86 |
+
if pretraining_path is not None:
|
87 |
+
if not Path(pretraining_path).exists():
|
88 |
+
logger.warning(
|
89 |
+
f"skipped pretraining because {pretraining_path} does not exist"
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
encoder = checkpoint_utils.load_pretrained_component_from_model(
|
93 |
+
component=encoder, checkpoint=pretraining_path
|
94 |
+
)
|
95 |
+
logger.info(f"loaded pretrained encoder from: {pretraining_path}")
|
96 |
+
return encoder
|
97 |
+
|
98 |
+
@classmethod
|
99 |
+
def build_decoder(cls, args, task, embed_tokens):
|
100 |
+
if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
|
101 |
+
return HeadSelectionTransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
|
102 |
+
else:
|
103 |
+
return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens)
|
104 |
+
|
105 |
+
|
106 |
+
class HeadSelectionS2TTransformerEncoder(S2TTransformerEncoder):
|
107 |
+
|
108 |
+
def __init__(self, args):
|
109 |
+
super().__init__(args)
|
110 |
+
self.attn_head_selector = AttnHeadSelector(
|
111 |
+
args.encoder_tasks,
|
112 |
+
args.encoder_layers,
|
113 |
+
args.total_encoder_attention_heads,
|
114 |
+
args.encoder_attention_heads,
|
115 |
+
args.attn_head_select_strategy,
|
116 |
+
)
|
117 |
+
self.task_ids = None
|
118 |
+
self.transformer_layers = nn.ModuleList([
|
119 |
+
HeadSelectionTransformerEncoderLayer(args, layer_idx, attn_head_selector=self.attn_head_selector) for layer_idx in range(args.encoder_layers)
|
120 |
+
])
|
121 |
+
|
122 |
+
def set_task_ids(self, task_ids):
|
123 |
+
self.task_ids = task_ids
|
124 |
+
|
125 |
+
def _forward(self, src_tokens, src_lengths, return_all_hiddens=False):
|
126 |
+
self.attn_head_selector.head_select(self.task_ids)
|
127 |
+
return super()._forward(src_tokens, src_lengths, return_all_hiddens)
|
128 |
+
|
129 |
+
|
130 |
+
class HeadSelectionTransformerDecoderScriptable(HeadSelectionTransformerDecoder):
|
131 |
+
def extract_features(
|
132 |
+
self,
|
133 |
+
prev_output_tokens,
|
134 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
135 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
136 |
+
full_context_alignment: bool = False,
|
137 |
+
alignment_layer: Optional[int] = None,
|
138 |
+
alignment_heads: Optional[int] = None,
|
139 |
+
):
|
140 |
+
# call scriptable method from parent class
|
141 |
+
x, _ = self.extract_features_scriptable(
|
142 |
+
prev_output_tokens,
|
143 |
+
encoder_out,
|
144 |
+
incremental_state,
|
145 |
+
full_context_alignment,
|
146 |
+
alignment_layer,
|
147 |
+
alignment_heads,
|
148 |
+
)
|
149 |
+
return x, None
|
150 |
+
|
151 |
+
|
152 |
+
@register_model_architecture(model_name="head_selection_s2t_transformer", arch_name="head_selection_s2t_transformer")
|
153 |
+
def base_architecture(args):
|
154 |
+
s2t_base_architecture(args)
|
155 |
+
args.encoder_attn_head_select = getattr(args, "encoder_attn_head_select", False)
|
156 |
+
args.decoder_self_attn_head_select = getattr(args, "decoder_self_attn_head_select", False)
|
157 |
+
args.dec_enc_attn_head_select = getattr(args, "dec_enc_attn_head_select", False)
|
158 |
+
args.total_encoder_attention_heads = getattr(args, "total_encoder_attention_heads", 8)
|
159 |
+
args.total_decoder_attention_heads = getattr(args, "total_decoder_attention_heads", 8)
|
160 |
+
args.attn_head_select_strategy = getattr(args, "attn_head_select_strategy", "group")
|
161 |
+
|
162 |
+
|
163 |
+
@register_model_architecture("head_selection_s2t_transformer", "head_selection_s2t_transformer_s")
|
164 |
+
def head_selection_s2t_transformer_s(args):
|
165 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
|
166 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8)
|
167 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
168 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
169 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
170 |
+
base_architecture(args)
|
fairseq/examples/attention_head_selection/src/models/head_selection_transformer.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Any, List, Dict, Optional
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from fairseq.utils import safe_hasattr
|
12 |
+
from fairseq.models.transformer import (
|
13 |
+
TransformerModel,
|
14 |
+
TransformerEncoder,
|
15 |
+
TransformerDecoder
|
16 |
+
)
|
17 |
+
|
18 |
+
from ..modules.attn_head_selector import AttnHeadSelector
|
19 |
+
from ..modules.head_selection_transformer_layer import (
|
20 |
+
HeadSelectionTransformerEncoderLayer,
|
21 |
+
HeadSelectionTransformerDecoderLayer
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class HeadSelectionTransformerModel(TransformerModel):
|
26 |
+
def __init__(self, args, encoder, decoder):
|
27 |
+
super().__init__(args, encoder, decoder)
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def add_args(parser):
|
31 |
+
TransformerModel.add_args(parser)
|
32 |
+
# encoder head selection
|
33 |
+
parser.add_argument(
|
34 |
+
"--encoder-attn-head-select",
|
35 |
+
action="store_true",
|
36 |
+
default=False,
|
37 |
+
help="encoder head selection"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--total-encoder-attention-heads",
|
41 |
+
type=int,
|
42 |
+
help="total number of encoder attention heads"
|
43 |
+
)
|
44 |
+
# decoder self attention
|
45 |
+
parser.add_argument(
|
46 |
+
"--decoder-self-attn-head-select",
|
47 |
+
action="store_true",
|
48 |
+
default=False,
|
49 |
+
help="decoder self-attention head selection"
|
50 |
+
)
|
51 |
+
# decoder-encoder attention
|
52 |
+
parser.add_argument(
|
53 |
+
"--dec-enc-attn-head-select",
|
54 |
+
action="store_true",
|
55 |
+
default=False,
|
56 |
+
help="decoder-encoder attention head selection"
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--total-decoder-attention-heads",
|
60 |
+
type=int,
|
61 |
+
help="total number of decoder attention heads"
|
62 |
+
)
|
63 |
+
# selection strategy
|
64 |
+
parser.add_argument(
|
65 |
+
"--attn-head-select-strategy",
|
66 |
+
type=str,
|
67 |
+
help="attention head selection strategy, subset or group"
|
68 |
+
)
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
72 |
+
if safe_hasattr(args, "encoder_attn_head_select") and args.encoder_attn_head_select:
|
73 |
+
return HeadSelectionTransformerEncoder(
|
74 |
+
args, src_dict, embed_tokens
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
return TransformerEncoder(args, src_dict, embed_tokens)
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
81 |
+
if (safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select) or (safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select):
|
82 |
+
return HeadSelectionTransformerDecoder(
|
83 |
+
args, tgt_dict, embed_tokens
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
return TransformerDecoder(args, tgt_dict, embed_tokens)
|
87 |
+
|
88 |
+
|
89 |
+
class HeadSelectionTransformerEncoder(TransformerEncoder):
|
90 |
+
|
91 |
+
def __init__(self, args, dictionary, embed_tokens):
|
92 |
+
self.num_tasks = args.encoder_tasks
|
93 |
+
self.num_layers = args.encoder_layers
|
94 |
+
self.total_num_heads = args.total_encoder_attention_heads
|
95 |
+
self.num_heads = args.encoder_attention_heads
|
96 |
+
self.select_strategy = args.attn_head_select_strategy
|
97 |
+
|
98 |
+
super().__init__(args, dictionary, embed_tokens)
|
99 |
+
self.attn_head_selector = AttnHeadSelector(
|
100 |
+
self.num_tasks,
|
101 |
+
self.num_layers,
|
102 |
+
self.total_num_heads,
|
103 |
+
self.num_heads,
|
104 |
+
self.select_strategy
|
105 |
+
)
|
106 |
+
self.task_ids = None
|
107 |
+
self.layers = nn.ModuleList(
|
108 |
+
[self.build_encoder_layer(args, i) for i in range(args.encoder_layers)]
|
109 |
+
)
|
110 |
+
|
111 |
+
def set_task_ids(self, task_ids):
|
112 |
+
self.task_ids = task_ids
|
113 |
+
|
114 |
+
def build_encoder_layer(self, args, layer_idx=None):
|
115 |
+
return HeadSelectionTransformerEncoderLayer(
|
116 |
+
args,
|
117 |
+
layer_idx,
|
118 |
+
attn_head_selector=self.attn_head_selector
|
119 |
+
)
|
120 |
+
|
121 |
+
def forward(
|
122 |
+
self,
|
123 |
+
src_tokens,
|
124 |
+
src_lengths: Optional[torch.Tensor] = None,
|
125 |
+
return_all_hiddens: bool = False,
|
126 |
+
token_embeddings: Optional[torch.Tensor] = None,
|
127 |
+
):
|
128 |
+
self.attn_head_selector.head_select(self.task_ids)
|
129 |
+
return super().forward(src_tokens, src_lengths, return_all_hiddens, token_embeddings)
|
130 |
+
|
131 |
+
|
132 |
+
class HeadSelectionTransformerDecoder(TransformerDecoder):
|
133 |
+
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
args,
|
137 |
+
dictionary,
|
138 |
+
embed_tokens,
|
139 |
+
no_encoder_attn=False,
|
140 |
+
output_projection=None,
|
141 |
+
):
|
142 |
+
self.num_tasks = args.decoder_tasks
|
143 |
+
self.num_layers = args.decoder_layers
|
144 |
+
self.total_num_heads = args.total_decoder_attention_heads
|
145 |
+
self.num_heads = args.decoder_attention_heads
|
146 |
+
self.select_strategy = args.attn_head_select_strategy
|
147 |
+
super().__init__(
|
148 |
+
args, dictionary, embed_tokens,
|
149 |
+
no_encoder_attn=no_encoder_attn,
|
150 |
+
output_projection=output_projection
|
151 |
+
)
|
152 |
+
self.self_attn_head_selector = None
|
153 |
+
self.enc_attn_head_selector = None
|
154 |
+
if safe_hasattr(args, "decoder_self_attn_head_select") and args.decoder_self_attn_head_select:
|
155 |
+
self.self_attn_head_selector = AttnHeadSelector(
|
156 |
+
self.num_tasks,
|
157 |
+
self.num_layers,
|
158 |
+
self.total_num_heads,
|
159 |
+
self.num_heads,
|
160 |
+
self.select_strategy
|
161 |
+
)
|
162 |
+
if safe_hasattr(args, "dec_enc_attn_head_select") and args.dec_enc_attn_head_select:
|
163 |
+
self.enc_attn_head_selector = AttnHeadSelector(
|
164 |
+
self.num_tasks,
|
165 |
+
self.num_layers,
|
166 |
+
self.total_num_heads,
|
167 |
+
self.num_heads,
|
168 |
+
self.select_strategy
|
169 |
+
)
|
170 |
+
self.task_ids = None
|
171 |
+
self.layers = nn.ModuleList(
|
172 |
+
[
|
173 |
+
self.build_head_selection_decoder_layer(args, no_encoder_attn, idx) for idx in range(args.decoder_layers)
|
174 |
+
]
|
175 |
+
)
|
176 |
+
|
177 |
+
def set_task_ids(self, task_ids):
|
178 |
+
self.task_ids = task_ids
|
179 |
+
|
180 |
+
def build_head_selection_decoder_layer(self, args, no_encoder_attn=False, layer_idx=None):
|
181 |
+
return HeadSelectionTransformerDecoderLayer(
|
182 |
+
args,
|
183 |
+
layer_idx,
|
184 |
+
self.self_attn_head_selector,
|
185 |
+
self.enc_attn_head_selector,
|
186 |
+
no_encoder_attn=no_encoder_attn
|
187 |
+
)
|
188 |
+
|
189 |
+
def forward(
|
190 |
+
self,
|
191 |
+
prev_output_tokens,
|
192 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
193 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
194 |
+
features_only: bool = False,
|
195 |
+
full_context_alignment: bool = False,
|
196 |
+
alignment_layer: Optional[int] = None,
|
197 |
+
alignment_heads: Optional[int] = None,
|
198 |
+
src_lengths: Optional[Any] = None,
|
199 |
+
return_all_hiddens: bool = False,
|
200 |
+
):
|
201 |
+
if self.self_attn_head_selector is not None:
|
202 |
+
self.self_attn_head_selector.head_select(self.task_ids)
|
203 |
+
if self.enc_attn_head_selector is not None:
|
204 |
+
self.enc_attn_head_selector.head_select(self.task_ids)
|
205 |
+
return super().forward(
|
206 |
+
prev_output_tokens=prev_output_tokens,
|
207 |
+
encoder_out=encoder_out,
|
208 |
+
incremental_state=incremental_state,
|
209 |
+
features_only=features_only,
|
210 |
+
full_context_alignment=full_context_alignment,
|
211 |
+
alignment_layer=alignment_layer,
|
212 |
+
alignment_heads=alignment_heads,
|
213 |
+
src_lengths=src_lengths,
|
214 |
+
return_all_hiddens=return_all_hiddens
|
215 |
+
)
|
fairseq/examples/attention_head_selection/src/modules/__init__.py
ADDED
File without changes
|
fairseq/examples/attention_head_selection/src/modules/attn_head_selector.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This source code is licensed under the MIT license found in the
|
2 |
+
# LICENSE file in the root directory of this source tree.
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
class AttnHeadSelector(nn.Module):
|
10 |
+
"""
|
11 |
+
Latent variable modeling of attention head selection
|
12 |
+
"""
|
13 |
+
def __init__(
|
14 |
+
self, num_tasks, num_layers,
|
15 |
+
total_num_heads, num_heads,
|
16 |
+
select_strategy="group",
|
17 |
+
head_select_temp=5.0
|
18 |
+
):
|
19 |
+
super(AttnHeadSelector, self).__init__()
|
20 |
+
self.num_tasks = num_tasks
|
21 |
+
self.num_layers = num_layers
|
22 |
+
self.total_num_heads = total_num_heads
|
23 |
+
self.num_heads = num_heads
|
24 |
+
self.select_strategy = select_strategy
|
25 |
+
self.temp = head_select_temp
|
26 |
+
|
27 |
+
self.head_logits = torch.nn.Parameter(
|
28 |
+
torch.Tensor(self.num_tasks, self.num_layers, total_num_heads),
|
29 |
+
requires_grad=True
|
30 |
+
)
|
31 |
+
nn.init.uniform_(
|
32 |
+
self.head_logits, a=math.log(0.01),
|
33 |
+
b=math.log(1.0)
|
34 |
+
)
|
35 |
+
|
36 |
+
def gumbel_sample(self, logits, tau=1.0):
|
37 |
+
gumbels1 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
38 |
+
gumbels2 = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
|
39 |
+
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
40 |
+
y_soft = gumbels1.sigmoid()
|
41 |
+
return y_soft
|
42 |
+
|
43 |
+
def subset_select(self, y_soft, topk, dim=-1):
|
44 |
+
top_values, top_inds = torch.topk(y_soft, k=topk, dim=dim)
|
45 |
+
top_ret = 1.0 - top_values.detach() + top_values
|
46 |
+
return top_inds.detach(), top_ret
|
47 |
+
|
48 |
+
def group_selet(self, y_soft, topk, dim=-1):
|
49 |
+
# top_values: (num_tasks, num_layers, topk)
|
50 |
+
top_values, top_inds = torch.max(
|
51 |
+
y_soft.view(self.num_tasks, self.num_layers, -1, topk), dim=2
|
52 |
+
)
|
53 |
+
top_inds = top_inds * topk + torch.arange(topk, device=top_inds.device).unsqueeze(0).unsqueeze(1)
|
54 |
+
top_ret = 1.0 - top_values.detach() + top_values
|
55 |
+
return top_inds.detach(), top_ret
|
56 |
+
|
57 |
+
def head_select(self, task_ids=None):
|
58 |
+
# gumbel_sample
|
59 |
+
self.head_samples = self.gumbel_sample(self.head_logits, tau=self.temp)
|
60 |
+
# head select
|
61 |
+
if self.select_strategy == "subset":
|
62 |
+
self.subset_heads, self.subset_weights = self.subset_select(
|
63 |
+
self.head_samples,
|
64 |
+
topk=self.num_heads,
|
65 |
+
)
|
66 |
+
elif self.select_strategy == "group":
|
67 |
+
self.subset_heads, self.subset_weights = self.group_selet(
|
68 |
+
self.head_samples,
|
69 |
+
topk=self.num_heads,
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
raise ValueError("{} is not supported".format(self.select_strategy))
|
73 |
+
|
74 |
+
self.batch_subset = self.subset_heads[task_ids, :, :]
|
75 |
+
self.batch_weights = self.subset_weights[task_ids, :, :]
|
76 |
+
|
77 |
+
def forward(self, layer_idx):
|
78 |
+
assert layer_idx is not None
|
79 |
+
batch_subset = self.batch_subset[:, layer_idx, :]
|
80 |
+
batch_weights = self.batch_weights[:, layer_idx, :]
|
81 |
+
return batch_subset, batch_weights
|
fairseq/examples/attention_head_selection/src/modules/head_selection_transformer_layer.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.utils import safe_getattr
|
7 |
+
from fairseq.modules import TransformerEncoderLayer, TransformerDecoderLayer
|
8 |
+
from ..modules.multihead_attention_selection import MultiheadAttentionSelection
|
9 |
+
|
10 |
+
|
11 |
+
class HeadSelectionTransformerEncoderLayer(TransformerEncoderLayer):
|
12 |
+
|
13 |
+
def __init__(self, args, layer_idx, attn_head_selector=None):
|
14 |
+
super().__init__(args)
|
15 |
+
self.layer_idx = layer_idx
|
16 |
+
self.self_attn = self.build_self_attention_selection(
|
17 |
+
self.embed_dim, args, attn_head_selector
|
18 |
+
)
|
19 |
+
|
20 |
+
def build_self_attention_selection(self, embed_dim, args, attn_head_selector=None):
|
21 |
+
return MultiheadAttentionSelection(
|
22 |
+
embed_dim,
|
23 |
+
args.total_encoder_attention_heads,
|
24 |
+
args.encoder_attention_heads,
|
25 |
+
dropout=args.attention_dropout,
|
26 |
+
self_attention=True,
|
27 |
+
q_noise=self.quant_noise,
|
28 |
+
qn_block_size=self.quant_noise_block_size,
|
29 |
+
layer_idx=self.layer_idx,
|
30 |
+
attn_head_selector=attn_head_selector
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class HeadSelectionTransformerDecoderLayer(TransformerDecoderLayer):
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
args,
|
39 |
+
layer_idx,
|
40 |
+
self_attn_head_selector=None,
|
41 |
+
enc_attn_head_selector=None,
|
42 |
+
no_encoder_attn=False,
|
43 |
+
add_bias_kv=False,
|
44 |
+
add_zero_attn=False,
|
45 |
+
):
|
46 |
+
self.layer_idx = layer_idx
|
47 |
+
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
|
48 |
+
if self_attn_head_selector is not None:
|
49 |
+
self.self_attn = self.build_self_attention_selection(
|
50 |
+
self.embed_dim, args,
|
51 |
+
self_attn_head_selector=self_attn_head_selector,
|
52 |
+
add_bias_kv=add_bias_kv,
|
53 |
+
add_zero_attn=add_zero_attn
|
54 |
+
)
|
55 |
+
if enc_attn_head_selector is not None:
|
56 |
+
self.encoder_attn = self.build_encoder_attention_selection(
|
57 |
+
self.embed_dim, args,
|
58 |
+
enc_attn_head_selector=enc_attn_head_selector
|
59 |
+
)
|
60 |
+
|
61 |
+
def build_self_attention_selection(
|
62 |
+
self, embed_dim, args, self_attn_head_selector=None,
|
63 |
+
add_bias_kv=False, add_zero_attn=False
|
64 |
+
):
|
65 |
+
return MultiheadAttentionSelection(
|
66 |
+
embed_dim,
|
67 |
+
args.total_decoder_attention_heads,
|
68 |
+
args.decoder_attention_heads,
|
69 |
+
dropout=args.attention_dropout,
|
70 |
+
add_bias_kv=add_bias_kv,
|
71 |
+
add_zero_attn=add_zero_attn,
|
72 |
+
self_attention=not safe_getattr(args, "cross_self_attention"),
|
73 |
+
q_noise=self.quant_noise,
|
74 |
+
qn_block_size=self.quant_noise_block_size,
|
75 |
+
layer_idx=self.layer_idx,
|
76 |
+
attn_head_selector=self_attn_head_selector,
|
77 |
+
)
|
78 |
+
|
79 |
+
def build_encoder_attention_selection(self, embed_dim, args, enc_attn_head_selector=None):
|
80 |
+
return MultiheadAttentionSelection(
|
81 |
+
embed_dim,
|
82 |
+
args.total_decoder_attention_heads,
|
83 |
+
args.decoder_attention_heads,
|
84 |
+
kdim=args.encoder_embed_dim,
|
85 |
+
vdim=args.encoder_embed_dim,
|
86 |
+
dropout=args.attention_dropout,
|
87 |
+
encoder_decoder_attention=True,
|
88 |
+
q_noise=self.quant_noise,
|
89 |
+
qn_block_size=self.quant_noise_block_size,
|
90 |
+
layer_idx=self.layer_idx,
|
91 |
+
attn_head_selector=enc_attn_head_selector,
|
92 |
+
)
|
fairseq/examples/attention_head_selection/src/modules/multihead_attention_selection.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Dict, Optional, Tuple
|
7 |
+
import torch
|
8 |
+
from fairseq import utils
|
9 |
+
from fairseq.modules.quant_noise import quant_noise
|
10 |
+
from torch import Tensor, nn
|
11 |
+
from torch.nn import Parameter
|
12 |
+
|
13 |
+
from fairseq.modules.multihead_attention import MultiheadAttention
|
14 |
+
from ..modules.multihead_functional import multi_head_attention_forward
|
15 |
+
|
16 |
+
|
17 |
+
class MultiheadAttentionSelection(MultiheadAttention):
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
embed_dim,
|
22 |
+
total_num_heads,
|
23 |
+
num_heads,
|
24 |
+
kdim=None,
|
25 |
+
vdim=None,
|
26 |
+
dropout=0.0,
|
27 |
+
bias=True,
|
28 |
+
add_bias_kv=False,
|
29 |
+
add_zero_attn=False,
|
30 |
+
self_attention=False,
|
31 |
+
encoder_decoder_attention=False,
|
32 |
+
q_noise=0.0,
|
33 |
+
qn_block_size=8,
|
34 |
+
layer_idx=0,
|
35 |
+
attn_head_selector=None
|
36 |
+
):
|
37 |
+
super().__init__(
|
38 |
+
embed_dim,
|
39 |
+
num_heads,
|
40 |
+
kdim=kdim,
|
41 |
+
vdim=vdim,
|
42 |
+
dropout=dropout,
|
43 |
+
bias=bias,
|
44 |
+
add_bias_kv=add_bias_kv,
|
45 |
+
add_zero_attn=add_zero_attn,
|
46 |
+
self_attention=self_attention,
|
47 |
+
encoder_decoder_attention=encoder_decoder_attention,
|
48 |
+
q_noise=q_noise,
|
49 |
+
qn_block_size=qn_block_size,
|
50 |
+
)
|
51 |
+
self.layer_idx = layer_idx
|
52 |
+
self.attn_head_selector = attn_head_selector
|
53 |
+
self.total_num_heads = total_num_heads
|
54 |
+
self.total_embed_dim = self.head_dim * total_num_heads
|
55 |
+
self.k_proj = quant_noise(
|
56 |
+
nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
|
57 |
+
)
|
58 |
+
self.v_proj = quant_noise(
|
59 |
+
nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
|
60 |
+
)
|
61 |
+
self.q_proj = quant_noise(
|
62 |
+
nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size
|
63 |
+
)
|
64 |
+
if add_bias_kv:
|
65 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
|
66 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim))
|
67 |
+
else:
|
68 |
+
self.bias_k = self.bias_v = None
|
69 |
+
self.reset_parameters()
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
query,
|
74 |
+
key: Optional[Tensor],
|
75 |
+
value: Optional[Tensor],
|
76 |
+
key_padding_mask: Optional[Tensor] = None,
|
77 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
78 |
+
need_weights: bool = True,
|
79 |
+
static_kv: bool = False,
|
80 |
+
attn_mask: Optional[Tensor] = None,
|
81 |
+
before_softmax: bool = False,
|
82 |
+
need_head_weights: bool = False,
|
83 |
+
# subset_heads: Optional[Tensor] = None,
|
84 |
+
# subset_weights: Optional[Tensor] = None
|
85 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
86 |
+
if need_head_weights:
|
87 |
+
need_weights = True
|
88 |
+
|
89 |
+
is_tpu = query.device.type == "xla"
|
90 |
+
|
91 |
+
subset_heads, subset_weights = self.attn_head_selector(self.layer_idx)
|
92 |
+
|
93 |
+
tgt_len, bsz, embed_dim = query.size()
|
94 |
+
src_len = tgt_len
|
95 |
+
assert list(query.size()) == [tgt_len, bsz, self.embed_dim]
|
96 |
+
if key is not None:
|
97 |
+
src_len, key_bsz, _ = key.size()
|
98 |
+
if not torch.jit.is_scripting():
|
99 |
+
assert key_bsz == bsz
|
100 |
+
assert value is not None
|
101 |
+
assert src_len, bsz == value.shape[:2]
|
102 |
+
|
103 |
+
if (
|
104 |
+
not self.onnx_trace
|
105 |
+
and not is_tpu # don't use PyTorch version on TPUs
|
106 |
+
and incremental_state is None
|
107 |
+
and not static_kv
|
108 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
109 |
+
# treats bias in linear module as method.
|
110 |
+
and not torch.jit.is_scripting()
|
111 |
+
):
|
112 |
+
assert key is not None and value is not None
|
113 |
+
return multi_head_attention_forward(
|
114 |
+
query,
|
115 |
+
key,
|
116 |
+
value,
|
117 |
+
self.embed_dim,
|
118 |
+
self.total_num_heads,
|
119 |
+
self.num_heads,
|
120 |
+
torch.empty([0]),
|
121 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
122 |
+
self.bias_k,
|
123 |
+
self.bias_v,
|
124 |
+
self.add_zero_attn,
|
125 |
+
self.dropout_module.p,
|
126 |
+
self.out_proj.weight,
|
127 |
+
self.out_proj.bias,
|
128 |
+
self.training or self.dropout_module.apply_during_inference,
|
129 |
+
key_padding_mask,
|
130 |
+
need_weights,
|
131 |
+
attn_mask,
|
132 |
+
use_separate_proj_weight=True,
|
133 |
+
q_proj_weight=self.q_proj.weight,
|
134 |
+
k_proj_weight=self.k_proj.weight,
|
135 |
+
v_proj_weight=self.v_proj.weight,
|
136 |
+
subset_heads=subset_heads,
|
137 |
+
subset_weights=subset_weights
|
138 |
+
)
|
139 |
+
|
140 |
+
if incremental_state is not None:
|
141 |
+
saved_state = self._get_input_buffer(incremental_state)
|
142 |
+
if saved_state is not None and "prev_key" in saved_state:
|
143 |
+
# previous time steps are cached - no need to recompute
|
144 |
+
# key and value if they are static
|
145 |
+
if static_kv:
|
146 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
147 |
+
key = value = None
|
148 |
+
else:
|
149 |
+
saved_state = None
|
150 |
+
|
151 |
+
if self.self_attention:
|
152 |
+
q = self.q_proj(query)
|
153 |
+
k = self.k_proj(query)
|
154 |
+
v = self.v_proj(query)
|
155 |
+
elif self.encoder_decoder_attention:
|
156 |
+
# encoder-decoder attention
|
157 |
+
q = self.q_proj(query)
|
158 |
+
if key is None:
|
159 |
+
assert value is None
|
160 |
+
k = v = None
|
161 |
+
else:
|
162 |
+
k = self.k_proj(key)
|
163 |
+
v = self.v_proj(key)
|
164 |
+
|
165 |
+
else:
|
166 |
+
assert key is not None and value is not None
|
167 |
+
q = self.q_proj(query)
|
168 |
+
k = self.k_proj(key)
|
169 |
+
v = self.v_proj(value)
|
170 |
+
q *= self.scaling
|
171 |
+
|
172 |
+
if self.bias_k is not None:
|
173 |
+
assert self.bias_v is not None
|
174 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
175 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
176 |
+
if attn_mask is not None:
|
177 |
+
attn_mask = torch.cat(
|
178 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
179 |
+
)
|
180 |
+
if key_padding_mask is not None:
|
181 |
+
key_padding_mask = torch.cat(
|
182 |
+
[
|
183 |
+
key_padding_mask,
|
184 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
185 |
+
],
|
186 |
+
dim=1,
|
187 |
+
)
|
188 |
+
|
189 |
+
q = (
|
190 |
+
q.contiguous()
|
191 |
+
.view(tgt_len, bsz * self.total_num_heads, self.head_dim)
|
192 |
+
.transpose(0, 1)
|
193 |
+
)
|
194 |
+
if k is not None:
|
195 |
+
k = (
|
196 |
+
k.contiguous()
|
197 |
+
.view(-1, bsz * self.total_num_heads, self.head_dim)
|
198 |
+
.transpose(0, 1)
|
199 |
+
)
|
200 |
+
if v is not None:
|
201 |
+
v = (
|
202 |
+
v.contiguous()
|
203 |
+
.view(-1, bsz * self.total_num_heads, self.head_dim)
|
204 |
+
.transpose(0, 1)
|
205 |
+
)
|
206 |
+
|
207 |
+
if saved_state is not None:
|
208 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
209 |
+
if "prev_key" in saved_state:
|
210 |
+
_prev_key = saved_state["prev_key"]
|
211 |
+
assert _prev_key is not None
|
212 |
+
prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim)
|
213 |
+
if static_kv:
|
214 |
+
k = prev_key
|
215 |
+
else:
|
216 |
+
assert k is not None
|
217 |
+
k = torch.cat([prev_key, k], dim=1)
|
218 |
+
src_len = k.size(1)
|
219 |
+
if "prev_value" in saved_state:
|
220 |
+
_prev_value = saved_state["prev_value"]
|
221 |
+
assert _prev_value is not None
|
222 |
+
prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim)
|
223 |
+
if static_kv:
|
224 |
+
v = prev_value
|
225 |
+
else:
|
226 |
+
assert v is not None
|
227 |
+
v = torch.cat([prev_value, v], dim=1)
|
228 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
229 |
+
if "prev_key_padding_mask" in saved_state:
|
230 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
231 |
+
assert k is not None and v is not None
|
232 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
233 |
+
key_padding_mask=key_padding_mask,
|
234 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
235 |
+
batch_size=bsz,
|
236 |
+
src_len=k.size(1),
|
237 |
+
static_kv=static_kv,
|
238 |
+
)
|
239 |
+
|
240 |
+
saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim)
|
241 |
+
saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim)
|
242 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
243 |
+
# In this branch incremental_state is never None
|
244 |
+
assert incremental_state is not None
|
245 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
246 |
+
assert k is not None
|
247 |
+
assert k.size(1) == src_len
|
248 |
+
|
249 |
+
# This is part of a workaround to get around fork/join parallelism
|
250 |
+
# not supporting Optional types.
|
251 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
252 |
+
key_padding_mask = None
|
253 |
+
|
254 |
+
if key_padding_mask is not None:
|
255 |
+
assert key_padding_mask.size(0) == bsz
|
256 |
+
assert key_padding_mask.size(1) == src_len
|
257 |
+
|
258 |
+
if self.add_zero_attn:
|
259 |
+
assert v is not None
|
260 |
+
src_len += 1
|
261 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
262 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
263 |
+
if attn_mask is not None:
|
264 |
+
attn_mask = torch.cat(
|
265 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
266 |
+
)
|
267 |
+
if key_padding_mask is not None:
|
268 |
+
key_padding_mask = torch.cat(
|
269 |
+
[
|
270 |
+
key_padding_mask,
|
271 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
272 |
+
key_padding_mask
|
273 |
+
),
|
274 |
+
],
|
275 |
+
dim=1,
|
276 |
+
)
|
277 |
+
|
278 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
279 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
280 |
+
|
281 |
+
assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len]
|
282 |
+
|
283 |
+
if attn_mask is not None:
|
284 |
+
attn_mask = attn_mask.unsqueeze(0)
|
285 |
+
if self.onnx_trace:
|
286 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
287 |
+
attn_weights += attn_mask
|
288 |
+
|
289 |
+
if key_padding_mask is not None:
|
290 |
+
# don't attend to padding symbols
|
291 |
+
attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len)
|
292 |
+
if not is_tpu:
|
293 |
+
attn_weights = attn_weights.masked_fill(
|
294 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
295 |
+
float("-inf"),
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
attn_weights = attn_weights.transpose(0, 2)
|
299 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
300 |
+
attn_weights = attn_weights.transpose(0, 2)
|
301 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
302 |
+
|
303 |
+
if before_softmax:
|
304 |
+
return attn_weights, v
|
305 |
+
|
306 |
+
attn_weights_float = utils.softmax(
|
307 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
308 |
+
)
|
309 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
310 |
+
attn_probs = self.dropout_module(attn_weights)
|
311 |
+
|
312 |
+
assert v is not None
|
313 |
+
|
314 |
+
# evaluation
|
315 |
+
if subset_heads is not None and subset_heads.numel() == 1:
|
316 |
+
subset_heads = subset_heads.repeat(bsz)
|
317 |
+
subset_weights = subset_weights.repeat(bsz)
|
318 |
+
|
319 |
+
if subset_heads is None:
|
320 |
+
attn = torch.bmm(attn_probs, v)
|
321 |
+
else:
|
322 |
+
# training with head selection
|
323 |
+
mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim)
|
324 |
+
attn = torch.stack(
|
325 |
+
[mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
|
326 |
+
)
|
327 |
+
attn = attn * subset_weights.unsqueeze(2).unsqueeze(3)
|
328 |
+
attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim)
|
329 |
+
|
330 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
331 |
+
if self.onnx_trace and attn.size(1) == 1:
|
332 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
333 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
334 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
335 |
+
else:
|
336 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
337 |
+
attn = self.out_proj(attn)
|
338 |
+
attn_weights: Optional[Tensor] = None
|
339 |
+
if need_weights:
|
340 |
+
if subset_heads is None:
|
341 |
+
attn_weights = attn_weights_float.view(
|
342 |
+
bsz, self.num_heads, tgt_len, src_len
|
343 |
+
).transpose(1, 0)
|
344 |
+
else:
|
345 |
+
mixed_attn_weights = attn_weights_float.view(
|
346 |
+
bsz, self.total_num_heads, tgt_len, src_len
|
347 |
+
)
|
348 |
+
attn_weights = torch.stack(
|
349 |
+
[mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
|
350 |
+
).transpose(1, 0)
|
351 |
+
if not need_head_weights:
|
352 |
+
# average attention weights over heads
|
353 |
+
attn_weights = attn_weights.mean(dim=0)
|
354 |
+
|
355 |
+
return attn, attn_weights
|
fairseq/examples/attention_head_selection/src/modules/multihead_functional.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from typing import Optional, Tuple
|
7 |
+
import torch
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn.functional import (
|
10 |
+
linear, softmax, dropout, pad,
|
11 |
+
has_torch_function,
|
12 |
+
handle_torch_function,
|
13 |
+
_in_projection_packed,
|
14 |
+
)
|
15 |
+
import math
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
|
19 |
+
def _scaled_dot_product_attention(
|
20 |
+
q: Tensor,
|
21 |
+
k: Tensor,
|
22 |
+
v: Tensor,
|
23 |
+
attn_mask: Optional[Tensor] = None,
|
24 |
+
dropout_p: float = 0.0,
|
25 |
+
bsz: int = 1,
|
26 |
+
subset_heads: Optional[Tensor] = None,
|
27 |
+
subset_weights: Optional[Tensor] = None,
|
28 |
+
) -> Tuple[Tensor, Tensor]:
|
29 |
+
B, Nt, E = q.shape
|
30 |
+
q = q / math.sqrt(E)
|
31 |
+
# B: bsz * total_num_heads
|
32 |
+
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
33 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
34 |
+
if attn_mask is not None:
|
35 |
+
attn += attn_mask
|
36 |
+
attn = softmax(attn, dim=-1)
|
37 |
+
if dropout_p > 0.0:
|
38 |
+
attn = dropout(attn, p=dropout_p)
|
39 |
+
if subset_heads is None:
|
40 |
+
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
41 |
+
output = torch.bmm(attn, v)
|
42 |
+
else:
|
43 |
+
mixed_output = torch.bmm(attn, v).contiguous().view(bsz, -1, Nt, E)
|
44 |
+
output = torch.stack(
|
45 |
+
[mixed_output[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))],
|
46 |
+
dim=1
|
47 |
+
)
|
48 |
+
output = output * subset_weights.unsqueeze(2).unsqueeze(3)
|
49 |
+
output = output.contiguous().view(-1, Nt, E)
|
50 |
+
if subset_heads is not None:
|
51 |
+
_, Nt, Ns = attn.size()
|
52 |
+
mixed_attn = attn.view(bsz, -1, Nt, Ns)
|
53 |
+
attn = torch.stack(
|
54 |
+
[mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1
|
55 |
+
)
|
56 |
+
return output, attn
|
57 |
+
|
58 |
+
|
59 |
+
def _in_projection(
|
60 |
+
q: Tensor,
|
61 |
+
k: Tensor,
|
62 |
+
v: Tensor,
|
63 |
+
w_q: Tensor,
|
64 |
+
w_k: Tensor,
|
65 |
+
w_v: Tensor,
|
66 |
+
b_q: Optional[Tensor] = None,
|
67 |
+
b_k: Optional[Tensor] = None,
|
68 |
+
b_v: Optional[Tensor] = None,
|
69 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
70 |
+
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
71 |
+
|
72 |
+
|
73 |
+
def multi_head_attention_forward(
|
74 |
+
query: Tensor,
|
75 |
+
key: Tensor,
|
76 |
+
value: Tensor,
|
77 |
+
embed_dim_to_check: int,
|
78 |
+
total_num_heads: int,
|
79 |
+
num_heads: int,
|
80 |
+
in_proj_weight: Tensor,
|
81 |
+
in_proj_bias: Optional[Tensor],
|
82 |
+
bias_k: Optional[Tensor],
|
83 |
+
bias_v: Optional[Tensor],
|
84 |
+
add_zero_attn: bool,
|
85 |
+
dropout_p: float,
|
86 |
+
out_proj_weight: Tensor,
|
87 |
+
out_proj_bias: Optional[Tensor],
|
88 |
+
training: bool = True,
|
89 |
+
key_padding_mask: Optional[Tensor] = None,
|
90 |
+
need_weights: bool = True,
|
91 |
+
attn_mask: Optional[Tensor] = None,
|
92 |
+
use_separate_proj_weight: bool = False,
|
93 |
+
q_proj_weight: Optional[Tensor] = None,
|
94 |
+
k_proj_weight: Optional[Tensor] = None,
|
95 |
+
v_proj_weight: Optional[Tensor] = None,
|
96 |
+
static_k: Optional[Tensor] = None,
|
97 |
+
static_v: Optional[Tensor] = None,
|
98 |
+
subset_heads: Optional[Tensor] = None,
|
99 |
+
subset_weights: Optional[Tensor] = None,
|
100 |
+
):
|
101 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
102 |
+
if has_torch_function(tens_ops):
|
103 |
+
return handle_torch_function(
|
104 |
+
multi_head_attention_forward,
|
105 |
+
tens_ops,
|
106 |
+
query,
|
107 |
+
key,
|
108 |
+
value,
|
109 |
+
embed_dim_to_check,
|
110 |
+
total_num_heads,
|
111 |
+
num_heads,
|
112 |
+
in_proj_weight,
|
113 |
+
in_proj_bias,
|
114 |
+
bias_k,
|
115 |
+
bias_v,
|
116 |
+
add_zero_attn,
|
117 |
+
dropout_p,
|
118 |
+
out_proj_weight,
|
119 |
+
out_proj_bias,
|
120 |
+
training=training,
|
121 |
+
key_padding_mask=key_padding_mask,
|
122 |
+
need_weights=need_weights,
|
123 |
+
attn_mask=attn_mask,
|
124 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
125 |
+
q_proj_weight=q_proj_weight,
|
126 |
+
k_proj_weight=k_proj_weight,
|
127 |
+
v_proj_weight=v_proj_weight,
|
128 |
+
static_k=static_k,
|
129 |
+
static_v=static_v,
|
130 |
+
subset_heads=subset_heads,
|
131 |
+
subset_weights=subset_weights
|
132 |
+
)
|
133 |
+
|
134 |
+
# set up shape vars
|
135 |
+
tgt_len, bsz, embed_dim = query.shape
|
136 |
+
src_len, _, _ = key.shape
|
137 |
+
assert embed_dim == embed_dim_to_check, \
|
138 |
+
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
139 |
+
if isinstance(embed_dim, torch.Tensor):
|
140 |
+
# embed_dim can be a tensor when JIT tracing
|
141 |
+
head_dim = embed_dim.div(num_heads, rounding_mode='trunc')
|
142 |
+
else:
|
143 |
+
head_dim = embed_dim // num_heads
|
144 |
+
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
145 |
+
if use_separate_proj_weight:
|
146 |
+
# allow MHA to have different embedding dimensions when separate projection weights are used
|
147 |
+
assert key.shape[:2] == value.shape[:2], \
|
148 |
+
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
149 |
+
else:
|
150 |
+
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
151 |
+
|
152 |
+
#
|
153 |
+
# compute in-projection
|
154 |
+
#
|
155 |
+
if not use_separate_proj_weight:
|
156 |
+
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
157 |
+
else:
|
158 |
+
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
159 |
+
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
160 |
+
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
161 |
+
if in_proj_bias is None:
|
162 |
+
b_q = b_k = b_v = None
|
163 |
+
else:
|
164 |
+
b_q, b_k, b_v = in_proj_bias.chunk(3)
|
165 |
+
q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)
|
166 |
+
|
167 |
+
# prep attention mask
|
168 |
+
if attn_mask is not None:
|
169 |
+
if attn_mask.dtype == torch.uint8:
|
170 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
171 |
+
attn_mask = attn_mask.to(torch.bool)
|
172 |
+
else:
|
173 |
+
assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
|
174 |
+
f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}"
|
175 |
+
# ensure attn_mask's dim is 3
|
176 |
+
if attn_mask.dim() == 2:
|
177 |
+
correct_2d_size = (tgt_len, src_len)
|
178 |
+
if attn_mask.shape != correct_2d_size:
|
179 |
+
raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
|
180 |
+
attn_mask = attn_mask.unsqueeze(0)
|
181 |
+
elif attn_mask.dim() == 3:
|
182 |
+
correct_3d_size = (bsz * total_num_heads, tgt_len, src_len)
|
183 |
+
if attn_mask.shape != correct_3d_size:
|
184 |
+
raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
|
185 |
+
else:
|
186 |
+
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
187 |
+
|
188 |
+
# prep key padding mask
|
189 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
190 |
+
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
191 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
192 |
+
|
193 |
+
# add bias along batch dimension (currently second)
|
194 |
+
if bias_k is not None and bias_v is not None:
|
195 |
+
assert static_k is None, "bias cannot be added to static key."
|
196 |
+
assert static_v is None, "bias cannot be added to static value."
|
197 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
198 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
199 |
+
if attn_mask is not None:
|
200 |
+
attn_mask = pad(attn_mask, (0, 1))
|
201 |
+
if key_padding_mask is not None:
|
202 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
203 |
+
else:
|
204 |
+
assert bias_k is None
|
205 |
+
assert bias_v is None
|
206 |
+
|
207 |
+
#
|
208 |
+
# reshape q, k, v for multihead attention and make em batch first
|
209 |
+
#
|
210 |
+
q = q.contiguous().view(tgt_len, bsz * total_num_heads, head_dim).transpose(0, 1)
|
211 |
+
if static_k is None:
|
212 |
+
k = k.contiguous().view(k.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
|
213 |
+
else:
|
214 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
215 |
+
assert static_k.size(0) == bsz * total_num_heads, \
|
216 |
+
f"expecting static_k.size(0) of {bsz * total_num_heads}, but got {static_k.size(0)}"
|
217 |
+
assert static_k.size(2) == head_dim, \
|
218 |
+
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
219 |
+
k = static_k
|
220 |
+
if static_v is None:
|
221 |
+
v = v.contiguous().view(v.shape[0], bsz * total_num_heads, head_dim).transpose(0, 1)
|
222 |
+
else:
|
223 |
+
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
224 |
+
assert static_v.size(0) == bsz * total_num_heads, \
|
225 |
+
f"expecting static_v.size(0) of {bsz * total_num_heads}, but got {static_v.size(0)}"
|
226 |
+
assert static_v.size(2) == head_dim, \
|
227 |
+
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
228 |
+
v = static_v
|
229 |
+
|
230 |
+
# add zero attention along batch dimension (now first)
|
231 |
+
if add_zero_attn:
|
232 |
+
zero_attn_shape = (bsz * total_num_heads, 1, head_dim)
|
233 |
+
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
234 |
+
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
235 |
+
if attn_mask is not None:
|
236 |
+
attn_mask = pad(attn_mask, (0, 1))
|
237 |
+
if key_padding_mask is not None:
|
238 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
239 |
+
|
240 |
+
# update source sequence length after adjustments
|
241 |
+
src_len = k.size(1)
|
242 |
+
|
243 |
+
# merge key padding and attention masks
|
244 |
+
if key_padding_mask is not None:
|
245 |
+
assert key_padding_mask.shape == (bsz, src_len), \
|
246 |
+
f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
247 |
+
key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \
|
248 |
+
expand(-1, total_num_heads, -1, -1).reshape(bsz * total_num_heads, 1, src_len)
|
249 |
+
if attn_mask is None:
|
250 |
+
attn_mask = key_padding_mask
|
251 |
+
elif attn_mask.dtype == torch.bool:
|
252 |
+
attn_mask = attn_mask.logical_or(key_padding_mask)
|
253 |
+
else:
|
254 |
+
attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf"))
|
255 |
+
|
256 |
+
# convert mask to float
|
257 |
+
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
258 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float)
|
259 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
260 |
+
attn_mask = new_attn_mask
|
261 |
+
|
262 |
+
# adjust dropout probability
|
263 |
+
if not training:
|
264 |
+
dropout_p = 0.0
|
265 |
+
|
266 |
+
#
|
267 |
+
# (deep breath) calculate attention and out projection
|
268 |
+
#
|
269 |
+
attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, bsz, subset_heads, subset_weights)
|
270 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
271 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
272 |
+
|
273 |
+
if need_weights:
|
274 |
+
# average attention weights over heads
|
275 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
276 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
277 |
+
else:
|
278 |
+
return attn_output, None
|