PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
6789f6f
·
verified ·
1 Parent(s): b1b22fb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml +36 -0
  2. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml +38 -0
  3. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml +38 -0
  4. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml +36 -0
  5. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml +36 -0
  6. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml +36 -0
  7. fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml +36 -0
  8. fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml +64 -0
  9. fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml +15 -0
  10. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml +37 -0
  11. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml +37 -0
  12. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml +36 -0
  13. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml +36 -0
  14. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml +36 -0
  15. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml +36 -0
  16. fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml +36 -0
  17. fairseq/examples/data2vec/data/__init__.py +17 -0
  18. fairseq/examples/data2vec/data/add_class_target_dataset.py +63 -0
  19. fairseq/examples/data2vec/data/image_dataset.py +127 -0
  20. fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py +135 -0
  21. fairseq/examples/data2vec/data/mae_image_dataset.py +418 -0
  22. fairseq/examples/data2vec/data/modality.py +14 -0
  23. fairseq/examples/data2vec/data/path_dataset.py +64 -0
  24. fairseq/examples/data2vec/models/__init__.py +0 -0
  25. fairseq/examples/data2vec/models/audio_classification.py +614 -0
  26. fairseq/examples/data2vec/models/data2vec2.py +813 -0
  27. fairseq/examples/data2vec/models/data2vec_audio.py +537 -0
  28. fairseq/examples/data2vec/models/data2vec_image_classification.py +143 -0
  29. fairseq/examples/data2vec/models/data2vec_text.py +517 -0
  30. fairseq/examples/data2vec/models/data2vec_text_classification.py +141 -0
  31. fairseq/examples/data2vec/models/data2vec_vision.py +727 -0
  32. fairseq/examples/data2vec/models/mae.py +829 -0
  33. fairseq/examples/data2vec/models/mae_image_classification.py +386 -0
  34. fairseq/examples/data2vec/models/modalities/__init__.py +0 -0
  35. fairseq/examples/data2vec/models/modalities/audio.py +192 -0
  36. fairseq/examples/data2vec/models/modalities/base.py +684 -0
  37. fairseq/examples/data2vec/models/modalities/images.py +256 -0
  38. fairseq/examples/data2vec/models/modalities/modules.py +589 -0
  39. fairseq/examples/data2vec/models/modalities/text.py +161 -0
  40. fairseq/examples/data2vec/models/utils.py +55 -0
  41. fairseq/examples/data2vec/scripts/convert_audioset_labels.py +63 -0
  42. fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh +18 -0
  43. fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh +16 -0
  44. fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh +28 -0
  45. fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh +17 -0
  46. fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh +21 -0
  47. fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh +21 -0
  48. fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh +17 -0
  49. fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh +23 -0
  50. fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh +25 -0
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 80
30
+ gpus_per_node: 8
31
+ tasks_per_node: 1
32
+ mem_gb: 0
33
+ nodes: 1
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ - task.local_cache_path
24
+ sweep:
25
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26
+ subdir: ''
27
+ launcher:
28
+ submitit_folder: ${hydra.sweep.dir}
29
+ timeout_min: 4320
30
+ cpus_per_task: 10
31
+ gpus_per_node: 8
32
+ tasks_per_node: 8
33
+ mem_gb: 450
34
+ nodes: 2
35
+ name: ${env:PREFIX}_${hydra.job.config_name}
36
+ partition: devlab,learnlab,learnfair,scavenge
37
+ constraint: volta32gb,ib4
38
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ - task.local_cache_path
24
+ - model.model_path
25
+ sweep:
26
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
27
+ subdir: ''
28
+ launcher:
29
+ submitit_folder: ${hydra.sweep.dir}
30
+ timeout_min: 4320
31
+ cpus_per_task: 10
32
+ gpus_per_node: 8
33
+ tasks_per_node: 8
34
+ mem_gb: 0
35
+ nodes: 2
36
+ name: ${env:PREFIX}_${hydra.job.config_name}
37
+ partition: wav2vec,learnlab,learnfair
38
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ sweep:
23
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
24
+ subdir: ''
25
+ launcher:
26
+ submitit_folder: ${hydra.sweep.dir}
27
+ timeout_min: 4320
28
+ cpus_per_task: 80
29
+ gpus_per_node: 8
30
+ tasks_per_node: 1
31
+ mem_gb: 450
32
+ nodes: 3
33
+ name: ${env:PREFIX}_${hydra.job.config_name}
34
+ partition: devlab,learnlab,learnfair,scavenge
35
+ constraint: volta32gb,ib4
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 4
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 6
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 8
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+
3
+ common:
4
+ fp16: true
5
+ log_format: json
6
+ log_interval: 200
7
+ tensorboard_logdir: tb
8
+
9
+ checkpoint:
10
+ save_interval: 5
11
+ save_interval_updates: 25000
12
+ keep_interval_updates: 1
13
+ no_epoch_checkpoints: true
14
+
15
+ task:
16
+ _name: image_pretraining
17
+ data: /datasets01/imagenet_full_size/061417
18
+
19
+ dataset:
20
+ num_workers: 6
21
+ batch_size: 128
22
+ skip_invalid_size_inputs_valid_test: true
23
+ required_batch_size_multiple: 2
24
+ disable_validation: true
25
+
26
+ distributed_training:
27
+ distributed_world_size: 16
28
+ ddp_backend: legacy_ddp
29
+
30
+ criterion:
31
+ _name: model
32
+ log_keys:
33
+ - ema_decay
34
+ - target_var
35
+ - pred_var
36
+
37
+ optimization:
38
+ max_update: 375300 #300*1251
39
+ lr: [0.0005]
40
+ clip_norm: 3.0
41
+
42
+ optimizer:
43
+ _name: adam
44
+ adam_betas: (0.9,0.999)
45
+ adam_eps: 1e-08
46
+ weight_decay: 0.05
47
+
48
+ lr_scheduler:
49
+ _name: cosine
50
+ warmup_updates: 12510 # it should be 10 epochs
51
+
52
+ model:
53
+ _name: data2vec_vision
54
+
55
+ attention_dropout: 0.05
56
+
57
+ ema_decay: 0.999
58
+ ema_end_decay: 0.9998
59
+ layer_norm_targets: True
60
+ average_top_k_layers: 6
61
+
62
+ loss_beta: 2.0
63
+
64
+ drop_path: 0.25
fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+ hydra:
3
+ sweep:
4
+ dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
5
+
6
+ distributed_training:
7
+ distributed_world_size: 1
8
+ nprocs_per_node: 1
9
+ distributed_port: -1
10
+
11
+ common:
12
+ log_interval: 1
13
+
14
+ dataset:
15
+ num_workers: 0
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 80
30
+ gpus_per_node: 8
31
+ tasks_per_node: 1
32
+ mem_gb: 450
33
+ nodes: 1
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: devlab,learnlab,learnfair,scavenge
36
+ constraint: volta32gb,ib4
37
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ - task.local_cache_path
24
+ sweep:
25
+ dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26
+ subdir: ''
27
+ launcher:
28
+ submitit_folder: ${hydra.sweep.dir}
29
+ timeout_min: 4320
30
+ cpus_per_task: 10
31
+ gpus_per_node: 8
32
+ tasks_per_node: 8
33
+ mem_gb: 0
34
+ nodes: 2
35
+ name: ${env:PREFIX}_${hydra.job.config_name}
36
+ partition: wav2vec,learnlab,learnfair
37
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ sweep:
23
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
24
+ subdir: ''
25
+ launcher:
26
+ submitit_folder: ${hydra.sweep.dir}
27
+ timeout_min: 4320
28
+ cpus_per_task: 80
29
+ gpus_per_node: 8
30
+ tasks_per_node: 1
31
+ mem_gb: 450
32
+ nodes: 3
33
+ name: ${env:PREFIX}_${hydra.job.config_name}
34
+ partition: devlab,learnlab,learnfair,scavenge
35
+ constraint: volta32gb,ib4
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ sweep:
23
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
24
+ subdir: ''
25
+ launcher:
26
+ submitit_folder: ${hydra.sweep.dir}
27
+ timeout_min: 4320
28
+ cpus_per_task: 10
29
+ gpus_per_node: 8
30
+ tasks_per_node: 8
31
+ mem_gb: 450
32
+ nodes: 4
33
+ name: ${env:PREFIX}_${hydra.job.config_name}
34
+ partition: devlab,learnlab,learnfair,scavenge
35
+ constraint: volta32gb,ib4
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 4
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 6
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ hydra:
4
+ job:
5
+ config:
6
+ override_dirname:
7
+ kv_sep: ':'
8
+ item_sep: '/'
9
+ exclude_keys:
10
+ - run_config
11
+ - distributed_training.distributed_port
12
+ - distributed_training.distributed_world_size
13
+ - model.pretrained_model_path
14
+ - model.target_network_path
15
+ - next_script
16
+ - task.cache_in_scratch
17
+ - task.data
18
+ - checkpoint.save_interval_updates
19
+ - checkpoint.keep_interval_updates
20
+ - checkpoint.save_on_overflow
21
+ - common.log_interval
22
+ - common.user_dir
23
+ sweep:
24
+ dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
25
+ subdir: ''
26
+ launcher:
27
+ submitit_folder: ${hydra.sweep.dir}
28
+ timeout_min: 4320
29
+ cpus_per_task: 10
30
+ gpus_per_node: 8
31
+ tasks_per_node: 8
32
+ mem_gb: 0
33
+ nodes: 8
34
+ name: ${env:PREFIX}_${hydra.job.config_name}
35
+ partition: wav2vec,learnlab,learnfair
36
+ max_num_timeout: 30
fairseq/examples/data2vec/data/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .image_dataset import ImageDataset
7
+ from .path_dataset import PathDataset
8
+ from .mae_image_dataset import MaeImageDataset
9
+ from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
10
+
11
+
12
+ __all__ = [
13
+ "ImageDataset",
14
+ "MaeImageDataset",
15
+ "MaeFinetuningImageDataset",
16
+ "PathDataset",
17
+ ]
fairseq/examples/data2vec/data/add_class_target_dataset.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from fairseq.data import BaseWrapperDataset, data_utils
9
+
10
+
11
+ class AddClassTargetDataset(BaseWrapperDataset):
12
+ def __init__(
13
+ self,
14
+ dataset,
15
+ labels,
16
+ multi_class,
17
+ num_classes=None,
18
+ label_indices=None,
19
+ add_to_input=True,
20
+ ):
21
+ super().__init__(dataset)
22
+
23
+ self.label_indices = label_indices
24
+ self.labels = labels
25
+ self.multi_class = multi_class
26
+ self.add_to_input = add_to_input
27
+ if num_classes is None and multi_class:
28
+ assert self.label_indices is not None
29
+ num_classes = len(self.label_indices)
30
+
31
+ self.num_classes = num_classes
32
+
33
+ def __getitem__(self, index):
34
+ item = self.dataset[index]
35
+ item_labels = self.labels[index]
36
+ if self.multi_class:
37
+ item["label"] = torch.zeros(self.num_classes)
38
+ for il in item_labels:
39
+ if self.label_indices is not None:
40
+ il = self.label_indices[il]
41
+ item["label"][il] = 1.0
42
+ else:
43
+ item["label"] = torch.tensor(
44
+ self.labels[index]
45
+ if self.label_indices is None
46
+ else self.label_indices[self.labels[index]]
47
+ )
48
+
49
+ return item
50
+
51
+ def collater(self, samples):
52
+ collated = self.dataset.collater(samples)
53
+ if len(collated) == 0:
54
+ return collated
55
+
56
+ indices = set(collated["id"].tolist())
57
+ target = [s["label"] for s in samples if s["id"] in indices]
58
+ collated["label"] = torch.stack(target, dim=0)
59
+
60
+ if self.add_to_input:
61
+ collated["net_input"]["label"] = collated["label"]
62
+
63
+ return collated
fairseq/examples/data2vec/data/image_dataset.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import os
11
+ from typing import Optional, Callable, Set
12
+
13
+ import torch
14
+
15
+ from torchvision.datasets.vision import VisionDataset
16
+ from torchvision.transforms import ToTensor
17
+
18
+ from fairseq.data import FairseqDataset
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class ImageDataset(FairseqDataset, VisionDataset):
25
+ def __init__(
26
+ self,
27
+ root: str,
28
+ extensions: Set[str],
29
+ load_classes: bool,
30
+ transform: Optional[Callable] = None,
31
+ shuffle=True,
32
+ ):
33
+ FairseqDataset.__init__(self)
34
+ VisionDataset.__init__(self, root=root, transform=transform)
35
+
36
+ self.shuffle = shuffle
37
+ self.tensor_transform = ToTensor()
38
+
39
+ self.classes = None
40
+ self.labels = None
41
+ if load_classes:
42
+ classes = [d.name for d in os.scandir(root) if d.is_dir()]
43
+ classes.sort()
44
+ self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
45
+ logger.info(f"loaded {len(self.classes)} classes")
46
+ self.labels = []
47
+
48
+ def walk_path(root_path):
49
+ for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
50
+ for fname in sorted(fnames):
51
+ fname_ext = os.path.splitext(fname)
52
+ if fname_ext[-1].lower() not in extensions:
53
+ continue
54
+
55
+ path = os.path.join(root, fname)
56
+ yield path
57
+
58
+ logger.info(f"finding images in {root}")
59
+ if self.classes is not None:
60
+ self.files = []
61
+ self.labels = []
62
+ for c, i in self.classes.items():
63
+ for f in walk_path(os.path.join(root, c)):
64
+ self.files.append(f)
65
+ self.labels.append(i)
66
+ else:
67
+ self.files = [f for f in walk_path(root)]
68
+
69
+ logger.info(f"loaded {len(self.files)} examples")
70
+
71
+ def __getitem__(self, index):
72
+ from PIL import Image
73
+
74
+ fpath = self.files[index]
75
+
76
+ with open(fpath, "rb") as f:
77
+ img = Image.open(f).convert("RGB")
78
+
79
+ if self.transform is None:
80
+ img = self.tensor_transform(img)
81
+ else:
82
+ img = self.transform(img)
83
+ assert torch.is_tensor(img)
84
+
85
+ res = {"id": index, "img": img}
86
+
87
+ if self.labels is not None:
88
+ res["label"] = self.labels[index]
89
+
90
+ return res
91
+
92
+ def __len__(self):
93
+ return len(self.files)
94
+
95
+ def collater(self, samples):
96
+ if len(samples) == 0:
97
+ return {}
98
+
99
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
100
+
101
+ res = {
102
+ "id": torch.LongTensor([s["id"] for s in samples]),
103
+ "net_input": {
104
+ "img": collated_img,
105
+ },
106
+ }
107
+
108
+ if "label" in samples[0]:
109
+ res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
110
+
111
+ return res
112
+
113
+ def num_tokens(self, index):
114
+ return 1
115
+
116
+ def size(self, index):
117
+ return 1
118
+
119
+ def ordered_indices(self):
120
+ """Return an ordered list of indices. Batches will be constructed based
121
+ on this order."""
122
+ if self.shuffle:
123
+ order = [np.random.permutation(len(self))]
124
+ else:
125
+ order = [np.arange(len(self))]
126
+
127
+ return order[0]
fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import os
11
+
12
+ import torch
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+ import PIL
19
+
20
+ from fairseq.data import FairseqDataset
21
+ from .mae_image_dataset import caching_loader
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
28
+ mean = IMAGENET_DEFAULT_MEAN
29
+ std = IMAGENET_DEFAULT_STD
30
+ # train transform
31
+ if is_train:
32
+ # this should always dispatch to transforms_imagenet_train
33
+ transform = create_transform(
34
+ input_size=input_size,
35
+ is_training=True,
36
+ color_jitter=color_jitter,
37
+ auto_augment=aa,
38
+ interpolation="bicubic",
39
+ re_prob=reprob,
40
+ re_mode=remode,
41
+ re_count=recount,
42
+ mean=mean,
43
+ std=std,
44
+ )
45
+ return transform
46
+
47
+ # eval transform
48
+ t = []
49
+ if input_size <= 224:
50
+ crop_pct = 224 / 256
51
+ else:
52
+ crop_pct = 1.0
53
+ size = int(input_size / crop_pct)
54
+ t.append(
55
+ transforms.Resize(
56
+ size, interpolation=PIL.Image.BICUBIC
57
+ ), # to maintain same ratio w.r.t. 224 images
58
+ )
59
+ t.append(transforms.CenterCrop(input_size))
60
+
61
+ t.append(transforms.ToTensor())
62
+ t.append(transforms.Normalize(mean, std))
63
+ return transforms.Compose(t)
64
+
65
+
66
+ class MaeFinetuningImageDataset(FairseqDataset):
67
+ def __init__(
68
+ self,
69
+ root: str,
70
+ split: str,
71
+ is_train: bool,
72
+ input_size,
73
+ color_jitter=None,
74
+ aa="rand-m9-mstd0.5-inc1",
75
+ reprob=0.25,
76
+ remode="pixel",
77
+ recount=1,
78
+ local_cache_path=None,
79
+ shuffle=True,
80
+ ):
81
+ FairseqDataset.__init__(self)
82
+
83
+ self.shuffle = shuffle
84
+
85
+ transform = build_transform(
86
+ is_train, input_size, color_jitter, aa, reprob, remode, recount
87
+ )
88
+
89
+ path = os.path.join(root, split)
90
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
91
+
92
+ self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
93
+
94
+ logger.info(f"loaded {len(self.dataset)} examples")
95
+
96
+ def __getitem__(self, index):
97
+ img, label = self.dataset[index]
98
+ return {"id": index, "img": img, "label": label}
99
+
100
+ def __len__(self):
101
+ return len(self.dataset)
102
+
103
+ def collater(self, samples):
104
+ if len(samples) == 0:
105
+ return {}
106
+
107
+ collated_img = torch.stack([s["img"] for s in samples], dim=0)
108
+
109
+ res = {
110
+ "id": torch.LongTensor([s["id"] for s in samples]),
111
+ "net_input": {
112
+ "imgs": collated_img,
113
+ },
114
+ }
115
+
116
+ if "label" in samples[0]:
117
+ res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
118
+
119
+ return res
120
+
121
+ def num_tokens(self, index):
122
+ return 1
123
+
124
+ def size(self, index):
125
+ return 1
126
+
127
+ def ordered_indices(self):
128
+ """Return an ordered list of indices. Batches will be constructed based
129
+ on this order."""
130
+ if self.shuffle:
131
+ order = [np.random.permutation(len(self))]
132
+ else:
133
+ order = [np.arange(len(self))]
134
+
135
+ return order[0]
fairseq/examples/data2vec/data/mae_image_dataset.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from functools import partial
8
+ import logging
9
+ import math
10
+ import random
11
+ import time
12
+
13
+ import numpy as np
14
+ import os
15
+
16
+ import torch
17
+
18
+ from torchvision import datasets, transforms
19
+ from .path_dataset import PathDataset
20
+
21
+ from fairseq.data import FairseqDataset
22
+ from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
23
+
24
+ from shutil import copyfile
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ def load(path, loader, cache):
30
+ if hasattr(caching_loader, "cache_root"):
31
+ cache = caching_loader.cache_root
32
+
33
+ cached_path = cache + path
34
+
35
+ num_tries = 3
36
+ for curr_try in range(num_tries):
37
+ try:
38
+ if curr_try == 2:
39
+ return loader(path)
40
+ if not os.path.exists(cached_path) or curr_try > 0:
41
+ os.makedirs(os.path.dirname(cached_path), exist_ok=True)
42
+ copyfile(path, cached_path)
43
+ os.chmod(cached_path, 0o777)
44
+ return loader(cached_path)
45
+ except Exception as e:
46
+ logger.warning(str(e))
47
+ if "Errno 13" in str(e):
48
+ caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
49
+ logger.warning(f"setting cache root to {caching_loader.cache_root}")
50
+ cached_path = caching_loader.cache_root + path
51
+ if curr_try == (num_tries - 1):
52
+ raise
53
+ time.sleep(2)
54
+
55
+
56
+ def caching_loader(cache_root: str, loader):
57
+ if cache_root is None:
58
+ return loader
59
+
60
+ if cache_root == "slurm_tmpdir":
61
+ cache_root = os.environ["SLURM_TMPDIR"]
62
+ assert len(cache_root) > 0
63
+
64
+ if not cache_root.endswith("/"):
65
+ cache_root += "/"
66
+
67
+ return partial(load, loader=loader, cache=cache_root)
68
+
69
+
70
+ class RandomResizedCropAndInterpolationWithTwoPic:
71
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
72
+
73
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
74
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
75
+ is finally resized to given size.
76
+ This is popularly used to train the Inception networks.
77
+
78
+ Args:
79
+ size: expected output size of each edge
80
+ scale: range of size of the origin size cropped
81
+ ratio: range of aspect ratio of the origin aspect ratio cropped
82
+ interpolation: Default: PIL.Image.BILINEAR
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ size,
88
+ second_size=None,
89
+ scale=(0.08, 1.0),
90
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
91
+ interpolation="bilinear",
92
+ second_interpolation="lanczos",
93
+ ):
94
+ if isinstance(size, tuple):
95
+ self.size = size
96
+ else:
97
+ self.size = (size, size)
98
+ if second_size is not None:
99
+ if isinstance(second_size, tuple):
100
+ self.second_size = second_size
101
+ else:
102
+ self.second_size = (second_size, second_size)
103
+ else:
104
+ self.second_size = None
105
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
106
+ logger.warning("range should be of kind (min, max)")
107
+
108
+ if interpolation == "random":
109
+ from PIL import Image
110
+
111
+ self.interpolation = (Image.BILINEAR, Image.BICUBIC)
112
+ else:
113
+ self.interpolation = self._pil_interp(interpolation)
114
+
115
+ self.second_interpolation = (
116
+ self._pil_interp(second_interpolation)
117
+ if second_interpolation is not None
118
+ else None
119
+ )
120
+ self.scale = scale
121
+ self.ratio = ratio
122
+
123
+ def _pil_interp(self, method):
124
+ from PIL import Image
125
+
126
+ if method == "bicubic":
127
+ return Image.BICUBIC
128
+ elif method == "lanczos":
129
+ return Image.LANCZOS
130
+ elif method == "hamming":
131
+ return Image.HAMMING
132
+ else:
133
+ # default bilinear, do we want to allow nearest?
134
+ return Image.BILINEAR
135
+
136
+ @staticmethod
137
+ def get_params(img, scale, ratio):
138
+ """Get parameters for ``crop`` for a random sized crop.
139
+
140
+ Args:
141
+ img (PIL Image): Image to be cropped.
142
+ scale (tuple): range of size of the origin size cropped
143
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
144
+
145
+ Returns:
146
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
147
+ sized crop.
148
+ """
149
+ area = img.size[0] * img.size[1]
150
+
151
+ for attempt in range(10):
152
+ target_area = random.uniform(*scale) * area
153
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
154
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
155
+
156
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
157
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
158
+
159
+ if w <= img.size[0] and h <= img.size[1]:
160
+ i = random.randint(0, img.size[1] - h)
161
+ j = random.randint(0, img.size[0] - w)
162
+ return i, j, h, w
163
+
164
+ # Fallback to central crop
165
+ in_ratio = img.size[0] / img.size[1]
166
+ if in_ratio < min(ratio):
167
+ w = img.size[0]
168
+ h = int(round(w / min(ratio)))
169
+ elif in_ratio > max(ratio):
170
+ h = img.size[1]
171
+ w = int(round(h * max(ratio)))
172
+ else: # whole image
173
+ w = img.size[0]
174
+ h = img.size[1]
175
+ i = (img.size[1] - h) // 2
176
+ j = (img.size[0] - w) // 2
177
+ return i, j, h, w
178
+
179
+ def __call__(self, img):
180
+ import torchvision.transforms.functional as F
181
+
182
+ """
183
+ Args:
184
+ img (PIL Image): Image to be cropped and resized.
185
+
186
+ Returns:
187
+ PIL Image: Randomly cropped and resized image.
188
+ """
189
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
190
+ if isinstance(self.interpolation, (tuple, list)):
191
+ interpolation = random.choice(self.interpolation)
192
+ else:
193
+ interpolation = self.interpolation
194
+ if self.second_size is None:
195
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
196
+ else:
197
+ return F.resized_crop(
198
+ img, i, j, h, w, self.size, interpolation
199
+ ), F.resized_crop(
200
+ img, i, j, h, w, self.second_size, self.second_interpolation
201
+ )
202
+
203
+
204
+ class MaeImageDataset(FairseqDataset):
205
+ def __init__(
206
+ self,
207
+ root: str,
208
+ split: str,
209
+ input_size,
210
+ local_cache_path=None,
211
+ shuffle=True,
212
+ key="imgs",
213
+ beit_transforms=False,
214
+ target_transform=False,
215
+ no_transform=False,
216
+ compute_mask=False,
217
+ patch_size: int = 16,
218
+ mask_prob: float = 0.75,
219
+ mask_prob_adjust: float = 0,
220
+ mask_length: int = 1,
221
+ inverse_mask: bool = False,
222
+ expand_adjacent: bool = False,
223
+ mask_dropout: float = 0,
224
+ non_overlapping: bool = False,
225
+ require_same_masks: bool = True,
226
+ clone_batch: int = 1,
227
+ dataset_type: str = "imagefolder",
228
+ ):
229
+ FairseqDataset.__init__(self)
230
+
231
+ self.shuffle = shuffle
232
+ self.key = key
233
+
234
+ loader = caching_loader(local_cache_path, datasets.folder.default_loader)
235
+
236
+ self.transform_source = None
237
+ self.transform_target = None
238
+
239
+ if target_transform:
240
+ self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
241
+ self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
242
+
243
+ if no_transform:
244
+ if input_size <= 224:
245
+ crop_pct = 224 / 256
246
+ else:
247
+ crop_pct = 1.0
248
+ size = int(input_size / crop_pct)
249
+
250
+ self.transform_train = transforms.Compose(
251
+ [
252
+ transforms.Resize(size, interpolation=3),
253
+ transforms.CenterCrop(input_size),
254
+ ]
255
+ )
256
+
257
+ self.transform_train = transforms.Resize((input_size, input_size))
258
+ elif beit_transforms:
259
+ beit_transform_list = []
260
+ if not target_transform:
261
+ beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
262
+ beit_transform_list.extend(
263
+ [
264
+ transforms.RandomHorizontalFlip(p=0.5),
265
+ RandomResizedCropAndInterpolationWithTwoPic(
266
+ size=input_size,
267
+ second_size=None,
268
+ interpolation="bicubic",
269
+ second_interpolation=None,
270
+ ),
271
+ ]
272
+ )
273
+ self.transform_train = transforms.Compose(beit_transform_list)
274
+ else:
275
+ self.transform_train = transforms.Compose(
276
+ [
277
+ transforms.RandomResizedCrop(
278
+ input_size, scale=(0.2, 1.0), interpolation=3
279
+ ), # 3 is bicubic
280
+ transforms.RandomHorizontalFlip(),
281
+ ]
282
+ )
283
+ self.final_transform = transforms.Compose(
284
+ [
285
+ transforms.ToTensor(),
286
+ transforms.Normalize(
287
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
288
+ ),
289
+ ]
290
+ )
291
+
292
+ if dataset_type == "imagefolder":
293
+ self.dataset = datasets.ImageFolder(
294
+ os.path.join(root, split), loader=loader
295
+ )
296
+ elif dataset_type == "path":
297
+ self.dataset = PathDataset(
298
+ root,
299
+ loader,
300
+ None,
301
+ None,
302
+ mean=[0.485, 0.456, 0.406],
303
+ std=[0.229, 0.224, 0.225],
304
+ )
305
+ else:
306
+ raise Exception(f"invalid dataset type {dataset_type}")
307
+
308
+ logger.info(
309
+ f"initial transform: {self.transform_train}, "
310
+ f"source transform: {self.transform_source}, "
311
+ f"target transform: {self.transform_target}, "
312
+ f"final transform: {self.final_transform}"
313
+ )
314
+ logger.info(f"loaded {len(self.dataset)} examples")
315
+
316
+ self.is_compute_mask = compute_mask
317
+ self.patches = (input_size // patch_size) ** 2
318
+ self.mask_prob = mask_prob
319
+ self.mask_prob_adjust = mask_prob_adjust
320
+ self.mask_length = mask_length
321
+ self.inverse_mask = inverse_mask
322
+ self.expand_adjacent = expand_adjacent
323
+ self.mask_dropout = mask_dropout
324
+ self.non_overlapping = non_overlapping
325
+ self.require_same_masks = require_same_masks
326
+ self.clone_batch = clone_batch
327
+
328
+ def __getitem__(self, index):
329
+ img, _ = self.dataset[index]
330
+
331
+ img = self.transform_train(img)
332
+
333
+ source = None
334
+ target = None
335
+ if self.transform_source is not None:
336
+ source = self.final_transform(self.transform_source(img))
337
+ if self.transform_target is not None:
338
+ target = self.final_transform(self.transform_target(img))
339
+
340
+ if source is None:
341
+ img = self.final_transform(img)
342
+
343
+ v = {"id": index, self.key: source if source is not None else img}
344
+ if target is not None:
345
+ v["target"] = target
346
+
347
+ if self.is_compute_mask:
348
+ if self.mask_length == 1:
349
+ mask = compute_block_mask_1d(
350
+ shape=(self.clone_batch, self.patches),
351
+ mask_prob=self.mask_prob,
352
+ mask_length=self.mask_length,
353
+ mask_prob_adjust=self.mask_prob_adjust,
354
+ inverse_mask=self.inverse_mask,
355
+ require_same_masks=True,
356
+ )
357
+ else:
358
+ mask = compute_block_mask_2d(
359
+ shape=(self.clone_batch, self.patches),
360
+ mask_prob=self.mask_prob,
361
+ mask_length=self.mask_length,
362
+ mask_prob_adjust=self.mask_prob_adjust,
363
+ inverse_mask=self.inverse_mask,
364
+ require_same_masks=True,
365
+ expand_adjcent=self.expand_adjacent,
366
+ mask_dropout=self.mask_dropout,
367
+ non_overlapping=self.non_overlapping,
368
+ )
369
+
370
+ v["precomputed_mask"] = mask
371
+
372
+ return v
373
+
374
+ def __len__(self):
375
+ return len(self.dataset)
376
+
377
+ def collater(self, samples):
378
+ if len(samples) == 0:
379
+ return {}
380
+
381
+ collated_img = torch.stack([s[self.key] for s in samples], dim=0)
382
+
383
+ res = {
384
+ "id": torch.LongTensor([s["id"] for s in samples]),
385
+ "net_input": {
386
+ self.key: collated_img,
387
+ },
388
+ }
389
+
390
+ if "target" in samples[0]:
391
+ collated_target = torch.stack([s["target"] for s in samples], dim=0)
392
+ res["net_input"]["target"] = collated_target
393
+
394
+ if "precomputed_mask" in samples[0]:
395
+ collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
396
+ res["net_input"]["precomputed_mask"] = collated_mask
397
+
398
+ return res
399
+
400
+ def num_tokens(self, index):
401
+ return 1
402
+
403
+ def size(self, index):
404
+ return 1
405
+
406
+ @property
407
+ def sizes(self):
408
+ return np.full((len(self),), 1)
409
+
410
+ def ordered_indices(self):
411
+ """Return an ordered list of indices. Batches will be constructed based
412
+ on this order."""
413
+ if self.shuffle:
414
+ order = [np.random.permutation(len(self))]
415
+ else:
416
+ order = [np.arange(len(self))]
417
+
418
+ return order[0]
fairseq/examples/data2vec/data/modality.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2017-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the LICENSE file in
5
+ # the root directory of this source tree. An additional grant of patent rights
6
+ # can be found in the PATENTS file in the same directory.
7
+
8
+ from enum import Enum, auto
9
+
10
+
11
+ class Modality(Enum):
12
+ AUDIO = auto()
13
+ IMAGE = auto()
14
+ TEXT = auto()
fairseq/examples/data2vec/data/path_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+
5
+ import logging
6
+ import numpy as np
7
+ import torchvision.transforms.functional as TF
8
+ import PIL
9
+ from PIL import Image
10
+ from torchvision.datasets import VisionDataset
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class PathDataset(VisionDataset):
16
+ def __init__(
17
+ self,
18
+ root: List[str],
19
+ loader: None = None,
20
+ transform: Optional[str] = None,
21
+ extra_transform: Optional[str] = None,
22
+ mean: Optional[List[float]] = None,
23
+ std: Optional[List[float]] = None,
24
+ ):
25
+ super().__init__(root=root)
26
+
27
+ PIL.Image.MAX_IMAGE_PIXELS = 256000001
28
+
29
+ self.files = []
30
+ for folder in self.root:
31
+ self.files.extend(
32
+ sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
33
+ )
34
+ self.files.extend(
35
+ sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
36
+ )
37
+
38
+ self.transform = transform
39
+ self.extra_transform = extra_transform
40
+ self.mean = mean
41
+ self.std = std
42
+
43
+ self.loader = loader
44
+
45
+ logger.info(f"loaded {len(self.files)} samples from {root}")
46
+
47
+ assert (mean is None) == (std is None)
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.files)
51
+
52
+ def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
53
+ path = self.files[idx]
54
+
55
+ if self.loader is not None:
56
+ return self.loader(path), None
57
+
58
+ img = Image.open(path).convert("RGB")
59
+ if self.transform is not None:
60
+ img = self.transform(img)
61
+ img = TF.to_tensor(img)
62
+ if self.mean is not None and self.std is not None:
63
+ img = TF.normalize(img, self.mean, self.std)
64
+ return img, None
fairseq/examples/data2vec/models/__init__.py ADDED
File without changes
fairseq/examples/data2vec/models/audio_classification.py ADDED
@@ -0,0 +1,614 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import contextlib
7
+ import logging
8
+ import re
9
+ from dataclasses import dataclass, field
10
+ from typing import Any, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from omegaconf import II, MISSING, open_dict
17
+
18
+ from fairseq import checkpoint_utils, tasks
19
+ from fairseq.dataclass import FairseqDataclass
20
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
21
+ from fairseq.models import (
22
+ BaseFairseqModel,
23
+ register_model,
24
+ )
25
+ from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
26
+ from fairseq.modules import TransposeLast
27
+ from fairseq.tasks import FairseqTask
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class AudioClassificationConfig(FairseqDataclass):
34
+ model_path: str = field(
35
+ default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
36
+ )
37
+ no_pretrained_weights: bool = field(
38
+ default=False, metadata={"help": "if true, does not load pretrained weights"}
39
+ )
40
+ dropout_input: float = field(
41
+ default=0.0,
42
+ metadata={"help": "dropout to apply to the input (after feat extr)"},
43
+ )
44
+ final_dropout: float = field(
45
+ default=0.0,
46
+ metadata={"help": "dropout after transformer and before final projection"},
47
+ )
48
+ dropout: float = field(
49
+ default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
50
+ )
51
+ attention_dropout: float = field(
52
+ default=0.0,
53
+ metadata={
54
+ "help": "dropout probability for attention weights inside wav2vec 2.0 model"
55
+ },
56
+ )
57
+ activation_dropout: float = field(
58
+ default=0.0,
59
+ metadata={
60
+ "help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
61
+ },
62
+ )
63
+
64
+ # masking
65
+ apply_mask: bool = field(
66
+ default=False, metadata={"help": "apply masking during fine-tuning"}
67
+ )
68
+ mask_length: int = field(
69
+ default=10, metadata={"help": "repeat the mask indices multiple times"}
70
+ )
71
+ mask_prob: float = field(
72
+ default=0.5,
73
+ metadata={
74
+ "help": "probability of replacing a token with mask (normalized by length)"
75
+ },
76
+ )
77
+ mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
78
+ default="static", metadata={"help": "how to choose masks"}
79
+ )
80
+ mask_other: float = field(
81
+ default=0,
82
+ metadata={
83
+ "help": "secondary mask argument (used for more complex distributions), "
84
+ "see help in compute_mask_indices"
85
+ },
86
+ )
87
+ no_mask_overlap: bool = field(
88
+ default=False, metadata={"help": "whether to allow masks to overlap"}
89
+ )
90
+ mask_min_space: Optional[int] = field(
91
+ default=1,
92
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
93
+ )
94
+ require_same_masks: bool = field(
95
+ default=True,
96
+ metadata={
97
+ "help": "whether to number of masked timesteps must be the same across all "
98
+ "examples in a batch"
99
+ },
100
+ )
101
+ mask_dropout: float = field(
102
+ default=0.0,
103
+ metadata={"help": "percent of masks to unmask for each sample"},
104
+ )
105
+
106
+ # channel masking
107
+ mask_channel_length: int = field(
108
+ default=10, metadata={"help": "length of the mask for features (channels)"}
109
+ )
110
+ mask_channel_prob: float = field(
111
+ default=0.0, metadata={"help": "probability of replacing a feature with 0"}
112
+ )
113
+ mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
114
+ default="static",
115
+ metadata={"help": "how to choose mask length for channel masking"},
116
+ )
117
+ mask_channel_other: float = field(
118
+ default=0,
119
+ metadata={
120
+ "help": "secondary mask argument (used for more complex distributions), "
121
+ "see help in compute_mask_indicesh"
122
+ },
123
+ )
124
+ no_mask_channel_overlap: bool = field(
125
+ default=False, metadata={"help": "whether to allow channel masks to overlap"}
126
+ )
127
+ freeze_finetune_updates: int = field(
128
+ default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
129
+ )
130
+ feature_grad_mult: float = field(
131
+ default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
132
+ )
133
+ layerdrop: float = field(
134
+ default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
135
+ )
136
+ mask_channel_min_space: Optional[int] = field(
137
+ default=1,
138
+ metadata={"help": "min space between spans (if no overlap is enabled)"},
139
+ )
140
+ mask_channel_before: bool = False
141
+ normalize: bool = II("task.normalize")
142
+ data: str = II("task.data")
143
+ # this holds the loaded wav2vec args
144
+ d2v_args: Any = None
145
+ offload_activations: bool = field(
146
+ default=False, metadata={"help": "offload_activations"}
147
+ )
148
+ min_params_to_wrap: int = field(
149
+ default=int(1e8),
150
+ metadata={
151
+ "help": "minimum number of params for a layer to be wrapped with FSDP() when "
152
+ "training with --ddp-backend=fully_sharded. Smaller values will "
153
+ "improve memory efficiency, but may make torch.distributed "
154
+ "communication less efficient due to smaller input sizes. This option "
155
+ "is set to 0 (i.e., always wrap) when --checkpoint-activations or "
156
+ "--offload-activations are passed."
157
+ },
158
+ )
159
+
160
+ checkpoint_activations: bool = field(
161
+ default=False,
162
+ metadata={"help": "recompute activations and save memory for extra compute"},
163
+ )
164
+ ddp_backend: str = II("distributed_training.ddp_backend")
165
+
166
+ prediction_mode: str = "lin_softmax"
167
+ eval_prediction_mode: Optional[str] = None
168
+ conv_kernel: int = -1
169
+ conv_stride: int = 1
170
+ two_convs: bool = False
171
+ extreme_factor: float = 1.0
172
+
173
+ conv_feature_layers: Optional[str] = field(
174
+ default=None,
175
+ metadata={
176
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
177
+ "[(dim, kernel_size, stride), ...]"
178
+ },
179
+ )
180
+
181
+ mixup_prob: float = 1.0
182
+ source_mixup: float = -1
183
+ same_mixup: bool = True
184
+ label_mixup: bool = False
185
+
186
+ gain_mode: str = "none"
187
+
188
+
189
+ @register_model("audio_classification", dataclass=AudioClassificationConfig)
190
+ class AudioClassificationModel(BaseFairseqModel):
191
+ def __init__(self, cfg: AudioClassificationConfig, num_classes):
192
+ super().__init__()
193
+
194
+ self.apply_mask = cfg.apply_mask
195
+ self.cfg = cfg
196
+
197
+ arg_overrides = {
198
+ "dropout": cfg.dropout,
199
+ "activation_dropout": cfg.activation_dropout,
200
+ "dropout_input": cfg.dropout_input,
201
+ "attention_dropout": cfg.attention_dropout,
202
+ "mask_length": cfg.mask_length,
203
+ "mask_prob": cfg.mask_prob,
204
+ "require_same_masks": getattr(cfg, "require_same_masks", True),
205
+ "mask_dropout": getattr(cfg, "mask_dropout", 0),
206
+ "mask_selection": cfg.mask_selection,
207
+ "mask_other": cfg.mask_other,
208
+ "no_mask_overlap": cfg.no_mask_overlap,
209
+ "mask_channel_length": cfg.mask_channel_length,
210
+ "mask_channel_prob": cfg.mask_channel_prob,
211
+ "mask_channel_before": cfg.mask_channel_before,
212
+ "mask_channel_selection": cfg.mask_channel_selection,
213
+ "mask_channel_other": cfg.mask_channel_other,
214
+ "no_mask_channel_overlap": cfg.no_mask_channel_overlap,
215
+ "encoder_layerdrop": cfg.layerdrop,
216
+ "feature_grad_mult": cfg.feature_grad_mult,
217
+ "checkpoint_activations": cfg.checkpoint_activations,
218
+ "offload_activations": cfg.offload_activations,
219
+ "min_params_to_wrap": cfg.min_params_to_wrap,
220
+ "mixup": -1,
221
+ }
222
+
223
+ if cfg.conv_feature_layers is not None:
224
+ arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
225
+
226
+ if cfg.d2v_args is None:
227
+ state = checkpoint_utils.load_checkpoint_to_cpu(
228
+ cfg.model_path, arg_overrides
229
+ )
230
+ d2v_args = state.get("cfg", None)
231
+ if d2v_args is None:
232
+ d2v_args = convert_namespace_to_omegaconf(state["args"])
233
+ d2v_args.criterion = None
234
+ d2v_args.lr_scheduler = None
235
+ cfg.d2v_args = d2v_args
236
+
237
+ logger.info(d2v_args)
238
+
239
+ else:
240
+ state = None
241
+ d2v_args = cfg.d2v_args
242
+
243
+ model_normalized = d2v_args.task.get(
244
+ "normalize", d2v_args.model.get("normalize", False)
245
+ )
246
+ assert cfg.normalize == model_normalized, (
247
+ "Fine-tuning works best when data normalization is the same. "
248
+ "Please check that --normalize is set or unset for both pre-training and here"
249
+ )
250
+
251
+ if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
252
+ with open_dict(d2v_args):
253
+ d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
254
+
255
+ d2v_args.task.data = cfg.data
256
+ task = tasks.setup_task(d2v_args.task)
257
+ model = task.build_model(d2v_args.model, from_checkpoint=True)
258
+
259
+ model.remove_pretraining_modules()
260
+
261
+ if state is not None and not cfg.no_pretrained_weights:
262
+ self.load_model_weights(state, model, cfg)
263
+
264
+ d = d2v_args.model.encoder_embed_dim
265
+
266
+ self.d2v_model = model
267
+
268
+ self.final_dropout = nn.Dropout(cfg.final_dropout)
269
+ self.freeze_finetune_updates = cfg.freeze_finetune_updates
270
+ self.num_updates = 0
271
+
272
+ for p in self.parameters():
273
+ p.param_group = "pretrained"
274
+
275
+ if cfg.prediction_mode == "proj_avg_proj":
276
+ self.proj = nn.Linear(d, d * 2)
277
+ self.proj2 = nn.Linear(d * 2, num_classes)
278
+
279
+ for p in self.proj.parameters():
280
+ p.param_group = "projection"
281
+ for p in self.proj2.parameters():
282
+ p.param_group = "projection"
283
+ elif self.cfg.prediction_mode == "summary_proj":
284
+ self.proj = nn.Linear(d // 3, num_classes)
285
+ for p in self.proj.parameters():
286
+ p.param_group = "projection"
287
+ elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
288
+ self.proj = nn.Sequential(
289
+ TransposeLast(),
290
+ nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
291
+ TransposeLast(),
292
+ )
293
+ for p in self.proj.parameters():
294
+ p.param_group = "projection"
295
+ elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
296
+ self.proj = nn.Sequential(
297
+ TransposeLast(),
298
+ nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
299
+ TransposeLast(),
300
+ nn.GELU(),
301
+ nn.Linear(d, num_classes),
302
+ )
303
+ for p in self.proj.parameters():
304
+ p.param_group = "projection"
305
+ else:
306
+ self.proj = nn.Linear(d, num_classes)
307
+ for p in self.proj.parameters():
308
+ p.param_group = "projection"
309
+
310
+ def upgrade_state_dict_named(self, state_dict, name):
311
+ super().upgrade_state_dict_named(state_dict, name)
312
+ return state_dict
313
+
314
+ @classmethod
315
+ def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
316
+ """Build a new model instance."""
317
+
318
+ assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
319
+
320
+ return cls(cfg, len(task.labels))
321
+
322
+ def load_model_weights(self, state, model, cfg):
323
+ if cfg.ddp_backend == "fully_sharded":
324
+ from fairseq.distributed import FullyShardedDataParallel
325
+
326
+ for name, module in model.named_modules():
327
+ if "encoder.layers" in name and len(name.split(".")) == 3:
328
+ # Only for layers, we do a special handling and load the weights one by one
329
+ # We dont load all weights together as that wont be memory efficient and may
330
+ # cause oom
331
+ new_dict = {
332
+ k.replace(name + ".", ""): v
333
+ for (k, v) in state["model"].items()
334
+ if name + "." in k
335
+ }
336
+ assert isinstance(module, FullyShardedDataParallel)
337
+ with module.summon_full_params():
338
+ module.load_state_dict(new_dict, strict=True)
339
+ module._reset_lazy_init()
340
+
341
+ # Once layers are loaded, filter them out and load everything else.
342
+ r = re.compile("encoder.layers.\d.")
343
+ filtered_list = list(filter(r.match, state["model"].keys()))
344
+
345
+ new_big_dict = {
346
+ k: v for (k, v) in state["model"].items() if k not in filtered_list
347
+ }
348
+
349
+ model.load_state_dict(new_big_dict, strict=False)
350
+ else:
351
+ if "_ema" in state["model"]:
352
+ del state["model"]["_ema"]
353
+ model.load_state_dict(state["model"], strict=False)
354
+
355
+ def set_num_updates(self, num_updates):
356
+ """Set the number of parameters updates."""
357
+ super().set_num_updates(num_updates)
358
+ self.num_updates = num_updates
359
+
360
+ def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
361
+ if fs == 16000:
362
+ n_fft = 2048
363
+ elif fs == 44100:
364
+ n_fft = 4096
365
+ else:
366
+ raise Exception("Invalid fs {}".format(fs))
367
+ stride = n_fft // 2
368
+
369
+ def a_weight(fs, n_fft, min_db=-80.0):
370
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
371
+ freq_sq = np.power(freq, 2)
372
+ freq_sq[0] = 1.0
373
+ weight = 2.0 + 20.0 * (
374
+ 2 * np.log10(12194)
375
+ + 2 * np.log10(freq_sq)
376
+ - np.log10(freq_sq + 12194 ** 2)
377
+ - np.log10(freq_sq + 20.6 ** 2)
378
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
379
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
380
+ )
381
+ weight = np.maximum(weight, min_db)
382
+
383
+ return weight
384
+
385
+ gain = []
386
+ for i in range(0, len(sound) - n_fft + 1, stride):
387
+ if mode == "RMSE":
388
+ g = np.mean(sound[i : i + n_fft] ** 2)
389
+ elif mode == "A_weighting":
390
+ spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
391
+ power_spec = np.abs(spec) ** 2
392
+ a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
393
+ g = np.sum(a_weighted_spec)
394
+ else:
395
+ raise Exception("Invalid mode {}".format(mode))
396
+ gain.append(g)
397
+
398
+ gain = np.array(gain)
399
+ gain = np.maximum(gain, np.power(10, min_db / 10))
400
+ gain_db = 10 * np.log10(gain)
401
+
402
+ return gain_db
403
+
404
+ # adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
405
+ def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
406
+ if fs == 16000:
407
+ n_fft = 2048
408
+ elif fs == 44100:
409
+ n_fft = 4096
410
+ else:
411
+ raise Exception("Invalid fs {}".format(fs))
412
+
413
+ if mode == "A_weighting":
414
+ if not hasattr(self, f"a_weight"):
415
+ self.a_weight = {}
416
+
417
+ if fs not in self.a_weight:
418
+
419
+ def a_weight(fs, n_fft, min_db=-80.0):
420
+ freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
421
+ freq_sq = freq ** 2
422
+ freq_sq[0] = 1.0
423
+ weight = 2.0 + 20.0 * (
424
+ 2 * np.log10(12194)
425
+ + 2 * np.log10(freq_sq)
426
+ - np.log10(freq_sq + 12194 ** 2)
427
+ - np.log10(freq_sq + 20.6 ** 2)
428
+ - 0.5 * np.log10(freq_sq + 107.7 ** 2)
429
+ - 0.5 * np.log10(freq_sq + 737.9 ** 2)
430
+ )
431
+ weight = np.maximum(weight, min_db)
432
+
433
+ return weight
434
+
435
+ self.a_weight[fs] = torch.from_numpy(
436
+ np.power(10, a_weight(fs, n_fft, min_db) / 10)
437
+ ).to(device=sound.device)
438
+
439
+ sound = sound.unfold(-1, n_fft, n_fft // 2)
440
+
441
+ if mode == "RMSE":
442
+ sound = sound ** 2
443
+ g = sound.mean(-1)
444
+ elif mode == "A_weighting":
445
+ w = torch.hann_window(n_fft, device=sound.device) * sound
446
+ spec = torch.fft.rfft(w)
447
+ power_spec = spec.abs() ** 2
448
+ a_weighted_spec = power_spec * self.a_weight[fs]
449
+ g = a_weighted_spec.sum(-1)
450
+ else:
451
+ raise Exception("Invalid mode {}".format(mode))
452
+
453
+ gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
454
+ gain_db = 10 * torch.log10(gain)
455
+
456
+ return gain_db
457
+
458
+ def forward(self, source, padding_mask, label=None, **kwargs):
459
+
460
+ if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
461
+ with torch.no_grad():
462
+ mixed_source = source
463
+ mix_mask = None
464
+ if self.cfg.mixup_prob < 1:
465
+ mix_mask = (
466
+ torch.empty((source.size(0),), device=source.device)
467
+ .bernoulli_(self.cfg.mixup_prob)
468
+ .bool()
469
+ )
470
+ mixed_source = source[mix_mask]
471
+
472
+ r = (
473
+ torch.FloatTensor(
474
+ 1 if self.cfg.same_mixup else mixed_source.size(0)
475
+ )
476
+ .uniform_(max(1e-6, self.cfg.source_mixup), 1)
477
+ .to(dtype=source.dtype, device=source.device)
478
+ )
479
+
480
+ mixup_perm = torch.randperm(source.size(0))
481
+ s2 = source[mixup_perm]
482
+
483
+ if self.cfg.gain_mode == "none":
484
+ p = r.unsqueeze(-1)
485
+ if mix_mask is not None:
486
+ s2 = s2[mix_mask]
487
+ else:
488
+ if self.cfg.gain_mode == "naive_rms":
489
+ G1 = source.pow(2).mean(dim=-1).sqrt()
490
+ else:
491
+ G1, _ = self.compute_gain_torch(
492
+ source, mode=self.cfg.gain_mode
493
+ ).max(-1)
494
+ G1 = G1.to(dtype=source.dtype)
495
+
496
+ G2 = G1[mixup_perm]
497
+
498
+ if mix_mask is not None:
499
+ G1 = G1[mix_mask]
500
+ G2 = G2[mix_mask]
501
+ s2 = s2[mix_mask]
502
+
503
+ p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
504
+ p = p.unsqueeze(-1)
505
+
506
+ mixed = (p * mixed_source) + (1 - p) * s2
507
+
508
+ if mix_mask is None:
509
+ source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
510
+ else:
511
+ source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
512
+
513
+ if label is not None and self.cfg.label_mixup:
514
+ r = r.unsqueeze(-1)
515
+ if mix_mask is None:
516
+ label = label * r + (1 - r) * label[mixup_perm]
517
+ else:
518
+ label[mix_mask] = (
519
+ label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
520
+ )
521
+
522
+ d2v_args = {
523
+ "source": source,
524
+ "padding_mask": padding_mask,
525
+ "mask": self.apply_mask and self.training,
526
+ }
527
+
528
+ ft = self.freeze_finetune_updates <= self.num_updates
529
+
530
+ with torch.no_grad() if not ft else contextlib.ExitStack():
531
+ res = self.d2v_model.extract_features(**d2v_args)
532
+
533
+ x = res["x"]
534
+ padding_mask = res["padding_mask"]
535
+ if padding_mask is not None:
536
+ x[padding_mask] = 0
537
+
538
+ x = self.final_dropout(x)
539
+
540
+ if self.training or (
541
+ self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
542
+ ):
543
+ prediction_mode = self.cfg.prediction_mode
544
+ else:
545
+ prediction_mode = self.cfg.eval_prediction_mode
546
+
547
+ if prediction_mode == "average_before":
548
+ x = x.mean(dim=1)
549
+
550
+ if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
551
+ x = self.proj(x)
552
+
553
+ logits = True
554
+ if prediction_mode == "lin_softmax":
555
+ x = F.logsigmoid(x.float())
556
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
557
+ x = x.clamp(max=0)
558
+ x = x - torch.log(-(torch.expm1(x)))
559
+ elif prediction_mode == "extremized_odds":
560
+ x = x.float().sum(dim=1)
561
+ x = x * self.cfg.extreme_factor
562
+ elif prediction_mode == "average_before":
563
+ x = x.float()
564
+ elif prediction_mode == "average":
565
+ x = x.float().mean(dim=1)
566
+ elif prediction_mode == "average_sigmoid":
567
+ x = torch.sigmoid(x.float())
568
+ x = x.mean(dim=1)
569
+ logits = False
570
+ elif prediction_mode == "max":
571
+ x, _ = x.float().max(dim=1)
572
+ elif prediction_mode == "max_sigmoid":
573
+ x = torch.sigmoid(x.float())
574
+ x, _ = x.float().max(dim=1)
575
+ logits = False
576
+ elif prediction_mode == "proj_avg_proj":
577
+ x = x.mean(dim=1)
578
+ x = self.proj2(x)
579
+ elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
580
+ x = self.d2v_model.summary(
581
+ x, padding_mask, proj=prediction_mode == "summary_proj"
582
+ )
583
+ x = x.type_as(source)
584
+ x = self.proj(x)
585
+ elif prediction_mode == "cls":
586
+ x = x[:,0]
587
+ x = self.proj(x)
588
+ else:
589
+ raise Exception(f"unknown prediction mode {prediction_mode}")
590
+
591
+ if label is None:
592
+ return torch.sigmoid(x) if logits else x
593
+
594
+ x = torch.nan_to_num(x)
595
+
596
+ if logits:
597
+ loss = F.binary_cross_entropy_with_logits(
598
+ x, label.float(), reduction="none"
599
+ )
600
+ else:
601
+ loss = F.binary_cross_entropy(x, label.float(), reduction="none")
602
+
603
+ result = {
604
+ "losses": {
605
+ "main": loss,
606
+ },
607
+ "sample_size": label.sum(),
608
+ }
609
+
610
+ if not self.training:
611
+ result["_predictions"] = torch.sigmoid(x) if logits else x
612
+ result["_targets"] = label
613
+
614
+ return result
fairseq/examples/data2vec/models/data2vec2.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional, Callable
10
+ from functools import partial
11
+ import numpy as np
12
+
13
+ from omegaconf import II
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.distributed as dist
19
+
20
+ from fairseq.modules import EMAModule, EMAModuleConfig
21
+
22
+ from fairseq.dataclass import FairseqDataclass
23
+ from fairseq.models import BaseFairseqModel, register_model
24
+
25
+ from examples.data2vec.data.modality import Modality
26
+
27
+ from examples.data2vec.models.modalities.base import (
28
+ MaskSeed,
29
+ D2vModalityConfig,
30
+ ModalitySpecificEncoder,
31
+ get_annealed_rate,
32
+ )
33
+ from examples.data2vec.models.modalities.modules import (
34
+ D2vDecoderConfig,
35
+ AltBlock,
36
+ Decoder1d,
37
+ )
38
+
39
+ from examples.data2vec.models.modalities.audio import (
40
+ D2vAudioConfig,
41
+ AudioEncoder,
42
+ )
43
+ from examples.data2vec.models.modalities.images import (
44
+ D2vImageConfig,
45
+ ImageEncoder,
46
+ )
47
+ from examples.data2vec.models.modalities.text import (
48
+ D2vTextConfig,
49
+ TextEncoder,
50
+ )
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ @dataclass
56
+ class D2vModalitiesConfig(FairseqDataclass):
57
+ audio: D2vAudioConfig = D2vAudioConfig()
58
+ image: D2vImageConfig = D2vImageConfig()
59
+ text: D2vTextConfig = D2vTextConfig()
60
+
61
+
62
+ @dataclass
63
+ class Data2VecMultiConfig(FairseqDataclass):
64
+
65
+ loss_beta: float = field(
66
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
67
+ )
68
+ loss_scale: Optional[float] = field(
69
+ default=None,
70
+ metadata={
71
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
72
+ },
73
+ )
74
+
75
+ depth: int = 8
76
+ start_drop_path_rate: float = 0
77
+ end_drop_path_rate: float = 0
78
+ num_heads: int = 12
79
+ norm_eps: float = 1e-6
80
+ norm_affine: bool = True
81
+ encoder_dropout: float = 0.1
82
+ post_mlp_drop: float = 0.1
83
+ attention_dropout: float = 0.1
84
+ activation_dropout: float = 0.0
85
+ dropout_input: float = 0.0
86
+ layerdrop: float = 0.0
87
+ embed_dim: int = 768
88
+ mlp_ratio: float = 4
89
+ layer_norm_first: bool = False
90
+
91
+ average_top_k_layers: int = field(
92
+ default=8, metadata={"help": "how many layers to average"}
93
+ )
94
+
95
+ end_of_block_targets: bool = False
96
+
97
+ clone_batch: int = 1
98
+
99
+ layer_norm_target_layer: bool = False
100
+ batch_norm_target_layer: bool = False
101
+ instance_norm_target_layer: bool = False
102
+ instance_norm_targets: bool = False
103
+ layer_norm_targets: bool = False
104
+
105
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
106
+ ema_same_dtype: bool = True
107
+ log_norms: bool = True
108
+ ema_end_decay: float = field(
109
+ default=0.9999, metadata={"help": "final ema decay rate"}
110
+ )
111
+
112
+ # when to finish annealing ema decay rate
113
+ ema_anneal_end_step: int = II("optimization.max_update")
114
+
115
+ ema_encoder_only: bool = field(
116
+ default=True,
117
+ metadata={
118
+ "help": "whether to momentum update only the shared transformer encoder"
119
+ },
120
+ )
121
+
122
+ max_update: int = II("optimization.max_update")
123
+
124
+ modalities: D2vModalitiesConfig = D2vModalitiesConfig()
125
+
126
+ shared_decoder: Optional[D2vDecoderConfig] = None
127
+
128
+ min_target_var: float = field(
129
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
130
+ )
131
+ min_pred_var: float = field(
132
+ default=0.01,
133
+ metadata={"help": "stop training if prediction var falls below this"},
134
+ )
135
+
136
+ supported_modality: Optional[Modality] = None
137
+ mae_init: bool = False
138
+
139
+ seed: int = II("common.seed")
140
+
141
+ skip_ema: bool = False
142
+
143
+ cls_loss: float = 0
144
+ recon_loss: float = 0
145
+ d2v_loss: float = 1
146
+
147
+ decoder_group: bool = False
148
+
149
+
150
+ @register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
151
+ class Data2VecMultiModel(BaseFairseqModel):
152
+ def make_modality_encoder(
153
+ self,
154
+ cfg: D2vModalityConfig,
155
+ embed_dim: int,
156
+ make_block: Callable[[float], nn.ModuleList],
157
+ norm_layer: Callable[[int], nn.LayerNorm],
158
+ layer_norm_first: bool,
159
+ alibi_biases,
160
+ task,
161
+ ) -> ModalitySpecificEncoder:
162
+ if cfg.type == Modality.AUDIO:
163
+ enc_cls = AudioEncoder
164
+ elif cfg.type == Modality.IMAGE:
165
+ enc_cls = ImageEncoder
166
+ elif cfg.type == Modality.TEXT:
167
+ enc_cls = TextEncoder
168
+ if hasattr(task, "text_task"):
169
+ task = task.text_task
170
+ else:
171
+ raise Exception(f"unsupported modality {cfg.type}")
172
+
173
+ return enc_cls(
174
+ cfg,
175
+ embed_dim,
176
+ make_block,
177
+ norm_layer,
178
+ layer_norm_first,
179
+ alibi_biases,
180
+ task,
181
+ )
182
+
183
+ def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
184
+ super().__init__()
185
+ self.cfg = cfg
186
+ self.modalities = modalities
187
+ self.task = task
188
+
189
+ make_layer_norm = partial(
190
+ nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
191
+ )
192
+
193
+ def make_block(drop_path, dim=None, heads=None):
194
+ return AltBlock(
195
+ cfg.embed_dim if dim is None else dim,
196
+ cfg.num_heads if heads is None else heads,
197
+ cfg.mlp_ratio,
198
+ qkv_bias=True,
199
+ drop=cfg.encoder_dropout,
200
+ attn_drop=cfg.attention_dropout,
201
+ mlp_drop=cfg.activation_dropout,
202
+ post_mlp_drop=cfg.post_mlp_drop,
203
+ drop_path=drop_path,
204
+ norm_layer=make_layer_norm,
205
+ layer_norm_first=cfg.layer_norm_first,
206
+ ffn_targets=not cfg.end_of_block_targets,
207
+ )
208
+
209
+ self.alibi_biases = {}
210
+ self.modality_encoders = nn.ModuleDict()
211
+ for mod in self.modalities:
212
+ mod_cfg = getattr(cfg.modalities, mod.name.lower())
213
+ enc = self.make_modality_encoder(
214
+ mod_cfg,
215
+ cfg.embed_dim,
216
+ make_block,
217
+ make_layer_norm,
218
+ cfg.layer_norm_first,
219
+ self.alibi_biases,
220
+ task,
221
+ )
222
+ self.modality_encoders[mod.name] = enc
223
+
224
+ self.ema = None
225
+
226
+ self.average_top_k_layers = cfg.average_top_k_layers
227
+ self.loss_beta = cfg.loss_beta
228
+ self.loss_scale = cfg.loss_scale
229
+
230
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
231
+
232
+ dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
233
+
234
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
235
+
236
+ self.norm = None
237
+ if cfg.layer_norm_first:
238
+ self.norm = make_layer_norm(cfg.embed_dim)
239
+
240
+ if self.cfg.mae_init:
241
+ self.apply(self._init_weights)
242
+ else:
243
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
244
+
245
+ self.apply(init_bert_params)
246
+
247
+ for mod_enc in self.modality_encoders.values():
248
+ mod_enc.reset_parameters()
249
+
250
+ if not skip_ema:
251
+ self.ema = self.make_ema_teacher(cfg.ema_decay)
252
+ self.shared_decoder = (
253
+ Decoder1d(cfg.shared_decoder, cfg.embed_dim)
254
+ if self.cfg.shared_decoder is not None
255
+ else None
256
+ )
257
+ if self.shared_decoder is not None:
258
+ self.shared_decoder.apply(self._init_weights)
259
+
260
+ self.recon_proj = None
261
+ if cfg.recon_loss > 0:
262
+ self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
263
+
264
+ for pn, p in self.named_parameters():
265
+ if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
266
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
267
+ if cfg.decoder_group and "decoder" in pn:
268
+ p.param_group = "decoder"
269
+
270
+ self.num_updates = 0
271
+
272
+ def _init_weights(self, m):
273
+
274
+ try:
275
+ from apex.normalization import FusedLayerNorm
276
+
277
+ fn = FusedLayerNorm
278
+ except:
279
+ fn = nn.LayerNorm
280
+
281
+ if isinstance(m, nn.Linear):
282
+ torch.nn.init.xavier_uniform_(m.weight)
283
+ if isinstance(m, nn.Linear) and m.bias is not None:
284
+ nn.init.constant_(m.bias, 0)
285
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
286
+ if m.bias is not None:
287
+ nn.init.constant_(m.bias, 0)
288
+ if m.weight is not None:
289
+ nn.init.constant_(m.weight, 1.0)
290
+
291
+ @torch.no_grad()
292
+ def make_ema_teacher(self, ema_decay):
293
+ ema_config = EMAModuleConfig(
294
+ ema_decay=ema_decay,
295
+ ema_fp32=True,
296
+ log_norms=self.cfg.log_norms,
297
+ add_missing_params=False,
298
+ )
299
+
300
+ model_copy = self.make_target_model()
301
+
302
+ return EMAModule(
303
+ model_copy,
304
+ ema_config,
305
+ copy_model=False,
306
+ )
307
+
308
+ def make_target_model(self):
309
+ logger.info("making target model")
310
+
311
+ model_copy = Data2VecMultiModel(
312
+ self.cfg, self.modalities, skip_ema=True, task=self.task
313
+ )
314
+
315
+ if self.cfg.ema_encoder_only:
316
+ model_copy = model_copy.blocks
317
+ for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
318
+ p_t.data.copy_(p_s.data)
319
+ else:
320
+ for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
321
+ p_t.data.copy_(p_s.data)
322
+
323
+ for mod_enc in model_copy.modality_encoders.values():
324
+ mod_enc.decoder = None
325
+ if not mod_enc.modality_cfg.ema_local_encoder:
326
+ mod_enc.local_encoder = None
327
+ mod_enc.project_features = None
328
+
329
+ model_copy.requires_grad_(False)
330
+ return model_copy
331
+
332
+ def set_num_updates(self, num_updates):
333
+ super().set_num_updates(num_updates)
334
+
335
+ if self.ema is not None and (
336
+ (self.num_updates == 0 and num_updates > 1)
337
+ or self.num_updates >= num_updates
338
+ ):
339
+ pass
340
+ elif self.training and self.ema is not None:
341
+ ema_weight_decay = None
342
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
343
+ if num_updates >= self.cfg.ema_anneal_end_step:
344
+ decay = self.cfg.ema_end_decay
345
+ else:
346
+ decay = get_annealed_rate(
347
+ self.cfg.ema_decay,
348
+ self.cfg.ema_end_decay,
349
+ num_updates,
350
+ self.cfg.ema_anneal_end_step,
351
+ )
352
+ self.ema.set_decay(decay, weight_decay=ema_weight_decay)
353
+ if self.ema.get_decay() < 1:
354
+ self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
355
+
356
+ self.num_updates = num_updates
357
+
358
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
359
+ state = super().state_dict(destination, prefix, keep_vars)
360
+
361
+ if self.ema is not None:
362
+ state[prefix + "_ema"] = self.ema.fp32_params
363
+
364
+ return state
365
+
366
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
367
+ k = prefix + "_ema"
368
+ if self.ema is not None:
369
+ assert k in state_dict
370
+ self.ema.restore(state_dict[k], True)
371
+ del state_dict[k]
372
+ elif k in state_dict:
373
+ del state_dict[k]
374
+
375
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
376
+
377
+ @classmethod
378
+ def build_model(cls, cfg: Data2VecMultiConfig, task=None):
379
+ """Build a new model instance."""
380
+ if task is None or not hasattr(task, "supported_modalities"):
381
+ modalities = (
382
+ [cfg.supported_modality]
383
+ if cfg.supported_modality is not None
384
+ else [
385
+ Modality.AUDIO,
386
+ Modality.IMAGE,
387
+ Modality.TEXT,
388
+ ]
389
+ )
390
+ else:
391
+ modalities = task.supported_modalities
392
+ return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
393
+
394
+ def forward(
395
+ self,
396
+ source,
397
+ target=None,
398
+ id=None,
399
+ mode=None,
400
+ padding_mask=None,
401
+ mask=True,
402
+ features_only=False,
403
+ force_remove_masked=False,
404
+ remove_extra_tokens=True,
405
+ precomputed_mask=None,
406
+ ):
407
+ if mode is None:
408
+ assert self.cfg.supported_modality is not None
409
+ mode = self.cfg.supported_modality
410
+
411
+ if isinstance(mode, Modality):
412
+ mode = mode.name
413
+
414
+ feature_extractor = self.modality_encoders[mode]
415
+
416
+ mask_seeds = None
417
+ if id is not None:
418
+ mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
419
+
420
+ extractor_out = feature_extractor(
421
+ source,
422
+ padding_mask,
423
+ mask,
424
+ remove_masked=not features_only or force_remove_masked,
425
+ clone_batch=self.cfg.clone_batch if not features_only else 1,
426
+ mask_seeds=mask_seeds,
427
+ precomputed_mask=precomputed_mask,
428
+ )
429
+
430
+ x = extractor_out["x"]
431
+ encoder_mask = extractor_out["encoder_mask"]
432
+ masked_padding_mask = extractor_out["padding_mask"]
433
+ masked_alibi_bias = extractor_out.get("alibi_bias", None)
434
+ alibi_scale = extractor_out.get("alibi_scale", None)
435
+
436
+ if self.dropout_input is not None:
437
+ x = self.dropout_input(x)
438
+
439
+ layer_results = []
440
+ for i, blk in enumerate(self.blocks):
441
+ if (
442
+ not self.training
443
+ or self.cfg.layerdrop == 0
444
+ or (np.random.random() > self.cfg.layerdrop)
445
+ ):
446
+ ab = masked_alibi_bias
447
+ if ab is not None and alibi_scale is not None:
448
+ scale = (
449
+ alibi_scale[i]
450
+ if alibi_scale.size(0) > 1
451
+ else alibi_scale.squeeze(0)
452
+ )
453
+ ab = ab * scale.type_as(ab)
454
+
455
+ x, lr = blk(
456
+ x,
457
+ padding_mask=masked_padding_mask,
458
+ alibi_bias=ab,
459
+ )
460
+ if features_only:
461
+ layer_results.append(lr)
462
+
463
+ if self.norm is not None:
464
+ x = self.norm(x)
465
+
466
+ if features_only:
467
+ if remove_extra_tokens:
468
+ x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
469
+ if masked_padding_mask is not None:
470
+ masked_padding_mask = masked_padding_mask[
471
+ :, feature_extractor.modality_cfg.num_extra_tokens :
472
+ ]
473
+
474
+ return {
475
+ "x": x,
476
+ "padding_mask": masked_padding_mask,
477
+ "layer_results": layer_results,
478
+ "mask": encoder_mask,
479
+ }
480
+
481
+ xs = []
482
+
483
+ if self.shared_decoder is not None:
484
+ dx = self.forward_decoder(
485
+ x,
486
+ feature_extractor,
487
+ self.shared_decoder,
488
+ encoder_mask,
489
+ )
490
+ xs.append(dx)
491
+ if feature_extractor.decoder is not None:
492
+ dx = self.forward_decoder(
493
+ x,
494
+ feature_extractor,
495
+ feature_extractor.decoder,
496
+ encoder_mask,
497
+ )
498
+ xs.append(dx)
499
+ orig_x = x
500
+
501
+ assert len(xs) > 0
502
+
503
+ p = next(self.ema.model.parameters())
504
+ device = x.device
505
+ dtype = x.dtype
506
+ ema_device = p.device
507
+ ema_dtype = p.dtype
508
+
509
+ if not self.cfg.ema_same_dtype:
510
+ dtype = ema_dtype
511
+
512
+ if ema_device != device or ema_dtype != dtype:
513
+ logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
514
+ self.ema.model = self.ema.model.to(dtype=dtype, device=device)
515
+ ema_dtype = dtype
516
+
517
+ def to_device(d):
518
+ for k, p in d.items():
519
+ if isinstance(d[k], dict):
520
+ to_device(d[k])
521
+ else:
522
+ d[k] = p.to(device=device)
523
+
524
+ to_device(self.ema.fp32_params)
525
+ tm = self.ema.model
526
+
527
+ with torch.no_grad():
528
+ tm.eval()
529
+
530
+ if self.cfg.ema_encoder_only:
531
+ assert target is None
532
+ ema_input = extractor_out["local_features"]
533
+ ema_input = feature_extractor.contextualized_features(
534
+ ema_input.to(dtype=ema_dtype),
535
+ padding_mask,
536
+ mask=False,
537
+ remove_masked=False,
538
+ )
539
+ ema_blocks = tm
540
+ else:
541
+ ema_blocks = tm.blocks
542
+ if feature_extractor.modality_cfg.ema_local_encoder:
543
+ inp = (
544
+ target.to(dtype=ema_dtype)
545
+ if target is not None
546
+ else source.to(dtype=ema_dtype)
547
+ )
548
+ ema_input = tm.modality_encoders[mode](
549
+ inp,
550
+ padding_mask,
551
+ mask=False,
552
+ remove_masked=False,
553
+ )
554
+ else:
555
+ assert target is None
556
+ ema_input = extractor_out["local_features"]
557
+ ema_feature_enc = tm.modality_encoders[mode]
558
+ ema_input = ema_feature_enc.contextualized_features(
559
+ ema_input.to(dtype=ema_dtype),
560
+ padding_mask,
561
+ mask=False,
562
+ remove_masked=False,
563
+ )
564
+
565
+ ema_padding_mask = ema_input["padding_mask"]
566
+ ema_alibi_bias = ema_input.get("alibi_bias", None)
567
+ ema_alibi_scale = ema_input.get("alibi_scale", None)
568
+ ema_input = ema_input["x"]
569
+
570
+ y = []
571
+ ema_x = []
572
+ extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
573
+ for i, blk in enumerate(ema_blocks):
574
+ ab = ema_alibi_bias
575
+ if ab is not None and alibi_scale is not None:
576
+ scale = (
577
+ ema_alibi_scale[i]
578
+ if ema_alibi_scale.size(0) > 1
579
+ else ema_alibi_scale.squeeze(0)
580
+ )
581
+ ab = ab * scale.type_as(ab)
582
+
583
+ ema_input, lr = blk(
584
+ ema_input,
585
+ padding_mask=ema_padding_mask,
586
+ alibi_bias=ab,
587
+ )
588
+ y.append(lr[:, extra_tokens:])
589
+ ema_x.append(ema_input[:, extra_tokens:])
590
+
591
+ y = self.make_targets(y, self.average_top_k_layers)
592
+ orig_targets = y
593
+
594
+ if self.cfg.clone_batch > 1:
595
+ y = y.repeat_interleave(self.cfg.clone_batch, 0)
596
+
597
+ masked = encoder_mask.mask.unsqueeze(-1)
598
+ masked_b = encoder_mask.mask.bool()
599
+ y = y[masked_b]
600
+
601
+ if xs[0].size(1) == masked_b.size(1):
602
+ xs = [x[masked_b] for x in xs]
603
+ else:
604
+ xs = [x.reshape(-1, x.size(-1)) for x in xs]
605
+
606
+ sample_size = masked.sum().long()
607
+
608
+ result = {
609
+ "losses": {},
610
+ "sample_size": sample_size,
611
+ }
612
+
613
+ sample_size = result["sample_size"]
614
+
615
+ if self.cfg.cls_loss > 0:
616
+ assert extra_tokens > 0
617
+ cls_target = orig_targets.mean(dim=1)
618
+ if self.cfg.clone_batch > 1:
619
+ cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
620
+ cls_pred = x[:, extra_tokens - 1]
621
+ result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
622
+ self.cfg.cls_loss * sample_size
623
+ )
624
+
625
+ if self.cfg.recon_loss > 0:
626
+
627
+ with torch.no_grad():
628
+ target = feature_extractor.patchify(source)
629
+ mean = target.mean(dim=-1, keepdim=True)
630
+ var = target.var(dim=-1, keepdim=True)
631
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
632
+
633
+ if self.cfg.clone_batch > 1:
634
+ target = target.repeat_interleave(self.cfg.clone_batch, 0)
635
+
636
+ if masked_b is not None:
637
+ target = target[masked_b]
638
+
639
+ recon = xs[0]
640
+ if self.recon_proj is not None:
641
+ recon = self.recon_proj(recon)
642
+
643
+ result["losses"]["recon"] = (
644
+ self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
645
+ )
646
+
647
+ if self.cfg.d2v_loss > 0:
648
+ for i, x in enumerate(xs):
649
+ reg_loss = self.d2v_loss(x, y)
650
+ n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
651
+ result["losses"][n] = reg_loss * self.cfg.d2v_loss
652
+
653
+ suffix = "" if len(self.modalities) == 1 else f"_{mode}"
654
+ with torch.no_grad():
655
+ if encoder_mask is not None:
656
+ result["masked_pct"] = 1 - (
657
+ encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
658
+ )
659
+ for i, x in enumerate(xs):
660
+ n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
661
+ result[n] = self.compute_var(x.float())
662
+ if self.ema is not None:
663
+ for k, v in self.ema.logs.items():
664
+ result[k] = v
665
+
666
+ y = y.float()
667
+ result[f"target_var{suffix}"] = self.compute_var(y)
668
+
669
+ if self.num_updates > 5000:
670
+ if result[f"target_var{suffix}"] < self.cfg.min_target_var:
671
+ logger.error(
672
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
673
+ )
674
+ raise Exception(
675
+ f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
676
+ )
677
+
678
+ for k in result.keys():
679
+ if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
680
+ logger.error(
681
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
682
+ )
683
+ raise Exception(
684
+ f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
685
+ )
686
+
687
+ result["ema_decay"] = self.ema.get_decay() * 1000
688
+
689
+ return result
690
+
691
+ def forward_decoder(
692
+ self,
693
+ x,
694
+ feature_extractor,
695
+ decoder,
696
+ mask_info,
697
+ ):
698
+ x = feature_extractor.decoder_input(x, mask_info)
699
+ x = decoder(*x)
700
+
701
+ return x
702
+
703
+ def d2v_loss(self, x, y):
704
+ x = x.view(-1, x.size(-1)).float()
705
+ y = y.view(-1, x.size(-1))
706
+
707
+ if self.loss_beta == 0:
708
+ loss = F.mse_loss(x, y, reduction="none")
709
+ else:
710
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
711
+
712
+ if self.loss_scale is not None:
713
+ scale = self.loss_scale
714
+ else:
715
+ scale = 1 / math.sqrt(x.size(-1))
716
+
717
+ reg_loss = loss * scale
718
+
719
+ return reg_loss
720
+
721
+ def make_targets(self, y, num_layers):
722
+
723
+ with torch.no_grad():
724
+ target_layer_results = y[-num_layers:]
725
+
726
+ permuted = False
727
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
728
+ target_layer_results = [
729
+ tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
730
+ ]
731
+ permuted = True
732
+ if self.cfg.batch_norm_target_layer:
733
+ target_layer_results = [
734
+ F.batch_norm(
735
+ tl.float(), running_mean=None, running_var=None, training=True
736
+ )
737
+ for tl in target_layer_results
738
+ ]
739
+ if self.cfg.instance_norm_target_layer:
740
+ target_layer_results = [
741
+ F.instance_norm(tl.float()) for tl in target_layer_results
742
+ ]
743
+ if permuted:
744
+ target_layer_results = [
745
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
746
+ ]
747
+ if self.cfg.layer_norm_target_layer:
748
+ target_layer_results = [
749
+ F.layer_norm(tl.float(), tl.shape[-1:])
750
+ for tl in target_layer_results
751
+ ]
752
+
753
+ y = target_layer_results[0].float()
754
+ for tl in target_layer_results[1:]:
755
+ y.add_(tl.float())
756
+ y = y.div_(len(target_layer_results))
757
+
758
+ if self.cfg.layer_norm_targets:
759
+ y = F.layer_norm(y, y.shape[-1:])
760
+
761
+ if self.cfg.instance_norm_targets:
762
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
763
+
764
+ return y
765
+
766
+ @staticmethod
767
+ def compute_var(y):
768
+ y = y.view(-1, y.size(-1))
769
+ if dist.is_initialized():
770
+ zc = torch.tensor(y.size(0)).cuda()
771
+ zs = y.sum(dim=0)
772
+ zss = (y**2).sum(dim=0)
773
+
774
+ dist.all_reduce(zc)
775
+ dist.all_reduce(zs)
776
+ dist.all_reduce(zss)
777
+
778
+ var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
779
+ return torch.sqrt(var + 1e-6).mean()
780
+ else:
781
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
782
+
783
+ def extract_features(
784
+ self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
785
+ ):
786
+ res = self.forward(
787
+ source,
788
+ mode=mode,
789
+ padding_mask=padding_mask,
790
+ mask=mask,
791
+ features_only=True,
792
+ remove_extra_tokens=remove_extra_tokens,
793
+ )
794
+ return res
795
+
796
+ def remove_pretraining_modules(self, modality=None, keep_decoder=False):
797
+ self.ema = None
798
+ self.cfg.clone_batch = 1
799
+ self.recon_proj = None
800
+
801
+ if not keep_decoder:
802
+ self.shared_decoder = None
803
+
804
+ modality = modality.lower() if modality is not None else None
805
+ for k in list(self.modality_encoders.keys()):
806
+ if modality is not None and k.lower() != modality:
807
+ del self.modality_encoders[k]
808
+ else:
809
+ self.modality_encoders[k].remove_pretraining_modules(
810
+ keep_decoder=keep_decoder
811
+ )
812
+ if not keep_decoder:
813
+ self.modality_encoders[k].decoder = None
fairseq/examples/data2vec/models/data2vec_audio.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ from omegaconf import II
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.distributed as dist
17
+
18
+ from fairseq.modules import EMAModule, EMAModuleConfig
19
+ from fairseq.data.data_utils import compute_mask_indices
20
+ from fairseq.models import BaseFairseqModel, register_model
21
+ from fairseq.models.wav2vec import (
22
+ ConvFeatureExtractionModel,
23
+ Wav2Vec2Config,
24
+ TransformerEncoder,
25
+ )
26
+ from fairseq.modules import (
27
+ GradMultiply,
28
+ LayerNorm,
29
+ )
30
+ from fairseq.utils import index_put
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ @dataclass
37
+ class Data2VecAudioConfig(Wav2Vec2Config):
38
+
39
+ loss_beta: float = field(
40
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
41
+ )
42
+ loss_scale: Optional[float] = field(
43
+ default=None,
44
+ metadata={
45
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
46
+ },
47
+ )
48
+ average_top_k_layers: int = field(
49
+ default=8, metadata={"help": "how many layers to average"}
50
+ )
51
+
52
+ layer_norm_target_layer: bool = False
53
+ instance_norm_target_layer: bool = False
54
+ instance_norm_targets: bool = False
55
+ layer_norm_targets: bool = False
56
+ batch_norm_target_layer: bool = False
57
+ group_norm_target_layer: bool = False
58
+
59
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
60
+ ema_end_decay: float = field(
61
+ default=0.9999, metadata={"help": "final ema decay rate"}
62
+ )
63
+
64
+ # when to finish annealing ema decay rate
65
+ ema_anneal_end_step: int = II("optimization.max_update")
66
+
67
+ ema_transformer_only: bool = field(
68
+ default=True,
69
+ metadata={"help": "whether to momentum update only the transformer"},
70
+ )
71
+ ema_layers_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer layers"},
74
+ )
75
+
76
+ max_update: int = II("optimization.max_update")
77
+
78
+ min_target_var: float = field(
79
+ default=0.1, metadata={"help": "stop training if target var falls below this"}
80
+ )
81
+ min_pred_var: float = field(
82
+ default=0.01,
83
+ metadata={"help": "stop training if prediction var falls below this"},
84
+ )
85
+
86
+
87
+ def get_annealed_rate(start, end, curr_step, total_steps):
88
+ r = end - start
89
+ pct_remaining = 1 - curr_step / total_steps
90
+ return end - r * pct_remaining
91
+
92
+
93
+ @register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
94
+ class Data2VecAudioModel(BaseFairseqModel):
95
+ def __init__(self, cfg: Data2VecAudioConfig):
96
+ super().__init__()
97
+ self.cfg = cfg
98
+
99
+ feature_enc_layers = eval(cfg.conv_feature_layers)
100
+ self.extractor_embed = feature_enc_layers[-1][0]
101
+
102
+ self.ema = None
103
+ self.embed = cfg.encoder_embed_dim
104
+
105
+ self.average_top_k_layers = cfg.average_top_k_layers
106
+ self.loss_beta = cfg.loss_beta
107
+ self.loss_scale = cfg.loss_scale
108
+
109
+ self.feature_extractor = ConvFeatureExtractionModel(
110
+ conv_layers=feature_enc_layers,
111
+ dropout=0.0,
112
+ mode=cfg.extractor_mode,
113
+ conv_bias=cfg.conv_bias,
114
+ )
115
+
116
+ self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
117
+
118
+ self.mask_prob = cfg.mask_prob
119
+ self.mask_selection = cfg.mask_selection
120
+ self.mask_other = cfg.mask_other
121
+ self.mask_length = cfg.mask_length
122
+ self.no_mask_overlap = cfg.no_mask_overlap
123
+ self.mask_min_space = cfg.mask_min_space
124
+
125
+ self.mask_channel_prob = cfg.mask_channel_prob
126
+ self.mask_channel_before = cfg.mask_channel_before
127
+ self.mask_channel_selection = cfg.mask_channel_selection
128
+ self.mask_channel_other = cfg.mask_channel_other
129
+ self.mask_channel_length = cfg.mask_channel_length
130
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
131
+ self.mask_channel_min_space = cfg.mask_channel_min_space
132
+
133
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
134
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
135
+
136
+ self.feature_grad_mult = cfg.feature_grad_mult
137
+
138
+ self.mask_emb = nn.Parameter(
139
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
140
+ )
141
+
142
+ self.encoder = TransformerEncoder(cfg)
143
+ self.layer_norm = LayerNorm(self.extractor_embed)
144
+
145
+ self.final_proj = nn.Linear(self.embed, self.embed)
146
+
147
+ self.num_updates = 0
148
+
149
+ def make_ema_teacher(self):
150
+ ema_config = EMAModuleConfig(
151
+ ema_decay=self.cfg.ema_decay,
152
+ ema_fp32=True,
153
+ )
154
+ skip_keys = set()
155
+ if self.cfg.ema_layers_only:
156
+ self.cfg.ema_transformer_only = True
157
+ for k, _ in self.encoder.pos_conv.named_parameters():
158
+ skip_keys.add(f"pos_conv.{k}")
159
+
160
+ self.ema = EMAModule(
161
+ self.encoder if self.cfg.ema_transformer_only else self,
162
+ ema_config,
163
+ skip_keys=skip_keys,
164
+ )
165
+
166
+ def set_num_updates(self, num_updates):
167
+ super().set_num_updates(num_updates)
168
+
169
+ if self.ema is None and self.final_proj is not None:
170
+ logger.info(f"making ema teacher")
171
+ self.make_ema_teacher()
172
+ elif self.training and self.ema is not None:
173
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
174
+ if num_updates >= self.cfg.ema_anneal_end_step:
175
+ decay = self.cfg.ema_end_decay
176
+ else:
177
+ decay = get_annealed_rate(
178
+ self.cfg.ema_decay,
179
+ self.cfg.ema_end_decay,
180
+ num_updates,
181
+ self.cfg.ema_anneal_end_step,
182
+ )
183
+ self.ema.set_decay(decay)
184
+ if self.ema.get_decay() < 1:
185
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
186
+
187
+ self.num_updates = num_updates
188
+
189
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
190
+ state = super().state_dict(destination, prefix, keep_vars)
191
+
192
+ if self.ema is not None:
193
+ state[prefix + "_ema"] = self.ema.fp32_params
194
+
195
+ return state
196
+
197
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
198
+ if self.ema is not None:
199
+ k = prefix + "_ema"
200
+ assert k in state_dict
201
+ self.ema.restore(state_dict[k], True)
202
+ del state_dict[k]
203
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
204
+
205
+ @classmethod
206
+ def build_model(cls, cfg: Data2VecAudioConfig, task=None):
207
+ """Build a new model instance."""
208
+
209
+ return cls(cfg)
210
+
211
+ def apply_mask(
212
+ self,
213
+ x,
214
+ padding_mask,
215
+ mask_indices=None,
216
+ mask_channel_indices=None,
217
+ ):
218
+ B, T, C = x.shape
219
+
220
+ if self.mask_channel_prob > 0 and self.mask_channel_before:
221
+ mask_channel_indices = compute_mask_indices(
222
+ (B, C),
223
+ None,
224
+ self.mask_channel_prob,
225
+ self.mask_channel_length,
226
+ self.mask_channel_selection,
227
+ self.mask_channel_other,
228
+ no_overlap=self.no_mask_channel_overlap,
229
+ min_space=self.mask_channel_min_space,
230
+ )
231
+ mask_channel_indices = (
232
+ torch.from_numpy(mask_channel_indices)
233
+ .to(x.device)
234
+ .unsqueeze(1)
235
+ .expand(-1, T, -1)
236
+ )
237
+ x[mask_channel_indices] = 0
238
+
239
+ if self.mask_prob > 0:
240
+ if mask_indices is None:
241
+ mask_indices = compute_mask_indices(
242
+ (B, T),
243
+ padding_mask,
244
+ self.mask_prob,
245
+ self.mask_length,
246
+ self.mask_selection,
247
+ self.mask_other,
248
+ min_masks=1,
249
+ no_overlap=self.no_mask_overlap,
250
+ min_space=self.mask_min_space,
251
+ require_same_masks=self.cfg.require_same_masks,
252
+ mask_dropout=self.cfg.mask_dropout,
253
+ )
254
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
255
+ x = index_put(x, mask_indices, self.mask_emb)
256
+ else:
257
+ mask_indices = None
258
+
259
+ if self.mask_channel_prob > 0 and not self.mask_channel_before:
260
+ if mask_channel_indices is None:
261
+ mask_channel_indices = compute_mask_indices(
262
+ (B, C),
263
+ None,
264
+ self.mask_channel_prob,
265
+ self.mask_channel_length,
266
+ self.mask_channel_selection,
267
+ self.mask_channel_other,
268
+ no_overlap=self.no_mask_channel_overlap,
269
+ min_space=self.mask_channel_min_space,
270
+ )
271
+ mask_channel_indices = (
272
+ torch.from_numpy(mask_channel_indices)
273
+ .to(x.device)
274
+ .unsqueeze(1)
275
+ .expand(-1, T, -1)
276
+ )
277
+ x = index_put(x, mask_channel_indices, 0)
278
+
279
+ return x, mask_indices
280
+
281
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
282
+ """
283
+ Computes the output length of the convolutional layers
284
+ """
285
+
286
+ def _conv_out_length(input_length, kernel_size, stride):
287
+ return torch.floor((input_length - kernel_size) / stride + 1)
288
+
289
+ conv_cfg_list = eval(self.cfg.conv_feature_layers)
290
+
291
+ for i in range(len(conv_cfg_list)):
292
+ input_lengths = _conv_out_length(
293
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
294
+ )
295
+
296
+ return input_lengths.to(torch.long)
297
+
298
+ def forward(
299
+ self,
300
+ source,
301
+ padding_mask=None,
302
+ mask=True,
303
+ features_only=False,
304
+ layer=None,
305
+ mask_indices=None,
306
+ mask_channel_indices=None,
307
+ padding_count=None,
308
+ ):
309
+ features = source
310
+
311
+ if self.feature_grad_mult > 0:
312
+ features = self.feature_extractor(features)
313
+ if self.feature_grad_mult != 1.0:
314
+ features = GradMultiply.apply(features, self.feature_grad_mult)
315
+ else:
316
+ with torch.no_grad():
317
+ features = self.feature_extractor(features)
318
+
319
+ features = features.transpose(1, 2)
320
+
321
+ features = self.layer_norm(features)
322
+
323
+ orig_padding_mask = padding_mask
324
+
325
+ if padding_mask is not None and padding_mask.any():
326
+ input_lengths = (1 - padding_mask.long()).sum(-1)
327
+ # apply conv formula to get real output_lengths
328
+ output_lengths = self._get_feat_extract_output_lengths(input_lengths)
329
+
330
+ padding_mask = torch.zeros(
331
+ features.shape[:2], dtype=features.dtype, device=features.device
332
+ )
333
+
334
+ # these two operations makes sure that all values
335
+ # before the output lengths indices are attended to
336
+ padding_mask[
337
+ (
338
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
339
+ output_lengths - 1,
340
+ )
341
+ ] = 1
342
+ padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
343
+ else:
344
+ padding_mask = None
345
+
346
+ if self.post_extract_proj is not None:
347
+ features = self.post_extract_proj(features)
348
+
349
+ pre_encoder_features = None
350
+ if self.cfg.ema_transformer_only:
351
+ pre_encoder_features = features.clone()
352
+
353
+ features = self.dropout_input(features)
354
+
355
+ if mask:
356
+ x, mask_indices = self.apply_mask(
357
+ features,
358
+ padding_mask,
359
+ mask_indices=mask_indices,
360
+ mask_channel_indices=mask_channel_indices,
361
+ )
362
+ else:
363
+ x = features
364
+ mask_indices = None
365
+
366
+ x, layer_results = self.encoder(
367
+ x,
368
+ padding_mask=padding_mask,
369
+ layer=layer,
370
+ )
371
+
372
+ if features_only:
373
+ return {
374
+ "x": x,
375
+ "padding_mask": padding_mask,
376
+ "layer_results": layer_results,
377
+ }
378
+
379
+ result = {
380
+ "losses": {},
381
+ }
382
+
383
+ with torch.no_grad():
384
+ self.ema.model.eval()
385
+
386
+ if self.cfg.ema_transformer_only:
387
+ y, layer_results = self.ema.model.extract_features(
388
+ pre_encoder_features,
389
+ padding_mask=padding_mask,
390
+ min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
391
+ )
392
+ y = {
393
+ "x": y,
394
+ "padding_mask": padding_mask,
395
+ "layer_results": layer_results,
396
+ }
397
+ else:
398
+ y = self.ema.model.extract_features(
399
+ source=source,
400
+ padding_mask=orig_padding_mask,
401
+ mask=False,
402
+ )
403
+
404
+ target_layer_results = [l[2] for l in y["layer_results"]]
405
+
406
+ permuted = False
407
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
408
+ target_layer_results = [
409
+ tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
410
+ ]
411
+ permuted = True
412
+
413
+ if self.cfg.batch_norm_target_layer:
414
+ target_layer_results = [
415
+ F.batch_norm(
416
+ tl.float(), running_mean=None, running_var=None, training=True
417
+ )
418
+ for tl in target_layer_results
419
+ ]
420
+
421
+ if self.cfg.instance_norm_target_layer:
422
+ target_layer_results = [
423
+ F.instance_norm(tl.float()) for tl in target_layer_results
424
+ ]
425
+
426
+ if permuted:
427
+ target_layer_results = [
428
+ tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
429
+ ]
430
+
431
+ if self.cfg.group_norm_target_layer:
432
+ target_layer_results = [
433
+ F.layer_norm(tl.float(), tl.shape[-2:])
434
+ for tl in target_layer_results
435
+ ]
436
+
437
+ if self.cfg.layer_norm_target_layer:
438
+ target_layer_results = [
439
+ F.layer_norm(tl.float(), tl.shape[-1:])
440
+ for tl in target_layer_results
441
+ ]
442
+
443
+ y = sum(target_layer_results) / len(target_layer_results)
444
+
445
+ if self.cfg.layer_norm_targets:
446
+ y = F.layer_norm(y.float(), y.shape[-1:])
447
+
448
+ if self.cfg.instance_norm_targets:
449
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
450
+
451
+ if not permuted:
452
+ y = y.transpose(0, 1)
453
+
454
+ y = y[mask_indices]
455
+
456
+ x = x[mask_indices]
457
+ x = self.final_proj(x)
458
+
459
+ sz = x.size(-1)
460
+
461
+ if self.loss_beta == 0:
462
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
463
+ else:
464
+ loss = F.smooth_l1_loss(
465
+ x.float(), y.float(), reduction="none", beta=self.loss_beta
466
+ ).sum(dim=-1)
467
+
468
+ if self.loss_scale is not None:
469
+ scale = self.loss_scale
470
+ else:
471
+ scale = 1 / math.sqrt(sz)
472
+
473
+ result["losses"]["regression"] = loss.sum() * scale
474
+
475
+ if "sample_size" not in result:
476
+ result["sample_size"] = loss.numel()
477
+
478
+ with torch.no_grad():
479
+ result["target_var"] = self.compute_var(y)
480
+ result["pred_var"] = self.compute_var(x.float())
481
+
482
+ if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
483
+ logger.error(
484
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
485
+ )
486
+ raise Exception(
487
+ f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
488
+ )
489
+ if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
490
+ logger.error(
491
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
492
+ )
493
+ raise Exception(
494
+ f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
495
+ )
496
+
497
+ if self.ema is not None:
498
+ result["ema_decay"] = self.ema.get_decay() * 1000
499
+
500
+ return result
501
+
502
+ @staticmethod
503
+ def compute_var(y):
504
+ y = y.view(-1, y.size(-1))
505
+ if dist.is_initialized():
506
+ zc = torch.tensor(y.size(0)).cuda()
507
+ zs = y.sum(dim=0)
508
+ zss = (y ** 2).sum(dim=0)
509
+
510
+ dist.all_reduce(zc)
511
+ dist.all_reduce(zs)
512
+ dist.all_reduce(zss)
513
+
514
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
515
+ return torch.sqrt(var + 1e-6).mean()
516
+ else:
517
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
518
+
519
+ def extract_features(
520
+ self, source, padding_mask, mask=False, layer=None
521
+ ):
522
+ res = self.forward(
523
+ source,
524
+ padding_mask,
525
+ mask=mask,
526
+ features_only=True,
527
+ layer=layer,
528
+ )
529
+ return res
530
+
531
+ def remove_pretraining_modules(self, last_layer=None):
532
+ self.final_proj = None
533
+ self.ema = None
534
+ if last_layer is not None:
535
+ self.encoder.layers = nn.ModuleList(
536
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
537
+ )
fairseq/examples/data2vec/models/data2vec_image_classification.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # The code in this file is adapted from the BeiT implementation which can be found here:
7
+ # https://github.com/microsoft/unilm/tree/master/beit
8
+
9
+ import logging
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any
13
+
14
+ from omegaconf import II, MISSING
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from fairseq import checkpoint_utils, tasks
21
+
22
+ from fairseq.dataclass import FairseqDataclass
23
+ from fairseq.models import BaseFairseqModel, register_model
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class Data2VecImageClassificationConfig(FairseqDataclass):
31
+ model_path: str = MISSING
32
+ no_pretrained_weights: bool = False
33
+ num_classes: int = 1000
34
+ mixup: float = 0.8
35
+ cutmix: float = 1.0
36
+ label_smoothing: float = 0.1
37
+
38
+ pretrained_model_args: Any = None
39
+ data: str = II("task.data")
40
+
41
+
42
+ @register_model(
43
+ "data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
44
+ )
45
+ class Data2VecImageClassificationModel(BaseFairseqModel):
46
+ def __init__(self, cfg: Data2VecImageClassificationConfig):
47
+ super().__init__()
48
+ self.cfg = cfg
49
+
50
+ if cfg.pretrained_model_args is None:
51
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
52
+ pretrained_args = state.get("cfg", None)
53
+ pretrained_args.criterion = None
54
+ pretrained_args.lr_scheduler = None
55
+ cfg.pretrained_model_args = pretrained_args
56
+
57
+ logger.info(pretrained_args)
58
+ else:
59
+ state = None
60
+ pretrained_args = cfg.pretrained_model_args
61
+
62
+ pretrained_args.task.data = cfg.data
63
+ task = tasks.setup_task(pretrained_args.task)
64
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
65
+
66
+ model.remove_pretraining_modules()
67
+
68
+ self.model = model
69
+
70
+ if state is not None and not cfg.no_pretrained_weights:
71
+ self.load_model_weights(state, model, cfg)
72
+
73
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
74
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
75
+
76
+ self.head.weight.data.mul_(1e-3)
77
+ self.head.bias.data.mul_(1e-3)
78
+
79
+ self.mixup_fn = None
80
+
81
+ if cfg.mixup > 0 or cfg.cutmix > 0:
82
+ from timm.data import Mixup
83
+
84
+ self.mixup_fn = Mixup(
85
+ mixup_alpha=cfg.mixup,
86
+ cutmix_alpha=cfg.cutmix,
87
+ cutmix_minmax=None,
88
+ prob=1.0,
89
+ switch_prob=0.5,
90
+ mode="batch",
91
+ label_smoothing=cfg.label_smoothing,
92
+ num_classes=cfg.num_classes,
93
+ )
94
+
95
+ def load_model_weights(self, state, model, cfg):
96
+ if "_ema" in state["model"]:
97
+ del state["model"]["_ema"]
98
+ model.load_state_dict(state["model"], strict=True)
99
+
100
+ @classmethod
101
+ def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
102
+ """Build a new model instance."""
103
+
104
+ return cls(cfg)
105
+
106
+ def forward(
107
+ self,
108
+ img,
109
+ label=None,
110
+ ):
111
+ if self.training and self.mixup_fn is not None and label is not None:
112
+ img, label = self.mixup_fn(img, label)
113
+
114
+ x = self.model(img, mask=False)
115
+ x = x[:, 1:]
116
+ x = self.fc_norm(x.mean(1))
117
+ x = self.head(x)
118
+
119
+ if label is None:
120
+ return x
121
+
122
+ if self.training and self.mixup_fn is not None:
123
+ loss = -label * F.log_softmax(x.float(), dim=-1)
124
+ else:
125
+ loss = F.cross_entropy(
126
+ x.float(),
127
+ label,
128
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
129
+ reduction="none",
130
+ )
131
+
132
+ result = {
133
+ "losses": {"regression": loss},
134
+ "sample_size": img.size(0),
135
+ }
136
+
137
+ if not self.training:
138
+ with torch.no_grad():
139
+ pred = x.argmax(-1)
140
+ correct = (pred == label).sum()
141
+ result["correct"] = correct
142
+
143
+ return result
fairseq/examples/data2vec/models/data2vec_text.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from omegaconf import II
16
+
17
+ from fairseq.dataclass import FairseqDataclass
18
+ from fairseq.modules import EMAModule, EMAModuleConfig
19
+ from fairseq.models import (
20
+ FairseqEncoder,
21
+ FairseqEncoderModel,
22
+ register_model,
23
+ )
24
+ from fairseq.models.roberta.model import RobertaLMHead, RobertaClassificationHead
25
+ from fairseq.models.transformer import TransformerEncoder, TransformerConfig
26
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ @dataclass
32
+ class Data2VecTextConfig(FairseqDataclass):
33
+ max_positions: int = II("task.tokens_per_sample")
34
+
35
+ head_layers: int = 1
36
+
37
+ transformer: TransformerConfig = TransformerConfig()
38
+
39
+ load_checkpoint_heads: bool = field(
40
+ default=False,
41
+ metadata={"help": "(re-)register and load heads when loading checkpoints"},
42
+ )
43
+
44
+ loss_beta: float = field(
45
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
46
+ )
47
+ loss_scale: Optional[float] = field(
48
+ default=None,
49
+ metadata={
50
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
51
+ },
52
+ )
53
+ average_top_k_layers: int = field(
54
+ default=8, metadata={"help": "how many layers to average"}
55
+ )
56
+
57
+ layer_norm_target_layer: bool = False
58
+ instance_norm_target_layer: bool = False
59
+ batch_norm_target_layer: bool = False
60
+ instance_norm_targets: bool = False
61
+ layer_norm_targets: bool = False
62
+
63
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
64
+ ema_end_decay: float = field(
65
+ default=0.9999, metadata={"help": "final ema decay rate"}
66
+ )
67
+
68
+ # when to finish annealing ema decay rate
69
+ ema_anneal_end_step: int = II("optimization.max_update")
70
+
71
+ ema_transformer_layers_only: bool = field(
72
+ default=True,
73
+ metadata={"help": "whether to momentum update only the transformer layers"},
74
+ )
75
+
76
+
77
+ def get_annealed_rate(start, end, curr_step, total_steps):
78
+ r = end - start
79
+ pct_remaining = 1 - curr_step / total_steps
80
+ return end - r * pct_remaining
81
+
82
+
83
+ @register_model("data2vec_text", dataclass=Data2VecTextConfig)
84
+ class Data2VecTextModel(FairseqEncoderModel):
85
+ def __init__(self, cfg: Data2VecTextConfig, encoder):
86
+ super().__init__(encoder)
87
+ self.cfg = cfg
88
+
89
+ # We follow BERT's random weight initialization
90
+ self.apply(init_bert_params)
91
+
92
+ self.classification_heads = nn.ModuleDict()
93
+
94
+ @classmethod
95
+ def build_model(cls, cfg, task):
96
+ """Build a new model instance."""
97
+
98
+ encoder = Data2VecTextEncoder(cfg, task.source_dictionary, task.cfg.data)
99
+
100
+ return cls(cfg, encoder)
101
+
102
+ def forward(
103
+ self,
104
+ src_tokens,
105
+ target_tokens=None,
106
+ features_only=False,
107
+ return_all_hiddens=False,
108
+ classification_head_name=None,
109
+ **kwargs,
110
+ ):
111
+ if classification_head_name is not None:
112
+ features_only = True
113
+
114
+ res = self.encoder(
115
+ src_tokens, target_tokens, features_only, return_all_hiddens, **kwargs
116
+ )
117
+
118
+ if isinstance(res, tuple):
119
+ x, extra = res
120
+ else:
121
+ return res
122
+
123
+ if classification_head_name is not None:
124
+ x = self.classification_heads[classification_head_name](x)
125
+ return x, extra
126
+
127
+ def get_normalized_probs(self, net_output, log_probs, sample=None):
128
+ """Get normalized probabilities (or log probs) from a net's output."""
129
+ logits = net_output[0].float()
130
+ if log_probs:
131
+ return F.log_softmax(logits, dim=-1)
132
+ else:
133
+ return F.softmax(logits, dim=-1)
134
+
135
+ def register_classification_head(
136
+ self, name, num_classes=None, inner_dim=None, **kwargs
137
+ ):
138
+ """Register a classification head."""
139
+ if name in self.classification_heads:
140
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
141
+ prev_inner_dim = self.classification_heads[name].dense.out_features
142
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
143
+ logger.warning(
144
+ 're-registering head "{}" with num_classes {} (prev: {}) '
145
+ "and inner_dim {} (prev: {})".format(
146
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
147
+ )
148
+ )
149
+ self.classification_heads[name] = RobertaClassificationHead(
150
+ input_dim=self.cfg.transformer.encoder.embed_dim,
151
+ inner_dim=inner_dim or self.cfg.transformer.encoder.embed_dim,
152
+ num_classes=num_classes,
153
+ activation_fn="tanh",
154
+ pooler_dropout=0,
155
+ )
156
+
157
+ @property
158
+ def supported_targets(self):
159
+ return {"self"}
160
+
161
+ def upgrade_state_dict_named(self, state_dict, name):
162
+ prefix = name + "." if name != "" else ""
163
+
164
+ # rename decoder -> encoder before upgrading children modules
165
+ for k in list(state_dict.keys()):
166
+ if k.startswith(prefix + "decoder"):
167
+ new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
168
+ state_dict[new_k] = state_dict[k]
169
+ del state_dict[k]
170
+
171
+ # rename emb_layer_norm -> layernorm_embedding
172
+ for k in list(state_dict.keys()):
173
+ if ".emb_layer_norm." in k:
174
+ new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
175
+ state_dict[new_k] = state_dict[k]
176
+ del state_dict[k]
177
+
178
+ if self.encoder.regression_head is not None:
179
+ if ".lm_head." in k:
180
+ new_k = k.replace(".lm_head.", ".regression_head.")
181
+ state_dict[new_k] = state_dict[k]
182
+ del state_dict[k]
183
+ else:
184
+ if ".regression_head." in k:
185
+ del state_dict[k]
186
+
187
+ # upgrade children modules
188
+ super().upgrade_state_dict_named(state_dict, name)
189
+
190
+ # Handle new classification heads present in the state dict.
191
+ current_head_names = (
192
+ []
193
+ if not hasattr(self, "classification_heads")
194
+ or self.classification_heads is None
195
+ else self.classification_heads.keys()
196
+ )
197
+ keys_to_delete = []
198
+ for k in state_dict.keys():
199
+ if not k.startswith(prefix + "classification_heads."):
200
+ continue
201
+
202
+ head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
203
+ num_classes = state_dict[
204
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
205
+ ].size(0)
206
+ inner_dim = state_dict[
207
+ prefix + "classification_heads." + head_name + ".dense.weight"
208
+ ].size(0)
209
+
210
+ if self.cfg.load_checkpoint_heads:
211
+ if head_name not in current_head_names:
212
+ self.register_classification_head(head_name, num_classes, inner_dim)
213
+ else:
214
+ if head_name not in current_head_names:
215
+ logger.warning(
216
+ "deleting classification head ({}) from checkpoint "
217
+ "not present in current model: {}".format(head_name, k)
218
+ )
219
+ keys_to_delete.append(k)
220
+ elif (
221
+ num_classes
222
+ != self.classification_heads[head_name].out_proj.out_features
223
+ or inner_dim
224
+ != self.classification_heads[head_name].dense.out_features
225
+ ):
226
+ logger.warning(
227
+ "deleting classification head ({}) from checkpoint "
228
+ "with different dimensions than current model: {}".format(
229
+ head_name, k
230
+ )
231
+ )
232
+ keys_to_delete.append(k)
233
+ for k in keys_to_delete:
234
+ del state_dict[k]
235
+
236
+ # Copy any newly-added classification heads into the state dict
237
+ # with their current weights.
238
+ if (
239
+ hasattr(self, "classification_heads")
240
+ and self.classification_heads is not None
241
+ and len(self.classification_heads) > 0
242
+ ):
243
+ cur_state = self.classification_heads.state_dict()
244
+ for k, v in cur_state.items():
245
+ if prefix + "classification_heads." + k not in state_dict:
246
+ logger.info("Overwriting " + prefix + "classification_heads." + k)
247
+ state_dict[prefix + "classification_heads." + k] = v
248
+
249
+ for k in list(state_dict.keys()):
250
+ if k.startswith(prefix + "encoder.lm_head.") or k.startswith(
251
+ prefix + "encoder.emb_head."
252
+ ):
253
+ del state_dict[k]
254
+
255
+ self.encoder.lm_head = None
256
+
257
+ if self.encoder.target_model is None:
258
+ for k in list(state_dict.keys()):
259
+ if k.startswith(prefix + "encoder.target_model."):
260
+ del state_dict[k]
261
+
262
+ if (self.encoder.ema is None) and (prefix + "encoder._ema" in state_dict):
263
+ del state_dict[prefix + "encoder._ema"]
264
+
265
+ def remove_pretraining_modules(self, last_layer=None):
266
+ self.encoder.lm_head = None
267
+ self.encoder.regression_head = None
268
+ self.encoder.ema = None
269
+ self.classification_heads = None
270
+
271
+ if last_layer is not None:
272
+ self.encoder.sentence_encoder.layers = nn.ModuleList(
273
+ l
274
+ for i, l in enumerate(self.encoder.sentence_encoder.layers)
275
+ if i <= last_layer
276
+ )
277
+ self.encoder.sentence_encoder.layer_norm = None
278
+
279
+
280
+ class Data2VecTextEncoder(FairseqEncoder):
281
+ def __init__(self, cfg: Data2VecTextConfig, dictionary, task_data):
282
+ super().__init__(dictionary)
283
+
284
+ self.cfg = cfg
285
+
286
+ embed_tokens = self.build_embedding(
287
+ len(dictionary), cfg.transformer.encoder.embed_dim, dictionary.pad()
288
+ )
289
+
290
+ self.sentence_encoder = self.build_encoder(cfg, dictionary, embed_tokens)
291
+ self.mask_idx = dictionary.index("<mask>")
292
+ assert self.mask_idx != dictionary.unk(), dictionary.symbols
293
+
294
+ self.ema = None
295
+ self.average_top_k_layers = cfg.average_top_k_layers
296
+ self.loss_scale = cfg.loss_scale
297
+
298
+ assert self.cfg.head_layers >= 1
299
+
300
+ embed_dim = cfg.transformer.encoder.embed_dim
301
+ curr_dim = embed_dim
302
+ projs = []
303
+ for i in range(self.cfg.head_layers - 1):
304
+ next_dim = embed_dim * 2 if i == 0 else curr_dim
305
+ projs.append(nn.Linear(curr_dim, next_dim))
306
+ projs.append(nn.GELU())
307
+ curr_dim = next_dim
308
+
309
+ projs.append(nn.Linear(curr_dim, embed_dim))
310
+ self.regression_head = nn.Sequential(*projs)
311
+
312
+ self.num_updates = 0
313
+
314
+ def build_embedding(self, vocab_size, embedding_dim, padding_idx):
315
+ return nn.Embedding(vocab_size, embedding_dim, padding_idx)
316
+
317
+ def build_encoder(self, cfg, dictionary, embed_tokens):
318
+ encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True)
319
+ encoder.apply(init_bert_params)
320
+ return encoder
321
+
322
+ def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
323
+ return RobertaLMHead(embed_dim, output_dim, activation_fn, weight)
324
+
325
+ def make_ema_teacher(self):
326
+ ema_config = EMAModuleConfig(
327
+ ema_decay=self.cfg.ema_decay,
328
+ ema_fp32=True,
329
+ )
330
+ skip_keys = set()
331
+ if self.cfg.ema_transformer_layers_only:
332
+ for k, _ in self.sentence_encoder.embed_positions.named_parameters():
333
+ skip_keys.add(f"embed_tokens.{k}")
334
+ for k, _ in self.sentence_encoder.embed_positions.named_parameters():
335
+ skip_keys.add(f"embed_positions.{k}")
336
+ if self.sentence_encoder.layernorm_embedding is not None:
337
+ for (
338
+ k,
339
+ _,
340
+ ) in self.sentence_encoder.layernorm_embedding.named_parameters():
341
+ skip_keys.add(f"layernorm_embedding.{k}")
342
+ if self.sentence_encoder.layer_norm is not None:
343
+ for k, _ in self.sentence_encoder.layer_norm.named_parameters():
344
+ skip_keys.add(f"layernorm_embedding.{k}")
345
+
346
+ self.ema = EMAModule(
347
+ self.sentence_encoder,
348
+ ema_config,
349
+ skip_keys=skip_keys,
350
+ )
351
+
352
+ def set_num_updates(self, num_updates):
353
+ super().set_num_updates(num_updates)
354
+
355
+ if self.ema is None and self.regression_head is not None:
356
+ logger.info(f"making ema teacher")
357
+ self.make_ema_teacher()
358
+ elif self.training and self.ema is not None:
359
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
360
+ if num_updates >= self.cfg.ema_anneal_end_step:
361
+ decay = self.cfg.ema_end_decay
362
+ else:
363
+ decay = get_annealed_rate(
364
+ self.cfg.ema_decay,
365
+ self.cfg.ema_end_decay,
366
+ num_updates,
367
+ self.cfg.ema_anneal_end_step,
368
+ )
369
+ self.ema.set_decay(decay)
370
+ if self.ema.get_decay() < 1:
371
+ self.ema.step(self.sentence_encoder)
372
+
373
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
374
+ state = super().state_dict(destination, prefix, keep_vars)
375
+ if self.ema is not None:
376
+ state[prefix + "_ema"] = self.ema.fp32_params
377
+ return state
378
+
379
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
380
+ if self.ema is not None:
381
+ k = prefix + "_ema"
382
+ assert k in state_dict
383
+ self.ema.restore(state_dict[k], True)
384
+ del state_dict[k]
385
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
386
+
387
+ def forward(
388
+ self,
389
+ src_tokens,
390
+ target_tokens=None,
391
+ features_only=False,
392
+ return_all_hiddens=False,
393
+ masked_tokens=None,
394
+ **unused,
395
+ ):
396
+ """
397
+ Args:
398
+ src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
399
+ features_only (bool, optional): skip LM head and just return
400
+ features. If True, the output will be of shape
401
+ `(batch, src_len, embed_dim)`.
402
+ return_all_hiddens (bool, optional): also return all of the
403
+ intermediate hidden states (default: False).
404
+
405
+ Returns:
406
+ tuple:
407
+ - the LM output of shape `(batch, src_len, vocab)`
408
+ - a dictionary of additional data, where 'inner_states'
409
+ is a list of hidden states. Note that the hidden
410
+ states have shape `(src_len, batch, vocab)`.
411
+ """
412
+
413
+ x, extra = self.extract_features(
414
+ src_tokens, return_all_hiddens=return_all_hiddens
415
+ )
416
+
417
+ if features_only:
418
+ return x, extra
419
+
420
+ assert target_tokens is not None
421
+
422
+ with torch.no_grad():
423
+ # use EMA parameter as the teacher
424
+ self.ema.model.eval()
425
+
426
+ encoder_out = self.ema.model(
427
+ target_tokens,
428
+ return_all_hiddens=True,
429
+ )
430
+ y = encoder_out["fc_results"]
431
+
432
+ y = y[-self.average_top_k_layers :]
433
+
434
+ permuted = False
435
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
436
+ y = [tl.permute(1, 2, 0) for tl in y] # TBC -> BCT
437
+ permuted = True
438
+
439
+ if self.cfg.batch_norm_target_layer:
440
+ y = [
441
+ F.batch_norm(
442
+ tl.float(), running_mean=None, running_var=None, training=True
443
+ )
444
+ for tl in y
445
+ ]
446
+
447
+ if self.cfg.instance_norm_target_layer:
448
+ y = [F.instance_norm(tl.float()) for tl in y]
449
+
450
+ if permuted:
451
+ y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
452
+
453
+ if self.cfg.layer_norm_target_layer:
454
+ y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
455
+
456
+ y = sum(y) / len(y)
457
+
458
+ if not permuted:
459
+ y = y.transpose(0, 1)
460
+
461
+ if self.cfg.layer_norm_targets:
462
+ y = F.layer_norm(y.float(), y.shape[-1:])
463
+
464
+ if self.cfg.instance_norm_targets:
465
+ y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
466
+
467
+ masked_indices = src_tokens.eq(self.mask_idx)
468
+
469
+ x = x[masked_indices]
470
+ y = y[masked_indices]
471
+
472
+ x = self.regression_head(x)
473
+
474
+ sz = x.size(-1)
475
+ if self.cfg.loss_beta == 0:
476
+ loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
477
+ else:
478
+ loss = F.smooth_l1_loss(
479
+ x.float(), y.float(), reduction="none", beta=self.cfg.loss_beta
480
+ ).sum(dim=-1)
481
+
482
+ result = {
483
+ "losses": {
484
+ "main": loss.sum() / math.sqrt(sz)
485
+ if self.loss_scale <= 0
486
+ else loss.sum() * self.loss_scale,
487
+ },
488
+ "sample_size": loss.numel(),
489
+ }
490
+
491
+ # logging other values
492
+ other_logs = {
493
+ "ema_decay": self.ema.get_decay() * 1000
494
+ }
495
+ result["logs"] = other_logs
496
+ return result
497
+
498
+ def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
499
+ encoder_out = self.sentence_encoder(
500
+ src_tokens,
501
+ return_all_hiddens=return_all_hiddens,
502
+ token_embeddings=kwargs.get("token_embeddings", None),
503
+ )
504
+ # T x B x C -> B x T x C
505
+ features = encoder_out["encoder_out"][0].transpose(0, 1)
506
+ inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
507
+ return features, {
508
+ "inner_states": inner_states,
509
+ "encoder_embedding": encoder_out["encoder_embedding"][0],
510
+ }
511
+
512
+ def output_layer(self, features, masked_tokens=None, **unused):
513
+ return self.lm_head(features, masked_tokens)
514
+
515
+ def max_positions(self):
516
+ """Maximum output length supported by the encoder."""
517
+ return self.cfg.max_positions
fairseq/examples/data2vec/models/data2vec_text_classification.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # The code in this file is adapted from the BeiT implementation which can be found here:
7
+ # https://github.com/microsoft/unilm/tree/master/beit
8
+
9
+ import logging
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any
13
+
14
+ from omegaconf import II, MISSING
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from fairseq import checkpoint_utils, tasks
21
+
22
+ from fairseq.dataclass import FairseqDataclass
23
+ from fairseq.models import BaseFairseqModel, register_model
24
+ from fairseq.models.roberta.model import RobertaClassificationHead
25
+
26
+ from examples.data2vec.data.modality import Modality
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class Data2VecTextClassificationConfig(FairseqDataclass):
34
+ pooler_dropout: float = 0.0
35
+ pooler_activation_fn: str = "tanh"
36
+ quant_noise_pq: int = 0
37
+ quant_noise_pq_block_size: int = 8
38
+ spectral_norm_classification_head: bool = False
39
+
40
+ model_path: str = MISSING
41
+ no_pretrained_weights: bool = False
42
+
43
+ pretrained_model_args: Any = None
44
+
45
+
46
+ @register_model(
47
+ "data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
48
+ )
49
+ class Data2VecTextClassificationModel(BaseFairseqModel):
50
+ def __init__(self, cfg: Data2VecTextClassificationConfig):
51
+ super().__init__()
52
+ self.cfg = cfg
53
+
54
+ if cfg.pretrained_model_args is None:
55
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
56
+ pretrained_args = state.get("cfg", None)
57
+ pretrained_args.criterion = None
58
+ pretrained_args.lr_scheduler = None
59
+ cfg.pretrained_model_args = pretrained_args
60
+
61
+ logger.info(pretrained_args)
62
+ else:
63
+ state = None
64
+ pretrained_args = cfg.pretrained_model_args
65
+
66
+ task = tasks.setup_task(pretrained_args.task)
67
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
68
+
69
+ model.remove_pretraining_modules()
70
+
71
+ self.model = model
72
+
73
+ if state is not None and not cfg.no_pretrained_weights:
74
+ self.load_model_weights(state, model, cfg)
75
+
76
+ self.classification_heads = nn.ModuleDict()
77
+
78
+
79
+ def load_model_weights(self, state, model, cfg):
80
+ for k in list(state["model"].keys()):
81
+ if (
82
+ k.startswith("shared_decoder") or
83
+ k.startswith("_ema") or
84
+ "decoder" in k
85
+ ):
86
+ logger.info(f"Deleting {k} from checkpoint")
87
+ del state["model"][k]
88
+ model.load_state_dict(state["model"], strict=True)
89
+
90
+ @classmethod
91
+ def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
92
+ """Build a new model instance."""
93
+
94
+ return cls(cfg)
95
+
96
+ def register_classification_head(
97
+ self, name, num_classes=None, inner_dim=None, **kwargs
98
+ ):
99
+ """Register a classification head."""
100
+ if name in self.classification_heads:
101
+ prev_num_classes = self.classification_heads[name].out_proj.out_features
102
+ prev_inner_dim = self.classification_heads[name].dense.out_features
103
+ if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
104
+ logger.warning(
105
+ 're-registering head "{}" with num_classes {} (prev: {}) '
106
+ "and inner_dim {} (prev: {})".format(
107
+ name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
108
+ )
109
+ )
110
+ embed_dim = self.cfg.pretrained_model_args.model.embed_dim
111
+ self.classification_heads[name] = RobertaClassificationHead(
112
+ input_dim=embed_dim,
113
+ inner_dim=inner_dim or embed_dim,
114
+ num_classes=num_classes,
115
+ activation_fn=self.cfg.pooler_activation_fn,
116
+ pooler_dropout=self.cfg.pooler_dropout,
117
+ q_noise=self.cfg.quant_noise_pq,
118
+ qn_block_size=self.cfg.quant_noise_pq_block_size,
119
+ do_spectral_norm=self.cfg.spectral_norm_classification_head,
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ source,
125
+ id,
126
+ padding_mask,
127
+ features_only=True,
128
+ remove_extra_tokens=True,
129
+ classification_head_name=None,
130
+ ):
131
+ encoder_out = self.model(
132
+ source,
133
+ id=id,
134
+ mode=Modality.TEXT,
135
+ padding_mask=padding_mask,
136
+ mask=False,
137
+ features_only=features_only,
138
+ remove_extra_tokens=remove_extra_tokens
139
+ )
140
+ logits = self.classification_heads[classification_head_name](encoder_out["x"])
141
+ return logits, encoder_out
fairseq/examples/data2vec/models/data2vec_vision.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # The code in this file is adapted from the BeiT implementation which can be found here:
7
+ # https://github.com/microsoft/unilm/tree/master/beit
8
+
9
+ import logging
10
+ import math
11
+ import numpy as np
12
+ import random
13
+
14
+ from dataclasses import dataclass, field
15
+ from typing import Optional
16
+
17
+ from omegaconf import II
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torch.distributed as dist
23
+
24
+ from fairseq.modules import EMAModule, EMAModuleConfig
25
+ from fairseq.dataclass import FairseqDataclass
26
+ from fairseq.models import BaseFairseqModel, register_model
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class Data2VecVisionConfig(FairseqDataclass):
34
+ layer_scale_init_value: float = field(
35
+ default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
36
+ )
37
+ num_mask_patches: int = field(
38
+ default=75,
39
+ metadata={"help": "number of the visual tokens/patches need be masked"},
40
+ )
41
+ min_mask_patches_per_block: int = 16
42
+ max_mask_patches_per_block: int = 196
43
+ image_size: int = 224
44
+ patch_size: int = 16
45
+ in_channels: int = 3
46
+
47
+ shared_rel_pos_bias: bool = True
48
+
49
+ drop_path: float = 0.1
50
+ attention_dropout: float = 0.0
51
+
52
+ depth: int = 12
53
+ embed_dim: int = 768
54
+ num_heads: int = 12
55
+ mlp_ratio: int = 4
56
+
57
+ loss_beta: float = field(
58
+ default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
59
+ )
60
+ loss_scale: Optional[float] = field(
61
+ default=None,
62
+ metadata={
63
+ "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
64
+ },
65
+ )
66
+ average_top_k_layers: int = field(
67
+ default=8, metadata={"help": "how many layers to average"}
68
+ )
69
+
70
+ end_of_block_targets: bool = True
71
+ layer_norm_target_layer: bool = False
72
+ instance_norm_target_layer: bool = False
73
+ batch_norm_target_layer: bool = False
74
+ instance_norm_targets: bool = False
75
+ layer_norm_targets: bool = False
76
+
77
+ ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
78
+ ema_end_decay: float = field(
79
+ default=0.9999, metadata={"help": "final ema decay rate"}
80
+ )
81
+
82
+ # when to finish annealing ema decay rate
83
+ ema_anneal_end_step: int = II("optimization.max_update")
84
+
85
+ ema_transformer_only: bool = field(
86
+ default=True,
87
+ metadata={"help": "whether to momentum update only the transformer layers"},
88
+ )
89
+
90
+
91
+ def get_annealed_rate(start, end, curr_step, total_steps):
92
+ r = end - start
93
+ pct_remaining = 1 - curr_step / total_steps
94
+ return end - r * pct_remaining
95
+
96
+
97
+ @register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
98
+ class Data2VecVisionModel(BaseFairseqModel):
99
+ def __init__(self, cfg: Data2VecVisionConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ self.ema = None
104
+
105
+ self.average_top_k_layers = cfg.average_top_k_layers
106
+ self.loss_beta = cfg.loss_beta
107
+ self.loss_scale = (
108
+ cfg.loss_scale
109
+ if cfg.loss_scale is not None
110
+ else 1 / math.sqrt(cfg.embed_dim)
111
+ )
112
+
113
+ self.patch_embed = PatchEmbed(
114
+ img_size=cfg.image_size,
115
+ patch_size=cfg.patch_size,
116
+ in_chans=cfg.in_channels,
117
+ embed_dim=cfg.embed_dim,
118
+ )
119
+
120
+ patch_size = self.patch_embed.patch_size
121
+ self.window_size = (
122
+ cfg.image_size // patch_size[0],
123
+ cfg.image_size // patch_size[1],
124
+ )
125
+
126
+ self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
127
+ self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
128
+
129
+ nn.init.trunc_normal_(self.cls_emb, 0.02)
130
+ nn.init.trunc_normal_(self.mask_emb, 0.02)
131
+
132
+ self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
133
+
134
+ self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
135
+ self.num_updates = 0
136
+
137
+ def make_ema_teacher(self):
138
+ ema_config = EMAModuleConfig(
139
+ ema_decay=self.cfg.ema_decay,
140
+ ema_fp32=True,
141
+ )
142
+ self.ema = EMAModule(
143
+ self.encoder if self.cfg.ema_transformer_only else self,
144
+ ema_config,
145
+ )
146
+
147
+ def set_num_updates(self, num_updates):
148
+ super().set_num_updates(num_updates)
149
+
150
+ if self.ema is None and self.final_proj is not None:
151
+ logger.info(f"making ema teacher")
152
+ self.make_ema_teacher()
153
+ elif self.training and self.ema is not None:
154
+ if self.cfg.ema_decay != self.cfg.ema_end_decay:
155
+ if num_updates >= self.cfg.ema_anneal_end_step:
156
+ decay = self.cfg.ema_end_decay
157
+ else:
158
+ decay = get_annealed_rate(
159
+ self.cfg.ema_decay,
160
+ self.cfg.ema_end_decay,
161
+ num_updates,
162
+ self.cfg.ema_anneal_end_step,
163
+ )
164
+ self.ema.set_decay(decay)
165
+ if self.ema.get_decay() < 1:
166
+ self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
167
+
168
+ self.num_updates = num_updates
169
+
170
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
171
+ state = super().state_dict(destination, prefix, keep_vars)
172
+
173
+ if self.ema is not None:
174
+ state[prefix + "_ema"] = self.ema.fp32_params
175
+
176
+ return state
177
+
178
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
179
+ if self.ema is not None:
180
+ k = prefix + "_ema"
181
+ assert k in state_dict
182
+ self.ema.restore(state_dict[k], True)
183
+ del state_dict[k]
184
+ return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
185
+
186
+ @classmethod
187
+ def build_model(cls, cfg: Data2VecVisionConfig, task=None):
188
+ """Build a new model instance."""
189
+
190
+ return cls(cfg)
191
+
192
+ def make_mask(self, bsz, num_masks, min_masks, max_masks):
193
+ height, width = self.window_size
194
+
195
+ masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
196
+
197
+ for i in range(bsz):
198
+ mask = masks[i]
199
+ mask_count = 0
200
+
201
+ min_aspect = 0.3
202
+ max_aspect = 1 / min_aspect
203
+ log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
204
+
205
+ def _mask(mask, max_mask_patches):
206
+ delta = 0
207
+ for attempt in range(10):
208
+ target_area = random.uniform(min_masks, max_mask_patches)
209
+ aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
210
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
211
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
212
+ if w < width and h < height:
213
+ top = random.randint(0, height - h)
214
+ left = random.randint(0, width - w)
215
+
216
+ num_masked = mask[top : top + h, left : left + w].sum()
217
+ # Overlap
218
+ if 0 < h * w - num_masked <= max_mask_patches:
219
+ for i in range(top, top + h):
220
+ for j in range(left, left + w):
221
+ if mask[i, j] == 0:
222
+ mask[i, j] = 1
223
+ delta += 1
224
+
225
+ if delta > 0:
226
+ break
227
+ return delta
228
+
229
+ while mask_count < num_masks:
230
+ max_mask_patches = min(num_masks - mask_count, max_masks)
231
+
232
+ delta = _mask(mask, max_mask_patches)
233
+ if delta == 0:
234
+ break
235
+ else:
236
+ mask_count += delta
237
+
238
+ return torch.from_numpy(masks)
239
+
240
+ def forward(
241
+ self,
242
+ img,
243
+ mask: bool = True,
244
+ layer_results: bool = False,
245
+ ):
246
+ x = self.patch_embed(img)
247
+ batch_size, seq_len, _ = x.size()
248
+
249
+ if mask:
250
+ mask_indices = self.make_mask(
251
+ img.size(0),
252
+ self.cfg.num_mask_patches,
253
+ self.cfg.min_mask_patches_per_block,
254
+ self.cfg.max_mask_patches_per_block,
255
+ )
256
+ bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
257
+ else:
258
+ mask_indices = bool_mask = None
259
+
260
+ cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
261
+ x = torch.cat((cls_tokens, x), dim=1)
262
+
263
+ if self.ema is not None:
264
+ with torch.no_grad():
265
+ self.ema.model.eval()
266
+
267
+ if self.cfg.ema_transformer_only:
268
+ y = self.ema.model(
269
+ x,
270
+ layer_results="end" if self.cfg.end_of_block_targets else "fc",
271
+ )
272
+ else:
273
+ y = self.ema.model(
274
+ img,
275
+ mask=False,
276
+ layer_results=True,
277
+ )
278
+
279
+ y = y[-self.cfg.average_top_k_layers :]
280
+
281
+ permuted = False
282
+ if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
283
+ y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
284
+ permuted = True
285
+
286
+ if self.cfg.batch_norm_target_layer:
287
+ y = [
288
+ F.batch_norm(
289
+ tl.float(), running_mean=None, running_var=None, training=True
290
+ )
291
+ for tl in y
292
+ ]
293
+
294
+ if self.cfg.instance_norm_target_layer:
295
+ y = [F.instance_norm(tl.float()) for tl in y]
296
+
297
+ if permuted:
298
+ y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
299
+
300
+ if self.cfg.layer_norm_target_layer:
301
+ y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
302
+
303
+ y = sum(y) / len(y)
304
+
305
+ if self.cfg.layer_norm_targets:
306
+ y = F.layer_norm(y.float(), y.shape[-1:])
307
+
308
+ if self.cfg.instance_norm_targets:
309
+ y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
310
+
311
+ y = y[bool_mask].float()
312
+
313
+ if mask_indices is not None:
314
+ mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
315
+ w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
316
+ x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
317
+
318
+ if layer_results:
319
+ enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
320
+ else:
321
+ enc_layer_results = None
322
+
323
+ x = self.encoder(x, layer_results=enc_layer_results)
324
+ if layer_results or mask_indices is None:
325
+ return x
326
+
327
+ x = x[bool_mask].float()
328
+
329
+ if self.loss_beta == 0:
330
+ loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
331
+ else:
332
+ loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
333
+ dim=-1
334
+ )
335
+
336
+ if self.loss_scale > 0:
337
+ loss = loss * self.loss_scale
338
+
339
+ result = {
340
+ "losses": {"regression": loss.sum()},
341
+ "sample_size": loss.numel(),
342
+ "target_var": self.compute_var(y),
343
+ "pred_var": self.compute_var(x),
344
+ "ema_decay": self.ema.get_decay() * 1000,
345
+ }
346
+ return result
347
+
348
+ @staticmethod
349
+ def compute_var(y):
350
+ y = y.view(-1, y.size(-1))
351
+ if dist.is_initialized():
352
+ zc = torch.tensor(y.size(0)).cuda()
353
+ zs = y.sum(dim=0)
354
+ zss = (y ** 2).sum(dim=0)
355
+
356
+ dist.all_reduce(zc)
357
+ dist.all_reduce(zs)
358
+ dist.all_reduce(zss)
359
+
360
+ var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
361
+ return torch.sqrt(var + 1e-6).mean()
362
+ else:
363
+ return torch.sqrt(y.var(dim=0) + 1e-6).mean()
364
+
365
+ def remove_pretraining_modules(self, last_layer=None):
366
+ self.final_proj = None
367
+ self.ema = None
368
+ self.encoder.norm = nn.Identity()
369
+ self.mask_emb = None
370
+ if last_layer is not None:
371
+ self.encoder.layers = nn.ModuleList(
372
+ l for i, l in enumerate(self.encoder.layers) if i <= last_layer
373
+ )
374
+
375
+
376
+ class PatchEmbed(nn.Module):
377
+ """Image to Patch Embedding"""
378
+
379
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
380
+ super().__init__()
381
+ if isinstance(img_size, int):
382
+ img_size = img_size, img_size
383
+ if isinstance(patch_size, int):
384
+ patch_size = patch_size, patch_size
385
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
386
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
387
+ self.img_size = img_size
388
+ self.patch_size = patch_size
389
+ self.num_patches = num_patches
390
+
391
+ self.conv = nn.Conv2d(
392
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
393
+ )
394
+
395
+ def forward(self, x):
396
+ # BCHW -> BTC
397
+ x = self.conv(x).flatten(2).transpose(1, 2)
398
+ return x
399
+
400
+
401
+ class Attention(nn.Module):
402
+ def __init__(
403
+ self,
404
+ dim,
405
+ num_heads=8,
406
+ qkv_bias=True,
407
+ attn_drop=0.0,
408
+ proj_drop=0.0,
409
+ window_size=None,
410
+ attn_head_dim=None,
411
+ ):
412
+ super().__init__()
413
+ self.num_heads = num_heads
414
+ head_dim = dim // num_heads
415
+ if attn_head_dim is not None:
416
+ head_dim = attn_head_dim
417
+ all_head_dim = head_dim * self.num_heads
418
+ self.scale = head_dim ** -0.5
419
+
420
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
421
+ if qkv_bias:
422
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
423
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
424
+ else:
425
+ self.q_bias = None
426
+ self.v_bias = None
427
+
428
+ if window_size:
429
+ self.window_size = window_size
430
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
431
+ 2 * window_size[1] - 1
432
+ ) + 3
433
+ self.relative_position_bias_table = nn.Parameter(
434
+ torch.zeros(self.num_relative_distance, num_heads)
435
+ ) # 2*Wh-1 * 2*Ww-1, nH
436
+ # cls to token & token 2 cls & cls to cls
437
+
438
+ # get pair-wise relative position index for each token inside the window
439
+ coords_h = torch.arange(window_size[0])
440
+ coords_w = torch.arange(window_size[1])
441
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
442
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
443
+ relative_coords = (
444
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
445
+ ) # 2, Wh*Ww, Wh*Ww
446
+ relative_coords = relative_coords.permute(
447
+ 1, 2, 0
448
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
449
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
450
+ relative_coords[:, :, 1] += window_size[1] - 1
451
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
452
+ relative_position_index = torch.zeros(
453
+ size=(window_size[0] * window_size[1] + 1,) * 2,
454
+ dtype=relative_coords.dtype,
455
+ )
456
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
457
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
458
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
459
+ relative_position_index[0, 0] = self.num_relative_distance - 1
460
+
461
+ self.register_buffer("relative_position_index", relative_position_index)
462
+ else:
463
+ self.window_size = None
464
+ self.relative_position_bias_table = None
465
+ self.relative_position_index = None
466
+
467
+ self.attn_drop = nn.Dropout(attn_drop)
468
+ self.proj = nn.Linear(all_head_dim, dim)
469
+ self.proj_drop = nn.Dropout(proj_drop)
470
+
471
+ def forward(self, x, rel_pos_bias=None):
472
+ B, N, C = x.shape
473
+ qkv_bias = None
474
+ if self.q_bias is not None:
475
+ qkv_bias = torch.cat(
476
+ (
477
+ self.q_bias,
478
+ torch.zeros_like(self.v_bias, requires_grad=False),
479
+ self.v_bias,
480
+ )
481
+ )
482
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
483
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
484
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
485
+ q, k, v = (
486
+ qkv[0],
487
+ qkv[1],
488
+ qkv[2],
489
+ ) # make torchscript happy (cannot use tensor as tuple)
490
+
491
+ q = q * self.scale
492
+ attn = q @ k.transpose(-2, -1)
493
+
494
+ if self.relative_position_bias_table is not None:
495
+ assert 1==2
496
+ relative_position_bias = self.relative_position_bias_table[
497
+ self.relative_position_index.view(-1)
498
+ ].view(
499
+ self.window_size[0] * self.window_size[1] + 1,
500
+ self.window_size[0] * self.window_size[1] + 1,
501
+ -1,
502
+ ) # Wh*Ww,Wh*Ww,nH
503
+ relative_position_bias = relative_position_bias.permute(
504
+ 2, 0, 1
505
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
506
+ attn = attn + relative_position_bias.unsqueeze(0)
507
+ print("attn.size() :", attn.size())
508
+ print("rel_pos_bias.size() :", rel_pos_bias.size())
509
+ if rel_pos_bias is not None:
510
+ attn = attn + rel_pos_bias
511
+ attn = attn.softmax(dim=-1)
512
+ attn = self.attn_drop(attn)
513
+
514
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
515
+ x = self.proj(x)
516
+ x = self.proj_drop(x)
517
+ return x
518
+
519
+
520
+ class RelativePositionBias(nn.Module):
521
+ def __init__(self, window_size, num_heads):
522
+ super().__init__()
523
+ self.window_size = window_size
524
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
525
+ 2 * window_size[1] - 1
526
+ ) + 3
527
+ self.relative_position_bias_table = nn.Parameter(
528
+ torch.zeros(self.num_relative_distance, num_heads)
529
+ )
530
+
531
+ # get pair-wise relative position index for each token inside the window
532
+ coords_h = torch.arange(window_size[0])
533
+ coords_w = torch.arange(window_size[1])
534
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
535
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
536
+ relative_coords = (
537
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
538
+ ) # 2, Wh*Ww, Wh*Ww
539
+ relative_coords = relative_coords.permute(
540
+ 1, 2, 0
541
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
542
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
543
+ relative_coords[:, :, 1] += window_size[1] - 1
544
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
545
+ relative_position_index = torch.zeros(
546
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
547
+ )
548
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
549
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
550
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
551
+ relative_position_index[0, 0] = self.num_relative_distance - 1
552
+
553
+ self.register_buffer("relative_position_index", relative_position_index)
554
+
555
+ def forward(self):
556
+ relative_position_bias = self.relative_position_bias_table[
557
+ self.relative_position_index.view(-1)
558
+ ].view(
559
+ self.window_size[0] * self.window_size[1] + 1,
560
+ self.window_size[0] * self.window_size[1] + 1,
561
+ -1,
562
+ ) # Wh*Ww,Wh*Ww,nH
563
+ print("self.window_size :", self.window_size)
564
+ print("self.num_relative_distance :", self.num_relative_distance)
565
+ print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
566
+ print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
567
+ print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
568
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
569
+
570
+
571
+ class DropPath(nn.Module):
572
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
573
+
574
+ def __init__(self, drop_prob=None):
575
+ super(DropPath, self).__init__()
576
+ self.drop_prob = drop_prob
577
+
578
+ def forward(self, x):
579
+ if self.drop_prob == 0.0 or not self.training:
580
+ return x
581
+ keep_prob = 1 - self.drop_prob
582
+ shape = (x.shape[0],) + (1,) * (
583
+ x.ndim - 1
584
+ ) # work with diff dim tensors, not just 2D ConvNets
585
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
586
+ random_tensor.floor_()
587
+ output = x.div(keep_prob) * random_tensor
588
+ return output
589
+
590
+ def extra_repr(self) -> str:
591
+ return "p={}".format(self.drop_prob)
592
+
593
+
594
+ class Block(nn.Module):
595
+ def __init__(
596
+ self,
597
+ dim,
598
+ num_heads,
599
+ mlp_ratio=4.0,
600
+ drop=0.0,
601
+ attn_drop=0.0,
602
+ drop_path=0.0,
603
+ init_values=None,
604
+ window_size=None,
605
+ ):
606
+ super().__init__()
607
+
608
+ self.norm1 = nn.LayerNorm(dim)
609
+ self.attn = Attention(
610
+ dim,
611
+ num_heads=num_heads,
612
+ attn_drop=attn_drop,
613
+ proj_drop=drop,
614
+ window_size=window_size,
615
+ )
616
+
617
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
618
+ self.norm2 = nn.LayerNorm(dim)
619
+ mlp_hidden_dim = int(dim * mlp_ratio)
620
+
621
+ self.mlp = nn.Sequential(
622
+ nn.Linear(dim, mlp_hidden_dim),
623
+ nn.GELU(),
624
+ nn.Linear(mlp_hidden_dim, dim),
625
+ nn.Dropout(drop),
626
+ )
627
+
628
+ if init_values > 0:
629
+ self.gamma_1 = nn.Parameter(
630
+ init_values * torch.ones((dim)), requires_grad=True
631
+ )
632
+ self.gamma_2 = nn.Parameter(
633
+ init_values * torch.ones((dim)), requires_grad=True
634
+ )
635
+ else:
636
+ self.gamma_1, self.gamma_2 = None, None
637
+
638
+ def forward(self, x, rel_pos_bias=None):
639
+ print("inside block :", x.size())
640
+ if self.gamma_1 is None:
641
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
642
+ fc_feature = self.drop_path(self.mlp(self.norm2(x)))
643
+ x = x + fc_feature
644
+ else:
645
+ x = x + self.drop_path(
646
+ self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
647
+ )
648
+ fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
649
+ x = x + fc_feature
650
+ return x, fc_feature
651
+
652
+
653
+ class TransformerEncoder(nn.Module):
654
+ def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
655
+ super().__init__()
656
+
657
+ self.rel_pos_bias = None
658
+ if cfg.shared_rel_pos_bias:
659
+ self.rel_pos_bias = RelativePositionBias(
660
+ window_size=patch_shape, num_heads=cfg.num_heads
661
+ )
662
+
663
+ dpr = [
664
+ x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
665
+ ] # stochastic depth decay rule
666
+
667
+ print("TransformerEncoder > patch_shape :", patch_shape)
668
+ self.blocks = nn.ModuleList(
669
+ Block(
670
+ dim=cfg.embed_dim,
671
+ num_heads=cfg.num_heads,
672
+ attn_drop=cfg.attention_dropout,
673
+ drop_path=dpr[i],
674
+ init_values=cfg.layer_scale_init_value,
675
+ window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
676
+ )
677
+ for i in range(cfg.depth)
678
+ )
679
+
680
+ self.norm = nn.LayerNorm(cfg.embed_dim)
681
+
682
+ self.apply(self.init_weights)
683
+ self.fix_init_weight()
684
+
685
+ def init_weights(self, m):
686
+ std = 0.02
687
+ if isinstance(m, nn.Linear):
688
+ nn.init.trunc_normal_(m.weight, std=std)
689
+ if isinstance(m, nn.Linear) and m.bias is not None:
690
+ nn.init.constant_(m.bias, 0)
691
+ elif isinstance(m, nn.LayerNorm):
692
+ nn.init.constant_(m.bias, 0)
693
+ nn.init.constant_(m.weight, 1.0)
694
+ elif isinstance(m, nn.Conv2d):
695
+ nn.init.trunc_normal_(m.weight, std=std)
696
+ if m.bias is not None:
697
+ nn.init.constant_(m.bias, 0)
698
+
699
+ def fix_init_weight(self):
700
+ def rescale(param, layer_id):
701
+ param.div_(math.sqrt(2.0 * layer_id))
702
+
703
+ for layer_id, layer in enumerate(self.blocks):
704
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
705
+ rescale(layer.mlp[2].weight.data, layer_id + 1)
706
+
707
+ def extract_features(self, x, layer_results):
708
+
709
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
710
+
711
+ z = []
712
+ for i, blk in enumerate(self.blocks):
713
+ x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
714
+ if layer_results == "end":
715
+ z.append(x)
716
+ elif layer_results == "fc":
717
+ z.append(fc_feature)
718
+
719
+ return z if layer_results else self.norm(x)
720
+
721
+ def forward(self, x, layer_results=None):
722
+ x = self.extract_features(x, layer_results=layer_results)
723
+ if layer_results:
724
+ return [z[:, 1:] for z in x]
725
+
726
+ x = x[:, 1:]
727
+ return x
fairseq/examples/data2vec/models/mae.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # The code in this file is adapted from the BeiT implementation which can be found here:
7
+ # https://github.com/microsoft/unilm/tree/master/beit
8
+
9
+ import logging
10
+ from dataclasses import dataclass
11
+ from functools import partial
12
+
13
+ from timm.models.vision_transformer import PatchEmbed, Block
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ import numpy as np
19
+
20
+ from fairseq.dataclass import FairseqDataclass
21
+ from fairseq.models import BaseFairseqModel, register_model
22
+ from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
23
+
24
+ try:
25
+ from apex.normalization import FusedLayerNorm
26
+ except:
27
+ FusedLayerNorm = nn.LayerNorm
28
+
29
+ import torch.nn.functional as F
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @dataclass
36
+ class MaeConfig(FairseqDataclass):
37
+ input_size: int = 224
38
+ in_chans: int = 3
39
+ patch_size: int = 16
40
+ embed_dim: int = 768
41
+ depth: int = 12
42
+ num_heads: int = 12
43
+ decoder_embed_dim: int = 512
44
+ decoder_depth: int = 8
45
+ decoder_num_heads: int = 16
46
+ mlp_ratio: int = 4
47
+ norm_eps: float = 1e-6
48
+
49
+ drop_path_rate: float = 0.0
50
+
51
+ mask_ratio: float = 0.75
52
+ norm_pix_loss: bool = True
53
+
54
+ w2v_block: bool = False
55
+ alt_block: bool = False
56
+ alt_block2: bool = False
57
+ alt_attention: bool = False
58
+ block_dropout: float = 0
59
+ attention_dropout: float = 0
60
+ activation_dropout: float = 0
61
+ layer_norm_first: bool = False
62
+
63
+ fused_ln: bool = True
64
+ end_of_block_targets: bool = True
65
+
66
+ no_decoder_embed: bool = False
67
+ no_decoder_pos_embed: bool = False
68
+ mask_noise_std: float = 0
69
+
70
+ single_qkv: bool = False
71
+ use_rel_pos_bias: bool = False
72
+ no_cls: bool = False
73
+
74
+
75
+ def modify_relative_position_bias(orig_bias, bsz, mask):
76
+ if mask is None:
77
+ return orig_bias.unsqueeze(0).repeat(
78
+ bsz, 1, 1, 1
79
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
80
+ heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
81
+ mask_for_rel_pos_bias = torch.cat(
82
+ (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
83
+ ).bool() # bsz x seqlen (add CLS token)
84
+ unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
85
+ unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
86
+ 1, heads, 1
87
+ ) # bsz x seq_len => bsz x heads x seq_len
88
+ b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
89
+ bsz, 1, 1, 1
90
+ ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
91
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
92
+ unmasked_for_rel_pos_bias.unsqueeze(-1)
93
+ )
94
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
95
+ new_len = b_t_t_rel_pos_bias.size(-2)
96
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
97
+ unmasked_for_rel_pos_bias.unsqueeze(-2)
98
+ )
99
+ b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
100
+ return b_t_t_rel_pos_bias
101
+
102
+
103
+ class AltBlock(nn.Module):
104
+ def __init__(
105
+ self,
106
+ dim,
107
+ num_heads,
108
+ mlp_ratio=4.0,
109
+ qkv_bias=False,
110
+ qk_scale=None,
111
+ drop=0.0,
112
+ attn_drop=0.0,
113
+ drop_path=0.0,
114
+ act_layer=nn.GELU,
115
+ norm_layer=nn.LayerNorm,
116
+ layer_norm_first=True,
117
+ ffn_targets=False,
118
+ use_rel_pos_bias=False,
119
+ window_size=None,
120
+ alt_attention=False,
121
+ ):
122
+ super().__init__()
123
+
124
+ self.layer_norm_first = layer_norm_first
125
+ self.ffn_targets = ffn_targets
126
+
127
+ from timm.models.vision_transformer import Attention, DropPath, Mlp
128
+
129
+ self.norm1 = norm_layer(dim)
130
+ self.use_rel_pos_bias = use_rel_pos_bias
131
+ if use_rel_pos_bias:
132
+ self.attn = AltAttention(
133
+ dim,
134
+ num_heads=num_heads,
135
+ qkv_bias=qkv_bias,
136
+ qk_scale=qk_scale,
137
+ attn_drop=attn_drop,
138
+ proj_drop=drop,
139
+ window_size=window_size,
140
+ )
141
+ else:
142
+ if alt_attention:
143
+ from .multi.modules import AltAttention as AltAttention2
144
+ self.attn = AltAttention2(
145
+ dim,
146
+ num_heads=num_heads,
147
+ qkv_bias=qkv_bias,
148
+ qk_scale=qk_scale,
149
+ attn_drop=attn_drop,
150
+ proj_drop=drop,
151
+ )
152
+ else:
153
+ self.attn = Attention(
154
+ dim,
155
+ num_heads=num_heads,
156
+ qkv_bias=qkv_bias,
157
+ qk_scale=qk_scale,
158
+ attn_drop=attn_drop,
159
+ proj_drop=drop,
160
+ )
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(
166
+ in_features=dim,
167
+ hidden_features=mlp_hidden_dim,
168
+ act_layer=act_layer,
169
+ drop=drop,
170
+ )
171
+
172
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
173
+ if self.layer_norm_first:
174
+ if self.use_rel_pos_bias:
175
+ x = x + self.drop_path(
176
+ self.attn(
177
+ self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
178
+ )
179
+ )
180
+ else:
181
+ x = x + self.drop_path(self.attn(self.norm1(x)))
182
+ t = self.mlp(self.norm2(x))
183
+ x = x + self.drop_path(t)
184
+ if not self.ffn_targets:
185
+ t = x
186
+ return x, t
187
+ else:
188
+ if self.use_rel_pos_bias:
189
+ x = x + self.drop_path(
190
+ self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
191
+ )
192
+ else:
193
+ x = x + self.drop_path(self.attn(x))
194
+ r = x = self.norm1(x)
195
+ x = self.mlp(x)
196
+ t = x
197
+ x = self.norm2(r + self.drop_path(x))
198
+ if not self.ffn_targets:
199
+ t = x
200
+ return x, t
201
+
202
+
203
+ class AltAttention(nn.Module):
204
+ def __init__(
205
+ self,
206
+ dim,
207
+ num_heads=8,
208
+ qkv_bias=True,
209
+ qk_scale=None,
210
+ attn_drop=0.0,
211
+ proj_drop=0.0,
212
+ window_size=None,
213
+ attn_head_dim=None,
214
+ ):
215
+ super().__init__()
216
+ self.num_heads = num_heads
217
+ head_dim = dim // num_heads
218
+ if attn_head_dim is not None:
219
+ head_dim = attn_head_dim
220
+ all_head_dim = head_dim * self.num_heads
221
+ self.scale = qk_scale or head_dim ** -0.5
222
+
223
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
224
+ if qkv_bias:
225
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
226
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
227
+ else:
228
+ self.q_bias = None
229
+ self.v_bias = None
230
+
231
+ if window_size:
232
+ self.window_size = window_size
233
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
234
+ 2 * window_size[1] - 1
235
+ ) + 3
236
+ self.relative_position_bias_table = nn.Parameter(
237
+ torch.zeros(self.num_relative_distance, num_heads)
238
+ ) # 2*Wh-1 * 2*Ww-1, nH
239
+ # cls to token & token 2 cls & cls to cls
240
+
241
+ # get pair-wise relative position index for each token inside the window
242
+ coords_h = torch.arange(window_size[0])
243
+ coords_w = torch.arange(window_size[1])
244
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
245
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
246
+ relative_coords = (
247
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
248
+ ) # 2, Wh*Ww, Wh*Ww
249
+ relative_coords = relative_coords.permute(
250
+ 1, 2, 0
251
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
252
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
253
+ relative_coords[:, :, 1] += window_size[1] - 1
254
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
255
+ relative_position_index = torch.zeros(
256
+ size=(window_size[0] * window_size[1] + 1,) * 2,
257
+ dtype=relative_coords.dtype,
258
+ )
259
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
260
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
261
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
262
+ relative_position_index[0, 0] = self.num_relative_distance - 1
263
+
264
+ self.register_buffer("relative_position_index", relative_position_index)
265
+ else:
266
+ self.window_size = None
267
+ self.relative_position_bias_table = None
268
+ self.relative_position_index = None
269
+
270
+ self.attn_drop = nn.Dropout(attn_drop)
271
+ self.proj = nn.Linear(all_head_dim, dim)
272
+ self.proj_drop = nn.Dropout(proj_drop)
273
+
274
+ def forward(self, x, rel_pos_bias=None, pos_mask=None):
275
+ B, N, C = x.shape
276
+ qkv_bias = None
277
+ if self.q_bias is not None:
278
+ qkv_bias = torch.cat(
279
+ (
280
+ self.q_bias,
281
+ torch.zeros_like(self.v_bias, requires_grad=False),
282
+ self.v_bias,
283
+ )
284
+ )
285
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
286
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
287
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
288
+ q, k, v = (
289
+ qkv[0],
290
+ qkv[1],
291
+ qkv[2],
292
+ ) # make torchscript happy (cannot use tensor as tuple)
293
+
294
+ q = q * self.scale
295
+ attn = q @ k.transpose(-2, -1)
296
+
297
+ if self.relative_position_bias_table is not None:
298
+ relative_position_bias = self.relative_position_bias_table[
299
+ self.relative_position_index.view(-1)
300
+ ].view(
301
+ self.window_size[0] * self.window_size[1] + 1,
302
+ self.window_size[0] * self.window_size[1] + 1,
303
+ -1,
304
+ ) # Wh*Ww,Wh*Ww,nH
305
+ relative_position_bias = relative_position_bias.permute(
306
+ 2, 0, 1
307
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
308
+ attn = attn + modify_relative_position_bias(
309
+ relative_position_bias, x.size(0), pos_mask
310
+ )
311
+
312
+ if rel_pos_bias is not None:
313
+ attn = attn + rel_pos_bias
314
+
315
+ attn = attn.softmax(dim=-1)
316
+ attn = self.attn_drop(attn)
317
+
318
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
319
+ x = self.proj(x)
320
+ x = self.proj_drop(x)
321
+ return x
322
+
323
+
324
+ class RelativePositionBias(nn.Module):
325
+ def __init__(self, window_size, num_heads):
326
+ super().__init__()
327
+ self.window_size = window_size
328
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
329
+ 2 * window_size[1] - 1
330
+ ) + 3
331
+ self.relative_position_bias_table = nn.Parameter(
332
+ torch.zeros(self.num_relative_distance, num_heads)
333
+ )
334
+
335
+ # get pair-wise relative position index for each token inside the window
336
+ coords_h = torch.arange(window_size[0])
337
+ coords_w = torch.arange(window_size[1])
338
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
339
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
340
+ relative_coords = (
341
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
342
+ ) # 2, Wh*Ww, Wh*Ww
343
+ relative_coords = relative_coords.permute(
344
+ 1, 2, 0
345
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
346
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
347
+ relative_coords[:, :, 1] += window_size[1] - 1
348
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
349
+ relative_position_index = torch.zeros(
350
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
351
+ )
352
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
353
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
354
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
355
+ relative_position_index[0, 0] = self.num_relative_distance - 1
356
+
357
+ self.register_buffer("relative_position_index", relative_position_index)
358
+
359
+ def forward(self):
360
+ relative_position_bias = self.relative_position_bias_table[
361
+ self.relative_position_index.view(-1)
362
+ ].view(
363
+ self.window_size[0] * self.window_size[1] + 1,
364
+ self.window_size[0] * self.window_size[1] + 1,
365
+ -1,
366
+ ) # Wh*Ww,Wh*Ww,nH
367
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+
369
+
370
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
371
+ """
372
+ grid_size: int of the grid height and width
373
+ return:
374
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
375
+ """
376
+ grid_h = np.arange(grid_size, dtype=np.float32)
377
+ grid_w = np.arange(grid_size, dtype=np.float32)
378
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
379
+ grid = np.stack(grid, axis=0)
380
+
381
+ grid = grid.reshape([2, 1, grid_size, grid_size])
382
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
383
+ if cls_token:
384
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
385
+ return pos_embed
386
+
387
+
388
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
389
+ assert embed_dim % 2 == 0
390
+
391
+ # use half of dimensions to encode grid_h
392
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
393
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
394
+
395
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
396
+ return emb
397
+
398
+
399
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
400
+ """
401
+ embed_dim: output dimension for each position
402
+ pos: a list of positions to be encoded: size (M,)
403
+ out: (M, D)
404
+ """
405
+ assert embed_dim % 2 == 0
406
+ omega = np.arange(embed_dim // 2, dtype=np.float)
407
+ omega /= embed_dim / 2.0
408
+ omega = 1.0 / 10000 ** omega # (D/2,)
409
+
410
+ pos = pos.reshape(-1) # (M,)
411
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
412
+
413
+ emb_sin = np.sin(out) # (M, D/2)
414
+ emb_cos = np.cos(out) # (M, D/2)
415
+
416
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
417
+ return emb
418
+
419
+
420
+ def interpolate_pos_embed(model, checkpoint_model):
421
+ if "pos_embed" in checkpoint_model:
422
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
423
+ embedding_size = pos_embed_checkpoint.shape[-1]
424
+ num_patches = model.patch_embed.num_patches
425
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
426
+ # height (== width) for the checkpoint position embedding
427
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
428
+ # height (== width) for the new position embedding
429
+ new_size = int(num_patches ** 0.5)
430
+ # class_token and dist_token are kept unchanged
431
+ if orig_size != new_size:
432
+ print(
433
+ "Position interpolate from %dx%d to %dx%d"
434
+ % (orig_size, orig_size, new_size, new_size)
435
+ )
436
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
437
+ # only the position tokens are interpolated
438
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
439
+ pos_tokens = pos_tokens.reshape(
440
+ -1, orig_size, orig_size, embedding_size
441
+ ).permute(0, 3, 1, 2)
442
+ pos_tokens = torch.nn.functional.interpolate(
443
+ pos_tokens,
444
+ size=(new_size, new_size),
445
+ mode="bicubic",
446
+ align_corners=False,
447
+ )
448
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
449
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
450
+ checkpoint_model["pos_embed"] = new_pos_embed
451
+
452
+
453
+ @register_model("mae", dataclass=MaeConfig)
454
+ class MaeModel(BaseFairseqModel):
455
+ def __init__(self, cfg: MaeConfig):
456
+ super().__init__()
457
+ self.cfg = cfg
458
+
459
+ self.mask_ratio = cfg.mask_ratio
460
+
461
+ # --------------------------------------------------------------------------
462
+ # MAE encoder specifics
463
+ self.patch_embed = PatchEmbed(
464
+ cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
465
+ )
466
+ num_patches = self.patch_embed.num_patches
467
+
468
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
469
+ self.pos_embed = nn.Parameter(
470
+ torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
471
+ ) # fixed sin-cos embedding
472
+
473
+ norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
474
+
475
+ dpr = [
476
+ x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
477
+ ] # stochastic depth decay rule
478
+
479
+ def make_block(drop_path):
480
+ if cfg.w2v_block:
481
+ return TransformerSentenceEncoderLayer(
482
+ embedding_dim=cfg.embed_dim,
483
+ ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
484
+ num_attention_heads=cfg.num_heads,
485
+ dropout=cfg.block_dropout,
486
+ attention_dropout=cfg.attention_dropout,
487
+ activation_dropout=cfg.activation_dropout,
488
+ activation_fn="gelu",
489
+ layer_norm_first=cfg.layer_norm_first,
490
+ drop_path=drop_path,
491
+ norm_eps=1e-6,
492
+ single_qkv=cfg.single_qkv,
493
+ fused_ln=cfg.fused_ln,
494
+ )
495
+ elif cfg.alt_block:
496
+ window_size = (
497
+ cfg.input_size // self.patch_embed.patch_size[0],
498
+ cfg.input_size // self.patch_embed.patch_size[1],
499
+ )
500
+ return AltBlock(
501
+ cfg.embed_dim,
502
+ cfg.num_heads,
503
+ cfg.mlp_ratio,
504
+ qkv_bias=True,
505
+ qk_scale=None,
506
+ norm_layer=norm_layer,
507
+ drop_path=drop_path,
508
+ layer_norm_first=cfg.layer_norm_first,
509
+ ffn_targets=not cfg.end_of_block_targets,
510
+ use_rel_pos_bias=cfg.use_rel_pos_bias,
511
+ window_size=window_size
512
+ if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
513
+ else None,
514
+ alt_attention=cfg.alt_attention,
515
+ )
516
+ elif cfg.alt_block2:
517
+ from .multi.modules import AltBlock as AltBlock2
518
+ return AltBlock2(
519
+ cfg.embed_dim,
520
+ cfg.num_heads,
521
+ cfg.mlp_ratio,
522
+ qkv_bias=True,
523
+ qk_scale=None,
524
+ norm_layer=norm_layer,
525
+ drop_path=drop_path,
526
+ layer_norm_first=cfg.layer_norm_first,
527
+ ffn_targets=not cfg.end_of_block_targets,
528
+ )
529
+ else:
530
+ return Block(
531
+ cfg.embed_dim,
532
+ cfg.num_heads,
533
+ cfg.mlp_ratio,
534
+ qkv_bias=True,
535
+ qk_scale=None,
536
+ norm_layer=norm_layer,
537
+ drop_path=drop_path,
538
+ )
539
+
540
+ self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
541
+ self.norm = norm_layer(cfg.embed_dim)
542
+ # --------------------------------------------------------------------------
543
+
544
+ # --------------------------------------------------------------------------
545
+ # MAE decoder specifics
546
+ self.decoder_embed = (
547
+ nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
548
+ if not cfg.no_decoder_embed
549
+ else None
550
+ )
551
+
552
+ self.mask_token = (
553
+ nn.Parameter(
554
+ torch.zeros(
555
+ 1,
556
+ 1,
557
+ cfg.decoder_embed_dim
558
+ if not cfg.no_decoder_embed
559
+ else cfg.embed_dim,
560
+ )
561
+ )
562
+ if cfg.mask_noise_std <= 0
563
+ else None
564
+ )
565
+
566
+ self.decoder_pos_embed = (
567
+ nn.Parameter(
568
+ torch.zeros(
569
+ 1,
570
+ num_patches + 1,
571
+ cfg.decoder_embed_dim
572
+ if not cfg.no_decoder_embed
573
+ else cfg.embed_dim,
574
+ ),
575
+ requires_grad=False,
576
+ )
577
+ if not cfg.no_decoder_pos_embed
578
+ else None
579
+ )
580
+
581
+ self.decoder_blocks = nn.ModuleList(
582
+ [
583
+ Block(
584
+ cfg.decoder_embed_dim,
585
+ cfg.decoder_num_heads,
586
+ cfg.mlp_ratio,
587
+ qkv_bias=True,
588
+ qk_scale=None,
589
+ norm_layer=norm_layer,
590
+ )
591
+ for _ in range(cfg.decoder_depth)
592
+ ]
593
+ )
594
+
595
+ self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
596
+ self.decoder_pred = nn.Linear(
597
+ cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
598
+ ) # decoder to patch
599
+ # --------------------------------------------------------------------------
600
+
601
+ self.norm_pix_loss = cfg.norm_pix_loss
602
+
603
+ self.initialize_weights()
604
+
605
+ for pn, p in self.named_parameters():
606
+ if len(p.shape) == 1 or pn.endswith(".bias"):
607
+ p.param_group = "no_decay"
608
+ else:
609
+ p.param_group = "with_decay"
610
+
611
+ def initialize_weights(self):
612
+ # initialization
613
+ # initialize (and freeze) pos_embed by sin-cos embedding
614
+ pos_embed = get_2d_sincos_pos_embed(
615
+ self.pos_embed.shape[-1],
616
+ int(self.patch_embed.num_patches ** 0.5),
617
+ cls_token=not self.cfg.no_cls,
618
+ )
619
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
620
+
621
+ if self.decoder_pos_embed is not None:
622
+ decoder_pos_embed = get_2d_sincos_pos_embed(
623
+ self.decoder_pos_embed.shape[-1],
624
+ int(self.patch_embed.num_patches ** 0.5),
625
+ cls_token=not self.cfg.no_cls,
626
+ )
627
+ self.decoder_pos_embed.data.copy_(
628
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
629
+ )
630
+
631
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
632
+ w = self.patch_embed.proj.weight.data
633
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
634
+
635
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
636
+ if self.cls_token is not None:
637
+ torch.nn.init.normal_(self.cls_token, std=0.02)
638
+
639
+ if self.mask_token is not None:
640
+ torch.nn.init.normal_(self.mask_token, std=0.02)
641
+
642
+ # initialize nn.Linear and nn.LayerNorm
643
+ self.apply(self._init_weights)
644
+
645
+ def _init_weights(self, m):
646
+ if isinstance(m, nn.Linear):
647
+ # we use xavier_uniform following official JAX ViT:
648
+ torch.nn.init.xavier_uniform_(m.weight)
649
+ if isinstance(m, nn.Linear) and m.bias is not None:
650
+ nn.init.constant_(m.bias, 0)
651
+ elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
652
+ nn.init.constant_(m.bias, 0)
653
+ nn.init.constant_(m.weight, 1.0)
654
+
655
+ def patchify(self, imgs):
656
+ """
657
+ imgs: (N, 3, H, W)
658
+ x: (N, L, patch_size**2 *3)
659
+ """
660
+ p = self.patch_embed.patch_size[0]
661
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
662
+
663
+ h = w = imgs.shape[2] // p
664
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
665
+ x = torch.einsum("nchpwq->nhwpqc", x)
666
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
667
+ return x
668
+
669
+ def unpatchify(self, x):
670
+ """
671
+ x: (N, L, patch_size**2 *3)
672
+ imgs: (N, 3, H, W)
673
+ """
674
+ p = self.patch_embed.patch_size[0]
675
+ h = w = int(x.shape[1] ** 0.5)
676
+ assert h * w == x.shape[1]
677
+
678
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
679
+ x = torch.einsum("nhwpqc->nchpwq", x)
680
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
681
+ return imgs
682
+
683
+ def random_masking(self, x, mask_ratio):
684
+ """
685
+ Perform per-sample random masking by per-sample shuffling.
686
+ Per-sample shuffling is done by argsort random noise.
687
+ x: [N, L, D], sequence
688
+ """
689
+ N, L, D = x.shape # batch, length, dim
690
+ len_keep = int(L * (1 - mask_ratio))
691
+
692
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
693
+
694
+ # sort noise for each sample
695
+ ids_shuffle = torch.argsort(
696
+ noise, dim=1
697
+ ) # ascend: small is keep, large is remove
698
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
699
+
700
+ # keep the first subset
701
+ ids_keep = ids_shuffle[:, :len_keep]
702
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
703
+
704
+ # generate the binary mask: 0 is keep, 1 is remove
705
+ mask = torch.ones([N, L], device=x.device)
706
+ mask[:, :len_keep] = 0
707
+ # unshuffle to get the binary mask
708
+ mask = torch.gather(mask, dim=1, index=ids_restore)
709
+
710
+ return x_masked, mask, ids_restore # x_masked is actually unmasked x
711
+
712
+ @classmethod
713
+ def build_model(cls, cfg: MaeConfig, task=None):
714
+ """Build a new model instance."""
715
+
716
+ return cls(cfg)
717
+
718
+ def forward_encoder(self, x, mask_ratio):
719
+ # embed patches
720
+ x = self.patch_embed(x)
721
+
722
+ # add pos embed w/o cls token
723
+ # if self.cls_token is not None:
724
+ # x = x + self.pos_embed
725
+ # else:
726
+ x = x + self.pos_embed[:, 1:, :]
727
+
728
+ # masking: length -> length * mask_ratio
729
+ if mask_ratio > 0:
730
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
731
+ else:
732
+ mask = ids_restore = None
733
+
734
+ # append cls token
735
+ if self.cls_token is not None:
736
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
737
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
738
+ x = torch.cat((cls_tokens, x), dim=1)
739
+
740
+ # apply Transformer blocks
741
+ for blk in self.blocks:
742
+ x = blk(x)
743
+
744
+ if self.norm is not None:
745
+ x = self.norm(x)
746
+
747
+ return x, mask, ids_restore
748
+
749
+ def forward_decoder(self, x, ids_restore):
750
+ # embed tokens
751
+ x = self.decoder_embed(x)
752
+
753
+ # append mask tokens to sequence
754
+ mask_tokens = self.mask_token.repeat(
755
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
756
+ )
757
+ if self.cls_token is not None:
758
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
759
+ else:
760
+ x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
761
+
762
+ x_ = torch.gather(
763
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
764
+ ) # unshuffle
765
+
766
+ if self.cls_token is not None:
767
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
768
+
769
+ # add pos embed
770
+ x = x + self.decoder_pos_embed
771
+
772
+ # apply Transformer blocks
773
+ for blk in self.decoder_blocks:
774
+ x = blk(x)
775
+ x = self.decoder_norm(x)
776
+
777
+ # predictor projection
778
+ x = self.decoder_pred(x)
779
+
780
+ if self.cls_token is not None:
781
+ # remove cls token
782
+ x = x[:, 1:, :]
783
+
784
+ return x
785
+
786
+ def forward_loss(self, imgs, pred, mask):
787
+ """
788
+ imgs: [N, 3, H, W]
789
+ pred: [N, L, p*p*3]
790
+ mask: [N, L], 0 is keep, 1 is remove,
791
+ """
792
+ target = self.patchify(imgs)
793
+ if self.norm_pix_loss:
794
+ mean = target.mean(dim=-1, keepdim=True)
795
+ var = target.var(dim=-1, keepdim=True)
796
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
797
+
798
+ loss = (pred - target) ** 2
799
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
800
+
801
+ loss = (loss * mask).sum()
802
+ return loss, mask.sum()
803
+
804
+ def forward(self, imgs, predictions_only=False):
805
+ latent, mask, ids_restore = self.forward_encoder(
806
+ imgs, self.mask_ratio if not predictions_only else 0
807
+ )
808
+
809
+ if predictions_only:
810
+ return latent
811
+
812
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
813
+ loss, sample_size = self.forward_loss(imgs, pred, mask)
814
+
815
+ result = {
816
+ "losses": {"regression": loss},
817
+ "sample_size": sample_size,
818
+ }
819
+ return result
820
+
821
+ def remove_pretraining_modules(self):
822
+ self.decoder_embed = None
823
+ self.decoder_blocks = None
824
+ self.decoder_norm = None
825
+ self.decoder_pos_embed = None
826
+ self.decoder_pred = None
827
+ self.mask_token = None
828
+ if self.cfg.layer_norm_first:
829
+ self.norm = None
fairseq/examples/data2vec/models/mae_image_classification.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # The code in this file is adapted from the BeiT implementation which can be found here:
7
+ # https://github.com/microsoft/unilm/tree/master/beit
8
+
9
+ import logging
10
+
11
+ from dataclasses import dataclass
12
+ from enum import Enum, auto
13
+ from typing import Any, Optional
14
+
15
+ import numpy as np
16
+ from omegaconf import II, MISSING
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from fairseq import checkpoint_utils, tasks
23
+ from omegaconf import open_dict
24
+
25
+ from fairseq.dataclass import FairseqDataclass
26
+ from fairseq.models import BaseFairseqModel, register_model
27
+ from .mae import interpolate_pos_embed
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class PredictionMode(Enum):
34
+ MEAN_POOLING = auto()
35
+ CLS_TOKEN = auto()
36
+ LIN_SOFTMAX = auto()
37
+
38
+
39
+ @dataclass
40
+ class MaeImageClassificationConfig(FairseqDataclass):
41
+ model_path: str = MISSING
42
+ no_pretrained_weights: bool = False
43
+ linear_classifier: bool = False
44
+ num_classes: int = 1000
45
+ mixup: float = 0.8
46
+ cutmix: float = 1.0
47
+ label_smoothing: float = 0.1
48
+
49
+ drop_path_rate: float = 0.1
50
+ layer_decay: float = 0.65
51
+
52
+ mixup_prob: float = 1.0
53
+ mixup_switch_prob: float = 0.5
54
+ mixup_mode: str = "batch"
55
+
56
+ pretrained_model_args: Any = None
57
+ data: str = II("task.data")
58
+
59
+ norm_eps: Optional[float] = None
60
+
61
+ remove_alibi: bool = False
62
+
63
+ # regularization overwrites
64
+ encoder_dropout: float = 0
65
+ post_mlp_drop: float = 0
66
+ attention_dropout: float = 0
67
+ activation_dropout: float = 0.0
68
+ dropout_input: float = 0.0
69
+ layerdrop: float = 0.0
70
+
71
+ prenet_layerdrop: float = 0
72
+ prenet_dropout: float = 0
73
+
74
+ use_fc_norm: bool = True
75
+ prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
76
+
77
+ no_decay_blocks: bool = True
78
+
79
+
80
+ def get_layer_id_for_vit(name, num_layers):
81
+ """
82
+ Assign a parameter with its layer id
83
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
84
+ """
85
+ if name in ["cls_token", "pos_embed"]:
86
+ return 0
87
+ elif name.startswith("patch_embed"):
88
+ return 0
89
+ elif name.startswith("rel_pos_bias"):
90
+ return num_layers - 1
91
+ elif name.startswith("blocks"):
92
+ return int(name.split(".")[1]) + 1
93
+ else:
94
+ return num_layers
95
+
96
+
97
+ @register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
98
+ class MaeImageClassificationModel(BaseFairseqModel):
99
+ def __init__(self, cfg: MaeImageClassificationConfig):
100
+ super().__init__()
101
+ self.cfg = cfg
102
+
103
+ if cfg.pretrained_model_args is None:
104
+ state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
105
+ pretrained_args = state.get("cfg", None)
106
+
107
+ pretrained_args.criterion = None
108
+ pretrained_args.lr_scheduler = None
109
+
110
+ logger.info(pretrained_args.model)
111
+
112
+ with open_dict(pretrained_args.model):
113
+ pretrained_args.model.drop_path_rate = cfg.drop_path_rate
114
+ if cfg.norm_eps is not None:
115
+ pretrained_args.model.norm_eps = cfg.norm_eps
116
+
117
+ cfg.pretrained_model_args = pretrained_args
118
+
119
+ logger.info(pretrained_args)
120
+ else:
121
+ state = None
122
+ pretrained_args = cfg.pretrained_model_args
123
+
124
+ if "data" in pretrained_args.task:
125
+ pretrained_args.task.data = cfg.data
126
+ elif "image" in pretrained_args.task:
127
+ pretrained_args.task.image.data = cfg.data
128
+
129
+ if "modalities" in pretrained_args.model:
130
+ prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
131
+ model_blocks = pretrained_args.model["depth"]
132
+ with open_dict(pretrained_args):
133
+ dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
134
+ pretrained_args.model["modalities"]["image"][
135
+ "start_drop_path_rate"
136
+ ] = dpr[0]
137
+ pretrained_args.model["modalities"]["image"][
138
+ "end_drop_path_rate"
139
+ ] = max(0, dpr[prenet_blocks - 1])
140
+ pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
141
+ pretrained_args.model["end_drop_path_rate"] = dpr[-1]
142
+
143
+ if "mae_masking" in pretrained_args.model["modalities"]["image"]:
144
+ del pretrained_args.model["modalities"]["image"]["mae_masking"]
145
+
146
+ if cfg.remove_alibi:
147
+ pretrained_args.model["modalities"]["image"][
148
+ "use_alibi_encoder"
149
+ ] = False
150
+ if (
151
+ state is not None
152
+ and "modality_encoders.IMAGE.alibi_bias" in state["model"]
153
+ ):
154
+ del state["model"]["modality_encoders.IMAGE.alibi_bias"]
155
+
156
+ pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
157
+ pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
158
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
159
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
160
+ pretrained_args.model["dropout_input"] = cfg.dropout_input
161
+ pretrained_args.model["layerdrop"] = cfg.layerdrop
162
+
163
+ pretrained_args.model["modalities"]["image"][
164
+ "prenet_layerdrop"
165
+ ] = cfg.prenet_layerdrop
166
+ pretrained_args.model["modalities"]["image"][
167
+ "prenet_dropout"
168
+ ] = cfg.prenet_dropout
169
+ else:
170
+ # not d2v multi
171
+ with open_dict(pretrained_args):
172
+ pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
173
+ pretrained_args.model["block_dropout"] = cfg.encoder_dropout
174
+ pretrained_args.model["attention_dropout"] = cfg.attention_dropout
175
+ pretrained_args.model["activation_dropout"] = cfg.activation_dropout
176
+
177
+ task = tasks.setup_task(pretrained_args.task)
178
+ model = task.build_model(pretrained_args.model, from_checkpoint=True)
179
+
180
+ self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
181
+ self.linear_classifier = cfg.linear_classifier
182
+
183
+ self.model = model
184
+
185
+ if state is not None and not cfg.no_pretrained_weights:
186
+ interpolate_pos_embed(model, state)
187
+
188
+ if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
189
+ state["model"][
190
+ "modality_encoders.IMAGE.positional_encoder.positions"
191
+ ] = state["model"][
192
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
193
+ ]
194
+ del state["model"][
195
+ "modality_encoders.IMAGE.positional_encoder.pos_embed"
196
+ ]
197
+ if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
198
+ del state["model"]["modality_encoders.IMAGE.encoder_mask"]
199
+
200
+ model.load_state_dict(state["model"], strict=True)
201
+
202
+ if self.d2v_multi:
203
+ model.remove_pretraining_modules(modality="image")
204
+ else:
205
+ model.remove_pretraining_modules()
206
+
207
+ if self.linear_classifier:
208
+ model.requires_grad_(False)
209
+
210
+ self.fc_norm = None
211
+ if self.cfg.use_fc_norm:
212
+ self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
213
+ nn.init.constant_(self.fc_norm.bias, 0)
214
+ nn.init.constant_(self.fc_norm.weight, 1.0)
215
+
216
+ self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
217
+
218
+ nn.init.trunc_normal_(self.head.weight, std=0.02)
219
+ nn.init.constant_(self.head.bias, 0)
220
+
221
+ self.mixup_fn = None
222
+
223
+ if cfg.mixup > 0 or cfg.cutmix > 0:
224
+ from timm.data import Mixup
225
+
226
+ self.mixup_fn = Mixup(
227
+ mixup_alpha=cfg.mixup,
228
+ cutmix_alpha=cfg.cutmix,
229
+ cutmix_minmax=None,
230
+ prob=cfg.mixup_prob,
231
+ switch_prob=cfg.mixup_switch_prob,
232
+ mode=cfg.mixup_mode,
233
+ label_smoothing=cfg.label_smoothing,
234
+ num_classes=cfg.num_classes,
235
+ )
236
+
237
+ if self.model.norm is not None:
238
+ for pn, p in self.model.norm.named_parameters():
239
+ if len(p.shape) == 1 or pn.endswith(".bias"):
240
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
241
+
242
+ if self.fc_norm is not None:
243
+ for pn, p in self.fc_norm.named_parameters():
244
+ if len(p.shape) == 1 or pn.endswith(".bias"):
245
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
246
+
247
+ for pn, p in self.head.named_parameters():
248
+ if len(p.shape) == 1 or pn.endswith(".bias"):
249
+ p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
250
+
251
+ if self.d2v_multi:
252
+ mod_encs = list(model.modality_encoders.values())
253
+ assert len(mod_encs) == 1, len(mod_encs)
254
+ blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
255
+ else:
256
+ blocks = model.blocks
257
+
258
+ num_layers = len(blocks) + 1
259
+ layer_scales = list(
260
+ cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
261
+ )
262
+
263
+ if self.d2v_multi:
264
+ for n, p in self.model.named_parameters():
265
+ optimizer_override_dict = {}
266
+
267
+ if len(p.shape) == 1 or n.endswith(".bias"):
268
+ optimizer_override_dict["weight_decay_scale"] = 0
269
+
270
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
271
+
272
+ if cfg.layer_decay > 0:
273
+ for i, b in enumerate(blocks):
274
+ lid = i + 1
275
+ if layer_scales[lid] == 1.0:
276
+ continue
277
+
278
+ for n, p in b.named_parameters():
279
+ optim_override = getattr(p, "optim_overrides", {})
280
+ if "optimizer" not in optim_override:
281
+ optim_override["optimizer"] = {}
282
+
283
+ if cfg.no_decay_blocks:
284
+ optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
285
+ p.optim_overrides = optim_override
286
+ else:
287
+ optim_override["optimizer"] = {
288
+ "lr_scale": layer_scales[lid]
289
+ }
290
+ p.optim_overrides = optim_override
291
+
292
+ else:
293
+ for n, p in self.model.named_parameters():
294
+ optimizer_override_dict = {}
295
+ layer_id = get_layer_id_for_vit(n, num_layers)
296
+
297
+ if len(p.shape) == 1 or n.endswith(".bias"):
298
+ optimizer_override_dict["weight_decay_scale"] = 0
299
+
300
+ if cfg.layer_decay > 0:
301
+ optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
302
+ p.optim_overrides = {"optimizer": optimizer_override_dict}
303
+
304
+ @classmethod
305
+ def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
306
+ """Build a new model instance."""
307
+
308
+ return cls(cfg)
309
+
310
+ def forward(
311
+ self,
312
+ imgs,
313
+ labels=None,
314
+ ):
315
+ if self.training and self.mixup_fn is not None and labels is not None:
316
+ imgs, labels = self.mixup_fn(imgs, labels)
317
+
318
+ if self.linear_classifier:
319
+ with torch.no_grad():
320
+ x = self.model_forward(imgs)
321
+ else:
322
+ x = self.model_forward(imgs)
323
+
324
+ if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
325
+ x = x.mean(dim=1)
326
+ elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
327
+ x = x[:, 0]
328
+ elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
329
+ dtype = x.dtype
330
+ x = F.logsigmoid(x.float())
331
+ x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
332
+ x = x.clamp(max=0)
333
+ x = x - torch.log(-(torch.expm1(x)))
334
+ x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
335
+ x = x.to(dtype=dtype)
336
+ else:
337
+ raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
338
+
339
+ if self.fc_norm is not None:
340
+ x = self.fc_norm(x)
341
+
342
+ x = self.head(x)
343
+
344
+ if labels is None:
345
+ return x
346
+
347
+ if self.training and self.mixup_fn is not None:
348
+ loss = -labels * F.log_softmax(x.float(), dim=-1)
349
+ else:
350
+ loss = F.cross_entropy(
351
+ x.float(),
352
+ labels,
353
+ label_smoothing=self.cfg.label_smoothing if self.training else 0,
354
+ reduction="none",
355
+ )
356
+
357
+ result = {
358
+ "losses": {"regression": loss},
359
+ "sample_size": imgs.size(0),
360
+ }
361
+
362
+ if not self.training:
363
+ with torch.no_grad():
364
+ pred = x.argmax(-1)
365
+ correct = (pred == labels).sum()
366
+ result["correct"] = correct
367
+
368
+ return result
369
+
370
+ def model_forward(self, imgs):
371
+ if self.d2v_multi:
372
+ x = self.model.extract_features(
373
+ imgs,
374
+ mode="IMAGE",
375
+ mask=False,
376
+ remove_extra_tokens=(
377
+ self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
378
+ ),
379
+ )["x"]
380
+ else:
381
+ x = self.model(imgs, predictions_only=True)
382
+ if (
383
+ "no_cls" not in self.model.cfg or not self.model.cfg.no_cls
384
+ ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
385
+ x = x[:, 1:]
386
+ return x
fairseq/examples/data2vec/models/modalities/__init__.py ADDED
File without changes
fairseq/examples/data2vec/models/modalities/audio.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import partial
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ from dataclasses import dataclass, field
11
+ from typing import Callable, Dict, Optional
12
+ from fairseq.models.wav2vec import ConvFeatureExtractionModel
13
+ from fairseq.modules import (
14
+ LayerNorm,
15
+ SamePad,
16
+ TransposeLast,
17
+ )
18
+ from fairseq.tasks import FairseqTask
19
+ from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
20
+ from .modules import BlockEncoder, Decoder1d
21
+ from examples.data2vec.data.modality import Modality
22
+
23
+
24
+ @dataclass
25
+ class D2vAudioConfig(D2vModalityConfig):
26
+ type: Modality = Modality.AUDIO
27
+ extractor_mode: str = "layer_norm"
28
+ feature_encoder_spec: str = field(
29
+ default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
30
+ metadata={
31
+ "help": "string describing convolutional feature extraction layers in form of a python list that contains "
32
+ "[(dim, kernel_size, stride), ...]"
33
+ },
34
+ )
35
+ conv_pos_width: int = field(
36
+ default=95,
37
+ metadata={"help": "number of filters for convolutional positional embeddings"},
38
+ )
39
+ conv_pos_groups: int = field(
40
+ default=16,
41
+ metadata={"help": "number of groups for convolutional positional embedding"},
42
+ )
43
+ conv_pos_depth: int = field(
44
+ default=5,
45
+ metadata={"help": "depth of positional encoder network"},
46
+ )
47
+ conv_pos_pre_ln: bool = False
48
+
49
+
50
+ class AudioEncoder(ModalitySpecificEncoder):
51
+
52
+ modality_cfg: D2vAudioConfig
53
+
54
+ def __init__(
55
+ self,
56
+ modality_cfg: D2vAudioConfig,
57
+ embed_dim: int,
58
+ make_block: Callable[[float], nn.ModuleList],
59
+ norm_layer: Callable[[int], nn.LayerNorm],
60
+ layer_norm_first: bool,
61
+ alibi_biases: Dict,
62
+ task: Optional[FairseqTask],
63
+ ):
64
+
65
+ self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
66
+ feature_embed_dim = self.feature_enc_layers[-1][0]
67
+
68
+ local_encoder = ConvFeatureExtractionModel(
69
+ conv_layers=self.feature_enc_layers,
70
+ dropout=0.0,
71
+ mode=modality_cfg.extractor_mode,
72
+ conv_bias=False,
73
+ )
74
+
75
+ project_features = nn.Sequential(
76
+ TransposeLast(),
77
+ nn.LayerNorm(feature_embed_dim),
78
+ nn.Linear(feature_embed_dim, embed_dim),
79
+ )
80
+
81
+ num_pos_layers = modality_cfg.conv_pos_depth
82
+ k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
83
+
84
+ positional_encoder = nn.Sequential(
85
+ TransposeLast(),
86
+ *[
87
+ nn.Sequential(
88
+ nn.Conv1d(
89
+ embed_dim,
90
+ embed_dim,
91
+ kernel_size=k,
92
+ padding=k // 2,
93
+ groups=modality_cfg.conv_pos_groups,
94
+ ),
95
+ SamePad(k),
96
+ TransposeLast(),
97
+ LayerNorm(embed_dim, elementwise_affine=False),
98
+ TransposeLast(),
99
+ nn.GELU(),
100
+ )
101
+ for _ in range(num_pos_layers)
102
+ ],
103
+ TransposeLast(),
104
+ )
105
+
106
+ if modality_cfg.conv_pos_pre_ln:
107
+ positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
108
+
109
+ dpr = np.linspace(
110
+ modality_cfg.start_drop_path_rate,
111
+ modality_cfg.end_drop_path_rate,
112
+ modality_cfg.prenet_depth,
113
+ )
114
+ context_encoder = BlockEncoder(
115
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
116
+ norm_layer(embed_dim) if not layer_norm_first else None,
117
+ layer_norm_first,
118
+ modality_cfg.prenet_layerdrop,
119
+ modality_cfg.prenet_dropout,
120
+ )
121
+
122
+ decoder = (
123
+ Decoder1d(modality_cfg.decoder, embed_dim)
124
+ if modality_cfg.decoder is not None
125
+ else None
126
+ )
127
+
128
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
129
+
130
+ super().__init__(
131
+ modality_cfg=modality_cfg,
132
+ embed_dim=embed_dim,
133
+ local_encoder=local_encoder,
134
+ project_features=project_features,
135
+ fixed_positional_encoder=None,
136
+ relative_positional_encoder=positional_encoder,
137
+ context_encoder=context_encoder,
138
+ decoder=decoder,
139
+ get_alibi_bias=alibi_bias_fn,
140
+ )
141
+
142
+ def convert_padding_mask(self, x, padding_mask):
143
+ def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
144
+ """
145
+ Computes the output length of the convolutional layers
146
+ """
147
+
148
+ def _conv_out_length(input_length, kernel_size, stride):
149
+ return torch.floor((input_length - kernel_size) / stride + 1)
150
+
151
+ for i in range(len(self.feature_enc_layers)):
152
+ input_lengths = _conv_out_length(
153
+ input_lengths,
154
+ self.feature_enc_layers[i][1],
155
+ self.feature_enc_layers[i][2],
156
+ )
157
+
158
+ return input_lengths.to(torch.long)
159
+
160
+ if padding_mask is not None:
161
+ input_lengths = (1 - padding_mask.long()).sum(-1)
162
+ # apply conv formula to get real output_lengths
163
+ output_lengths = get_feat_extract_output_lengths(input_lengths)
164
+
165
+ if padding_mask.any():
166
+ padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
167
+
168
+ # these two operations makes sure that all values
169
+ # before the output lengths indices are attended to
170
+ padding_mask[
171
+ (
172
+ torch.arange(padding_mask.shape[0], device=padding_mask.device),
173
+ output_lengths - 1,
174
+ )
175
+ ] = 1
176
+ padding_mask = (
177
+ 1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
178
+ ).bool()
179
+ else:
180
+ padding_mask = torch.zeros(
181
+ x.shape[:2], dtype=torch.bool, device=x.device
182
+ )
183
+
184
+ return padding_mask
185
+
186
+ def reset_parameters(self):
187
+ super().reset_parameters()
188
+ for mod in self.project_features.children():
189
+ if isinstance(mod, nn.Linear):
190
+ mod.reset_parameters()
191
+ if self.decoder is not None:
192
+ self.decoder.reset_parameters()
fairseq/examples/data2vec/models/modalities/base.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from collections import namedtuple
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from omegaconf import MISSING, II
16
+ from typing import Optional, Callable
17
+ from fairseq.data.data_utils import compute_mask_indices
18
+ from fairseq.modules import GradMultiply
19
+ from fairseq.utils import index_put
20
+ from examples.data2vec.data.modality import Modality
21
+ from .modules import D2vDecoderConfig
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class D2vModalityConfig:
28
+ type: Modality = MISSING
29
+ prenet_depth: int = 4
30
+ prenet_layerdrop: float = 0
31
+ prenet_dropout: float = 0
32
+ start_drop_path_rate: float = 0
33
+ end_drop_path_rate: float = 0
34
+
35
+ num_extra_tokens: int = 0
36
+ init_extra_token_zero: bool = True
37
+
38
+ mask_noise_std: float = 0.01
39
+ mask_prob_min: Optional[float] = None
40
+ mask_prob: float = 0.7
41
+ inverse_mask: bool = False
42
+ mask_prob_adjust: float = 0
43
+ keep_masked_pct: float = 0
44
+
45
+ mask_length: int = 5
46
+ add_masks: bool = False
47
+ remove_masks: bool = False
48
+ mask_dropout: float = 0.0
49
+ encoder_zero_mask: bool = True
50
+
51
+ mask_channel_prob: float = 0.0
52
+ mask_channel_length: int = 64
53
+
54
+ ema_local_encoder: bool = False # used in data2vec_multi
55
+ local_grad_mult: float = 1.0
56
+
57
+ use_alibi_encoder: bool = False
58
+ alibi_scale: float = 1.0
59
+ learned_alibi: bool = False
60
+ alibi_max_pos: Optional[int] = None
61
+ learned_alibi_scale: bool = False
62
+ learned_alibi_scale_per_head: bool = False
63
+ learned_alibi_scale_per_layer: bool = False
64
+
65
+ num_alibi_heads: int = II("model.num_heads")
66
+ model_depth: int = II("model.depth")
67
+
68
+ decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
69
+
70
+
71
+ MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
72
+ MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
73
+
74
+
75
+ class ModalitySpecificEncoder(nn.Module):
76
+ def __init__(
77
+ self,
78
+ modality_cfg: D2vModalityConfig,
79
+ embed_dim: int,
80
+ local_encoder: nn.Module,
81
+ project_features: nn.Module,
82
+ fixed_positional_encoder: Optional[nn.Module],
83
+ relative_positional_encoder: Optional[nn.Module],
84
+ context_encoder: nn.Module,
85
+ decoder: nn.Module,
86
+ get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
87
+ ):
88
+ super().__init__()
89
+
90
+ self.modality_cfg = modality_cfg
91
+ self.local_encoder = local_encoder
92
+ self.project_features = project_features
93
+ self.fixed_positional_encoder = fixed_positional_encoder
94
+ self.relative_positional_encoder = relative_positional_encoder
95
+ self.context_encoder = context_encoder
96
+
97
+ self.decoder = decoder
98
+ self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
99
+
100
+ self.local_grad_mult = self.modality_cfg.local_grad_mult
101
+
102
+ self.extra_tokens = None
103
+ if modality_cfg.num_extra_tokens > 0:
104
+ self.extra_tokens = nn.Parameter(
105
+ torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
106
+ )
107
+ if not modality_cfg.init_extra_token_zero:
108
+ nn.init.normal_(self.extra_tokens)
109
+ elif self.extra_tokens.size(1) > 1:
110
+ nn.init.normal_(self.extra_tokens[:, 1:])
111
+
112
+ self.alibi_scale = None
113
+ if self.get_alibi_bias is not None:
114
+ self.alibi_scale = nn.Parameter(
115
+ torch.full(
116
+ (
117
+ (modality_cfg.prenet_depth + modality_cfg.model_depth)
118
+ if modality_cfg.learned_alibi_scale_per_layer
119
+ else 1,
120
+ 1,
121
+ self.modality_cfg.num_alibi_heads
122
+ if modality_cfg.learned_alibi_scale_per_head
123
+ else 1,
124
+ 1,
125
+ 1,
126
+ ),
127
+ modality_cfg.alibi_scale,
128
+ dtype=torch.float,
129
+ ),
130
+ requires_grad=modality_cfg.learned_alibi_scale,
131
+ )
132
+
133
+ if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
134
+ assert modality_cfg.alibi_max_pos is not None
135
+ alibi_bias = self.get_alibi_bias(
136
+ batch_size=1,
137
+ time_steps=modality_cfg.alibi_max_pos,
138
+ heads=modality_cfg.num_alibi_heads,
139
+ scale=1.0,
140
+ dtype=torch.float,
141
+ device="cpu",
142
+ )
143
+ self.alibi_bias = nn.Parameter(alibi_bias)
144
+ self.get_alibi_bias = partial(
145
+ _learned_alibi_bias, alibi_bias=self.alibi_bias
146
+ )
147
+
148
+ def upgrade_state_dict_named(self, state_dict, name):
149
+ k = f"{name}.alibi_scale"
150
+ if k in state_dict and state_dict[k].dim() == 4:
151
+ state_dict[k] = state_dict[k].unsqueeze(0)
152
+
153
+ return state_dict
154
+
155
+ def convert_padding_mask(self, x, padding_mask):
156
+ return padding_mask
157
+
158
+ def decoder_input(self, x, mask_info: MaskInfo):
159
+ inp_drop = self.modality_cfg.decoder.input_dropout
160
+ if inp_drop > 0:
161
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
162
+
163
+ num_extra = self.modality_cfg.num_extra_tokens
164
+
165
+ if mask_info is not None:
166
+ num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
167
+
168
+ mask_tokens = x.new_empty(
169
+ x.size(0),
170
+ num_masked,
171
+ x.size(-1),
172
+ ).normal_(0, self.modality_cfg.mask_noise_std)
173
+
174
+ x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
175
+ x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
176
+
177
+ if self.modality_cfg.decoder.add_positions_masked:
178
+ assert self.fixed_positional_encoder is not None
179
+ pos = self.fixed_positional_encoder(x, None)
180
+ x = x + (pos * mask_info.mask.unsqueeze(-1))
181
+ else:
182
+ x = x[:, num_extra:]
183
+
184
+ if self.modality_cfg.decoder.add_positions_all:
185
+ assert self.fixed_positional_encoder is not None
186
+ x = x + self.fixed_positional_encoder(x, None)
187
+
188
+ return x, mask_info
189
+
190
+ def local_features(self, features):
191
+ if self.local_grad_mult > 0:
192
+ if self.local_grad_mult == 1.0:
193
+ x = self.local_encoder(features)
194
+ else:
195
+ x = GradMultiply.apply(
196
+ self.local_encoder(features), self.local_grad_mult
197
+ )
198
+ else:
199
+ with torch.no_grad():
200
+ x = self.local_encoder(features)
201
+
202
+ x = self.project_features(x)
203
+ return x
204
+
205
+ def contextualized_features(
206
+ self,
207
+ x,
208
+ padding_mask,
209
+ mask,
210
+ remove_masked,
211
+ clone_batch: int = 1,
212
+ mask_seeds: Optional[torch.Tensor] = None,
213
+ precomputed_mask=None,
214
+ ):
215
+
216
+ if padding_mask is not None:
217
+ padding_mask = self.convert_padding_mask(x, padding_mask)
218
+
219
+ local_features = x
220
+ if mask and clone_batch == 1:
221
+ local_features = local_features.clone()
222
+
223
+ orig_B, orig_T, _ = x.shape
224
+ pre_mask_B = orig_B
225
+ mask_info = None
226
+
227
+ x_pos = None
228
+ if self.fixed_positional_encoder is not None:
229
+ x = x + self.fixed_positional_encoder(x, padding_mask)
230
+
231
+ if mask:
232
+ if clone_batch > 1:
233
+ x = x.repeat_interleave(clone_batch, 0)
234
+ if mask_seeds is not None:
235
+ clone_hash = [
236
+ int(hash((mask_seeds.seed, ind)) % 1e10)
237
+ for ind in range(clone_batch - 1)
238
+ ]
239
+ clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
240
+
241
+ id = mask_seeds.ids
242
+ id = id.repeat_interleave(clone_batch, 0)
243
+ id = id.view(-1, clone_batch) + clone_hash.to(id)
244
+ id = id.view(-1)
245
+ mask_seeds = MaskSeed(
246
+ seed=mask_seeds.seed, update=mask_seeds.update, ids=id
247
+ )
248
+ if padding_mask is not None:
249
+ padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
250
+
251
+ x, mask_info = self.compute_mask(
252
+ x,
253
+ padding_mask,
254
+ mask_seed=mask_seeds,
255
+ apply=self.relative_positional_encoder is not None or not remove_masked,
256
+ precomputed_mask=precomputed_mask,
257
+ )
258
+
259
+ if self.relative_positional_encoder is not None:
260
+ x_pos = self.relative_positional_encoder(x)
261
+
262
+ masked_padding_mask = padding_mask
263
+ if mask and remove_masked:
264
+ x = mask_info.x_unmasked
265
+ if x_pos is not None:
266
+ x = x + gather_unmasked(x_pos, mask_info)
267
+
268
+ if padding_mask is not None and padding_mask.any():
269
+ masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
270
+ if not masked_padding_mask.any():
271
+ masked_padding_mask = None
272
+ else:
273
+ masked_padding_mask = None
274
+
275
+ elif x_pos is not None:
276
+ x = x + x_pos
277
+
278
+ alibi_bias = None
279
+ alibi_scale = self.alibi_scale
280
+
281
+ if self.get_alibi_bias is not None:
282
+ alibi_bias = self.get_alibi_bias(
283
+ batch_size=pre_mask_B,
284
+ time_steps=orig_T,
285
+ heads=self.modality_cfg.num_alibi_heads,
286
+ dtype=torch.float32,
287
+ device=x.device,
288
+ )
289
+
290
+ if alibi_scale is not None:
291
+ alibi_scale = alibi_scale.clamp_min(0)
292
+ if alibi_scale.size(0) == 1:
293
+ alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
294
+ alibi_scale = None
295
+
296
+ if clone_batch > 1:
297
+ alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
298
+
299
+ if mask_info is not None and remove_masked:
300
+ alibi_bias = masked_alibi(alibi_bias, mask_info)
301
+
302
+ if self.extra_tokens is not None:
303
+ num = self.extra_tokens.size(1)
304
+ x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
305
+ if masked_padding_mask is not None:
306
+ # B x T
307
+ masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
308
+ if alibi_bias is not None:
309
+ # B x H x T x T
310
+ alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
311
+
312
+ x = self.context_encoder(
313
+ x,
314
+ masked_padding_mask,
315
+ alibi_bias,
316
+ alibi_scale[: self.modality_cfg.prenet_depth]
317
+ if alibi_scale is not None
318
+ else None,
319
+ )
320
+
321
+ return {
322
+ "x": x,
323
+ "local_features": local_features,
324
+ "padding_mask": masked_padding_mask,
325
+ "alibi_bias": alibi_bias,
326
+ "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
327
+ if alibi_scale is not None and alibi_scale.size(0) > 1
328
+ else alibi_scale,
329
+ "encoder_mask": mask_info,
330
+ }
331
+
332
+ def forward(
333
+ self,
334
+ features,
335
+ padding_mask,
336
+ mask: bool,
337
+ remove_masked: bool,
338
+ clone_batch: int = 1,
339
+ mask_seeds: Optional[torch.Tensor] = None,
340
+ precomputed_mask=None,
341
+ ):
342
+ x = self.local_features(features)
343
+ return self.contextualized_features(
344
+ x,
345
+ padding_mask,
346
+ mask,
347
+ remove_masked,
348
+ clone_batch,
349
+ mask_seeds,
350
+ precomputed_mask,
351
+ )
352
+
353
+ def reset_parameters(self):
354
+ pass
355
+
356
+ def compute_mask(
357
+ self,
358
+ x,
359
+ padding_mask,
360
+ mask_seed: Optional[MaskSeed],
361
+ apply,
362
+ precomputed_mask,
363
+ ):
364
+ if precomputed_mask is not None:
365
+ mask = precomputed_mask
366
+ mask_info = self.make_maskinfo(x, mask)
367
+ else:
368
+ B, T, C = x.shape
369
+ cfg = self.modality_cfg
370
+
371
+ mask_prob = cfg.mask_prob
372
+
373
+ if (
374
+ cfg.mask_prob_min is not None
375
+ and cfg.mask_prob_min >= 0
376
+ and cfg.mask_prob_min < mask_prob
377
+ ):
378
+ mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
379
+
380
+ if mask_prob > 0:
381
+ if cfg.mask_length == 1:
382
+ mask_info = random_masking(x, mask_prob, mask_seed)
383
+ else:
384
+ if self.modality_cfg.inverse_mask:
385
+ mask_prob = 1 - mask_prob
386
+
387
+ mask = compute_mask_indices(
388
+ (B, T),
389
+ padding_mask,
390
+ mask_prob,
391
+ cfg.mask_length,
392
+ min_masks=1,
393
+ require_same_masks=True,
394
+ mask_dropout=cfg.mask_dropout,
395
+ add_masks=cfg.add_masks,
396
+ seed=mask_seed.seed if mask_seed is not None else None,
397
+ epoch=mask_seed.update if mask_seed is not None else None,
398
+ indices=mask_seed.ids if mask_seed is not None else None,
399
+ )
400
+
401
+ mask = torch.from_numpy(mask).to(device=x.device)
402
+ if self.modality_cfg.inverse_mask:
403
+ mask = 1 - mask
404
+ mask_info = self.make_maskinfo(x, mask)
405
+ else:
406
+ mask_info = None
407
+
408
+ if apply:
409
+ x = self.apply_mask(x, mask_info)
410
+
411
+ return x, mask_info
412
+
413
+ def make_maskinfo(self, x, mask, shape=None):
414
+ if shape is None:
415
+ B, T, D = x.shape
416
+ else:
417
+ B, T, D = shape
418
+
419
+ mask = mask.to(torch.uint8)
420
+ ids_shuffle = mask.argsort(dim=1)
421
+ ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
422
+
423
+ len_keep = T - mask[0].sum()
424
+ if self.modality_cfg.keep_masked_pct > 0:
425
+ len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
426
+
427
+ ids_keep = ids_shuffle[:, :len_keep]
428
+
429
+ if shape is not None:
430
+ x_unmasked = None
431
+ else:
432
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
433
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
434
+
435
+ mask_info = MaskInfo(
436
+ x_unmasked=x_unmasked,
437
+ mask=mask,
438
+ ids_restore=ids_restore,
439
+ ids_keep=ids_keep,
440
+ )
441
+ return mask_info
442
+
443
+ def apply_mask(self, x, mask_info):
444
+ cfg = self.modality_cfg
445
+ B, T, C = x.shape
446
+
447
+ if mask_info is not None:
448
+ mask = mask_info.mask
449
+ if cfg.encoder_zero_mask:
450
+ x = x * (1 - mask.type_as(x).unsqueeze(-1))
451
+ else:
452
+ num_masks = mask.sum().item()
453
+ masks = x.new_empty(num_masks, x.size(-1)).normal_(
454
+ 0, cfg.mask_noise_std
455
+ )
456
+ x = index_put(x, mask, masks)
457
+ if cfg.mask_channel_prob > 0:
458
+ mask_channel = compute_mask_indices(
459
+ (B, C),
460
+ None,
461
+ cfg.mask_channel_prob,
462
+ cfg.mask_channel_length,
463
+ )
464
+ mask_channel = (
465
+ torch.from_numpy(mask_channel)
466
+ .to(x.device)
467
+ .unsqueeze(1)
468
+ .expand(-1, T, -1)
469
+ )
470
+ x = index_put(x, mask_channel, 0)
471
+ return x
472
+
473
+ def remove_pretraining_modules(self, keep_decoder=False):
474
+ if not keep_decoder:
475
+ self.decoder = None
476
+
477
+
478
+ def get_annealed_rate(start, end, curr_step, total_steps):
479
+ if curr_step >= total_steps:
480
+ return end
481
+ r = end - start
482
+ pct_remaining = 1 - curr_step / total_steps
483
+ return end - r * pct_remaining
484
+
485
+
486
+ # adapted from MAE
487
+ def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
488
+ N, L, D = x.shape # batch, length, dim
489
+ len_keep = int(L * (1 - mask_ratio))
490
+
491
+ generator = None
492
+ if mask_seed is not None:
493
+ seed = int(
494
+ hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
495
+ )
496
+ generator = torch.Generator(device=x.device)
497
+ generator.manual_seed(seed)
498
+
499
+ noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
500
+
501
+ # sort noise for each sample
502
+ ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
503
+ ids_restore = ids_shuffle.argsort(dim=1)
504
+
505
+ # keep the first subset
506
+ ids_keep = ids_shuffle[:, :len_keep]
507
+ ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
508
+ x_unmasked = torch.gather(x, dim=1, index=ids_keep)
509
+
510
+ # generate the binary mask: 0 is keep, 1 is remove
511
+ mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
512
+ mask[:, :len_keep] = 0
513
+ # unshuffle to get the binary mask
514
+ mask = torch.gather(mask, dim=1, index=ids_restore)
515
+
516
+ ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
517
+
518
+ return MaskInfo(
519
+ x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
520
+ )
521
+
522
+
523
+ def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
524
+ return torch.gather(
525
+ x,
526
+ dim=1,
527
+ index=mask_info.ids_keep,
528
+ )
529
+
530
+
531
+ def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
532
+ return torch.gather(
533
+ x,
534
+ dim=1,
535
+ index=mask_info.ids_keep[..., 0], # ignore the feature dimension
536
+ )
537
+
538
+
539
+ def get_alibi(
540
+ max_positions: int,
541
+ attention_heads: int,
542
+ dims: int = 1,
543
+ distance: str = "manhattan",
544
+ ):
545
+ def get_slopes(n):
546
+ def get_slopes_power_of_2(n):
547
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
548
+ ratio = start
549
+ return [start * ratio**i for i in range(n)]
550
+
551
+ # In the paper, we only train models that have 2^a heads for some
552
+ # a. This function has some good properties that only occur when
553
+ # the input is a power of 2. To maintain that even when the number
554
+ # of heads is not a power of 2, we use this workaround.
555
+ if math.log2(n).is_integer():
556
+ return get_slopes_power_of_2(n)
557
+ else:
558
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
559
+ return (
560
+ get_slopes_power_of_2(closest_power_of_2)
561
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
562
+ )
563
+
564
+ maxpos = max_positions
565
+ attn_heads = attention_heads
566
+ slopes = torch.Tensor(get_slopes(attn_heads))
567
+
568
+ if dims == 1:
569
+ # prepare alibi position linear bias. Note that wav2vec2 is non
570
+ # autoregressive model so we want a symmetric mask with 0 on the
571
+ # diagonal and other wise linear decreasing valuees
572
+ pos_bias = (
573
+ torch.abs(
574
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
575
+ )
576
+ * -1
577
+ )
578
+ elif dims == 2:
579
+ if distance == "manhattan":
580
+ df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
581
+ elif distance == "euclidean":
582
+ df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
583
+
584
+ n = math.sqrt(max_positions)
585
+ assert n.is_integer(), n
586
+ n = int(n)
587
+
588
+ pos_bias = torch.zeros((max_positions, max_positions))
589
+
590
+ for i in range(n):
591
+ for j in range(n):
592
+ for k in range(n):
593
+ for l in range(n):
594
+ new_x = i * n + j
595
+ new_y = k * n + l
596
+ pos_bias[new_x, new_y] = -df(i, j, k, l)
597
+
598
+ else:
599
+ raise Exception(f"unsupported number of alibi dims: {dims}")
600
+
601
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
602
+ attn_heads, -1, -1
603
+ )
604
+
605
+ return alibi_bias
606
+
607
+
608
+ def get_alibi_bias(
609
+ alibi_biases,
610
+ batch_size,
611
+ time_steps,
612
+ heads,
613
+ dtype,
614
+ device,
615
+ dims=1,
616
+ distance="manhattan",
617
+ ):
618
+ cache_key = f"{dims}_{heads}_{distance}"
619
+
620
+ buffered = alibi_biases.get(cache_key, None)
621
+
622
+ target_size = heads * batch_size
623
+ if (
624
+ buffered is None
625
+ or buffered.size(0) < target_size
626
+ or buffered.size(1) < time_steps
627
+ or buffered.dtype != dtype
628
+ or buffered.device != device
629
+ ):
630
+ bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
631
+ bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
632
+
633
+ buffered = (
634
+ get_alibi(bt, heads, dims=dims, distance=distance)
635
+ .to(dtype=dtype, device=device)
636
+ .repeat(bn, 1, 1)
637
+ )
638
+
639
+ alibi_biases[cache_key] = buffered
640
+
641
+ b = buffered[:target_size, :time_steps, :time_steps]
642
+ b = b.view(batch_size, heads, time_steps, time_steps)
643
+ return b
644
+
645
+
646
+ def _learned_alibi_bias(
647
+ alibi_bias,
648
+ batch_size,
649
+ time_steps,
650
+ heads,
651
+ scale,
652
+ dtype,
653
+ device,
654
+ ):
655
+ assert alibi_bias.size(1) == heads, alibi_bias.shape
656
+ assert alibi_bias.dtype == dtype, alibi_bias.dtype
657
+ assert alibi_bias.device == device, alibi_bias.device
658
+
659
+ if alibi_bias.size(-1) < time_steps:
660
+ psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
661
+ alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
662
+
663
+ alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
664
+ return alibi_bias[..., :time_steps, :time_steps]
665
+
666
+
667
+ def masked_alibi(alibi_bias, mask_info):
668
+ H = alibi_bias.size(1)
669
+
670
+ orig_bias = alibi_bias
671
+
672
+ index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
673
+ alibi_bias = torch.gather(
674
+ orig_bias,
675
+ dim=-2,
676
+ index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
677
+ )
678
+ alibi_bias = torch.gather(
679
+ alibi_bias,
680
+ dim=-1,
681
+ index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
682
+ )
683
+
684
+ return alibi_bias
fairseq/examples/data2vec/models/modalities/images.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from functools import partial
11
+ from dataclasses import dataclass
12
+ from typing import Callable, Dict, Optional
13
+ from timm.models.layers import to_2tuple
14
+ from fairseq.tasks import FairseqTask
15
+ from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
16
+ from .base import (
17
+ D2vModalityConfig,
18
+ ModalitySpecificEncoder,
19
+ get_alibi_bias,
20
+ MaskSeed,
21
+ )
22
+ from .modules import (
23
+ BlockEncoder,
24
+ Decoder2d,
25
+ FixedPositionalEncoder,
26
+ TransformerDecoder,
27
+ EncDecTransformerDecoder,
28
+ )
29
+ from examples.data2vec.data.modality import Modality
30
+
31
+
32
+ @dataclass
33
+ class D2vImageConfig(D2vModalityConfig):
34
+ type: Modality = Modality.IMAGE
35
+
36
+ input_size: int = 224
37
+ in_chans: int = 3
38
+ patch_size: int = 16
39
+ embed_dim: int = 768
40
+
41
+ alibi_dims: int = 2
42
+ alibi_distance: str = "manhattan"
43
+
44
+ fixed_positions: bool = True
45
+
46
+ transformer_decoder: bool = False
47
+ enc_dec_transformer: bool = False
48
+
49
+
50
+ class ImageEncoder(ModalitySpecificEncoder):
51
+
52
+ modality_cfg: D2vImageConfig
53
+
54
+ def __init__(
55
+ self,
56
+ modality_cfg: D2vImageConfig,
57
+ embed_dim: int,
58
+ make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
59
+ norm_layer: Callable[[int], nn.LayerNorm],
60
+ layer_norm_first: bool,
61
+ alibi_biases: Dict,
62
+ task: Optional[FairseqTask],
63
+ ):
64
+
65
+ img_size = to_2tuple(modality_cfg.input_size)
66
+ patch_size = to_2tuple(modality_cfg.patch_size)
67
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
68
+
69
+ local_encoder = PatchEmbed(
70
+ modality_cfg.input_size,
71
+ modality_cfg.patch_size,
72
+ modality_cfg.in_chans,
73
+ modality_cfg.embed_dim,
74
+ )
75
+
76
+ w = local_encoder.proj.weight.data
77
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
78
+
79
+ if modality_cfg.embed_dim != embed_dim:
80
+ local_encoder = nn.Sequential(
81
+ local_encoder,
82
+ nn.Linear(modality_cfg.embed_dim, embed_dim),
83
+ )
84
+
85
+ project_features = nn.Identity()
86
+
87
+ pos_embed = nn.Parameter(
88
+ torch.zeros(1, num_patches, embed_dim), requires_grad=False
89
+ )
90
+
91
+ side_n = int(num_patches ** 0.5)
92
+
93
+ emb = get_2d_sincos_pos_embed(
94
+ pos_embed.shape[-1],
95
+ side_n,
96
+ cls_token=False,
97
+ )
98
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
99
+ fixed_positional_encoder = (
100
+ FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
101
+ )
102
+
103
+ dpr = np.linspace(
104
+ modality_cfg.start_drop_path_rate,
105
+ modality_cfg.end_drop_path_rate,
106
+ modality_cfg.prenet_depth,
107
+ )
108
+
109
+ context_encoder = BlockEncoder(
110
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
111
+ norm_layer(embed_dim) if not layer_norm_first else None,
112
+ layer_norm_first,
113
+ modality_cfg.prenet_layerdrop,
114
+ modality_cfg.prenet_dropout,
115
+ )
116
+
117
+ if modality_cfg.transformer_decoder:
118
+ if modality_cfg.enc_dec_transformer:
119
+ decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
120
+ else:
121
+ dec_enc = BlockEncoder(
122
+ nn.ModuleList(
123
+ make_block(0, modality_cfg.decoder.decoder_dim, 8)
124
+ for _ in range(modality_cfg.decoder.decoder_layers)
125
+ ),
126
+ None,
127
+ layer_norm_first,
128
+ 0,
129
+ 0,
130
+ )
131
+ decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
132
+ else:
133
+ decoder = (
134
+ Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
135
+ if modality_cfg.decoder is not None
136
+ else None
137
+ )
138
+
139
+ alibi_bias_fn = partial(
140
+ get_alibi_bias,
141
+ alibi_biases=alibi_biases,
142
+ heads=modality_cfg.num_alibi_heads,
143
+ dims=modality_cfg.alibi_dims,
144
+ distance=modality_cfg.alibi_distance,
145
+ )
146
+
147
+ super().__init__(
148
+ modality_cfg=modality_cfg,
149
+ embed_dim=embed_dim,
150
+ local_encoder=local_encoder,
151
+ project_features=project_features,
152
+ fixed_positional_encoder=fixed_positional_encoder,
153
+ relative_positional_encoder=None,
154
+ context_encoder=context_encoder,
155
+ decoder=decoder,
156
+ get_alibi_bias=alibi_bias_fn,
157
+ )
158
+
159
+ def reset_parameters(self):
160
+ super().reset_parameters()
161
+ if self.decoder is not None:
162
+ self.decoder.reset_parameters()
163
+
164
+ @torch.no_grad()
165
+ def patchify(self, imgs):
166
+ """
167
+ imgs: (N, 3, H, W)
168
+ x: (N, L, patch_size**2 *3)
169
+ """
170
+ p = self.modality_cfg.patch_size
171
+ h = w = imgs.shape[2] // p
172
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
173
+ x = torch.einsum("nchpwq->nhwpqc", x)
174
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
175
+
176
+ return x
177
+
178
+ @torch.no_grad()
179
+ def unpatchify(self, x):
180
+ """
181
+ x: (N, L, patch_size**2 *3)
182
+ imgs: (N, 3, H, W)
183
+ """
184
+ p = self.modality_cfg.patch_size
185
+ h = w = int(x.shape[1] ** 0.5)
186
+ assert h * w == x.shape[1]
187
+
188
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
189
+ x = torch.einsum("nhwpqc->nchpwq", x)
190
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
191
+ return imgs
192
+
193
+ def compute_mask(
194
+ self,
195
+ x,
196
+ padding_mask,
197
+ mask_seed: Optional[MaskSeed],
198
+ apply,
199
+ shape=None,
200
+ precomputed_mask=None,
201
+ ):
202
+ mlen = self.modality_cfg.mask_length
203
+ if mlen <= 1:
204
+ return super().compute_mask(
205
+ x, padding_mask, mask_seed, apply, precomputed_mask
206
+ )
207
+
208
+ if precomputed_mask is not None:
209
+ mask = precomputed_mask
210
+ else:
211
+ from fairseq.data.data_utils import compute_block_mask_2d
212
+
213
+ if shape is not None:
214
+ B, L, D = shape
215
+ else:
216
+ B, L, D = x.shape
217
+
218
+ mask = compute_block_mask_2d(
219
+ shape=(B, L),
220
+ mask_prob=self.modality_cfg.mask_prob,
221
+ mask_length=self.modality_cfg.mask_length,
222
+ mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
223
+ inverse_mask=self.modality_cfg.inverse_mask,
224
+ require_same_masks=True,
225
+ mask_dropout=self.modality_cfg.mask_dropout,
226
+ )
227
+
228
+ mask_info = self.make_maskinfo(x, mask, shape)
229
+ if apply:
230
+ x = self.apply_mask(x, mask_info)
231
+
232
+ return x, mask_info
233
+
234
+ def decoder_input(self, x, mask_info):
235
+ if (
236
+ not self.modality_cfg.transformer_decoder
237
+ or not self.modality_cfg.enc_dec_transformer
238
+ ):
239
+ return super().decoder_input(x, mask_info)
240
+
241
+ inp_drop = self.modality_cfg.decoder.input_dropout
242
+ if inp_drop > 0:
243
+ x = F.dropout(x, inp_drop, training=self.training, inplace=True)
244
+
245
+ kv = x[:, self.modality_cfg.num_extra_tokens :]
246
+
247
+ assert self.fixed_positional_encoder is not None
248
+ pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
249
+
250
+ mask = mask_info.mask.bool()
251
+ if self.modality_cfg.decoder.add_positions_all:
252
+ kv = kv + pos[~mask].view(kv.shape)
253
+
254
+ q = pos[mask].view(x.size(0), -1, x.size(-1))
255
+
256
+ return q, kv
fairseq/examples/data2vec/models/modalities/modules.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from dataclasses import dataclass
11
+ from fairseq.modules import (
12
+ LayerNorm,
13
+ SamePad,
14
+ SamePad2d,
15
+ TransposeLast,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class D2vDecoderConfig:
21
+ decoder_dim: int = 384
22
+ decoder_groups: int = 16
23
+ decoder_kernel: int = 5
24
+ decoder_layers: int = 5
25
+ input_dropout: float = 0.1
26
+
27
+ add_positions_masked: bool = False
28
+ add_positions_all: bool = False
29
+
30
+ decoder_residual: bool = True
31
+ projection_layers: int = 1
32
+ projection_ratio: float = 2.0
33
+
34
+
35
+ class FixedPositionalEncoder(nn.Module):
36
+ def __init__(self, pos_embed):
37
+ super().__init__()
38
+ self.positions = pos_embed
39
+
40
+ def forward(self, x, padding_mask):
41
+ return self.positions
42
+
43
+
44
+ class TextFeatPositionalEncoder(nn.Module):
45
+ """
46
+ Original encoder expects (B, T) long input. This module wraps it to take
47
+ local_encoder output which are (B, T, D) float tensors
48
+ """
49
+
50
+ def __init__(self, pos_encoder):
51
+ super().__init__()
52
+ self.pos_encoder = pos_encoder
53
+
54
+ def forward(self, x, padding_mask):
55
+ # assume padded token embeddings are 0s
56
+ # TODO: consider using padding_mask as input
57
+ return self.pos_encoder(x[..., 0])
58
+
59
+
60
+ class BlockEncoder(nn.Module):
61
+ def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
62
+ super().__init__()
63
+ self.blocks = blocks
64
+ self.norm = norm_layer
65
+ self.layer_norm_first = layer_norm_first
66
+ self.layerdrop = layerdrop
67
+ self.dropout = nn.Dropout(dropout, inplace=True)
68
+
69
+ def forward(self, x, padding_mask, alibi_bias, alibi_scale):
70
+ if self.norm is not None and not self.layer_norm_first:
71
+ x = self.norm(x)
72
+
73
+ x = self.dropout(x)
74
+
75
+ for i, blk in enumerate(self.blocks):
76
+ if (
77
+ not self.training
78
+ or self.layerdrop == 0
79
+ or (np.random.random() > self.layerdrop)
80
+ ):
81
+ ab = alibi_bias
82
+ if ab is not None and alibi_scale is not None:
83
+ scale = (
84
+ alibi_scale[i]
85
+ if alibi_scale.size(0) > 1
86
+ else alibi_scale.squeeze(0)
87
+ )
88
+ ab = ab * scale.type_as(ab)
89
+ x, _ = blk(x, padding_mask, ab)
90
+
91
+ if self.norm is not None and self.layer_norm_first:
92
+ x = self.norm(x)
93
+
94
+ return x
95
+
96
+
97
+ class DecoderBase(nn.Module):
98
+ decoder_cfg: D2vDecoderConfig
99
+
100
+ def __init__(self, cfg: D2vDecoderConfig):
101
+ super().__init__()
102
+
103
+ self.decoder_cfg = cfg
104
+
105
+ def reset_parameters(self):
106
+ for mod in self.proj.modules():
107
+ if isinstance(mod, nn.Linear):
108
+ mod.reset_parameters()
109
+
110
+ def add_residual(self, x, residual, i, mask_info):
111
+ if (
112
+ residual is None
113
+ or not self.decoder_cfg.decoder_residual
114
+ or residual.size(1) != x.size(1)
115
+ ):
116
+ return x
117
+
118
+ ret = x + residual
119
+
120
+ return ret
121
+
122
+
123
+ class Decoder1d(DecoderBase):
124
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
125
+ super().__init__(cfg)
126
+
127
+ def make_block(in_dim):
128
+ block = [
129
+ nn.Conv1d(
130
+ in_dim,
131
+ cfg.decoder_dim,
132
+ kernel_size=cfg.decoder_kernel,
133
+ padding=cfg.decoder_kernel // 2,
134
+ groups=cfg.decoder_groups,
135
+ ),
136
+ SamePad(cfg.decoder_kernel),
137
+ TransposeLast(),
138
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
139
+ TransposeLast(),
140
+ nn.GELU(),
141
+ ]
142
+
143
+ return nn.Sequential(*block)
144
+
145
+ self.blocks = nn.Sequential(
146
+ *[
147
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
148
+ for i in range(cfg.decoder_layers)
149
+ ]
150
+ )
151
+
152
+ projs = []
153
+ curr_dim = cfg.decoder_dim
154
+ for i in range(cfg.projection_layers - 1):
155
+ next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
156
+ projs.append(nn.Linear(curr_dim, next_dim))
157
+ projs.append(nn.GELU())
158
+ curr_dim = next_dim
159
+ projs.append(nn.Linear(curr_dim, input_dim))
160
+ if len(projs) == 1:
161
+ self.proj = projs[0]
162
+ else:
163
+ self.proj = nn.Sequential(*projs)
164
+
165
+ def forward(self, x, mask_info):
166
+
167
+ x = x.transpose(1, 2)
168
+
169
+ residual = x
170
+
171
+ for i, layer in enumerate(self.blocks):
172
+ x = layer(x)
173
+ x = self.add_residual(x, residual, i, mask_info)
174
+ residual = x
175
+
176
+ x = x.transpose(1, 2)
177
+ x = self.proj(x)
178
+ return x
179
+
180
+
181
+ class Decoder2d(DecoderBase):
182
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size):
183
+ super().__init__(cfg)
184
+
185
+ self.h_size = h_size
186
+ self.w_size = w_size
187
+
188
+ def make_block(in_dim):
189
+ block = [
190
+ nn.Conv2d(
191
+ in_dim,
192
+ cfg.decoder_dim,
193
+ kernel_size=cfg.decoder_kernel,
194
+ padding=cfg.decoder_kernel // 2,
195
+ groups=cfg.decoder_groups,
196
+ ),
197
+ SamePad2d(cfg.decoder_kernel),
198
+ TransposeLast(tranpose_dim=-3),
199
+ LayerNorm(cfg.decoder_dim, elementwise_affine=False),
200
+ TransposeLast(tranpose_dim=-3),
201
+ nn.GELU(),
202
+ ]
203
+
204
+ return nn.Sequential(*block)
205
+
206
+ self.blocks = nn.Sequential(
207
+ *[
208
+ make_block(input_dim if i == 0 else cfg.decoder_dim)
209
+ for i in range(cfg.decoder_layers)
210
+ ]
211
+ )
212
+
213
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
214
+
215
+ def forward(self, x, mask_info):
216
+ B, T, C = x.shape
217
+
218
+ x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size)
219
+
220
+ residual = x
221
+
222
+ for i, layer in enumerate(self.blocks):
223
+ x = layer(x)
224
+ x = self.add_residual(x, residual, i, mask_info)
225
+ residual = x
226
+
227
+ x = x.reshape(B, -1, T).transpose(1, 2)
228
+ x = self.proj(x)
229
+ return x
230
+
231
+
232
+ class TransformerDecoder(nn.Module):
233
+ decoder_cfg: D2vDecoderConfig
234
+
235
+ def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder):
236
+ super().__init__()
237
+
238
+ self.decoder_cfg = cfg
239
+
240
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
241
+
242
+ self.encoder = encoder
243
+
244
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
245
+
246
+ def reset_parameters(self):
247
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
248
+
249
+ self.apply(init_bert_params)
250
+
251
+ def forward(self, x, mask_info):
252
+ x = self.input_proj(x)
253
+ x = self.encoder(x, None, None, 1)
254
+ x = self.proj(x)
255
+ return x
256
+
257
+
258
+ class AltBlock(nn.Module):
259
+ def __init__(
260
+ self,
261
+ dim,
262
+ num_heads,
263
+ mlp_ratio=4.0,
264
+ qkv_bias=False,
265
+ qk_scale=None,
266
+ drop=0.0,
267
+ attn_drop=0.0,
268
+ mlp_drop=0.0,
269
+ post_mlp_drop=0.0,
270
+ drop_path=0.0,
271
+ act_layer=nn.GELU,
272
+ norm_layer=nn.LayerNorm,
273
+ layer_norm_first=True,
274
+ ffn_targets=False,
275
+ cosine_attention=False,
276
+ ):
277
+ super().__init__()
278
+
279
+ self.layer_norm_first = layer_norm_first
280
+ self.ffn_targets = ffn_targets
281
+
282
+ from timm.models.vision_transformer import DropPath, Mlp
283
+
284
+ self.norm1 = norm_layer(dim)
285
+ self.attn = AltAttention(
286
+ dim,
287
+ num_heads=num_heads,
288
+ qkv_bias=qkv_bias,
289
+ qk_scale=qk_scale,
290
+ attn_drop=attn_drop,
291
+ proj_drop=drop,
292
+ cosine_attention=cosine_attention,
293
+ )
294
+
295
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
296
+ self.norm2 = norm_layer(dim)
297
+ mlp_hidden_dim = int(dim * mlp_ratio)
298
+ self.mlp = Mlp(
299
+ in_features=dim,
300
+ hidden_features=mlp_hidden_dim,
301
+ act_layer=act_layer,
302
+ drop=mlp_drop,
303
+ )
304
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
305
+
306
+ def forward(self, x, padding_mask=None, alibi_bias=None):
307
+ if self.layer_norm_first:
308
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
309
+ r = x = self.mlp(self.norm2(x))
310
+ t = x
311
+ x = r + self.drop_path(self.post_mlp_dropout(x))
312
+ if not self.ffn_targets:
313
+ t = x
314
+ else:
315
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
316
+ r = x = self.norm1(x)
317
+ x = self.mlp(x)
318
+ t = x
319
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
320
+ if not self.ffn_targets:
321
+ t = x
322
+
323
+ return x, t
324
+
325
+
326
+ class AltAttention(nn.Module):
327
+ def __init__(
328
+ self,
329
+ dim,
330
+ num_heads=8,
331
+ qkv_bias=False,
332
+ qk_scale=None,
333
+ attn_drop=0.0,
334
+ proj_drop=0.0,
335
+ cosine_attention=False,
336
+ ):
337
+ super().__init__()
338
+ self.num_heads = num_heads
339
+ head_dim = dim // num_heads
340
+ self.scale = qk_scale or head_dim ** -0.5
341
+
342
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
343
+ self.attn_drop = nn.Dropout(attn_drop)
344
+ self.proj = nn.Linear(dim, dim)
345
+ self.proj_drop = nn.Dropout(proj_drop)
346
+
347
+ self.cosine_attention = cosine_attention
348
+
349
+ if cosine_attention:
350
+ self.logit_scale = nn.Parameter(
351
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
352
+ )
353
+
354
+ def forward(self, x, padding_mask=None, alibi_bias=None):
355
+ B, N, C = x.shape
356
+ qkv = (
357
+ self.qkv(x)
358
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
359
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
360
+ )
361
+ q, k, v = (
362
+ qkv[0],
363
+ qkv[1],
364
+ qkv[2],
365
+ ) # make torchscript happy (cannot use tensor as tuple)
366
+
367
+ dtype = q.dtype
368
+
369
+ if self.cosine_attention:
370
+ # cosine attention
371
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
372
+ logit_scale = torch.clamp(
373
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
374
+ ).exp()
375
+ attn = attn * logit_scale
376
+ else:
377
+ q = q * self.scale
378
+ attn = q @ k.transpose(-2, -1)
379
+
380
+ if alibi_bias is not None:
381
+ attn = attn.type_as(alibi_bias)
382
+ attn[:, : alibi_bias.size(1)] += alibi_bias
383
+
384
+ if padding_mask is not None and padding_mask.any():
385
+ attn = attn.masked_fill(
386
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
387
+ float("-inf"),
388
+ )
389
+
390
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
391
+ attn = self.attn_drop(attn)
392
+ x = (attn @ v).transpose(1, 2) #
393
+ x = x.reshape(B, N, C)
394
+ x = self.proj(x)
395
+ x = self.proj_drop(x)
396
+ return x
397
+
398
+
399
+ class EncDecAttention(nn.Module):
400
+ def __init__(
401
+ self,
402
+ q_dim,
403
+ kv_dim,
404
+ num_heads=8,
405
+ qkv_bias=False,
406
+ qk_scale=None,
407
+ attn_drop=0.0,
408
+ proj_drop=0.0,
409
+ cosine_attention=False,
410
+ ):
411
+ super().__init__()
412
+ self.num_heads = num_heads
413
+ head_dim = q_dim // num_heads
414
+ self.scale = qk_scale or head_dim ** -0.5
415
+
416
+ self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias)
417
+ self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias)
418
+ self.attn_drop = nn.Dropout(attn_drop)
419
+ self.proj = nn.Linear(q_dim, q_dim)
420
+ self.proj_drop = nn.Dropout(proj_drop)
421
+
422
+ self.cosine_attention = cosine_attention
423
+
424
+ if cosine_attention:
425
+ self.logit_scale = nn.Parameter(
426
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
427
+ )
428
+
429
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
430
+ B, N, C = q.shape
431
+
432
+ q = (
433
+ self.q_proj(q)
434
+ .reshape(B, N, self.num_heads, C // self.num_heads)
435
+ .permute(0, 2, 1, 3)
436
+ ) # B x H x L x D
437
+ kv = (
438
+ self.kv_proj(kv)
439
+ .reshape(B, -1, 2, self.num_heads, C // self.num_heads)
440
+ .permute(2, 0, 3, 1, 4)
441
+ ) # kv x B x H x L x D
442
+ k, v = (
443
+ kv[0],
444
+ kv[1],
445
+ ) # make torchscript happy (cannot use tensor as tuple)
446
+
447
+ dtype = q.dtype
448
+
449
+ if self.cosine_attention:
450
+ # cosine attention
451
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
452
+ logit_scale = torch.clamp(
453
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
454
+ ).exp()
455
+ attn = attn * logit_scale
456
+ else:
457
+ q = q * self.scale
458
+ attn = q @ k.transpose(-2, -1)
459
+
460
+ if alibi_bias is not None:
461
+ attn = attn.type_as(alibi_bias)
462
+ attn[:, : alibi_bias.size(1)] += alibi_bias
463
+
464
+ if padding_mask is not None and padding_mask.any():
465
+ attn = attn.masked_fill(
466
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
467
+ float("-inf"),
468
+ )
469
+
470
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
471
+ attn = self.attn_drop(attn)
472
+ x = (attn @ v).transpose(1, 2) #
473
+ x = x.reshape(B, N, C)
474
+ x = self.proj(x)
475
+ x = self.proj_drop(x)
476
+ return x
477
+
478
+
479
+ class EncDecBlock(nn.Module):
480
+ def __init__(
481
+ self,
482
+ q_dim,
483
+ kv_dim,
484
+ num_heads,
485
+ mlp_ratio=4.0,
486
+ qkv_bias=False,
487
+ qk_scale=None,
488
+ drop=0.0,
489
+ attn_drop=0.0,
490
+ mlp_drop=0.0,
491
+ post_mlp_drop=0.0,
492
+ drop_path=0.0,
493
+ act_layer=nn.GELU,
494
+ norm_layer=nn.LayerNorm,
495
+ layer_norm_first=True,
496
+ cosine_attention=False,
497
+ first_residual=True,
498
+ ):
499
+ super().__init__()
500
+
501
+ self.layer_norm_first = layer_norm_first
502
+
503
+ from timm.models.vision_transformer import DropPath, Mlp
504
+
505
+ self.norm1 = norm_layer(q_dim)
506
+ self.attn = EncDecAttention(
507
+ q_dim,
508
+ kv_dim,
509
+ num_heads=num_heads,
510
+ qkv_bias=qkv_bias,
511
+ qk_scale=qk_scale,
512
+ attn_drop=attn_drop,
513
+ proj_drop=drop,
514
+ cosine_attention=cosine_attention,
515
+ )
516
+
517
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
518
+ self.norm2 = norm_layer(q_dim)
519
+ mlp_hidden_dim = int(q_dim * mlp_ratio)
520
+ self.mlp = Mlp(
521
+ in_features=q_dim,
522
+ hidden_features=mlp_hidden_dim,
523
+ act_layer=act_layer,
524
+ drop=mlp_drop,
525
+ )
526
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
527
+ self.first_residual = first_residual
528
+
529
+ def forward(self, q, kv, padding_mask=None, alibi_bias=None):
530
+ r = q if self.first_residual else 0
531
+ if self.layer_norm_first:
532
+ x = r + self.drop_path(
533
+ self.attn(self.norm1(q), kv, padding_mask, alibi_bias)
534
+ )
535
+ r = x = self.mlp(self.norm2(x))
536
+ x = r + self.drop_path(self.post_mlp_dropout(x))
537
+ else:
538
+ x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias))
539
+ r = x = self.norm1(x)
540
+ x = self.mlp(x)
541
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
542
+
543
+ return x
544
+
545
+
546
+ class EncDecTransformerDecoder(nn.Module):
547
+ def __init__(self, cfg: D2vDecoderConfig, input_dim):
548
+ super().__init__()
549
+
550
+ self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
551
+
552
+ self.blocks = nn.Sequential(
553
+ *[
554
+ EncDecBlock(
555
+ q_dim=cfg.decoder_dim,
556
+ kv_dim=input_dim,
557
+ num_heads=8,
558
+ mlp_ratio=4.0,
559
+ qkv_bias=True,
560
+ qk_scale=None,
561
+ drop=0.0,
562
+ attn_drop=0.0,
563
+ mlp_drop=0.0,
564
+ post_mlp_drop=0.0,
565
+ drop_path=0.0,
566
+ act_layer=nn.GELU,
567
+ norm_layer=nn.LayerNorm,
568
+ layer_norm_first=False,
569
+ cosine_attention=False,
570
+ first_residual=i > 0,
571
+ )
572
+ for i in range(cfg.decoder_layers)
573
+ ]
574
+ )
575
+
576
+ self.proj = nn.Linear(cfg.decoder_dim, input_dim)
577
+
578
+ def reset_parameters(self):
579
+ from fairseq.modules.transformer_sentence_encoder import init_bert_params
580
+
581
+ self.apply(init_bert_params)
582
+
583
+ def forward(self, x, kv):
584
+ x = self.input_proj(x)
585
+ for i, layer in enumerate(self.blocks):
586
+ x = layer(x, kv)
587
+
588
+ x = self.proj(x)
589
+ return x
fairseq/examples/data2vec/models/modalities/text.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from typing import Callable, Dict, Optional
10
+
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
15
+ from fairseq.tasks import FairseqTask
16
+ from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
17
+ from .modules import BlockEncoder, Decoder1d
18
+ from examples.data2vec.data.modality import Modality
19
+
20
+
21
+ @dataclass
22
+ class D2vTextConfig(D2vModalityConfig):
23
+ type: Modality = Modality.TEXT
24
+ max_source_positions: int = 512
25
+ learned_pos: bool = True
26
+ dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
27
+
28
+ no_scale_embedding: bool = True
29
+ layernorm_embedding: bool = True
30
+ no_token_positional_embeddings: bool = False
31
+
32
+
33
+ class TextEncoder(ModalitySpecificEncoder):
34
+
35
+ modality_cfg: D2vTextConfig
36
+
37
+ def __init__(
38
+ self,
39
+ modality_cfg: D2vTextConfig,
40
+ embed_dim: int,
41
+ make_block: Callable[[float], nn.ModuleList],
42
+ norm_layer: Callable[[int], nn.LayerNorm],
43
+ layer_norm_first: bool,
44
+ alibi_biases: Dict,
45
+ task: Optional[FairseqTask],
46
+ ):
47
+ self.pad_idx = task.source_dictionary.pad()
48
+ self.vocab_size = len(task.source_dictionary)
49
+
50
+ local_encoder = TextLocalEncoder(
51
+ vocab_size=self.vocab_size,
52
+ embed_dim=embed_dim,
53
+ max_source_positions=modality_cfg.max_source_positions,
54
+ pad_idx=self.pad_idx,
55
+ no_scale_embedding=modality_cfg.no_scale_embedding,
56
+ layernorm_embedding=modality_cfg.layernorm_embedding,
57
+ dropout=modality_cfg.dropout,
58
+ no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
59
+ learned_pos=modality_cfg.learned_pos,
60
+ )
61
+ dpr = np.linspace(
62
+ modality_cfg.start_drop_path_rate,
63
+ modality_cfg.end_drop_path_rate,
64
+ modality_cfg.prenet_depth,
65
+ )
66
+ context_encoder = BlockEncoder(
67
+ nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
68
+ norm_layer(embed_dim)
69
+ if not layer_norm_first and modality_cfg.prenet_depth > 0
70
+ else None,
71
+ layer_norm_first,
72
+ modality_cfg.prenet_layerdrop,
73
+ modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
74
+ )
75
+ decoder = (
76
+ Decoder1d(modality_cfg.decoder, embed_dim)
77
+ if modality_cfg.decoder is not None
78
+ else None
79
+ )
80
+
81
+ alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
82
+
83
+ super().__init__(
84
+ modality_cfg=modality_cfg,
85
+ embed_dim=embed_dim,
86
+ local_encoder=local_encoder,
87
+ project_features=nn.Identity(),
88
+ fixed_positional_encoder=None,
89
+ relative_positional_encoder=None,
90
+ context_encoder=context_encoder,
91
+ decoder=decoder,
92
+ get_alibi_bias=alibi_bias_fn,
93
+ )
94
+
95
+ def reset_parameters(self):
96
+ super().reset_parameters()
97
+
98
+ def convert_padding_mask(self, x, padding_mask):
99
+ if padding_mask is None or padding_mask.size(1) == x.size(1):
100
+ return padding_mask
101
+
102
+ diff = self.downsample - padding_mask.size(1) % self.downsample
103
+ if 0 < diff < self.downsample:
104
+ padding_mask = F.pad(padding_mask, (0, diff), value=True)
105
+
106
+ padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
107
+ padding_mask = padding_mask.all(-1)
108
+ if padding_mask.size(1) > x.size(1):
109
+ padding_mask = padding_mask[:, : x.size(1)]
110
+
111
+ assert x.size(1) == padding_mask.size(
112
+ 1
113
+ ), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
114
+
115
+ return padding_mask
116
+
117
+
118
+ class TextLocalEncoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ vocab_size,
122
+ embed_dim,
123
+ max_source_positions,
124
+ pad_idx,
125
+ no_scale_embedding,
126
+ layernorm_embedding,
127
+ dropout,
128
+ no_token_positional_embeddings,
129
+ learned_pos,
130
+ ):
131
+ super().__init__()
132
+ self.pad_idx = pad_idx
133
+ self.dropout_module = FairseqDropout(dropout)
134
+
135
+ self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
136
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
137
+ self.embed_positions = (
138
+ PositionalEmbedding(
139
+ max_source_positions,
140
+ embed_dim,
141
+ pad_idx,
142
+ learned=learned_pos,
143
+ )
144
+ if not no_token_positional_embeddings
145
+ else None
146
+ )
147
+ self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
148
+
149
+ self.layernorm_embedding = None
150
+ if layernorm_embedding:
151
+ self.layernorm_embedding = LayerNorm(embed_dim)
152
+
153
+ def forward(self, src_tokens):
154
+ x = self.embed_scale * self.embed_tokens(src_tokens)
155
+ if self.embed_positions is not None:
156
+ x = x + self.embed_positions(src_tokens)
157
+
158
+ if self.layernorm_embedding is not None:
159
+ x = self.layernorm_embedding(x)
160
+ x = self.dropout_module(x)
161
+ return x
fairseq/examples/data2vec/models/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ def get_alibi(
5
+ max_positions: int,
6
+ attention_heads: int,
7
+ ):
8
+ def get_slopes(n):
9
+ def get_slopes_power_of_2(n):
10
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
11
+ ratio = start
12
+ return [start * ratio ** i for i in range(n)]
13
+
14
+ # In the paper, we only train models that have 2^a heads for some
15
+ # a. This function has some good properties that only occur when
16
+ # the input is a power of 2. To maintain that even when the number
17
+ # of heads is not a power of 2, we use this workaround.
18
+ if math.log2(n).is_integer():
19
+ return get_slopes_power_of_2(n)
20
+ else:
21
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
22
+ return (
23
+ get_slopes_power_of_2(closest_power_of_2)
24
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
25
+ )
26
+
27
+ maxpos = max_positions
28
+ attn_heads = attention_heads
29
+ slopes = torch.Tensor(get_slopes(attn_heads))
30
+ # prepare alibi position linear bias. Note that wav2vec2 is non
31
+ # autoregressive model so we want a symmetric mask with 0 on the
32
+ # diagonal and other wise linear decreasing valuees
33
+ pos_bias = (
34
+ torch.abs(
35
+ torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
36
+ )
37
+ * -1
38
+ )
39
+ alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
40
+ attn_heads, -1, -1
41
+ )
42
+ return alibi_bias
43
+
44
+ def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
45
+ alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
46
+ H = alibi_bias.size(1)
47
+ alibi_mask = mask_indices.unsqueeze(1)
48
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
49
+ alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
50
+ M = alibi_bias.size(-2)
51
+ alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
52
+ alibi_bias = alibi_bias.view(-1, M, M)
53
+ return alibi_bias
54
+
55
+
fairseq/examples/data2vec/scripts/convert_audioset_labels.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import os
9
+
10
+
11
+ def get_parser():
12
+ parser = argparse.ArgumentParser(description="convert audioset labels")
13
+ # fmt: off
14
+ parser.add_argument('in_file', help='audioset csv file to convert')
15
+ parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest')
16
+ parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file')
17
+ parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels')
18
+ # fmt: on
19
+
20
+ return parser
21
+
22
+
23
+ def main():
24
+ parser = get_parser()
25
+ args = parser.parse_args()
26
+
27
+ label_descriptors = {}
28
+ with open(args.descriptors, "r") as ldf:
29
+ next(ldf)
30
+ for line in ldf:
31
+ if line.strip() == "":
32
+ continue
33
+
34
+ items = line.split(",")
35
+ assert len(items) > 2, line
36
+ idx = items[0]
37
+ lbl = items[1]
38
+ assert lbl not in label_descriptors, lbl
39
+ label_descriptors[lbl] = idx
40
+
41
+ labels = {}
42
+ with open(args.in_file, "r") as ifd:
43
+ for line in ifd:
44
+ if line.lstrip().startswith("#"):
45
+ continue
46
+ items = line.rstrip().split(",")
47
+ id = items[0].strip()
48
+ start = items[1].strip()
49
+ end = items[2].strip()
50
+ lbls = [label_descriptors[it.strip(' "')] for it in items[3:]]
51
+ labels[id] = [start, end, ",".join(lbls)]
52
+
53
+ with open(args.manifest, "r") as mf, open(args.output, "w") as of:
54
+ next(mf)
55
+ for line in mf:
56
+ path, _ = line.split("\t")
57
+ id = os.path.splitext(os.path.basename(path))[0]
58
+ lbl = labels[id]
59
+ print("\t".join(lbl), file=of)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ job_id="$1"
6
+ task_id="$2"
7
+ dir="$3"
8
+
9
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
10
+
11
+ mkdir -p "$dir/log"
12
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
13
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
14
+ sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
15
+ sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
16
+
17
+ sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
18
+
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ dir="$1"
6
+
7
+ echo "dir: $dir"
8
+
9
+ mkdir -p "$dir/log"
10
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
11
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
12
+ sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
13
+ sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
14
+
15
+ sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
16
+
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ dir="$1"
4
+ cp="$dir/checkpoints/checkpoint_last.pt"
5
+
6
+ echo "dir: $dir"
7
+
8
+ declare -A tasks
9
+ tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
10
+ tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
11
+ tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
12
+ tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
13
+ tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
14
+ tasks[mnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MNLI-bin"
15
+ tasks[qqp]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QQP-bin"
16
+ tasks[sts_b]="/fsx-wav2vec/abaevski/data/nlp/GLUE/STS-B-bin"
17
+
18
+ lrs=(5e-6 8e-6 1e-5 2e-5)
19
+
20
+ for task data_path in ${(kv)tasks}; do
21
+ for lr in $lrs; do
22
+ echo $lr $task
23
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
24
+ python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/multi/text_finetuning \
25
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
26
+ model.model_path="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model=text_wrap
27
+ done
28
+ done
fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ job_id="$1"
6
+ task_id="$2"
7
+ dir="$3"
8
+
9
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
10
+
11
+ mkdir -p "$dir/log"
12
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
13
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
14
+ sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/ft_%A.out"
15
+ sbatch_args="$sbatch_args -e $dir/log/ft_%A.err"
16
+
17
+ sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_char_fair_local_lr.sh $dir
fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ job_id=$1
4
+ task_id=$2
5
+ dir="$3"
6
+ cp="$dir/$task_id/checkpoints/checkpoint_last.pt"
7
+
8
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
9
+
10
+ declare -A tasks
11
+ tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
12
+ tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
13
+ tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
14
+ tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
15
+ tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
16
+
17
+ for task data_path in ${(kv)tasks}; do
18
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
19
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
20
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
21
+ done
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ job_id=$1
4
+ task_id=$2
5
+ dir="$3"
6
+ cp="$dir/checkpoints/checkpoint_last.pt"
7
+
8
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
9
+
10
+ declare -A tasks
11
+ tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
12
+ tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
13
+ tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
14
+ tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
15
+ tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
16
+
17
+ for task data_path in ${(kv)tasks}; do
18
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
19
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
20
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
21
+ done
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ set -eu
4
+
5
+ job_id="$1"
6
+ task_id="$2"
7
+ dir="$3"
8
+
9
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
10
+
11
+ mkdir -p "$dir/log"
12
+ sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
13
+ sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
14
+ sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
15
+ sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
16
+
17
+ sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ job_id=$1
4
+ task_id=$2
5
+ dir="$3"
6
+ cp="$dir/checkpoints/checkpoint_last.pt"
7
+
8
+ echo "job_id: $job_id, task_id: $task_id, dir: $dir"
9
+
10
+ declare -A tasks
11
+ tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
12
+ tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
13
+ tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
14
+ tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
15
+ tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
16
+
17
+ for task data_path in ${(kv)tasks}; do
18
+ for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
19
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
20
+ --config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
21
+ checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
22
+ done
23
+ done
fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env zsh
2
+
3
+ dir="$1"
4
+ cp="$dir/checkpoints/checkpoint_last.pt"
5
+
6
+ echo "dir: $dir"
7
+
8
+ declare -A tasks
9
+ tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
10
+ tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
11
+ tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
12
+ tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
13
+ tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
14
+
15
+ lrs=(5e-6 8e-6 1e-5 2e-5)
16
+
17
+ for task data_path in ${(kv)tasks}; do
18
+ for lr in $lrs; do
19
+ echo $lr $task
20
+ PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
21
+ python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
22
+ --config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
23
+ checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]"
24
+ done
25
+ done