Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml +38 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml +38 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml +36 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml +64 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml +15 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml +37 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml +37 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml +36 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml +36 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml +36 -0
- fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml +36 -0
- fairseq/examples/data2vec/data/__init__.py +17 -0
- fairseq/examples/data2vec/data/add_class_target_dataset.py +63 -0
- fairseq/examples/data2vec/data/image_dataset.py +127 -0
- fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py +135 -0
- fairseq/examples/data2vec/data/mae_image_dataset.py +418 -0
- fairseq/examples/data2vec/data/modality.py +14 -0
- fairseq/examples/data2vec/data/path_dataset.py +64 -0
- fairseq/examples/data2vec/models/__init__.py +0 -0
- fairseq/examples/data2vec/models/audio_classification.py +614 -0
- fairseq/examples/data2vec/models/data2vec2.py +813 -0
- fairseq/examples/data2vec/models/data2vec_audio.py +537 -0
- fairseq/examples/data2vec/models/data2vec_image_classification.py +143 -0
- fairseq/examples/data2vec/models/data2vec_text.py +517 -0
- fairseq/examples/data2vec/models/data2vec_text_classification.py +141 -0
- fairseq/examples/data2vec/models/data2vec_vision.py +727 -0
- fairseq/examples/data2vec/models/mae.py +829 -0
- fairseq/examples/data2vec/models/mae_image_classification.py +386 -0
- fairseq/examples/data2vec/models/modalities/__init__.py +0 -0
- fairseq/examples/data2vec/models/modalities/audio.py +192 -0
- fairseq/examples/data2vec/models/modalities/base.py +684 -0
- fairseq/examples/data2vec/models/modalities/images.py +256 -0
- fairseq/examples/data2vec/models/modalities/modules.py +589 -0
- fairseq/examples/data2vec/models/modalities/text.py +161 -0
- fairseq/examples/data2vec/models/utils.py +55 -0
- fairseq/examples/data2vec/scripts/convert_audioset_labels.py +63 -0
- fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh +18 -0
- fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh +16 -0
- fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh +28 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh +17 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh +21 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh +21 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh +17 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh +23 -0
- fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh +25 -0
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_1_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 80
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 1
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 1
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
- task.local_cache_path
|
24 |
+
sweep:
|
25 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
26 |
+
subdir: ''
|
27 |
+
launcher:
|
28 |
+
submitit_folder: ${hydra.sweep.dir}
|
29 |
+
timeout_min: 4320
|
30 |
+
cpus_per_task: 10
|
31 |
+
gpus_per_node: 8
|
32 |
+
tasks_per_node: 8
|
33 |
+
mem_gb: 450
|
34 |
+
nodes: 2
|
35 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
36 |
+
partition: devlab,learnlab,learnfair,scavenge
|
37 |
+
constraint: volta32gb,ib4
|
38 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_2_aws.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
- task.local_cache_path
|
24 |
+
- model.model_path
|
25 |
+
sweep:
|
26 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
27 |
+
subdir: ''
|
28 |
+
launcher:
|
29 |
+
submitit_folder: ${hydra.sweep.dir}
|
30 |
+
timeout_min: 4320
|
31 |
+
cpus_per_task: 10
|
32 |
+
gpus_per_node: 8
|
33 |
+
tasks_per_node: 8
|
34 |
+
mem_gb: 0
|
35 |
+
nodes: 2
|
36 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
37 |
+
partition: wav2vec,learnlab,learnfair
|
38 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_3.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
sweep:
|
23 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
24 |
+
subdir: ''
|
25 |
+
launcher:
|
26 |
+
submitit_folder: ${hydra.sweep.dir}
|
27 |
+
timeout_min: 4320
|
28 |
+
cpus_per_task: 80
|
29 |
+
gpus_per_node: 8
|
30 |
+
tasks_per_node: 1
|
31 |
+
mem_gb: 450
|
32 |
+
nodes: 3
|
33 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
34 |
+
partition: devlab,learnlab,learnfair,scavenge
|
35 |
+
constraint: volta32gb,ib4
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_4_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 4
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_6_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 6
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/finetuning/run_config/slurm_8_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 8
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/base_imagenet_d2v1.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _group_
|
2 |
+
|
3 |
+
common:
|
4 |
+
fp16: true
|
5 |
+
log_format: json
|
6 |
+
log_interval: 200
|
7 |
+
tensorboard_logdir: tb
|
8 |
+
|
9 |
+
checkpoint:
|
10 |
+
save_interval: 5
|
11 |
+
save_interval_updates: 25000
|
12 |
+
keep_interval_updates: 1
|
13 |
+
no_epoch_checkpoints: true
|
14 |
+
|
15 |
+
task:
|
16 |
+
_name: image_pretraining
|
17 |
+
data: /datasets01/imagenet_full_size/061417
|
18 |
+
|
19 |
+
dataset:
|
20 |
+
num_workers: 6
|
21 |
+
batch_size: 128
|
22 |
+
skip_invalid_size_inputs_valid_test: true
|
23 |
+
required_batch_size_multiple: 2
|
24 |
+
disable_validation: true
|
25 |
+
|
26 |
+
distributed_training:
|
27 |
+
distributed_world_size: 16
|
28 |
+
ddp_backend: legacy_ddp
|
29 |
+
|
30 |
+
criterion:
|
31 |
+
_name: model
|
32 |
+
log_keys:
|
33 |
+
- ema_decay
|
34 |
+
- target_var
|
35 |
+
- pred_var
|
36 |
+
|
37 |
+
optimization:
|
38 |
+
max_update: 375300 #300*1251
|
39 |
+
lr: [0.0005]
|
40 |
+
clip_norm: 3.0
|
41 |
+
|
42 |
+
optimizer:
|
43 |
+
_name: adam
|
44 |
+
adam_betas: (0.9,0.999)
|
45 |
+
adam_eps: 1e-08
|
46 |
+
weight_decay: 0.05
|
47 |
+
|
48 |
+
lr_scheduler:
|
49 |
+
_name: cosine
|
50 |
+
warmup_updates: 12510 # it should be 10 epochs
|
51 |
+
|
52 |
+
model:
|
53 |
+
_name: data2vec_vision
|
54 |
+
|
55 |
+
attention_dropout: 0.05
|
56 |
+
|
57 |
+
ema_decay: 0.999
|
58 |
+
ema_end_decay: 0.9998
|
59 |
+
layer_norm_targets: True
|
60 |
+
average_top_k_layers: 6
|
61 |
+
|
62 |
+
loss_beta: 2.0
|
63 |
+
|
64 |
+
drop_path: 0.25
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/local.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
hydra:
|
3 |
+
sweep:
|
4 |
+
dir: ${env:PWD}/tmp_dbg/${now:%H-%M-%S}
|
5 |
+
|
6 |
+
distributed_training:
|
7 |
+
distributed_world_size: 1
|
8 |
+
nprocs_per_node: 1
|
9 |
+
distributed_port: -1
|
10 |
+
|
11 |
+
common:
|
12 |
+
log_interval: 1
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
num_workers: 0
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_1.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 80
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 1
|
32 |
+
mem_gb: 450
|
33 |
+
nodes: 1
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: devlab,learnlab,learnfair,scavenge
|
36 |
+
constraint: volta32gb,ib4
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_2_aws.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
- task.local_cache_path
|
24 |
+
sweep:
|
25 |
+
dir: /fsx-wav2vec/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
26 |
+
subdir: ''
|
27 |
+
launcher:
|
28 |
+
submitit_folder: ${hydra.sweep.dir}
|
29 |
+
timeout_min: 4320
|
30 |
+
cpus_per_task: 10
|
31 |
+
gpus_per_node: 8
|
32 |
+
tasks_per_node: 8
|
33 |
+
mem_gb: 0
|
34 |
+
nodes: 2
|
35 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
36 |
+
partition: wav2vec,learnlab,learnfair
|
37 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_3.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
sweep:
|
23 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
24 |
+
subdir: ''
|
25 |
+
launcher:
|
26 |
+
submitit_folder: ${hydra.sweep.dir}
|
27 |
+
timeout_min: 4320
|
28 |
+
cpus_per_task: 80
|
29 |
+
gpus_per_node: 8
|
30 |
+
tasks_per_node: 1
|
31 |
+
mem_gb: 450
|
32 |
+
nodes: 3
|
33 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
34 |
+
partition: devlab,learnlab,learnfair,scavenge
|
35 |
+
constraint: volta32gb,ib4
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
sweep:
|
23 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
24 |
+
subdir: ''
|
25 |
+
launcher:
|
26 |
+
submitit_folder: ${hydra.sweep.dir}
|
27 |
+
timeout_min: 4320
|
28 |
+
cpus_per_task: 10
|
29 |
+
gpus_per_node: 8
|
30 |
+
tasks_per_node: 8
|
31 |
+
mem_gb: 450
|
32 |
+
nodes: 4
|
33 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
34 |
+
partition: devlab,learnlab,learnfair,scavenge
|
35 |
+
constraint: volta32gb,ib4
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_4_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 4
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_6_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 6
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/config/vision/pretraining/run_config/slurm_8_aws.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
hydra:
|
4 |
+
job:
|
5 |
+
config:
|
6 |
+
override_dirname:
|
7 |
+
kv_sep: ':'
|
8 |
+
item_sep: '/'
|
9 |
+
exclude_keys:
|
10 |
+
- run_config
|
11 |
+
- distributed_training.distributed_port
|
12 |
+
- distributed_training.distributed_world_size
|
13 |
+
- model.pretrained_model_path
|
14 |
+
- model.target_network_path
|
15 |
+
- next_script
|
16 |
+
- task.cache_in_scratch
|
17 |
+
- task.data
|
18 |
+
- checkpoint.save_interval_updates
|
19 |
+
- checkpoint.keep_interval_updates
|
20 |
+
- checkpoint.save_on_overflow
|
21 |
+
- common.log_interval
|
22 |
+
- common.user_dir
|
23 |
+
sweep:
|
24 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
|
25 |
+
subdir: ''
|
26 |
+
launcher:
|
27 |
+
submitit_folder: ${hydra.sweep.dir}
|
28 |
+
timeout_min: 4320
|
29 |
+
cpus_per_task: 10
|
30 |
+
gpus_per_node: 8
|
31 |
+
tasks_per_node: 8
|
32 |
+
mem_gb: 0
|
33 |
+
nodes: 8
|
34 |
+
name: ${env:PREFIX}_${hydra.job.config_name}
|
35 |
+
partition: wav2vec,learnlab,learnfair
|
36 |
+
max_num_timeout: 30
|
fairseq/examples/data2vec/data/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .image_dataset import ImageDataset
|
7 |
+
from .path_dataset import PathDataset
|
8 |
+
from .mae_image_dataset import MaeImageDataset
|
9 |
+
from .mae_finetuning_image_dataset import MaeFinetuningImageDataset
|
10 |
+
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
"ImageDataset",
|
14 |
+
"MaeImageDataset",
|
15 |
+
"MaeFinetuningImageDataset",
|
16 |
+
"PathDataset",
|
17 |
+
]
|
fairseq/examples/data2vec/data/add_class_target_dataset.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from fairseq.data import BaseWrapperDataset, data_utils
|
9 |
+
|
10 |
+
|
11 |
+
class AddClassTargetDataset(BaseWrapperDataset):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
dataset,
|
15 |
+
labels,
|
16 |
+
multi_class,
|
17 |
+
num_classes=None,
|
18 |
+
label_indices=None,
|
19 |
+
add_to_input=True,
|
20 |
+
):
|
21 |
+
super().__init__(dataset)
|
22 |
+
|
23 |
+
self.label_indices = label_indices
|
24 |
+
self.labels = labels
|
25 |
+
self.multi_class = multi_class
|
26 |
+
self.add_to_input = add_to_input
|
27 |
+
if num_classes is None and multi_class:
|
28 |
+
assert self.label_indices is not None
|
29 |
+
num_classes = len(self.label_indices)
|
30 |
+
|
31 |
+
self.num_classes = num_classes
|
32 |
+
|
33 |
+
def __getitem__(self, index):
|
34 |
+
item = self.dataset[index]
|
35 |
+
item_labels = self.labels[index]
|
36 |
+
if self.multi_class:
|
37 |
+
item["label"] = torch.zeros(self.num_classes)
|
38 |
+
for il in item_labels:
|
39 |
+
if self.label_indices is not None:
|
40 |
+
il = self.label_indices[il]
|
41 |
+
item["label"][il] = 1.0
|
42 |
+
else:
|
43 |
+
item["label"] = torch.tensor(
|
44 |
+
self.labels[index]
|
45 |
+
if self.label_indices is None
|
46 |
+
else self.label_indices[self.labels[index]]
|
47 |
+
)
|
48 |
+
|
49 |
+
return item
|
50 |
+
|
51 |
+
def collater(self, samples):
|
52 |
+
collated = self.dataset.collater(samples)
|
53 |
+
if len(collated) == 0:
|
54 |
+
return collated
|
55 |
+
|
56 |
+
indices = set(collated["id"].tolist())
|
57 |
+
target = [s["label"] for s in samples if s["id"] in indices]
|
58 |
+
collated["label"] = torch.stack(target, dim=0)
|
59 |
+
|
60 |
+
if self.add_to_input:
|
61 |
+
collated["net_input"]["label"] = collated["label"]
|
62 |
+
|
63 |
+
return collated
|
fairseq/examples/data2vec/data/image_dataset.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
from typing import Optional, Callable, Set
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
from torchvision.datasets.vision import VisionDataset
|
16 |
+
from torchvision.transforms import ToTensor
|
17 |
+
|
18 |
+
from fairseq.data import FairseqDataset
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
class ImageDataset(FairseqDataset, VisionDataset):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
root: str,
|
28 |
+
extensions: Set[str],
|
29 |
+
load_classes: bool,
|
30 |
+
transform: Optional[Callable] = None,
|
31 |
+
shuffle=True,
|
32 |
+
):
|
33 |
+
FairseqDataset.__init__(self)
|
34 |
+
VisionDataset.__init__(self, root=root, transform=transform)
|
35 |
+
|
36 |
+
self.shuffle = shuffle
|
37 |
+
self.tensor_transform = ToTensor()
|
38 |
+
|
39 |
+
self.classes = None
|
40 |
+
self.labels = None
|
41 |
+
if load_classes:
|
42 |
+
classes = [d.name for d in os.scandir(root) if d.is_dir()]
|
43 |
+
classes.sort()
|
44 |
+
self.classes = {cls_name: i for i, cls_name in enumerate(classes)}
|
45 |
+
logger.info(f"loaded {len(self.classes)} classes")
|
46 |
+
self.labels = []
|
47 |
+
|
48 |
+
def walk_path(root_path):
|
49 |
+
for root, _, fnames in sorted(os.walk(root_path, followlinks=True)):
|
50 |
+
for fname in sorted(fnames):
|
51 |
+
fname_ext = os.path.splitext(fname)
|
52 |
+
if fname_ext[-1].lower() not in extensions:
|
53 |
+
continue
|
54 |
+
|
55 |
+
path = os.path.join(root, fname)
|
56 |
+
yield path
|
57 |
+
|
58 |
+
logger.info(f"finding images in {root}")
|
59 |
+
if self.classes is not None:
|
60 |
+
self.files = []
|
61 |
+
self.labels = []
|
62 |
+
for c, i in self.classes.items():
|
63 |
+
for f in walk_path(os.path.join(root, c)):
|
64 |
+
self.files.append(f)
|
65 |
+
self.labels.append(i)
|
66 |
+
else:
|
67 |
+
self.files = [f for f in walk_path(root)]
|
68 |
+
|
69 |
+
logger.info(f"loaded {len(self.files)} examples")
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
from PIL import Image
|
73 |
+
|
74 |
+
fpath = self.files[index]
|
75 |
+
|
76 |
+
with open(fpath, "rb") as f:
|
77 |
+
img = Image.open(f).convert("RGB")
|
78 |
+
|
79 |
+
if self.transform is None:
|
80 |
+
img = self.tensor_transform(img)
|
81 |
+
else:
|
82 |
+
img = self.transform(img)
|
83 |
+
assert torch.is_tensor(img)
|
84 |
+
|
85 |
+
res = {"id": index, "img": img}
|
86 |
+
|
87 |
+
if self.labels is not None:
|
88 |
+
res["label"] = self.labels[index]
|
89 |
+
|
90 |
+
return res
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return len(self.files)
|
94 |
+
|
95 |
+
def collater(self, samples):
|
96 |
+
if len(samples) == 0:
|
97 |
+
return {}
|
98 |
+
|
99 |
+
collated_img = torch.stack([s["img"] for s in samples], dim=0)
|
100 |
+
|
101 |
+
res = {
|
102 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
103 |
+
"net_input": {
|
104 |
+
"img": collated_img,
|
105 |
+
},
|
106 |
+
}
|
107 |
+
|
108 |
+
if "label" in samples[0]:
|
109 |
+
res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples])
|
110 |
+
|
111 |
+
return res
|
112 |
+
|
113 |
+
def num_tokens(self, index):
|
114 |
+
return 1
|
115 |
+
|
116 |
+
def size(self, index):
|
117 |
+
return 1
|
118 |
+
|
119 |
+
def ordered_indices(self):
|
120 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
121 |
+
on this order."""
|
122 |
+
if self.shuffle:
|
123 |
+
order = [np.random.permutation(len(self))]
|
124 |
+
else:
|
125 |
+
order = [np.arange(len(self))]
|
126 |
+
|
127 |
+
return order[0]
|
fairseq/examples/data2vec/data/mae_finetuning_image_dataset.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from torchvision import datasets, transforms
|
15 |
+
|
16 |
+
from timm.data import create_transform
|
17 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
18 |
+
import PIL
|
19 |
+
|
20 |
+
from fairseq.data import FairseqDataset
|
21 |
+
from .mae_image_dataset import caching_loader
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def build_transform(is_train, input_size, color_jitter, aa, reprob, remode, recount):
|
28 |
+
mean = IMAGENET_DEFAULT_MEAN
|
29 |
+
std = IMAGENET_DEFAULT_STD
|
30 |
+
# train transform
|
31 |
+
if is_train:
|
32 |
+
# this should always dispatch to transforms_imagenet_train
|
33 |
+
transform = create_transform(
|
34 |
+
input_size=input_size,
|
35 |
+
is_training=True,
|
36 |
+
color_jitter=color_jitter,
|
37 |
+
auto_augment=aa,
|
38 |
+
interpolation="bicubic",
|
39 |
+
re_prob=reprob,
|
40 |
+
re_mode=remode,
|
41 |
+
re_count=recount,
|
42 |
+
mean=mean,
|
43 |
+
std=std,
|
44 |
+
)
|
45 |
+
return transform
|
46 |
+
|
47 |
+
# eval transform
|
48 |
+
t = []
|
49 |
+
if input_size <= 224:
|
50 |
+
crop_pct = 224 / 256
|
51 |
+
else:
|
52 |
+
crop_pct = 1.0
|
53 |
+
size = int(input_size / crop_pct)
|
54 |
+
t.append(
|
55 |
+
transforms.Resize(
|
56 |
+
size, interpolation=PIL.Image.BICUBIC
|
57 |
+
), # to maintain same ratio w.r.t. 224 images
|
58 |
+
)
|
59 |
+
t.append(transforms.CenterCrop(input_size))
|
60 |
+
|
61 |
+
t.append(transforms.ToTensor())
|
62 |
+
t.append(transforms.Normalize(mean, std))
|
63 |
+
return transforms.Compose(t)
|
64 |
+
|
65 |
+
|
66 |
+
class MaeFinetuningImageDataset(FairseqDataset):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
root: str,
|
70 |
+
split: str,
|
71 |
+
is_train: bool,
|
72 |
+
input_size,
|
73 |
+
color_jitter=None,
|
74 |
+
aa="rand-m9-mstd0.5-inc1",
|
75 |
+
reprob=0.25,
|
76 |
+
remode="pixel",
|
77 |
+
recount=1,
|
78 |
+
local_cache_path=None,
|
79 |
+
shuffle=True,
|
80 |
+
):
|
81 |
+
FairseqDataset.__init__(self)
|
82 |
+
|
83 |
+
self.shuffle = shuffle
|
84 |
+
|
85 |
+
transform = build_transform(
|
86 |
+
is_train, input_size, color_jitter, aa, reprob, remode, recount
|
87 |
+
)
|
88 |
+
|
89 |
+
path = os.path.join(root, split)
|
90 |
+
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
|
91 |
+
|
92 |
+
self.dataset = datasets.ImageFolder(path, loader=loader, transform=transform)
|
93 |
+
|
94 |
+
logger.info(f"loaded {len(self.dataset)} examples")
|
95 |
+
|
96 |
+
def __getitem__(self, index):
|
97 |
+
img, label = self.dataset[index]
|
98 |
+
return {"id": index, "img": img, "label": label}
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.dataset)
|
102 |
+
|
103 |
+
def collater(self, samples):
|
104 |
+
if len(samples) == 0:
|
105 |
+
return {}
|
106 |
+
|
107 |
+
collated_img = torch.stack([s["img"] for s in samples], dim=0)
|
108 |
+
|
109 |
+
res = {
|
110 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
111 |
+
"net_input": {
|
112 |
+
"imgs": collated_img,
|
113 |
+
},
|
114 |
+
}
|
115 |
+
|
116 |
+
if "label" in samples[0]:
|
117 |
+
res["net_input"]["labels"] = torch.LongTensor([s["label"] for s in samples])
|
118 |
+
|
119 |
+
return res
|
120 |
+
|
121 |
+
def num_tokens(self, index):
|
122 |
+
return 1
|
123 |
+
|
124 |
+
def size(self, index):
|
125 |
+
return 1
|
126 |
+
|
127 |
+
def ordered_indices(self):
|
128 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
129 |
+
on this order."""
|
130 |
+
if self.shuffle:
|
131 |
+
order = [np.random.permutation(len(self))]
|
132 |
+
else:
|
133 |
+
order = [np.arange(len(self))]
|
134 |
+
|
135 |
+
return order[0]
|
fairseq/examples/data2vec/data/mae_image_dataset.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
from functools import partial
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import os
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from torchvision import datasets, transforms
|
19 |
+
from .path_dataset import PathDataset
|
20 |
+
|
21 |
+
from fairseq.data import FairseqDataset
|
22 |
+
from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d
|
23 |
+
|
24 |
+
from shutil import copyfile
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
def load(path, loader, cache):
|
30 |
+
if hasattr(caching_loader, "cache_root"):
|
31 |
+
cache = caching_loader.cache_root
|
32 |
+
|
33 |
+
cached_path = cache + path
|
34 |
+
|
35 |
+
num_tries = 3
|
36 |
+
for curr_try in range(num_tries):
|
37 |
+
try:
|
38 |
+
if curr_try == 2:
|
39 |
+
return loader(path)
|
40 |
+
if not os.path.exists(cached_path) or curr_try > 0:
|
41 |
+
os.makedirs(os.path.dirname(cached_path), exist_ok=True)
|
42 |
+
copyfile(path, cached_path)
|
43 |
+
os.chmod(cached_path, 0o777)
|
44 |
+
return loader(cached_path)
|
45 |
+
except Exception as e:
|
46 |
+
logger.warning(str(e))
|
47 |
+
if "Errno 13" in str(e):
|
48 |
+
caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}"
|
49 |
+
logger.warning(f"setting cache root to {caching_loader.cache_root}")
|
50 |
+
cached_path = caching_loader.cache_root + path
|
51 |
+
if curr_try == (num_tries - 1):
|
52 |
+
raise
|
53 |
+
time.sleep(2)
|
54 |
+
|
55 |
+
|
56 |
+
def caching_loader(cache_root: str, loader):
|
57 |
+
if cache_root is None:
|
58 |
+
return loader
|
59 |
+
|
60 |
+
if cache_root == "slurm_tmpdir":
|
61 |
+
cache_root = os.environ["SLURM_TMPDIR"]
|
62 |
+
assert len(cache_root) > 0
|
63 |
+
|
64 |
+
if not cache_root.endswith("/"):
|
65 |
+
cache_root += "/"
|
66 |
+
|
67 |
+
return partial(load, loader=loader, cache=cache_root)
|
68 |
+
|
69 |
+
|
70 |
+
class RandomResizedCropAndInterpolationWithTwoPic:
|
71 |
+
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
72 |
+
|
73 |
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
74 |
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
75 |
+
is finally resized to given size.
|
76 |
+
This is popularly used to train the Inception networks.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
size: expected output size of each edge
|
80 |
+
scale: range of size of the origin size cropped
|
81 |
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
82 |
+
interpolation: Default: PIL.Image.BILINEAR
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
size,
|
88 |
+
second_size=None,
|
89 |
+
scale=(0.08, 1.0),
|
90 |
+
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
91 |
+
interpolation="bilinear",
|
92 |
+
second_interpolation="lanczos",
|
93 |
+
):
|
94 |
+
if isinstance(size, tuple):
|
95 |
+
self.size = size
|
96 |
+
else:
|
97 |
+
self.size = (size, size)
|
98 |
+
if second_size is not None:
|
99 |
+
if isinstance(second_size, tuple):
|
100 |
+
self.second_size = second_size
|
101 |
+
else:
|
102 |
+
self.second_size = (second_size, second_size)
|
103 |
+
else:
|
104 |
+
self.second_size = None
|
105 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
106 |
+
logger.warning("range should be of kind (min, max)")
|
107 |
+
|
108 |
+
if interpolation == "random":
|
109 |
+
from PIL import Image
|
110 |
+
|
111 |
+
self.interpolation = (Image.BILINEAR, Image.BICUBIC)
|
112 |
+
else:
|
113 |
+
self.interpolation = self._pil_interp(interpolation)
|
114 |
+
|
115 |
+
self.second_interpolation = (
|
116 |
+
self._pil_interp(second_interpolation)
|
117 |
+
if second_interpolation is not None
|
118 |
+
else None
|
119 |
+
)
|
120 |
+
self.scale = scale
|
121 |
+
self.ratio = ratio
|
122 |
+
|
123 |
+
def _pil_interp(self, method):
|
124 |
+
from PIL import Image
|
125 |
+
|
126 |
+
if method == "bicubic":
|
127 |
+
return Image.BICUBIC
|
128 |
+
elif method == "lanczos":
|
129 |
+
return Image.LANCZOS
|
130 |
+
elif method == "hamming":
|
131 |
+
return Image.HAMMING
|
132 |
+
else:
|
133 |
+
# default bilinear, do we want to allow nearest?
|
134 |
+
return Image.BILINEAR
|
135 |
+
|
136 |
+
@staticmethod
|
137 |
+
def get_params(img, scale, ratio):
|
138 |
+
"""Get parameters for ``crop`` for a random sized crop.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
img (PIL Image): Image to be cropped.
|
142 |
+
scale (tuple): range of size of the origin size cropped
|
143 |
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
147 |
+
sized crop.
|
148 |
+
"""
|
149 |
+
area = img.size[0] * img.size[1]
|
150 |
+
|
151 |
+
for attempt in range(10):
|
152 |
+
target_area = random.uniform(*scale) * area
|
153 |
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
154 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
155 |
+
|
156 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
157 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
158 |
+
|
159 |
+
if w <= img.size[0] and h <= img.size[1]:
|
160 |
+
i = random.randint(0, img.size[1] - h)
|
161 |
+
j = random.randint(0, img.size[0] - w)
|
162 |
+
return i, j, h, w
|
163 |
+
|
164 |
+
# Fallback to central crop
|
165 |
+
in_ratio = img.size[0] / img.size[1]
|
166 |
+
if in_ratio < min(ratio):
|
167 |
+
w = img.size[0]
|
168 |
+
h = int(round(w / min(ratio)))
|
169 |
+
elif in_ratio > max(ratio):
|
170 |
+
h = img.size[1]
|
171 |
+
w = int(round(h * max(ratio)))
|
172 |
+
else: # whole image
|
173 |
+
w = img.size[0]
|
174 |
+
h = img.size[1]
|
175 |
+
i = (img.size[1] - h) // 2
|
176 |
+
j = (img.size[0] - w) // 2
|
177 |
+
return i, j, h, w
|
178 |
+
|
179 |
+
def __call__(self, img):
|
180 |
+
import torchvision.transforms.functional as F
|
181 |
+
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
img (PIL Image): Image to be cropped and resized.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
PIL Image: Randomly cropped and resized image.
|
188 |
+
"""
|
189 |
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
190 |
+
if isinstance(self.interpolation, (tuple, list)):
|
191 |
+
interpolation = random.choice(self.interpolation)
|
192 |
+
else:
|
193 |
+
interpolation = self.interpolation
|
194 |
+
if self.second_size is None:
|
195 |
+
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
196 |
+
else:
|
197 |
+
return F.resized_crop(
|
198 |
+
img, i, j, h, w, self.size, interpolation
|
199 |
+
), F.resized_crop(
|
200 |
+
img, i, j, h, w, self.second_size, self.second_interpolation
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
class MaeImageDataset(FairseqDataset):
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
root: str,
|
208 |
+
split: str,
|
209 |
+
input_size,
|
210 |
+
local_cache_path=None,
|
211 |
+
shuffle=True,
|
212 |
+
key="imgs",
|
213 |
+
beit_transforms=False,
|
214 |
+
target_transform=False,
|
215 |
+
no_transform=False,
|
216 |
+
compute_mask=False,
|
217 |
+
patch_size: int = 16,
|
218 |
+
mask_prob: float = 0.75,
|
219 |
+
mask_prob_adjust: float = 0,
|
220 |
+
mask_length: int = 1,
|
221 |
+
inverse_mask: bool = False,
|
222 |
+
expand_adjacent: bool = False,
|
223 |
+
mask_dropout: float = 0,
|
224 |
+
non_overlapping: bool = False,
|
225 |
+
require_same_masks: bool = True,
|
226 |
+
clone_batch: int = 1,
|
227 |
+
dataset_type: str = "imagefolder",
|
228 |
+
):
|
229 |
+
FairseqDataset.__init__(self)
|
230 |
+
|
231 |
+
self.shuffle = shuffle
|
232 |
+
self.key = key
|
233 |
+
|
234 |
+
loader = caching_loader(local_cache_path, datasets.folder.default_loader)
|
235 |
+
|
236 |
+
self.transform_source = None
|
237 |
+
self.transform_target = None
|
238 |
+
|
239 |
+
if target_transform:
|
240 |
+
self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4)
|
241 |
+
self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4)
|
242 |
+
|
243 |
+
if no_transform:
|
244 |
+
if input_size <= 224:
|
245 |
+
crop_pct = 224 / 256
|
246 |
+
else:
|
247 |
+
crop_pct = 1.0
|
248 |
+
size = int(input_size / crop_pct)
|
249 |
+
|
250 |
+
self.transform_train = transforms.Compose(
|
251 |
+
[
|
252 |
+
transforms.Resize(size, interpolation=3),
|
253 |
+
transforms.CenterCrop(input_size),
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
self.transform_train = transforms.Resize((input_size, input_size))
|
258 |
+
elif beit_transforms:
|
259 |
+
beit_transform_list = []
|
260 |
+
if not target_transform:
|
261 |
+
beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4))
|
262 |
+
beit_transform_list.extend(
|
263 |
+
[
|
264 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
265 |
+
RandomResizedCropAndInterpolationWithTwoPic(
|
266 |
+
size=input_size,
|
267 |
+
second_size=None,
|
268 |
+
interpolation="bicubic",
|
269 |
+
second_interpolation=None,
|
270 |
+
),
|
271 |
+
]
|
272 |
+
)
|
273 |
+
self.transform_train = transforms.Compose(beit_transform_list)
|
274 |
+
else:
|
275 |
+
self.transform_train = transforms.Compose(
|
276 |
+
[
|
277 |
+
transforms.RandomResizedCrop(
|
278 |
+
input_size, scale=(0.2, 1.0), interpolation=3
|
279 |
+
), # 3 is bicubic
|
280 |
+
transforms.RandomHorizontalFlip(),
|
281 |
+
]
|
282 |
+
)
|
283 |
+
self.final_transform = transforms.Compose(
|
284 |
+
[
|
285 |
+
transforms.ToTensor(),
|
286 |
+
transforms.Normalize(
|
287 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
288 |
+
),
|
289 |
+
]
|
290 |
+
)
|
291 |
+
|
292 |
+
if dataset_type == "imagefolder":
|
293 |
+
self.dataset = datasets.ImageFolder(
|
294 |
+
os.path.join(root, split), loader=loader
|
295 |
+
)
|
296 |
+
elif dataset_type == "path":
|
297 |
+
self.dataset = PathDataset(
|
298 |
+
root,
|
299 |
+
loader,
|
300 |
+
None,
|
301 |
+
None,
|
302 |
+
mean=[0.485, 0.456, 0.406],
|
303 |
+
std=[0.229, 0.224, 0.225],
|
304 |
+
)
|
305 |
+
else:
|
306 |
+
raise Exception(f"invalid dataset type {dataset_type}")
|
307 |
+
|
308 |
+
logger.info(
|
309 |
+
f"initial transform: {self.transform_train}, "
|
310 |
+
f"source transform: {self.transform_source}, "
|
311 |
+
f"target transform: {self.transform_target}, "
|
312 |
+
f"final transform: {self.final_transform}"
|
313 |
+
)
|
314 |
+
logger.info(f"loaded {len(self.dataset)} examples")
|
315 |
+
|
316 |
+
self.is_compute_mask = compute_mask
|
317 |
+
self.patches = (input_size // patch_size) ** 2
|
318 |
+
self.mask_prob = mask_prob
|
319 |
+
self.mask_prob_adjust = mask_prob_adjust
|
320 |
+
self.mask_length = mask_length
|
321 |
+
self.inverse_mask = inverse_mask
|
322 |
+
self.expand_adjacent = expand_adjacent
|
323 |
+
self.mask_dropout = mask_dropout
|
324 |
+
self.non_overlapping = non_overlapping
|
325 |
+
self.require_same_masks = require_same_masks
|
326 |
+
self.clone_batch = clone_batch
|
327 |
+
|
328 |
+
def __getitem__(self, index):
|
329 |
+
img, _ = self.dataset[index]
|
330 |
+
|
331 |
+
img = self.transform_train(img)
|
332 |
+
|
333 |
+
source = None
|
334 |
+
target = None
|
335 |
+
if self.transform_source is not None:
|
336 |
+
source = self.final_transform(self.transform_source(img))
|
337 |
+
if self.transform_target is not None:
|
338 |
+
target = self.final_transform(self.transform_target(img))
|
339 |
+
|
340 |
+
if source is None:
|
341 |
+
img = self.final_transform(img)
|
342 |
+
|
343 |
+
v = {"id": index, self.key: source if source is not None else img}
|
344 |
+
if target is not None:
|
345 |
+
v["target"] = target
|
346 |
+
|
347 |
+
if self.is_compute_mask:
|
348 |
+
if self.mask_length == 1:
|
349 |
+
mask = compute_block_mask_1d(
|
350 |
+
shape=(self.clone_batch, self.patches),
|
351 |
+
mask_prob=self.mask_prob,
|
352 |
+
mask_length=self.mask_length,
|
353 |
+
mask_prob_adjust=self.mask_prob_adjust,
|
354 |
+
inverse_mask=self.inverse_mask,
|
355 |
+
require_same_masks=True,
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
mask = compute_block_mask_2d(
|
359 |
+
shape=(self.clone_batch, self.patches),
|
360 |
+
mask_prob=self.mask_prob,
|
361 |
+
mask_length=self.mask_length,
|
362 |
+
mask_prob_adjust=self.mask_prob_adjust,
|
363 |
+
inverse_mask=self.inverse_mask,
|
364 |
+
require_same_masks=True,
|
365 |
+
expand_adjcent=self.expand_adjacent,
|
366 |
+
mask_dropout=self.mask_dropout,
|
367 |
+
non_overlapping=self.non_overlapping,
|
368 |
+
)
|
369 |
+
|
370 |
+
v["precomputed_mask"] = mask
|
371 |
+
|
372 |
+
return v
|
373 |
+
|
374 |
+
def __len__(self):
|
375 |
+
return len(self.dataset)
|
376 |
+
|
377 |
+
def collater(self, samples):
|
378 |
+
if len(samples) == 0:
|
379 |
+
return {}
|
380 |
+
|
381 |
+
collated_img = torch.stack([s[self.key] for s in samples], dim=0)
|
382 |
+
|
383 |
+
res = {
|
384 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
385 |
+
"net_input": {
|
386 |
+
self.key: collated_img,
|
387 |
+
},
|
388 |
+
}
|
389 |
+
|
390 |
+
if "target" in samples[0]:
|
391 |
+
collated_target = torch.stack([s["target"] for s in samples], dim=0)
|
392 |
+
res["net_input"]["target"] = collated_target
|
393 |
+
|
394 |
+
if "precomputed_mask" in samples[0]:
|
395 |
+
collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0)
|
396 |
+
res["net_input"]["precomputed_mask"] = collated_mask
|
397 |
+
|
398 |
+
return res
|
399 |
+
|
400 |
+
def num_tokens(self, index):
|
401 |
+
return 1
|
402 |
+
|
403 |
+
def size(self, index):
|
404 |
+
return 1
|
405 |
+
|
406 |
+
@property
|
407 |
+
def sizes(self):
|
408 |
+
return np.full((len(self),), 1)
|
409 |
+
|
410 |
+
def ordered_indices(self):
|
411 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
412 |
+
on this order."""
|
413 |
+
if self.shuffle:
|
414 |
+
order = [np.random.permutation(len(self))]
|
415 |
+
else:
|
416 |
+
order = [np.arange(len(self))]
|
417 |
+
|
418 |
+
return order[0]
|
fairseq/examples/data2vec/data/modality.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2017-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the LICENSE file in
|
5 |
+
# the root directory of this source tree. An additional grant of patent rights
|
6 |
+
# can be found in the PATENTS file in the same directory.
|
7 |
+
|
8 |
+
from enum import Enum, auto
|
9 |
+
|
10 |
+
|
11 |
+
class Modality(Enum):
|
12 |
+
AUDIO = auto()
|
13 |
+
IMAGE = auto()
|
14 |
+
TEXT = auto()
|
fairseq/examples/data2vec/data/path_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import numpy as np
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
import PIL
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.datasets import VisionDataset
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
|
15 |
+
class PathDataset(VisionDataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
root: List[str],
|
19 |
+
loader: None = None,
|
20 |
+
transform: Optional[str] = None,
|
21 |
+
extra_transform: Optional[str] = None,
|
22 |
+
mean: Optional[List[float]] = None,
|
23 |
+
std: Optional[List[float]] = None,
|
24 |
+
):
|
25 |
+
super().__init__(root=root)
|
26 |
+
|
27 |
+
PIL.Image.MAX_IMAGE_PIXELS = 256000001
|
28 |
+
|
29 |
+
self.files = []
|
30 |
+
for folder in self.root:
|
31 |
+
self.files.extend(
|
32 |
+
sorted(glob.glob(os.path.join(folder, "**", "*.jpg"), recursive=True))
|
33 |
+
)
|
34 |
+
self.files.extend(
|
35 |
+
sorted(glob.glob(os.path.join(folder, "**", "*.png"), recursive=True))
|
36 |
+
)
|
37 |
+
|
38 |
+
self.transform = transform
|
39 |
+
self.extra_transform = extra_transform
|
40 |
+
self.mean = mean
|
41 |
+
self.std = std
|
42 |
+
|
43 |
+
self.loader = loader
|
44 |
+
|
45 |
+
logger.info(f"loaded {len(self.files)} samples from {root}")
|
46 |
+
|
47 |
+
assert (mean is None) == (std is None)
|
48 |
+
|
49 |
+
def __len__(self) -> int:
|
50 |
+
return len(self.files)
|
51 |
+
|
52 |
+
def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
|
53 |
+
path = self.files[idx]
|
54 |
+
|
55 |
+
if self.loader is not None:
|
56 |
+
return self.loader(path), None
|
57 |
+
|
58 |
+
img = Image.open(path).convert("RGB")
|
59 |
+
if self.transform is not None:
|
60 |
+
img = self.transform(img)
|
61 |
+
img = TF.to_tensor(img)
|
62 |
+
if self.mean is not None and self.std is not None:
|
63 |
+
img = TF.normalize(img, self.mean, self.std)
|
64 |
+
return img, None
|
fairseq/examples/data2vec/models/__init__.py
ADDED
File without changes
|
fairseq/examples/data2vec/models/audio_classification.py
ADDED
@@ -0,0 +1,614 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import contextlib
|
7 |
+
import logging
|
8 |
+
import re
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import Any, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import numpy as np
|
16 |
+
from omegaconf import II, MISSING, open_dict
|
17 |
+
|
18 |
+
from fairseq import checkpoint_utils, tasks
|
19 |
+
from fairseq.dataclass import FairseqDataclass
|
20 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
21 |
+
from fairseq.models import (
|
22 |
+
BaseFairseqModel,
|
23 |
+
register_model,
|
24 |
+
)
|
25 |
+
from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES
|
26 |
+
from fairseq.modules import TransposeLast
|
27 |
+
from fairseq.tasks import FairseqTask
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class AudioClassificationConfig(FairseqDataclass):
|
34 |
+
model_path: str = field(
|
35 |
+
default=MISSING, metadata={"help": "path to wav2vec 2.0 model"}
|
36 |
+
)
|
37 |
+
no_pretrained_weights: bool = field(
|
38 |
+
default=False, metadata={"help": "if true, does not load pretrained weights"}
|
39 |
+
)
|
40 |
+
dropout_input: float = field(
|
41 |
+
default=0.0,
|
42 |
+
metadata={"help": "dropout to apply to the input (after feat extr)"},
|
43 |
+
)
|
44 |
+
final_dropout: float = field(
|
45 |
+
default=0.0,
|
46 |
+
metadata={"help": "dropout after transformer and before final projection"},
|
47 |
+
)
|
48 |
+
dropout: float = field(
|
49 |
+
default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"}
|
50 |
+
)
|
51 |
+
attention_dropout: float = field(
|
52 |
+
default=0.0,
|
53 |
+
metadata={
|
54 |
+
"help": "dropout probability for attention weights inside wav2vec 2.0 model"
|
55 |
+
},
|
56 |
+
)
|
57 |
+
activation_dropout: float = field(
|
58 |
+
default=0.0,
|
59 |
+
metadata={
|
60 |
+
"help": "dropout probability after activation in FFN inside wav2vec 2.0 model"
|
61 |
+
},
|
62 |
+
)
|
63 |
+
|
64 |
+
# masking
|
65 |
+
apply_mask: bool = field(
|
66 |
+
default=False, metadata={"help": "apply masking during fine-tuning"}
|
67 |
+
)
|
68 |
+
mask_length: int = field(
|
69 |
+
default=10, metadata={"help": "repeat the mask indices multiple times"}
|
70 |
+
)
|
71 |
+
mask_prob: float = field(
|
72 |
+
default=0.5,
|
73 |
+
metadata={
|
74 |
+
"help": "probability of replacing a token with mask (normalized by length)"
|
75 |
+
},
|
76 |
+
)
|
77 |
+
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
78 |
+
default="static", metadata={"help": "how to choose masks"}
|
79 |
+
)
|
80 |
+
mask_other: float = field(
|
81 |
+
default=0,
|
82 |
+
metadata={
|
83 |
+
"help": "secondary mask argument (used for more complex distributions), "
|
84 |
+
"see help in compute_mask_indices"
|
85 |
+
},
|
86 |
+
)
|
87 |
+
no_mask_overlap: bool = field(
|
88 |
+
default=False, metadata={"help": "whether to allow masks to overlap"}
|
89 |
+
)
|
90 |
+
mask_min_space: Optional[int] = field(
|
91 |
+
default=1,
|
92 |
+
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
93 |
+
)
|
94 |
+
require_same_masks: bool = field(
|
95 |
+
default=True,
|
96 |
+
metadata={
|
97 |
+
"help": "whether to number of masked timesteps must be the same across all "
|
98 |
+
"examples in a batch"
|
99 |
+
},
|
100 |
+
)
|
101 |
+
mask_dropout: float = field(
|
102 |
+
default=0.0,
|
103 |
+
metadata={"help": "percent of masks to unmask for each sample"},
|
104 |
+
)
|
105 |
+
|
106 |
+
# channel masking
|
107 |
+
mask_channel_length: int = field(
|
108 |
+
default=10, metadata={"help": "length of the mask for features (channels)"}
|
109 |
+
)
|
110 |
+
mask_channel_prob: float = field(
|
111 |
+
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
|
112 |
+
)
|
113 |
+
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
|
114 |
+
default="static",
|
115 |
+
metadata={"help": "how to choose mask length for channel masking"},
|
116 |
+
)
|
117 |
+
mask_channel_other: float = field(
|
118 |
+
default=0,
|
119 |
+
metadata={
|
120 |
+
"help": "secondary mask argument (used for more complex distributions), "
|
121 |
+
"see help in compute_mask_indicesh"
|
122 |
+
},
|
123 |
+
)
|
124 |
+
no_mask_channel_overlap: bool = field(
|
125 |
+
default=False, metadata={"help": "whether to allow channel masks to overlap"}
|
126 |
+
)
|
127 |
+
freeze_finetune_updates: int = field(
|
128 |
+
default=0, metadata={"help": "dont finetune wav2vec for this many updates"}
|
129 |
+
)
|
130 |
+
feature_grad_mult: float = field(
|
131 |
+
default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"}
|
132 |
+
)
|
133 |
+
layerdrop: float = field(
|
134 |
+
default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"}
|
135 |
+
)
|
136 |
+
mask_channel_min_space: Optional[int] = field(
|
137 |
+
default=1,
|
138 |
+
metadata={"help": "min space between spans (if no overlap is enabled)"},
|
139 |
+
)
|
140 |
+
mask_channel_before: bool = False
|
141 |
+
normalize: bool = II("task.normalize")
|
142 |
+
data: str = II("task.data")
|
143 |
+
# this holds the loaded wav2vec args
|
144 |
+
d2v_args: Any = None
|
145 |
+
offload_activations: bool = field(
|
146 |
+
default=False, metadata={"help": "offload_activations"}
|
147 |
+
)
|
148 |
+
min_params_to_wrap: int = field(
|
149 |
+
default=int(1e8),
|
150 |
+
metadata={
|
151 |
+
"help": "minimum number of params for a layer to be wrapped with FSDP() when "
|
152 |
+
"training with --ddp-backend=fully_sharded. Smaller values will "
|
153 |
+
"improve memory efficiency, but may make torch.distributed "
|
154 |
+
"communication less efficient due to smaller input sizes. This option "
|
155 |
+
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
|
156 |
+
"--offload-activations are passed."
|
157 |
+
},
|
158 |
+
)
|
159 |
+
|
160 |
+
checkpoint_activations: bool = field(
|
161 |
+
default=False,
|
162 |
+
metadata={"help": "recompute activations and save memory for extra compute"},
|
163 |
+
)
|
164 |
+
ddp_backend: str = II("distributed_training.ddp_backend")
|
165 |
+
|
166 |
+
prediction_mode: str = "lin_softmax"
|
167 |
+
eval_prediction_mode: Optional[str] = None
|
168 |
+
conv_kernel: int = -1
|
169 |
+
conv_stride: int = 1
|
170 |
+
two_convs: bool = False
|
171 |
+
extreme_factor: float = 1.0
|
172 |
+
|
173 |
+
conv_feature_layers: Optional[str] = field(
|
174 |
+
default=None,
|
175 |
+
metadata={
|
176 |
+
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
|
177 |
+
"[(dim, kernel_size, stride), ...]"
|
178 |
+
},
|
179 |
+
)
|
180 |
+
|
181 |
+
mixup_prob: float = 1.0
|
182 |
+
source_mixup: float = -1
|
183 |
+
same_mixup: bool = True
|
184 |
+
label_mixup: bool = False
|
185 |
+
|
186 |
+
gain_mode: str = "none"
|
187 |
+
|
188 |
+
|
189 |
+
@register_model("audio_classification", dataclass=AudioClassificationConfig)
|
190 |
+
class AudioClassificationModel(BaseFairseqModel):
|
191 |
+
def __init__(self, cfg: AudioClassificationConfig, num_classes):
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.apply_mask = cfg.apply_mask
|
195 |
+
self.cfg = cfg
|
196 |
+
|
197 |
+
arg_overrides = {
|
198 |
+
"dropout": cfg.dropout,
|
199 |
+
"activation_dropout": cfg.activation_dropout,
|
200 |
+
"dropout_input": cfg.dropout_input,
|
201 |
+
"attention_dropout": cfg.attention_dropout,
|
202 |
+
"mask_length": cfg.mask_length,
|
203 |
+
"mask_prob": cfg.mask_prob,
|
204 |
+
"require_same_masks": getattr(cfg, "require_same_masks", True),
|
205 |
+
"mask_dropout": getattr(cfg, "mask_dropout", 0),
|
206 |
+
"mask_selection": cfg.mask_selection,
|
207 |
+
"mask_other": cfg.mask_other,
|
208 |
+
"no_mask_overlap": cfg.no_mask_overlap,
|
209 |
+
"mask_channel_length": cfg.mask_channel_length,
|
210 |
+
"mask_channel_prob": cfg.mask_channel_prob,
|
211 |
+
"mask_channel_before": cfg.mask_channel_before,
|
212 |
+
"mask_channel_selection": cfg.mask_channel_selection,
|
213 |
+
"mask_channel_other": cfg.mask_channel_other,
|
214 |
+
"no_mask_channel_overlap": cfg.no_mask_channel_overlap,
|
215 |
+
"encoder_layerdrop": cfg.layerdrop,
|
216 |
+
"feature_grad_mult": cfg.feature_grad_mult,
|
217 |
+
"checkpoint_activations": cfg.checkpoint_activations,
|
218 |
+
"offload_activations": cfg.offload_activations,
|
219 |
+
"min_params_to_wrap": cfg.min_params_to_wrap,
|
220 |
+
"mixup": -1,
|
221 |
+
}
|
222 |
+
|
223 |
+
if cfg.conv_feature_layers is not None:
|
224 |
+
arg_overrides["conv_feature_layers"] = cfg.conv_feature_layers
|
225 |
+
|
226 |
+
if cfg.d2v_args is None:
|
227 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
228 |
+
cfg.model_path, arg_overrides
|
229 |
+
)
|
230 |
+
d2v_args = state.get("cfg", None)
|
231 |
+
if d2v_args is None:
|
232 |
+
d2v_args = convert_namespace_to_omegaconf(state["args"])
|
233 |
+
d2v_args.criterion = None
|
234 |
+
d2v_args.lr_scheduler = None
|
235 |
+
cfg.d2v_args = d2v_args
|
236 |
+
|
237 |
+
logger.info(d2v_args)
|
238 |
+
|
239 |
+
else:
|
240 |
+
state = None
|
241 |
+
d2v_args = cfg.d2v_args
|
242 |
+
|
243 |
+
model_normalized = d2v_args.task.get(
|
244 |
+
"normalize", d2v_args.model.get("normalize", False)
|
245 |
+
)
|
246 |
+
assert cfg.normalize == model_normalized, (
|
247 |
+
"Fine-tuning works best when data normalization is the same. "
|
248 |
+
"Please check that --normalize is set or unset for both pre-training and here"
|
249 |
+
)
|
250 |
+
|
251 |
+
if hasattr(cfg, "checkpoint_activations") and cfg.checkpoint_activations:
|
252 |
+
with open_dict(d2v_args):
|
253 |
+
d2v_args.model.checkpoint_activations = cfg.checkpoint_activations
|
254 |
+
|
255 |
+
d2v_args.task.data = cfg.data
|
256 |
+
task = tasks.setup_task(d2v_args.task)
|
257 |
+
model = task.build_model(d2v_args.model, from_checkpoint=True)
|
258 |
+
|
259 |
+
model.remove_pretraining_modules()
|
260 |
+
|
261 |
+
if state is not None and not cfg.no_pretrained_weights:
|
262 |
+
self.load_model_weights(state, model, cfg)
|
263 |
+
|
264 |
+
d = d2v_args.model.encoder_embed_dim
|
265 |
+
|
266 |
+
self.d2v_model = model
|
267 |
+
|
268 |
+
self.final_dropout = nn.Dropout(cfg.final_dropout)
|
269 |
+
self.freeze_finetune_updates = cfg.freeze_finetune_updates
|
270 |
+
self.num_updates = 0
|
271 |
+
|
272 |
+
for p in self.parameters():
|
273 |
+
p.param_group = "pretrained"
|
274 |
+
|
275 |
+
if cfg.prediction_mode == "proj_avg_proj":
|
276 |
+
self.proj = nn.Linear(d, d * 2)
|
277 |
+
self.proj2 = nn.Linear(d * 2, num_classes)
|
278 |
+
|
279 |
+
for p in self.proj.parameters():
|
280 |
+
p.param_group = "projection"
|
281 |
+
for p in self.proj2.parameters():
|
282 |
+
p.param_group = "projection"
|
283 |
+
elif self.cfg.prediction_mode == "summary_proj":
|
284 |
+
self.proj = nn.Linear(d // 3, num_classes)
|
285 |
+
for p in self.proj.parameters():
|
286 |
+
p.param_group = "projection"
|
287 |
+
elif self.cfg.conv_kernel > 1 and not self.cfg.two_convs:
|
288 |
+
self.proj = nn.Sequential(
|
289 |
+
TransposeLast(),
|
290 |
+
nn.Conv1d(d, num_classes, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
|
291 |
+
TransposeLast(),
|
292 |
+
)
|
293 |
+
for p in self.proj.parameters():
|
294 |
+
p.param_group = "projection"
|
295 |
+
elif self.cfg.conv_kernel > 0 and self.cfg.two_convs:
|
296 |
+
self.proj = nn.Sequential(
|
297 |
+
TransposeLast(),
|
298 |
+
nn.Conv1d(d, d, kernel_size=self.cfg.conv_kernel, stride=self.cfg.conv_stride),
|
299 |
+
TransposeLast(),
|
300 |
+
nn.GELU(),
|
301 |
+
nn.Linear(d, num_classes),
|
302 |
+
)
|
303 |
+
for p in self.proj.parameters():
|
304 |
+
p.param_group = "projection"
|
305 |
+
else:
|
306 |
+
self.proj = nn.Linear(d, num_classes)
|
307 |
+
for p in self.proj.parameters():
|
308 |
+
p.param_group = "projection"
|
309 |
+
|
310 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
311 |
+
super().upgrade_state_dict_named(state_dict, name)
|
312 |
+
return state_dict
|
313 |
+
|
314 |
+
@classmethod
|
315 |
+
def build_model(cls, cfg: AudioClassificationConfig, task: FairseqTask):
|
316 |
+
"""Build a new model instance."""
|
317 |
+
|
318 |
+
assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'"
|
319 |
+
|
320 |
+
return cls(cfg, len(task.labels))
|
321 |
+
|
322 |
+
def load_model_weights(self, state, model, cfg):
|
323 |
+
if cfg.ddp_backend == "fully_sharded":
|
324 |
+
from fairseq.distributed import FullyShardedDataParallel
|
325 |
+
|
326 |
+
for name, module in model.named_modules():
|
327 |
+
if "encoder.layers" in name and len(name.split(".")) == 3:
|
328 |
+
# Only for layers, we do a special handling and load the weights one by one
|
329 |
+
# We dont load all weights together as that wont be memory efficient and may
|
330 |
+
# cause oom
|
331 |
+
new_dict = {
|
332 |
+
k.replace(name + ".", ""): v
|
333 |
+
for (k, v) in state["model"].items()
|
334 |
+
if name + "." in k
|
335 |
+
}
|
336 |
+
assert isinstance(module, FullyShardedDataParallel)
|
337 |
+
with module.summon_full_params():
|
338 |
+
module.load_state_dict(new_dict, strict=True)
|
339 |
+
module._reset_lazy_init()
|
340 |
+
|
341 |
+
# Once layers are loaded, filter them out and load everything else.
|
342 |
+
r = re.compile("encoder.layers.\d.")
|
343 |
+
filtered_list = list(filter(r.match, state["model"].keys()))
|
344 |
+
|
345 |
+
new_big_dict = {
|
346 |
+
k: v for (k, v) in state["model"].items() if k not in filtered_list
|
347 |
+
}
|
348 |
+
|
349 |
+
model.load_state_dict(new_big_dict, strict=False)
|
350 |
+
else:
|
351 |
+
if "_ema" in state["model"]:
|
352 |
+
del state["model"]["_ema"]
|
353 |
+
model.load_state_dict(state["model"], strict=False)
|
354 |
+
|
355 |
+
def set_num_updates(self, num_updates):
|
356 |
+
"""Set the number of parameters updates."""
|
357 |
+
super().set_num_updates(num_updates)
|
358 |
+
self.num_updates = num_updates
|
359 |
+
|
360 |
+
def compute_gain(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
|
361 |
+
if fs == 16000:
|
362 |
+
n_fft = 2048
|
363 |
+
elif fs == 44100:
|
364 |
+
n_fft = 4096
|
365 |
+
else:
|
366 |
+
raise Exception("Invalid fs {}".format(fs))
|
367 |
+
stride = n_fft // 2
|
368 |
+
|
369 |
+
def a_weight(fs, n_fft, min_db=-80.0):
|
370 |
+
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
371 |
+
freq_sq = np.power(freq, 2)
|
372 |
+
freq_sq[0] = 1.0
|
373 |
+
weight = 2.0 + 20.0 * (
|
374 |
+
2 * np.log10(12194)
|
375 |
+
+ 2 * np.log10(freq_sq)
|
376 |
+
- np.log10(freq_sq + 12194 ** 2)
|
377 |
+
- np.log10(freq_sq + 20.6 ** 2)
|
378 |
+
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
379 |
+
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
|
380 |
+
)
|
381 |
+
weight = np.maximum(weight, min_db)
|
382 |
+
|
383 |
+
return weight
|
384 |
+
|
385 |
+
gain = []
|
386 |
+
for i in range(0, len(sound) - n_fft + 1, stride):
|
387 |
+
if mode == "RMSE":
|
388 |
+
g = np.mean(sound[i : i + n_fft] ** 2)
|
389 |
+
elif mode == "A_weighting":
|
390 |
+
spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i : i + n_fft])
|
391 |
+
power_spec = np.abs(spec) ** 2
|
392 |
+
a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10)
|
393 |
+
g = np.sum(a_weighted_spec)
|
394 |
+
else:
|
395 |
+
raise Exception("Invalid mode {}".format(mode))
|
396 |
+
gain.append(g)
|
397 |
+
|
398 |
+
gain = np.array(gain)
|
399 |
+
gain = np.maximum(gain, np.power(10, min_db / 10))
|
400 |
+
gain_db = 10 * np.log10(gain)
|
401 |
+
|
402 |
+
return gain_db
|
403 |
+
|
404 |
+
# adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
|
405 |
+
def compute_gain_torch(self, sound, fs=16_000, min_db=-80.0, mode="A_weighting"):
|
406 |
+
if fs == 16000:
|
407 |
+
n_fft = 2048
|
408 |
+
elif fs == 44100:
|
409 |
+
n_fft = 4096
|
410 |
+
else:
|
411 |
+
raise Exception("Invalid fs {}".format(fs))
|
412 |
+
|
413 |
+
if mode == "A_weighting":
|
414 |
+
if not hasattr(self, f"a_weight"):
|
415 |
+
self.a_weight = {}
|
416 |
+
|
417 |
+
if fs not in self.a_weight:
|
418 |
+
|
419 |
+
def a_weight(fs, n_fft, min_db=-80.0):
|
420 |
+
freq = np.linspace(0, fs // 2, n_fft // 2 + 1)
|
421 |
+
freq_sq = freq ** 2
|
422 |
+
freq_sq[0] = 1.0
|
423 |
+
weight = 2.0 + 20.0 * (
|
424 |
+
2 * np.log10(12194)
|
425 |
+
+ 2 * np.log10(freq_sq)
|
426 |
+
- np.log10(freq_sq + 12194 ** 2)
|
427 |
+
- np.log10(freq_sq + 20.6 ** 2)
|
428 |
+
- 0.5 * np.log10(freq_sq + 107.7 ** 2)
|
429 |
+
- 0.5 * np.log10(freq_sq + 737.9 ** 2)
|
430 |
+
)
|
431 |
+
weight = np.maximum(weight, min_db)
|
432 |
+
|
433 |
+
return weight
|
434 |
+
|
435 |
+
self.a_weight[fs] = torch.from_numpy(
|
436 |
+
np.power(10, a_weight(fs, n_fft, min_db) / 10)
|
437 |
+
).to(device=sound.device)
|
438 |
+
|
439 |
+
sound = sound.unfold(-1, n_fft, n_fft // 2)
|
440 |
+
|
441 |
+
if mode == "RMSE":
|
442 |
+
sound = sound ** 2
|
443 |
+
g = sound.mean(-1)
|
444 |
+
elif mode == "A_weighting":
|
445 |
+
w = torch.hann_window(n_fft, device=sound.device) * sound
|
446 |
+
spec = torch.fft.rfft(w)
|
447 |
+
power_spec = spec.abs() ** 2
|
448 |
+
a_weighted_spec = power_spec * self.a_weight[fs]
|
449 |
+
g = a_weighted_spec.sum(-1)
|
450 |
+
else:
|
451 |
+
raise Exception("Invalid mode {}".format(mode))
|
452 |
+
|
453 |
+
gain = torch.maximum(g, torch.tensor(10 ** (min_db / 10), device=g.device))
|
454 |
+
gain_db = 10 * torch.log10(gain)
|
455 |
+
|
456 |
+
return gain_db
|
457 |
+
|
458 |
+
def forward(self, source, padding_mask, label=None, **kwargs):
|
459 |
+
|
460 |
+
if self.cfg.source_mixup >= 0 and self.training and self.cfg.mixup_prob > 0:
|
461 |
+
with torch.no_grad():
|
462 |
+
mixed_source = source
|
463 |
+
mix_mask = None
|
464 |
+
if self.cfg.mixup_prob < 1:
|
465 |
+
mix_mask = (
|
466 |
+
torch.empty((source.size(0),), device=source.device)
|
467 |
+
.bernoulli_(self.cfg.mixup_prob)
|
468 |
+
.bool()
|
469 |
+
)
|
470 |
+
mixed_source = source[mix_mask]
|
471 |
+
|
472 |
+
r = (
|
473 |
+
torch.FloatTensor(
|
474 |
+
1 if self.cfg.same_mixup else mixed_source.size(0)
|
475 |
+
)
|
476 |
+
.uniform_(max(1e-6, self.cfg.source_mixup), 1)
|
477 |
+
.to(dtype=source.dtype, device=source.device)
|
478 |
+
)
|
479 |
+
|
480 |
+
mixup_perm = torch.randperm(source.size(0))
|
481 |
+
s2 = source[mixup_perm]
|
482 |
+
|
483 |
+
if self.cfg.gain_mode == "none":
|
484 |
+
p = r.unsqueeze(-1)
|
485 |
+
if mix_mask is not None:
|
486 |
+
s2 = s2[mix_mask]
|
487 |
+
else:
|
488 |
+
if self.cfg.gain_mode == "naive_rms":
|
489 |
+
G1 = source.pow(2).mean(dim=-1).sqrt()
|
490 |
+
else:
|
491 |
+
G1, _ = self.compute_gain_torch(
|
492 |
+
source, mode=self.cfg.gain_mode
|
493 |
+
).max(-1)
|
494 |
+
G1 = G1.to(dtype=source.dtype)
|
495 |
+
|
496 |
+
G2 = G1[mixup_perm]
|
497 |
+
|
498 |
+
if mix_mask is not None:
|
499 |
+
G1 = G1[mix_mask]
|
500 |
+
G2 = G2[mix_mask]
|
501 |
+
s2 = s2[mix_mask]
|
502 |
+
|
503 |
+
p = 1 / (1 + 10 ** ((G1 - G2) / 20) * (1 - r) / r)
|
504 |
+
p = p.unsqueeze(-1)
|
505 |
+
|
506 |
+
mixed = (p * mixed_source) + (1 - p) * s2
|
507 |
+
|
508 |
+
if mix_mask is None:
|
509 |
+
source = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
|
510 |
+
else:
|
511 |
+
source[mix_mask] = mixed / torch.sqrt(p ** 2 + (1 - p) ** 2)
|
512 |
+
|
513 |
+
if label is not None and self.cfg.label_mixup:
|
514 |
+
r = r.unsqueeze(-1)
|
515 |
+
if mix_mask is None:
|
516 |
+
label = label * r + (1 - r) * label[mixup_perm]
|
517 |
+
else:
|
518 |
+
label[mix_mask] = (
|
519 |
+
label[mix_mask] * r + (1 - r) * label[mixup_perm][mix_mask]
|
520 |
+
)
|
521 |
+
|
522 |
+
d2v_args = {
|
523 |
+
"source": source,
|
524 |
+
"padding_mask": padding_mask,
|
525 |
+
"mask": self.apply_mask and self.training,
|
526 |
+
}
|
527 |
+
|
528 |
+
ft = self.freeze_finetune_updates <= self.num_updates
|
529 |
+
|
530 |
+
with torch.no_grad() if not ft else contextlib.ExitStack():
|
531 |
+
res = self.d2v_model.extract_features(**d2v_args)
|
532 |
+
|
533 |
+
x = res["x"]
|
534 |
+
padding_mask = res["padding_mask"]
|
535 |
+
if padding_mask is not None:
|
536 |
+
x[padding_mask] = 0
|
537 |
+
|
538 |
+
x = self.final_dropout(x)
|
539 |
+
|
540 |
+
if self.training or (
|
541 |
+
self.cfg.eval_prediction_mode is None or self.cfg.eval_prediction_mode == ""
|
542 |
+
):
|
543 |
+
prediction_mode = self.cfg.prediction_mode
|
544 |
+
else:
|
545 |
+
prediction_mode = self.cfg.eval_prediction_mode
|
546 |
+
|
547 |
+
if prediction_mode == "average_before":
|
548 |
+
x = x.mean(dim=1)
|
549 |
+
|
550 |
+
if prediction_mode != "summary_mha" and prediction_mode != "summary_proj" and prediction_mode != "cls":
|
551 |
+
x = self.proj(x)
|
552 |
+
|
553 |
+
logits = True
|
554 |
+
if prediction_mode == "lin_softmax":
|
555 |
+
x = F.logsigmoid(x.float())
|
556 |
+
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x, dim=1)
|
557 |
+
x = x.clamp(max=0)
|
558 |
+
x = x - torch.log(-(torch.expm1(x)))
|
559 |
+
elif prediction_mode == "extremized_odds":
|
560 |
+
x = x.float().sum(dim=1)
|
561 |
+
x = x * self.cfg.extreme_factor
|
562 |
+
elif prediction_mode == "average_before":
|
563 |
+
x = x.float()
|
564 |
+
elif prediction_mode == "average":
|
565 |
+
x = x.float().mean(dim=1)
|
566 |
+
elif prediction_mode == "average_sigmoid":
|
567 |
+
x = torch.sigmoid(x.float())
|
568 |
+
x = x.mean(dim=1)
|
569 |
+
logits = False
|
570 |
+
elif prediction_mode == "max":
|
571 |
+
x, _ = x.float().max(dim=1)
|
572 |
+
elif prediction_mode == "max_sigmoid":
|
573 |
+
x = torch.sigmoid(x.float())
|
574 |
+
x, _ = x.float().max(dim=1)
|
575 |
+
logits = False
|
576 |
+
elif prediction_mode == "proj_avg_proj":
|
577 |
+
x = x.mean(dim=1)
|
578 |
+
x = self.proj2(x)
|
579 |
+
elif prediction_mode == "summary_mha" or prediction_mode == "summary_proj":
|
580 |
+
x = self.d2v_model.summary(
|
581 |
+
x, padding_mask, proj=prediction_mode == "summary_proj"
|
582 |
+
)
|
583 |
+
x = x.type_as(source)
|
584 |
+
x = self.proj(x)
|
585 |
+
elif prediction_mode == "cls":
|
586 |
+
x = x[:,0]
|
587 |
+
x = self.proj(x)
|
588 |
+
else:
|
589 |
+
raise Exception(f"unknown prediction mode {prediction_mode}")
|
590 |
+
|
591 |
+
if label is None:
|
592 |
+
return torch.sigmoid(x) if logits else x
|
593 |
+
|
594 |
+
x = torch.nan_to_num(x)
|
595 |
+
|
596 |
+
if logits:
|
597 |
+
loss = F.binary_cross_entropy_with_logits(
|
598 |
+
x, label.float(), reduction="none"
|
599 |
+
)
|
600 |
+
else:
|
601 |
+
loss = F.binary_cross_entropy(x, label.float(), reduction="none")
|
602 |
+
|
603 |
+
result = {
|
604 |
+
"losses": {
|
605 |
+
"main": loss,
|
606 |
+
},
|
607 |
+
"sample_size": label.sum(),
|
608 |
+
}
|
609 |
+
|
610 |
+
if not self.training:
|
611 |
+
result["_predictions"] = torch.sigmoid(x) if logits else x
|
612 |
+
result["_targets"] = label
|
613 |
+
|
614 |
+
return result
|
fairseq/examples/data2vec/models/data2vec2.py
ADDED
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Optional, Callable
|
10 |
+
from functools import partial
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from omegaconf import II
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.distributed as dist
|
19 |
+
|
20 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
21 |
+
|
22 |
+
from fairseq.dataclass import FairseqDataclass
|
23 |
+
from fairseq.models import BaseFairseqModel, register_model
|
24 |
+
|
25 |
+
from examples.data2vec.data.modality import Modality
|
26 |
+
|
27 |
+
from examples.data2vec.models.modalities.base import (
|
28 |
+
MaskSeed,
|
29 |
+
D2vModalityConfig,
|
30 |
+
ModalitySpecificEncoder,
|
31 |
+
get_annealed_rate,
|
32 |
+
)
|
33 |
+
from examples.data2vec.models.modalities.modules import (
|
34 |
+
D2vDecoderConfig,
|
35 |
+
AltBlock,
|
36 |
+
Decoder1d,
|
37 |
+
)
|
38 |
+
|
39 |
+
from examples.data2vec.models.modalities.audio import (
|
40 |
+
D2vAudioConfig,
|
41 |
+
AudioEncoder,
|
42 |
+
)
|
43 |
+
from examples.data2vec.models.modalities.images import (
|
44 |
+
D2vImageConfig,
|
45 |
+
ImageEncoder,
|
46 |
+
)
|
47 |
+
from examples.data2vec.models.modalities.text import (
|
48 |
+
D2vTextConfig,
|
49 |
+
TextEncoder,
|
50 |
+
)
|
51 |
+
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class D2vModalitiesConfig(FairseqDataclass):
|
57 |
+
audio: D2vAudioConfig = D2vAudioConfig()
|
58 |
+
image: D2vImageConfig = D2vImageConfig()
|
59 |
+
text: D2vTextConfig = D2vTextConfig()
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class Data2VecMultiConfig(FairseqDataclass):
|
64 |
+
|
65 |
+
loss_beta: float = field(
|
66 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
67 |
+
)
|
68 |
+
loss_scale: Optional[float] = field(
|
69 |
+
default=None,
|
70 |
+
metadata={
|
71 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
72 |
+
},
|
73 |
+
)
|
74 |
+
|
75 |
+
depth: int = 8
|
76 |
+
start_drop_path_rate: float = 0
|
77 |
+
end_drop_path_rate: float = 0
|
78 |
+
num_heads: int = 12
|
79 |
+
norm_eps: float = 1e-6
|
80 |
+
norm_affine: bool = True
|
81 |
+
encoder_dropout: float = 0.1
|
82 |
+
post_mlp_drop: float = 0.1
|
83 |
+
attention_dropout: float = 0.1
|
84 |
+
activation_dropout: float = 0.0
|
85 |
+
dropout_input: float = 0.0
|
86 |
+
layerdrop: float = 0.0
|
87 |
+
embed_dim: int = 768
|
88 |
+
mlp_ratio: float = 4
|
89 |
+
layer_norm_first: bool = False
|
90 |
+
|
91 |
+
average_top_k_layers: int = field(
|
92 |
+
default=8, metadata={"help": "how many layers to average"}
|
93 |
+
)
|
94 |
+
|
95 |
+
end_of_block_targets: bool = False
|
96 |
+
|
97 |
+
clone_batch: int = 1
|
98 |
+
|
99 |
+
layer_norm_target_layer: bool = False
|
100 |
+
batch_norm_target_layer: bool = False
|
101 |
+
instance_norm_target_layer: bool = False
|
102 |
+
instance_norm_targets: bool = False
|
103 |
+
layer_norm_targets: bool = False
|
104 |
+
|
105 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
106 |
+
ema_same_dtype: bool = True
|
107 |
+
log_norms: bool = True
|
108 |
+
ema_end_decay: float = field(
|
109 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
110 |
+
)
|
111 |
+
|
112 |
+
# when to finish annealing ema decay rate
|
113 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
114 |
+
|
115 |
+
ema_encoder_only: bool = field(
|
116 |
+
default=True,
|
117 |
+
metadata={
|
118 |
+
"help": "whether to momentum update only the shared transformer encoder"
|
119 |
+
},
|
120 |
+
)
|
121 |
+
|
122 |
+
max_update: int = II("optimization.max_update")
|
123 |
+
|
124 |
+
modalities: D2vModalitiesConfig = D2vModalitiesConfig()
|
125 |
+
|
126 |
+
shared_decoder: Optional[D2vDecoderConfig] = None
|
127 |
+
|
128 |
+
min_target_var: float = field(
|
129 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
130 |
+
)
|
131 |
+
min_pred_var: float = field(
|
132 |
+
default=0.01,
|
133 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
134 |
+
)
|
135 |
+
|
136 |
+
supported_modality: Optional[Modality] = None
|
137 |
+
mae_init: bool = False
|
138 |
+
|
139 |
+
seed: int = II("common.seed")
|
140 |
+
|
141 |
+
skip_ema: bool = False
|
142 |
+
|
143 |
+
cls_loss: float = 0
|
144 |
+
recon_loss: float = 0
|
145 |
+
d2v_loss: float = 1
|
146 |
+
|
147 |
+
decoder_group: bool = False
|
148 |
+
|
149 |
+
|
150 |
+
@register_model("data2vec_multi", dataclass=Data2VecMultiConfig)
|
151 |
+
class Data2VecMultiModel(BaseFairseqModel):
|
152 |
+
def make_modality_encoder(
|
153 |
+
self,
|
154 |
+
cfg: D2vModalityConfig,
|
155 |
+
embed_dim: int,
|
156 |
+
make_block: Callable[[float], nn.ModuleList],
|
157 |
+
norm_layer: Callable[[int], nn.LayerNorm],
|
158 |
+
layer_norm_first: bool,
|
159 |
+
alibi_biases,
|
160 |
+
task,
|
161 |
+
) -> ModalitySpecificEncoder:
|
162 |
+
if cfg.type == Modality.AUDIO:
|
163 |
+
enc_cls = AudioEncoder
|
164 |
+
elif cfg.type == Modality.IMAGE:
|
165 |
+
enc_cls = ImageEncoder
|
166 |
+
elif cfg.type == Modality.TEXT:
|
167 |
+
enc_cls = TextEncoder
|
168 |
+
if hasattr(task, "text_task"):
|
169 |
+
task = task.text_task
|
170 |
+
else:
|
171 |
+
raise Exception(f"unsupported modality {cfg.type}")
|
172 |
+
|
173 |
+
return enc_cls(
|
174 |
+
cfg,
|
175 |
+
embed_dim,
|
176 |
+
make_block,
|
177 |
+
norm_layer,
|
178 |
+
layer_norm_first,
|
179 |
+
alibi_biases,
|
180 |
+
task,
|
181 |
+
)
|
182 |
+
|
183 |
+
def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None):
|
184 |
+
super().__init__()
|
185 |
+
self.cfg = cfg
|
186 |
+
self.modalities = modalities
|
187 |
+
self.task = task
|
188 |
+
|
189 |
+
make_layer_norm = partial(
|
190 |
+
nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine
|
191 |
+
)
|
192 |
+
|
193 |
+
def make_block(drop_path, dim=None, heads=None):
|
194 |
+
return AltBlock(
|
195 |
+
cfg.embed_dim if dim is None else dim,
|
196 |
+
cfg.num_heads if heads is None else heads,
|
197 |
+
cfg.mlp_ratio,
|
198 |
+
qkv_bias=True,
|
199 |
+
drop=cfg.encoder_dropout,
|
200 |
+
attn_drop=cfg.attention_dropout,
|
201 |
+
mlp_drop=cfg.activation_dropout,
|
202 |
+
post_mlp_drop=cfg.post_mlp_drop,
|
203 |
+
drop_path=drop_path,
|
204 |
+
norm_layer=make_layer_norm,
|
205 |
+
layer_norm_first=cfg.layer_norm_first,
|
206 |
+
ffn_targets=not cfg.end_of_block_targets,
|
207 |
+
)
|
208 |
+
|
209 |
+
self.alibi_biases = {}
|
210 |
+
self.modality_encoders = nn.ModuleDict()
|
211 |
+
for mod in self.modalities:
|
212 |
+
mod_cfg = getattr(cfg.modalities, mod.name.lower())
|
213 |
+
enc = self.make_modality_encoder(
|
214 |
+
mod_cfg,
|
215 |
+
cfg.embed_dim,
|
216 |
+
make_block,
|
217 |
+
make_layer_norm,
|
218 |
+
cfg.layer_norm_first,
|
219 |
+
self.alibi_biases,
|
220 |
+
task,
|
221 |
+
)
|
222 |
+
self.modality_encoders[mod.name] = enc
|
223 |
+
|
224 |
+
self.ema = None
|
225 |
+
|
226 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
227 |
+
self.loss_beta = cfg.loss_beta
|
228 |
+
self.loss_scale = cfg.loss_scale
|
229 |
+
|
230 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
231 |
+
|
232 |
+
dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth)
|
233 |
+
|
234 |
+
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
|
235 |
+
|
236 |
+
self.norm = None
|
237 |
+
if cfg.layer_norm_first:
|
238 |
+
self.norm = make_layer_norm(cfg.embed_dim)
|
239 |
+
|
240 |
+
if self.cfg.mae_init:
|
241 |
+
self.apply(self._init_weights)
|
242 |
+
else:
|
243 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
244 |
+
|
245 |
+
self.apply(init_bert_params)
|
246 |
+
|
247 |
+
for mod_enc in self.modality_encoders.values():
|
248 |
+
mod_enc.reset_parameters()
|
249 |
+
|
250 |
+
if not skip_ema:
|
251 |
+
self.ema = self.make_ema_teacher(cfg.ema_decay)
|
252 |
+
self.shared_decoder = (
|
253 |
+
Decoder1d(cfg.shared_decoder, cfg.embed_dim)
|
254 |
+
if self.cfg.shared_decoder is not None
|
255 |
+
else None
|
256 |
+
)
|
257 |
+
if self.shared_decoder is not None:
|
258 |
+
self.shared_decoder.apply(self._init_weights)
|
259 |
+
|
260 |
+
self.recon_proj = None
|
261 |
+
if cfg.recon_loss > 0:
|
262 |
+
self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
|
263 |
+
|
264 |
+
for pn, p in self.named_parameters():
|
265 |
+
if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn:
|
266 |
+
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
267 |
+
if cfg.decoder_group and "decoder" in pn:
|
268 |
+
p.param_group = "decoder"
|
269 |
+
|
270 |
+
self.num_updates = 0
|
271 |
+
|
272 |
+
def _init_weights(self, m):
|
273 |
+
|
274 |
+
try:
|
275 |
+
from apex.normalization import FusedLayerNorm
|
276 |
+
|
277 |
+
fn = FusedLayerNorm
|
278 |
+
except:
|
279 |
+
fn = nn.LayerNorm
|
280 |
+
|
281 |
+
if isinstance(m, nn.Linear):
|
282 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
283 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
284 |
+
nn.init.constant_(m.bias, 0)
|
285 |
+
elif isinstance(m, nn.LayerNorm) or isinstance(m, fn):
|
286 |
+
if m.bias is not None:
|
287 |
+
nn.init.constant_(m.bias, 0)
|
288 |
+
if m.weight is not None:
|
289 |
+
nn.init.constant_(m.weight, 1.0)
|
290 |
+
|
291 |
+
@torch.no_grad()
|
292 |
+
def make_ema_teacher(self, ema_decay):
|
293 |
+
ema_config = EMAModuleConfig(
|
294 |
+
ema_decay=ema_decay,
|
295 |
+
ema_fp32=True,
|
296 |
+
log_norms=self.cfg.log_norms,
|
297 |
+
add_missing_params=False,
|
298 |
+
)
|
299 |
+
|
300 |
+
model_copy = self.make_target_model()
|
301 |
+
|
302 |
+
return EMAModule(
|
303 |
+
model_copy,
|
304 |
+
ema_config,
|
305 |
+
copy_model=False,
|
306 |
+
)
|
307 |
+
|
308 |
+
def make_target_model(self):
|
309 |
+
logger.info("making target model")
|
310 |
+
|
311 |
+
model_copy = Data2VecMultiModel(
|
312 |
+
self.cfg, self.modalities, skip_ema=True, task=self.task
|
313 |
+
)
|
314 |
+
|
315 |
+
if self.cfg.ema_encoder_only:
|
316 |
+
model_copy = model_copy.blocks
|
317 |
+
for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()):
|
318 |
+
p_t.data.copy_(p_s.data)
|
319 |
+
else:
|
320 |
+
for p_s, p_t in zip(self.parameters(), model_copy.parameters()):
|
321 |
+
p_t.data.copy_(p_s.data)
|
322 |
+
|
323 |
+
for mod_enc in model_copy.modality_encoders.values():
|
324 |
+
mod_enc.decoder = None
|
325 |
+
if not mod_enc.modality_cfg.ema_local_encoder:
|
326 |
+
mod_enc.local_encoder = None
|
327 |
+
mod_enc.project_features = None
|
328 |
+
|
329 |
+
model_copy.requires_grad_(False)
|
330 |
+
return model_copy
|
331 |
+
|
332 |
+
def set_num_updates(self, num_updates):
|
333 |
+
super().set_num_updates(num_updates)
|
334 |
+
|
335 |
+
if self.ema is not None and (
|
336 |
+
(self.num_updates == 0 and num_updates > 1)
|
337 |
+
or self.num_updates >= num_updates
|
338 |
+
):
|
339 |
+
pass
|
340 |
+
elif self.training and self.ema is not None:
|
341 |
+
ema_weight_decay = None
|
342 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
343 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
344 |
+
decay = self.cfg.ema_end_decay
|
345 |
+
else:
|
346 |
+
decay = get_annealed_rate(
|
347 |
+
self.cfg.ema_decay,
|
348 |
+
self.cfg.ema_end_decay,
|
349 |
+
num_updates,
|
350 |
+
self.cfg.ema_anneal_end_step,
|
351 |
+
)
|
352 |
+
self.ema.set_decay(decay, weight_decay=ema_weight_decay)
|
353 |
+
if self.ema.get_decay() < 1:
|
354 |
+
self.ema.step(self.blocks if self.cfg.ema_encoder_only else self)
|
355 |
+
|
356 |
+
self.num_updates = num_updates
|
357 |
+
|
358 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
359 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
360 |
+
|
361 |
+
if self.ema is not None:
|
362 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
363 |
+
|
364 |
+
return state
|
365 |
+
|
366 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
367 |
+
k = prefix + "_ema"
|
368 |
+
if self.ema is not None:
|
369 |
+
assert k in state_dict
|
370 |
+
self.ema.restore(state_dict[k], True)
|
371 |
+
del state_dict[k]
|
372 |
+
elif k in state_dict:
|
373 |
+
del state_dict[k]
|
374 |
+
|
375 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
376 |
+
|
377 |
+
@classmethod
|
378 |
+
def build_model(cls, cfg: Data2VecMultiConfig, task=None):
|
379 |
+
"""Build a new model instance."""
|
380 |
+
if task is None or not hasattr(task, "supported_modalities"):
|
381 |
+
modalities = (
|
382 |
+
[cfg.supported_modality]
|
383 |
+
if cfg.supported_modality is not None
|
384 |
+
else [
|
385 |
+
Modality.AUDIO,
|
386 |
+
Modality.IMAGE,
|
387 |
+
Modality.TEXT,
|
388 |
+
]
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
modalities = task.supported_modalities
|
392 |
+
return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema)
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
source,
|
397 |
+
target=None,
|
398 |
+
id=None,
|
399 |
+
mode=None,
|
400 |
+
padding_mask=None,
|
401 |
+
mask=True,
|
402 |
+
features_only=False,
|
403 |
+
force_remove_masked=False,
|
404 |
+
remove_extra_tokens=True,
|
405 |
+
precomputed_mask=None,
|
406 |
+
):
|
407 |
+
if mode is None:
|
408 |
+
assert self.cfg.supported_modality is not None
|
409 |
+
mode = self.cfg.supported_modality
|
410 |
+
|
411 |
+
if isinstance(mode, Modality):
|
412 |
+
mode = mode.name
|
413 |
+
|
414 |
+
feature_extractor = self.modality_encoders[mode]
|
415 |
+
|
416 |
+
mask_seeds = None
|
417 |
+
if id is not None:
|
418 |
+
mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id)
|
419 |
+
|
420 |
+
extractor_out = feature_extractor(
|
421 |
+
source,
|
422 |
+
padding_mask,
|
423 |
+
mask,
|
424 |
+
remove_masked=not features_only or force_remove_masked,
|
425 |
+
clone_batch=self.cfg.clone_batch if not features_only else 1,
|
426 |
+
mask_seeds=mask_seeds,
|
427 |
+
precomputed_mask=precomputed_mask,
|
428 |
+
)
|
429 |
+
|
430 |
+
x = extractor_out["x"]
|
431 |
+
encoder_mask = extractor_out["encoder_mask"]
|
432 |
+
masked_padding_mask = extractor_out["padding_mask"]
|
433 |
+
masked_alibi_bias = extractor_out.get("alibi_bias", None)
|
434 |
+
alibi_scale = extractor_out.get("alibi_scale", None)
|
435 |
+
|
436 |
+
if self.dropout_input is not None:
|
437 |
+
x = self.dropout_input(x)
|
438 |
+
|
439 |
+
layer_results = []
|
440 |
+
for i, blk in enumerate(self.blocks):
|
441 |
+
if (
|
442 |
+
not self.training
|
443 |
+
or self.cfg.layerdrop == 0
|
444 |
+
or (np.random.random() > self.cfg.layerdrop)
|
445 |
+
):
|
446 |
+
ab = masked_alibi_bias
|
447 |
+
if ab is not None and alibi_scale is not None:
|
448 |
+
scale = (
|
449 |
+
alibi_scale[i]
|
450 |
+
if alibi_scale.size(0) > 1
|
451 |
+
else alibi_scale.squeeze(0)
|
452 |
+
)
|
453 |
+
ab = ab * scale.type_as(ab)
|
454 |
+
|
455 |
+
x, lr = blk(
|
456 |
+
x,
|
457 |
+
padding_mask=masked_padding_mask,
|
458 |
+
alibi_bias=ab,
|
459 |
+
)
|
460 |
+
if features_only:
|
461 |
+
layer_results.append(lr)
|
462 |
+
|
463 |
+
if self.norm is not None:
|
464 |
+
x = self.norm(x)
|
465 |
+
|
466 |
+
if features_only:
|
467 |
+
if remove_extra_tokens:
|
468 |
+
x = x[:, feature_extractor.modality_cfg.num_extra_tokens :]
|
469 |
+
if masked_padding_mask is not None:
|
470 |
+
masked_padding_mask = masked_padding_mask[
|
471 |
+
:, feature_extractor.modality_cfg.num_extra_tokens :
|
472 |
+
]
|
473 |
+
|
474 |
+
return {
|
475 |
+
"x": x,
|
476 |
+
"padding_mask": masked_padding_mask,
|
477 |
+
"layer_results": layer_results,
|
478 |
+
"mask": encoder_mask,
|
479 |
+
}
|
480 |
+
|
481 |
+
xs = []
|
482 |
+
|
483 |
+
if self.shared_decoder is not None:
|
484 |
+
dx = self.forward_decoder(
|
485 |
+
x,
|
486 |
+
feature_extractor,
|
487 |
+
self.shared_decoder,
|
488 |
+
encoder_mask,
|
489 |
+
)
|
490 |
+
xs.append(dx)
|
491 |
+
if feature_extractor.decoder is not None:
|
492 |
+
dx = self.forward_decoder(
|
493 |
+
x,
|
494 |
+
feature_extractor,
|
495 |
+
feature_extractor.decoder,
|
496 |
+
encoder_mask,
|
497 |
+
)
|
498 |
+
xs.append(dx)
|
499 |
+
orig_x = x
|
500 |
+
|
501 |
+
assert len(xs) > 0
|
502 |
+
|
503 |
+
p = next(self.ema.model.parameters())
|
504 |
+
device = x.device
|
505 |
+
dtype = x.dtype
|
506 |
+
ema_device = p.device
|
507 |
+
ema_dtype = p.dtype
|
508 |
+
|
509 |
+
if not self.cfg.ema_same_dtype:
|
510 |
+
dtype = ema_dtype
|
511 |
+
|
512 |
+
if ema_device != device or ema_dtype != dtype:
|
513 |
+
logger.info(f"adjusting ema dtype to {dtype} and device to {device}")
|
514 |
+
self.ema.model = self.ema.model.to(dtype=dtype, device=device)
|
515 |
+
ema_dtype = dtype
|
516 |
+
|
517 |
+
def to_device(d):
|
518 |
+
for k, p in d.items():
|
519 |
+
if isinstance(d[k], dict):
|
520 |
+
to_device(d[k])
|
521 |
+
else:
|
522 |
+
d[k] = p.to(device=device)
|
523 |
+
|
524 |
+
to_device(self.ema.fp32_params)
|
525 |
+
tm = self.ema.model
|
526 |
+
|
527 |
+
with torch.no_grad():
|
528 |
+
tm.eval()
|
529 |
+
|
530 |
+
if self.cfg.ema_encoder_only:
|
531 |
+
assert target is None
|
532 |
+
ema_input = extractor_out["local_features"]
|
533 |
+
ema_input = feature_extractor.contextualized_features(
|
534 |
+
ema_input.to(dtype=ema_dtype),
|
535 |
+
padding_mask,
|
536 |
+
mask=False,
|
537 |
+
remove_masked=False,
|
538 |
+
)
|
539 |
+
ema_blocks = tm
|
540 |
+
else:
|
541 |
+
ema_blocks = tm.blocks
|
542 |
+
if feature_extractor.modality_cfg.ema_local_encoder:
|
543 |
+
inp = (
|
544 |
+
target.to(dtype=ema_dtype)
|
545 |
+
if target is not None
|
546 |
+
else source.to(dtype=ema_dtype)
|
547 |
+
)
|
548 |
+
ema_input = tm.modality_encoders[mode](
|
549 |
+
inp,
|
550 |
+
padding_mask,
|
551 |
+
mask=False,
|
552 |
+
remove_masked=False,
|
553 |
+
)
|
554 |
+
else:
|
555 |
+
assert target is None
|
556 |
+
ema_input = extractor_out["local_features"]
|
557 |
+
ema_feature_enc = tm.modality_encoders[mode]
|
558 |
+
ema_input = ema_feature_enc.contextualized_features(
|
559 |
+
ema_input.to(dtype=ema_dtype),
|
560 |
+
padding_mask,
|
561 |
+
mask=False,
|
562 |
+
remove_masked=False,
|
563 |
+
)
|
564 |
+
|
565 |
+
ema_padding_mask = ema_input["padding_mask"]
|
566 |
+
ema_alibi_bias = ema_input.get("alibi_bias", None)
|
567 |
+
ema_alibi_scale = ema_input.get("alibi_scale", None)
|
568 |
+
ema_input = ema_input["x"]
|
569 |
+
|
570 |
+
y = []
|
571 |
+
ema_x = []
|
572 |
+
extra_tokens = feature_extractor.modality_cfg.num_extra_tokens
|
573 |
+
for i, blk in enumerate(ema_blocks):
|
574 |
+
ab = ema_alibi_bias
|
575 |
+
if ab is not None and alibi_scale is not None:
|
576 |
+
scale = (
|
577 |
+
ema_alibi_scale[i]
|
578 |
+
if ema_alibi_scale.size(0) > 1
|
579 |
+
else ema_alibi_scale.squeeze(0)
|
580 |
+
)
|
581 |
+
ab = ab * scale.type_as(ab)
|
582 |
+
|
583 |
+
ema_input, lr = blk(
|
584 |
+
ema_input,
|
585 |
+
padding_mask=ema_padding_mask,
|
586 |
+
alibi_bias=ab,
|
587 |
+
)
|
588 |
+
y.append(lr[:, extra_tokens:])
|
589 |
+
ema_x.append(ema_input[:, extra_tokens:])
|
590 |
+
|
591 |
+
y = self.make_targets(y, self.average_top_k_layers)
|
592 |
+
orig_targets = y
|
593 |
+
|
594 |
+
if self.cfg.clone_batch > 1:
|
595 |
+
y = y.repeat_interleave(self.cfg.clone_batch, 0)
|
596 |
+
|
597 |
+
masked = encoder_mask.mask.unsqueeze(-1)
|
598 |
+
masked_b = encoder_mask.mask.bool()
|
599 |
+
y = y[masked_b]
|
600 |
+
|
601 |
+
if xs[0].size(1) == masked_b.size(1):
|
602 |
+
xs = [x[masked_b] for x in xs]
|
603 |
+
else:
|
604 |
+
xs = [x.reshape(-1, x.size(-1)) for x in xs]
|
605 |
+
|
606 |
+
sample_size = masked.sum().long()
|
607 |
+
|
608 |
+
result = {
|
609 |
+
"losses": {},
|
610 |
+
"sample_size": sample_size,
|
611 |
+
}
|
612 |
+
|
613 |
+
sample_size = result["sample_size"]
|
614 |
+
|
615 |
+
if self.cfg.cls_loss > 0:
|
616 |
+
assert extra_tokens > 0
|
617 |
+
cls_target = orig_targets.mean(dim=1)
|
618 |
+
if self.cfg.clone_batch > 1:
|
619 |
+
cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0)
|
620 |
+
cls_pred = x[:, extra_tokens - 1]
|
621 |
+
result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * (
|
622 |
+
self.cfg.cls_loss * sample_size
|
623 |
+
)
|
624 |
+
|
625 |
+
if self.cfg.recon_loss > 0:
|
626 |
+
|
627 |
+
with torch.no_grad():
|
628 |
+
target = feature_extractor.patchify(source)
|
629 |
+
mean = target.mean(dim=-1, keepdim=True)
|
630 |
+
var = target.var(dim=-1, keepdim=True)
|
631 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
632 |
+
|
633 |
+
if self.cfg.clone_batch > 1:
|
634 |
+
target = target.repeat_interleave(self.cfg.clone_batch, 0)
|
635 |
+
|
636 |
+
if masked_b is not None:
|
637 |
+
target = target[masked_b]
|
638 |
+
|
639 |
+
recon = xs[0]
|
640 |
+
if self.recon_proj is not None:
|
641 |
+
recon = self.recon_proj(recon)
|
642 |
+
|
643 |
+
result["losses"]["recon"] = (
|
644 |
+
self.d2v_loss(recon, target.float()) * self.cfg.recon_loss
|
645 |
+
)
|
646 |
+
|
647 |
+
if self.cfg.d2v_loss > 0:
|
648 |
+
for i, x in enumerate(xs):
|
649 |
+
reg_loss = self.d2v_loss(x, y)
|
650 |
+
n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression"
|
651 |
+
result["losses"][n] = reg_loss * self.cfg.d2v_loss
|
652 |
+
|
653 |
+
suffix = "" if len(self.modalities) == 1 else f"_{mode}"
|
654 |
+
with torch.no_grad():
|
655 |
+
if encoder_mask is not None:
|
656 |
+
result["masked_pct"] = 1 - (
|
657 |
+
encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1)
|
658 |
+
)
|
659 |
+
for i, x in enumerate(xs):
|
660 |
+
n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}"
|
661 |
+
result[n] = self.compute_var(x.float())
|
662 |
+
if self.ema is not None:
|
663 |
+
for k, v in self.ema.logs.items():
|
664 |
+
result[k] = v
|
665 |
+
|
666 |
+
y = y.float()
|
667 |
+
result[f"target_var{suffix}"] = self.compute_var(y)
|
668 |
+
|
669 |
+
if self.num_updates > 5000:
|
670 |
+
if result[f"target_var{suffix}"] < self.cfg.min_target_var:
|
671 |
+
logger.error(
|
672 |
+
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
|
673 |
+
)
|
674 |
+
raise Exception(
|
675 |
+
f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})"
|
676 |
+
)
|
677 |
+
|
678 |
+
for k in result.keys():
|
679 |
+
if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var:
|
680 |
+
logger.error(
|
681 |
+
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
|
682 |
+
)
|
683 |
+
raise Exception(
|
684 |
+
f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})"
|
685 |
+
)
|
686 |
+
|
687 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
688 |
+
|
689 |
+
return result
|
690 |
+
|
691 |
+
def forward_decoder(
|
692 |
+
self,
|
693 |
+
x,
|
694 |
+
feature_extractor,
|
695 |
+
decoder,
|
696 |
+
mask_info,
|
697 |
+
):
|
698 |
+
x = feature_extractor.decoder_input(x, mask_info)
|
699 |
+
x = decoder(*x)
|
700 |
+
|
701 |
+
return x
|
702 |
+
|
703 |
+
def d2v_loss(self, x, y):
|
704 |
+
x = x.view(-1, x.size(-1)).float()
|
705 |
+
y = y.view(-1, x.size(-1))
|
706 |
+
|
707 |
+
if self.loss_beta == 0:
|
708 |
+
loss = F.mse_loss(x, y, reduction="none")
|
709 |
+
else:
|
710 |
+
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta)
|
711 |
+
|
712 |
+
if self.loss_scale is not None:
|
713 |
+
scale = self.loss_scale
|
714 |
+
else:
|
715 |
+
scale = 1 / math.sqrt(x.size(-1))
|
716 |
+
|
717 |
+
reg_loss = loss * scale
|
718 |
+
|
719 |
+
return reg_loss
|
720 |
+
|
721 |
+
def make_targets(self, y, num_layers):
|
722 |
+
|
723 |
+
with torch.no_grad():
|
724 |
+
target_layer_results = y[-num_layers:]
|
725 |
+
|
726 |
+
permuted = False
|
727 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
728 |
+
target_layer_results = [
|
729 |
+
tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT
|
730 |
+
]
|
731 |
+
permuted = True
|
732 |
+
if self.cfg.batch_norm_target_layer:
|
733 |
+
target_layer_results = [
|
734 |
+
F.batch_norm(
|
735 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
736 |
+
)
|
737 |
+
for tl in target_layer_results
|
738 |
+
]
|
739 |
+
if self.cfg.instance_norm_target_layer:
|
740 |
+
target_layer_results = [
|
741 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
742 |
+
]
|
743 |
+
if permuted:
|
744 |
+
target_layer_results = [
|
745 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
746 |
+
]
|
747 |
+
if self.cfg.layer_norm_target_layer:
|
748 |
+
target_layer_results = [
|
749 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
750 |
+
for tl in target_layer_results
|
751 |
+
]
|
752 |
+
|
753 |
+
y = target_layer_results[0].float()
|
754 |
+
for tl in target_layer_results[1:]:
|
755 |
+
y.add_(tl.float())
|
756 |
+
y = y.div_(len(target_layer_results))
|
757 |
+
|
758 |
+
if self.cfg.layer_norm_targets:
|
759 |
+
y = F.layer_norm(y, y.shape[-1:])
|
760 |
+
|
761 |
+
if self.cfg.instance_norm_targets:
|
762 |
+
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
|
763 |
+
|
764 |
+
return y
|
765 |
+
|
766 |
+
@staticmethod
|
767 |
+
def compute_var(y):
|
768 |
+
y = y.view(-1, y.size(-1))
|
769 |
+
if dist.is_initialized():
|
770 |
+
zc = torch.tensor(y.size(0)).cuda()
|
771 |
+
zs = y.sum(dim=0)
|
772 |
+
zss = (y**2).sum(dim=0)
|
773 |
+
|
774 |
+
dist.all_reduce(zc)
|
775 |
+
dist.all_reduce(zs)
|
776 |
+
dist.all_reduce(zss)
|
777 |
+
|
778 |
+
var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
|
779 |
+
return torch.sqrt(var + 1e-6).mean()
|
780 |
+
else:
|
781 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
782 |
+
|
783 |
+
def extract_features(
|
784 |
+
self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True
|
785 |
+
):
|
786 |
+
res = self.forward(
|
787 |
+
source,
|
788 |
+
mode=mode,
|
789 |
+
padding_mask=padding_mask,
|
790 |
+
mask=mask,
|
791 |
+
features_only=True,
|
792 |
+
remove_extra_tokens=remove_extra_tokens,
|
793 |
+
)
|
794 |
+
return res
|
795 |
+
|
796 |
+
def remove_pretraining_modules(self, modality=None, keep_decoder=False):
|
797 |
+
self.ema = None
|
798 |
+
self.cfg.clone_batch = 1
|
799 |
+
self.recon_proj = None
|
800 |
+
|
801 |
+
if not keep_decoder:
|
802 |
+
self.shared_decoder = None
|
803 |
+
|
804 |
+
modality = modality.lower() if modality is not None else None
|
805 |
+
for k in list(self.modality_encoders.keys()):
|
806 |
+
if modality is not None and k.lower() != modality:
|
807 |
+
del self.modality_encoders[k]
|
808 |
+
else:
|
809 |
+
self.modality_encoders[k].remove_pretraining_modules(
|
810 |
+
keep_decoder=keep_decoder
|
811 |
+
)
|
812 |
+
if not keep_decoder:
|
813 |
+
self.modality_encoders[k].decoder = None
|
fairseq/examples/data2vec/models/data2vec_audio.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
from omegaconf import II
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import torch.distributed as dist
|
17 |
+
|
18 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
19 |
+
from fairseq.data.data_utils import compute_mask_indices
|
20 |
+
from fairseq.models import BaseFairseqModel, register_model
|
21 |
+
from fairseq.models.wav2vec import (
|
22 |
+
ConvFeatureExtractionModel,
|
23 |
+
Wav2Vec2Config,
|
24 |
+
TransformerEncoder,
|
25 |
+
)
|
26 |
+
from fairseq.modules import (
|
27 |
+
GradMultiply,
|
28 |
+
LayerNorm,
|
29 |
+
)
|
30 |
+
from fairseq.utils import index_put
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class Data2VecAudioConfig(Wav2Vec2Config):
|
38 |
+
|
39 |
+
loss_beta: float = field(
|
40 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
41 |
+
)
|
42 |
+
loss_scale: Optional[float] = field(
|
43 |
+
default=None,
|
44 |
+
metadata={
|
45 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
46 |
+
},
|
47 |
+
)
|
48 |
+
average_top_k_layers: int = field(
|
49 |
+
default=8, metadata={"help": "how many layers to average"}
|
50 |
+
)
|
51 |
+
|
52 |
+
layer_norm_target_layer: bool = False
|
53 |
+
instance_norm_target_layer: bool = False
|
54 |
+
instance_norm_targets: bool = False
|
55 |
+
layer_norm_targets: bool = False
|
56 |
+
batch_norm_target_layer: bool = False
|
57 |
+
group_norm_target_layer: bool = False
|
58 |
+
|
59 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
60 |
+
ema_end_decay: float = field(
|
61 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
62 |
+
)
|
63 |
+
|
64 |
+
# when to finish annealing ema decay rate
|
65 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
66 |
+
|
67 |
+
ema_transformer_only: bool = field(
|
68 |
+
default=True,
|
69 |
+
metadata={"help": "whether to momentum update only the transformer"},
|
70 |
+
)
|
71 |
+
ema_layers_only: bool = field(
|
72 |
+
default=True,
|
73 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
74 |
+
)
|
75 |
+
|
76 |
+
max_update: int = II("optimization.max_update")
|
77 |
+
|
78 |
+
min_target_var: float = field(
|
79 |
+
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
80 |
+
)
|
81 |
+
min_pred_var: float = field(
|
82 |
+
default=0.01,
|
83 |
+
metadata={"help": "stop training if prediction var falls below this"},
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
88 |
+
r = end - start
|
89 |
+
pct_remaining = 1 - curr_step / total_steps
|
90 |
+
return end - r * pct_remaining
|
91 |
+
|
92 |
+
|
93 |
+
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
94 |
+
class Data2VecAudioModel(BaseFairseqModel):
|
95 |
+
def __init__(self, cfg: Data2VecAudioConfig):
|
96 |
+
super().__init__()
|
97 |
+
self.cfg = cfg
|
98 |
+
|
99 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
100 |
+
self.extractor_embed = feature_enc_layers[-1][0]
|
101 |
+
|
102 |
+
self.ema = None
|
103 |
+
self.embed = cfg.encoder_embed_dim
|
104 |
+
|
105 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
106 |
+
self.loss_beta = cfg.loss_beta
|
107 |
+
self.loss_scale = cfg.loss_scale
|
108 |
+
|
109 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
110 |
+
conv_layers=feature_enc_layers,
|
111 |
+
dropout=0.0,
|
112 |
+
mode=cfg.extractor_mode,
|
113 |
+
conv_bias=cfg.conv_bias,
|
114 |
+
)
|
115 |
+
|
116 |
+
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
117 |
+
|
118 |
+
self.mask_prob = cfg.mask_prob
|
119 |
+
self.mask_selection = cfg.mask_selection
|
120 |
+
self.mask_other = cfg.mask_other
|
121 |
+
self.mask_length = cfg.mask_length
|
122 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
123 |
+
self.mask_min_space = cfg.mask_min_space
|
124 |
+
|
125 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
126 |
+
self.mask_channel_before = cfg.mask_channel_before
|
127 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
128 |
+
self.mask_channel_other = cfg.mask_channel_other
|
129 |
+
self.mask_channel_length = cfg.mask_channel_length
|
130 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
131 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
132 |
+
|
133 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
134 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
135 |
+
|
136 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
137 |
+
|
138 |
+
self.mask_emb = nn.Parameter(
|
139 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
140 |
+
)
|
141 |
+
|
142 |
+
self.encoder = TransformerEncoder(cfg)
|
143 |
+
self.layer_norm = LayerNorm(self.extractor_embed)
|
144 |
+
|
145 |
+
self.final_proj = nn.Linear(self.embed, self.embed)
|
146 |
+
|
147 |
+
self.num_updates = 0
|
148 |
+
|
149 |
+
def make_ema_teacher(self):
|
150 |
+
ema_config = EMAModuleConfig(
|
151 |
+
ema_decay=self.cfg.ema_decay,
|
152 |
+
ema_fp32=True,
|
153 |
+
)
|
154 |
+
skip_keys = set()
|
155 |
+
if self.cfg.ema_layers_only:
|
156 |
+
self.cfg.ema_transformer_only = True
|
157 |
+
for k, _ in self.encoder.pos_conv.named_parameters():
|
158 |
+
skip_keys.add(f"pos_conv.{k}")
|
159 |
+
|
160 |
+
self.ema = EMAModule(
|
161 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
162 |
+
ema_config,
|
163 |
+
skip_keys=skip_keys,
|
164 |
+
)
|
165 |
+
|
166 |
+
def set_num_updates(self, num_updates):
|
167 |
+
super().set_num_updates(num_updates)
|
168 |
+
|
169 |
+
if self.ema is None and self.final_proj is not None:
|
170 |
+
logger.info(f"making ema teacher")
|
171 |
+
self.make_ema_teacher()
|
172 |
+
elif self.training and self.ema is not None:
|
173 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
174 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
175 |
+
decay = self.cfg.ema_end_decay
|
176 |
+
else:
|
177 |
+
decay = get_annealed_rate(
|
178 |
+
self.cfg.ema_decay,
|
179 |
+
self.cfg.ema_end_decay,
|
180 |
+
num_updates,
|
181 |
+
self.cfg.ema_anneal_end_step,
|
182 |
+
)
|
183 |
+
self.ema.set_decay(decay)
|
184 |
+
if self.ema.get_decay() < 1:
|
185 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
186 |
+
|
187 |
+
self.num_updates = num_updates
|
188 |
+
|
189 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
190 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
191 |
+
|
192 |
+
if self.ema is not None:
|
193 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
194 |
+
|
195 |
+
return state
|
196 |
+
|
197 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
198 |
+
if self.ema is not None:
|
199 |
+
k = prefix + "_ema"
|
200 |
+
assert k in state_dict
|
201 |
+
self.ema.restore(state_dict[k], True)
|
202 |
+
del state_dict[k]
|
203 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
204 |
+
|
205 |
+
@classmethod
|
206 |
+
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
207 |
+
"""Build a new model instance."""
|
208 |
+
|
209 |
+
return cls(cfg)
|
210 |
+
|
211 |
+
def apply_mask(
|
212 |
+
self,
|
213 |
+
x,
|
214 |
+
padding_mask,
|
215 |
+
mask_indices=None,
|
216 |
+
mask_channel_indices=None,
|
217 |
+
):
|
218 |
+
B, T, C = x.shape
|
219 |
+
|
220 |
+
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
221 |
+
mask_channel_indices = compute_mask_indices(
|
222 |
+
(B, C),
|
223 |
+
None,
|
224 |
+
self.mask_channel_prob,
|
225 |
+
self.mask_channel_length,
|
226 |
+
self.mask_channel_selection,
|
227 |
+
self.mask_channel_other,
|
228 |
+
no_overlap=self.no_mask_channel_overlap,
|
229 |
+
min_space=self.mask_channel_min_space,
|
230 |
+
)
|
231 |
+
mask_channel_indices = (
|
232 |
+
torch.from_numpy(mask_channel_indices)
|
233 |
+
.to(x.device)
|
234 |
+
.unsqueeze(1)
|
235 |
+
.expand(-1, T, -1)
|
236 |
+
)
|
237 |
+
x[mask_channel_indices] = 0
|
238 |
+
|
239 |
+
if self.mask_prob > 0:
|
240 |
+
if mask_indices is None:
|
241 |
+
mask_indices = compute_mask_indices(
|
242 |
+
(B, T),
|
243 |
+
padding_mask,
|
244 |
+
self.mask_prob,
|
245 |
+
self.mask_length,
|
246 |
+
self.mask_selection,
|
247 |
+
self.mask_other,
|
248 |
+
min_masks=1,
|
249 |
+
no_overlap=self.no_mask_overlap,
|
250 |
+
min_space=self.mask_min_space,
|
251 |
+
require_same_masks=self.cfg.require_same_masks,
|
252 |
+
mask_dropout=self.cfg.mask_dropout,
|
253 |
+
)
|
254 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
255 |
+
x = index_put(x, mask_indices, self.mask_emb)
|
256 |
+
else:
|
257 |
+
mask_indices = None
|
258 |
+
|
259 |
+
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
260 |
+
if mask_channel_indices is None:
|
261 |
+
mask_channel_indices = compute_mask_indices(
|
262 |
+
(B, C),
|
263 |
+
None,
|
264 |
+
self.mask_channel_prob,
|
265 |
+
self.mask_channel_length,
|
266 |
+
self.mask_channel_selection,
|
267 |
+
self.mask_channel_other,
|
268 |
+
no_overlap=self.no_mask_channel_overlap,
|
269 |
+
min_space=self.mask_channel_min_space,
|
270 |
+
)
|
271 |
+
mask_channel_indices = (
|
272 |
+
torch.from_numpy(mask_channel_indices)
|
273 |
+
.to(x.device)
|
274 |
+
.unsqueeze(1)
|
275 |
+
.expand(-1, T, -1)
|
276 |
+
)
|
277 |
+
x = index_put(x, mask_channel_indices, 0)
|
278 |
+
|
279 |
+
return x, mask_indices
|
280 |
+
|
281 |
+
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
282 |
+
"""
|
283 |
+
Computes the output length of the convolutional layers
|
284 |
+
"""
|
285 |
+
|
286 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
287 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
288 |
+
|
289 |
+
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
290 |
+
|
291 |
+
for i in range(len(conv_cfg_list)):
|
292 |
+
input_lengths = _conv_out_length(
|
293 |
+
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
294 |
+
)
|
295 |
+
|
296 |
+
return input_lengths.to(torch.long)
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
source,
|
301 |
+
padding_mask=None,
|
302 |
+
mask=True,
|
303 |
+
features_only=False,
|
304 |
+
layer=None,
|
305 |
+
mask_indices=None,
|
306 |
+
mask_channel_indices=None,
|
307 |
+
padding_count=None,
|
308 |
+
):
|
309 |
+
features = source
|
310 |
+
|
311 |
+
if self.feature_grad_mult > 0:
|
312 |
+
features = self.feature_extractor(features)
|
313 |
+
if self.feature_grad_mult != 1.0:
|
314 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
315 |
+
else:
|
316 |
+
with torch.no_grad():
|
317 |
+
features = self.feature_extractor(features)
|
318 |
+
|
319 |
+
features = features.transpose(1, 2)
|
320 |
+
|
321 |
+
features = self.layer_norm(features)
|
322 |
+
|
323 |
+
orig_padding_mask = padding_mask
|
324 |
+
|
325 |
+
if padding_mask is not None and padding_mask.any():
|
326 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
327 |
+
# apply conv formula to get real output_lengths
|
328 |
+
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
329 |
+
|
330 |
+
padding_mask = torch.zeros(
|
331 |
+
features.shape[:2], dtype=features.dtype, device=features.device
|
332 |
+
)
|
333 |
+
|
334 |
+
# these two operations makes sure that all values
|
335 |
+
# before the output lengths indices are attended to
|
336 |
+
padding_mask[
|
337 |
+
(
|
338 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
339 |
+
output_lengths - 1,
|
340 |
+
)
|
341 |
+
] = 1
|
342 |
+
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
343 |
+
else:
|
344 |
+
padding_mask = None
|
345 |
+
|
346 |
+
if self.post_extract_proj is not None:
|
347 |
+
features = self.post_extract_proj(features)
|
348 |
+
|
349 |
+
pre_encoder_features = None
|
350 |
+
if self.cfg.ema_transformer_only:
|
351 |
+
pre_encoder_features = features.clone()
|
352 |
+
|
353 |
+
features = self.dropout_input(features)
|
354 |
+
|
355 |
+
if mask:
|
356 |
+
x, mask_indices = self.apply_mask(
|
357 |
+
features,
|
358 |
+
padding_mask,
|
359 |
+
mask_indices=mask_indices,
|
360 |
+
mask_channel_indices=mask_channel_indices,
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
x = features
|
364 |
+
mask_indices = None
|
365 |
+
|
366 |
+
x, layer_results = self.encoder(
|
367 |
+
x,
|
368 |
+
padding_mask=padding_mask,
|
369 |
+
layer=layer,
|
370 |
+
)
|
371 |
+
|
372 |
+
if features_only:
|
373 |
+
return {
|
374 |
+
"x": x,
|
375 |
+
"padding_mask": padding_mask,
|
376 |
+
"layer_results": layer_results,
|
377 |
+
}
|
378 |
+
|
379 |
+
result = {
|
380 |
+
"losses": {},
|
381 |
+
}
|
382 |
+
|
383 |
+
with torch.no_grad():
|
384 |
+
self.ema.model.eval()
|
385 |
+
|
386 |
+
if self.cfg.ema_transformer_only:
|
387 |
+
y, layer_results = self.ema.model.extract_features(
|
388 |
+
pre_encoder_features,
|
389 |
+
padding_mask=padding_mask,
|
390 |
+
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
391 |
+
)
|
392 |
+
y = {
|
393 |
+
"x": y,
|
394 |
+
"padding_mask": padding_mask,
|
395 |
+
"layer_results": layer_results,
|
396 |
+
}
|
397 |
+
else:
|
398 |
+
y = self.ema.model.extract_features(
|
399 |
+
source=source,
|
400 |
+
padding_mask=orig_padding_mask,
|
401 |
+
mask=False,
|
402 |
+
)
|
403 |
+
|
404 |
+
target_layer_results = [l[2] for l in y["layer_results"]]
|
405 |
+
|
406 |
+
permuted = False
|
407 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
408 |
+
target_layer_results = [
|
409 |
+
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
410 |
+
]
|
411 |
+
permuted = True
|
412 |
+
|
413 |
+
if self.cfg.batch_norm_target_layer:
|
414 |
+
target_layer_results = [
|
415 |
+
F.batch_norm(
|
416 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
417 |
+
)
|
418 |
+
for tl in target_layer_results
|
419 |
+
]
|
420 |
+
|
421 |
+
if self.cfg.instance_norm_target_layer:
|
422 |
+
target_layer_results = [
|
423 |
+
F.instance_norm(tl.float()) for tl in target_layer_results
|
424 |
+
]
|
425 |
+
|
426 |
+
if permuted:
|
427 |
+
target_layer_results = [
|
428 |
+
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
429 |
+
]
|
430 |
+
|
431 |
+
if self.cfg.group_norm_target_layer:
|
432 |
+
target_layer_results = [
|
433 |
+
F.layer_norm(tl.float(), tl.shape[-2:])
|
434 |
+
for tl in target_layer_results
|
435 |
+
]
|
436 |
+
|
437 |
+
if self.cfg.layer_norm_target_layer:
|
438 |
+
target_layer_results = [
|
439 |
+
F.layer_norm(tl.float(), tl.shape[-1:])
|
440 |
+
for tl in target_layer_results
|
441 |
+
]
|
442 |
+
|
443 |
+
y = sum(target_layer_results) / len(target_layer_results)
|
444 |
+
|
445 |
+
if self.cfg.layer_norm_targets:
|
446 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
447 |
+
|
448 |
+
if self.cfg.instance_norm_targets:
|
449 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
450 |
+
|
451 |
+
if not permuted:
|
452 |
+
y = y.transpose(0, 1)
|
453 |
+
|
454 |
+
y = y[mask_indices]
|
455 |
+
|
456 |
+
x = x[mask_indices]
|
457 |
+
x = self.final_proj(x)
|
458 |
+
|
459 |
+
sz = x.size(-1)
|
460 |
+
|
461 |
+
if self.loss_beta == 0:
|
462 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
463 |
+
else:
|
464 |
+
loss = F.smooth_l1_loss(
|
465 |
+
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
466 |
+
).sum(dim=-1)
|
467 |
+
|
468 |
+
if self.loss_scale is not None:
|
469 |
+
scale = self.loss_scale
|
470 |
+
else:
|
471 |
+
scale = 1 / math.sqrt(sz)
|
472 |
+
|
473 |
+
result["losses"]["regression"] = loss.sum() * scale
|
474 |
+
|
475 |
+
if "sample_size" not in result:
|
476 |
+
result["sample_size"] = loss.numel()
|
477 |
+
|
478 |
+
with torch.no_grad():
|
479 |
+
result["target_var"] = self.compute_var(y)
|
480 |
+
result["pred_var"] = self.compute_var(x.float())
|
481 |
+
|
482 |
+
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
483 |
+
logger.error(
|
484 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
485 |
+
)
|
486 |
+
raise Exception(
|
487 |
+
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
488 |
+
)
|
489 |
+
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
490 |
+
logger.error(
|
491 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
492 |
+
)
|
493 |
+
raise Exception(
|
494 |
+
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
495 |
+
)
|
496 |
+
|
497 |
+
if self.ema is not None:
|
498 |
+
result["ema_decay"] = self.ema.get_decay() * 1000
|
499 |
+
|
500 |
+
return result
|
501 |
+
|
502 |
+
@staticmethod
|
503 |
+
def compute_var(y):
|
504 |
+
y = y.view(-1, y.size(-1))
|
505 |
+
if dist.is_initialized():
|
506 |
+
zc = torch.tensor(y.size(0)).cuda()
|
507 |
+
zs = y.sum(dim=0)
|
508 |
+
zss = (y ** 2).sum(dim=0)
|
509 |
+
|
510 |
+
dist.all_reduce(zc)
|
511 |
+
dist.all_reduce(zs)
|
512 |
+
dist.all_reduce(zss)
|
513 |
+
|
514 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
515 |
+
return torch.sqrt(var + 1e-6).mean()
|
516 |
+
else:
|
517 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
518 |
+
|
519 |
+
def extract_features(
|
520 |
+
self, source, padding_mask, mask=False, layer=None
|
521 |
+
):
|
522 |
+
res = self.forward(
|
523 |
+
source,
|
524 |
+
padding_mask,
|
525 |
+
mask=mask,
|
526 |
+
features_only=True,
|
527 |
+
layer=layer,
|
528 |
+
)
|
529 |
+
return res
|
530 |
+
|
531 |
+
def remove_pretraining_modules(self, last_layer=None):
|
532 |
+
self.final_proj = None
|
533 |
+
self.ema = None
|
534 |
+
if last_layer is not None:
|
535 |
+
self.encoder.layers = nn.ModuleList(
|
536 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
537 |
+
)
|
fairseq/examples/data2vec/models/data2vec_image_classification.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# The code in this file is adapted from the BeiT implementation which can be found here:
|
7 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import Any
|
13 |
+
|
14 |
+
from omegaconf import II, MISSING
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from fairseq import checkpoint_utils, tasks
|
21 |
+
|
22 |
+
from fairseq.dataclass import FairseqDataclass
|
23 |
+
from fairseq.models import BaseFairseqModel, register_model
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class Data2VecImageClassificationConfig(FairseqDataclass):
|
31 |
+
model_path: str = MISSING
|
32 |
+
no_pretrained_weights: bool = False
|
33 |
+
num_classes: int = 1000
|
34 |
+
mixup: float = 0.8
|
35 |
+
cutmix: float = 1.0
|
36 |
+
label_smoothing: float = 0.1
|
37 |
+
|
38 |
+
pretrained_model_args: Any = None
|
39 |
+
data: str = II("task.data")
|
40 |
+
|
41 |
+
|
42 |
+
@register_model(
|
43 |
+
"data2vec_image_classification", dataclass=Data2VecImageClassificationConfig
|
44 |
+
)
|
45 |
+
class Data2VecImageClassificationModel(BaseFairseqModel):
|
46 |
+
def __init__(self, cfg: Data2VecImageClassificationConfig):
|
47 |
+
super().__init__()
|
48 |
+
self.cfg = cfg
|
49 |
+
|
50 |
+
if cfg.pretrained_model_args is None:
|
51 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
52 |
+
pretrained_args = state.get("cfg", None)
|
53 |
+
pretrained_args.criterion = None
|
54 |
+
pretrained_args.lr_scheduler = None
|
55 |
+
cfg.pretrained_model_args = pretrained_args
|
56 |
+
|
57 |
+
logger.info(pretrained_args)
|
58 |
+
else:
|
59 |
+
state = None
|
60 |
+
pretrained_args = cfg.pretrained_model_args
|
61 |
+
|
62 |
+
pretrained_args.task.data = cfg.data
|
63 |
+
task = tasks.setup_task(pretrained_args.task)
|
64 |
+
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
65 |
+
|
66 |
+
model.remove_pretraining_modules()
|
67 |
+
|
68 |
+
self.model = model
|
69 |
+
|
70 |
+
if state is not None and not cfg.no_pretrained_weights:
|
71 |
+
self.load_model_weights(state, model, cfg)
|
72 |
+
|
73 |
+
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim)
|
74 |
+
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
|
75 |
+
|
76 |
+
self.head.weight.data.mul_(1e-3)
|
77 |
+
self.head.bias.data.mul_(1e-3)
|
78 |
+
|
79 |
+
self.mixup_fn = None
|
80 |
+
|
81 |
+
if cfg.mixup > 0 or cfg.cutmix > 0:
|
82 |
+
from timm.data import Mixup
|
83 |
+
|
84 |
+
self.mixup_fn = Mixup(
|
85 |
+
mixup_alpha=cfg.mixup,
|
86 |
+
cutmix_alpha=cfg.cutmix,
|
87 |
+
cutmix_minmax=None,
|
88 |
+
prob=1.0,
|
89 |
+
switch_prob=0.5,
|
90 |
+
mode="batch",
|
91 |
+
label_smoothing=cfg.label_smoothing,
|
92 |
+
num_classes=cfg.num_classes,
|
93 |
+
)
|
94 |
+
|
95 |
+
def load_model_weights(self, state, model, cfg):
|
96 |
+
if "_ema" in state["model"]:
|
97 |
+
del state["model"]["_ema"]
|
98 |
+
model.load_state_dict(state["model"], strict=True)
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
def build_model(cls, cfg: Data2VecImageClassificationConfig, task=None):
|
102 |
+
"""Build a new model instance."""
|
103 |
+
|
104 |
+
return cls(cfg)
|
105 |
+
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
img,
|
109 |
+
label=None,
|
110 |
+
):
|
111 |
+
if self.training and self.mixup_fn is not None and label is not None:
|
112 |
+
img, label = self.mixup_fn(img, label)
|
113 |
+
|
114 |
+
x = self.model(img, mask=False)
|
115 |
+
x = x[:, 1:]
|
116 |
+
x = self.fc_norm(x.mean(1))
|
117 |
+
x = self.head(x)
|
118 |
+
|
119 |
+
if label is None:
|
120 |
+
return x
|
121 |
+
|
122 |
+
if self.training and self.mixup_fn is not None:
|
123 |
+
loss = -label * F.log_softmax(x.float(), dim=-1)
|
124 |
+
else:
|
125 |
+
loss = F.cross_entropy(
|
126 |
+
x.float(),
|
127 |
+
label,
|
128 |
+
label_smoothing=self.cfg.label_smoothing if self.training else 0,
|
129 |
+
reduction="none",
|
130 |
+
)
|
131 |
+
|
132 |
+
result = {
|
133 |
+
"losses": {"regression": loss},
|
134 |
+
"sample_size": img.size(0),
|
135 |
+
}
|
136 |
+
|
137 |
+
if not self.training:
|
138 |
+
with torch.no_grad():
|
139 |
+
pred = x.argmax(-1)
|
140 |
+
correct = (pred == label).sum()
|
141 |
+
result["correct"] = correct
|
142 |
+
|
143 |
+
return result
|
fairseq/examples/data2vec/models/data2vec_text.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass, field
|
7 |
+
from typing import Optional
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from omegaconf import II
|
16 |
+
|
17 |
+
from fairseq.dataclass import FairseqDataclass
|
18 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
19 |
+
from fairseq.models import (
|
20 |
+
FairseqEncoder,
|
21 |
+
FairseqEncoderModel,
|
22 |
+
register_model,
|
23 |
+
)
|
24 |
+
from fairseq.models.roberta.model import RobertaLMHead, RobertaClassificationHead
|
25 |
+
from fairseq.models.transformer import TransformerEncoder, TransformerConfig
|
26 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class Data2VecTextConfig(FairseqDataclass):
|
33 |
+
max_positions: int = II("task.tokens_per_sample")
|
34 |
+
|
35 |
+
head_layers: int = 1
|
36 |
+
|
37 |
+
transformer: TransformerConfig = TransformerConfig()
|
38 |
+
|
39 |
+
load_checkpoint_heads: bool = field(
|
40 |
+
default=False,
|
41 |
+
metadata={"help": "(re-)register and load heads when loading checkpoints"},
|
42 |
+
)
|
43 |
+
|
44 |
+
loss_beta: float = field(
|
45 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
46 |
+
)
|
47 |
+
loss_scale: Optional[float] = field(
|
48 |
+
default=None,
|
49 |
+
metadata={
|
50 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
51 |
+
},
|
52 |
+
)
|
53 |
+
average_top_k_layers: int = field(
|
54 |
+
default=8, metadata={"help": "how many layers to average"}
|
55 |
+
)
|
56 |
+
|
57 |
+
layer_norm_target_layer: bool = False
|
58 |
+
instance_norm_target_layer: bool = False
|
59 |
+
batch_norm_target_layer: bool = False
|
60 |
+
instance_norm_targets: bool = False
|
61 |
+
layer_norm_targets: bool = False
|
62 |
+
|
63 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
64 |
+
ema_end_decay: float = field(
|
65 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
66 |
+
)
|
67 |
+
|
68 |
+
# when to finish annealing ema decay rate
|
69 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
70 |
+
|
71 |
+
ema_transformer_layers_only: bool = field(
|
72 |
+
default=True,
|
73 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
78 |
+
r = end - start
|
79 |
+
pct_remaining = 1 - curr_step / total_steps
|
80 |
+
return end - r * pct_remaining
|
81 |
+
|
82 |
+
|
83 |
+
@register_model("data2vec_text", dataclass=Data2VecTextConfig)
|
84 |
+
class Data2VecTextModel(FairseqEncoderModel):
|
85 |
+
def __init__(self, cfg: Data2VecTextConfig, encoder):
|
86 |
+
super().__init__(encoder)
|
87 |
+
self.cfg = cfg
|
88 |
+
|
89 |
+
# We follow BERT's random weight initialization
|
90 |
+
self.apply(init_bert_params)
|
91 |
+
|
92 |
+
self.classification_heads = nn.ModuleDict()
|
93 |
+
|
94 |
+
@classmethod
|
95 |
+
def build_model(cls, cfg, task):
|
96 |
+
"""Build a new model instance."""
|
97 |
+
|
98 |
+
encoder = Data2VecTextEncoder(cfg, task.source_dictionary, task.cfg.data)
|
99 |
+
|
100 |
+
return cls(cfg, encoder)
|
101 |
+
|
102 |
+
def forward(
|
103 |
+
self,
|
104 |
+
src_tokens,
|
105 |
+
target_tokens=None,
|
106 |
+
features_only=False,
|
107 |
+
return_all_hiddens=False,
|
108 |
+
classification_head_name=None,
|
109 |
+
**kwargs,
|
110 |
+
):
|
111 |
+
if classification_head_name is not None:
|
112 |
+
features_only = True
|
113 |
+
|
114 |
+
res = self.encoder(
|
115 |
+
src_tokens, target_tokens, features_only, return_all_hiddens, **kwargs
|
116 |
+
)
|
117 |
+
|
118 |
+
if isinstance(res, tuple):
|
119 |
+
x, extra = res
|
120 |
+
else:
|
121 |
+
return res
|
122 |
+
|
123 |
+
if classification_head_name is not None:
|
124 |
+
x = self.classification_heads[classification_head_name](x)
|
125 |
+
return x, extra
|
126 |
+
|
127 |
+
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
128 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
129 |
+
logits = net_output[0].float()
|
130 |
+
if log_probs:
|
131 |
+
return F.log_softmax(logits, dim=-1)
|
132 |
+
else:
|
133 |
+
return F.softmax(logits, dim=-1)
|
134 |
+
|
135 |
+
def register_classification_head(
|
136 |
+
self, name, num_classes=None, inner_dim=None, **kwargs
|
137 |
+
):
|
138 |
+
"""Register a classification head."""
|
139 |
+
if name in self.classification_heads:
|
140 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
141 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
142 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
143 |
+
logger.warning(
|
144 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
145 |
+
"and inner_dim {} (prev: {})".format(
|
146 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
147 |
+
)
|
148 |
+
)
|
149 |
+
self.classification_heads[name] = RobertaClassificationHead(
|
150 |
+
input_dim=self.cfg.transformer.encoder.embed_dim,
|
151 |
+
inner_dim=inner_dim or self.cfg.transformer.encoder.embed_dim,
|
152 |
+
num_classes=num_classes,
|
153 |
+
activation_fn="tanh",
|
154 |
+
pooler_dropout=0,
|
155 |
+
)
|
156 |
+
|
157 |
+
@property
|
158 |
+
def supported_targets(self):
|
159 |
+
return {"self"}
|
160 |
+
|
161 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
162 |
+
prefix = name + "." if name != "" else ""
|
163 |
+
|
164 |
+
# rename decoder -> encoder before upgrading children modules
|
165 |
+
for k in list(state_dict.keys()):
|
166 |
+
if k.startswith(prefix + "decoder"):
|
167 |
+
new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
|
168 |
+
state_dict[new_k] = state_dict[k]
|
169 |
+
del state_dict[k]
|
170 |
+
|
171 |
+
# rename emb_layer_norm -> layernorm_embedding
|
172 |
+
for k in list(state_dict.keys()):
|
173 |
+
if ".emb_layer_norm." in k:
|
174 |
+
new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
|
175 |
+
state_dict[new_k] = state_dict[k]
|
176 |
+
del state_dict[k]
|
177 |
+
|
178 |
+
if self.encoder.regression_head is not None:
|
179 |
+
if ".lm_head." in k:
|
180 |
+
new_k = k.replace(".lm_head.", ".regression_head.")
|
181 |
+
state_dict[new_k] = state_dict[k]
|
182 |
+
del state_dict[k]
|
183 |
+
else:
|
184 |
+
if ".regression_head." in k:
|
185 |
+
del state_dict[k]
|
186 |
+
|
187 |
+
# upgrade children modules
|
188 |
+
super().upgrade_state_dict_named(state_dict, name)
|
189 |
+
|
190 |
+
# Handle new classification heads present in the state dict.
|
191 |
+
current_head_names = (
|
192 |
+
[]
|
193 |
+
if not hasattr(self, "classification_heads")
|
194 |
+
or self.classification_heads is None
|
195 |
+
else self.classification_heads.keys()
|
196 |
+
)
|
197 |
+
keys_to_delete = []
|
198 |
+
for k in state_dict.keys():
|
199 |
+
if not k.startswith(prefix + "classification_heads."):
|
200 |
+
continue
|
201 |
+
|
202 |
+
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
|
203 |
+
num_classes = state_dict[
|
204 |
+
prefix + "classification_heads." + head_name + ".out_proj.weight"
|
205 |
+
].size(0)
|
206 |
+
inner_dim = state_dict[
|
207 |
+
prefix + "classification_heads." + head_name + ".dense.weight"
|
208 |
+
].size(0)
|
209 |
+
|
210 |
+
if self.cfg.load_checkpoint_heads:
|
211 |
+
if head_name not in current_head_names:
|
212 |
+
self.register_classification_head(head_name, num_classes, inner_dim)
|
213 |
+
else:
|
214 |
+
if head_name not in current_head_names:
|
215 |
+
logger.warning(
|
216 |
+
"deleting classification head ({}) from checkpoint "
|
217 |
+
"not present in current model: {}".format(head_name, k)
|
218 |
+
)
|
219 |
+
keys_to_delete.append(k)
|
220 |
+
elif (
|
221 |
+
num_classes
|
222 |
+
!= self.classification_heads[head_name].out_proj.out_features
|
223 |
+
or inner_dim
|
224 |
+
!= self.classification_heads[head_name].dense.out_features
|
225 |
+
):
|
226 |
+
logger.warning(
|
227 |
+
"deleting classification head ({}) from checkpoint "
|
228 |
+
"with different dimensions than current model: {}".format(
|
229 |
+
head_name, k
|
230 |
+
)
|
231 |
+
)
|
232 |
+
keys_to_delete.append(k)
|
233 |
+
for k in keys_to_delete:
|
234 |
+
del state_dict[k]
|
235 |
+
|
236 |
+
# Copy any newly-added classification heads into the state dict
|
237 |
+
# with their current weights.
|
238 |
+
if (
|
239 |
+
hasattr(self, "classification_heads")
|
240 |
+
and self.classification_heads is not None
|
241 |
+
and len(self.classification_heads) > 0
|
242 |
+
):
|
243 |
+
cur_state = self.classification_heads.state_dict()
|
244 |
+
for k, v in cur_state.items():
|
245 |
+
if prefix + "classification_heads." + k not in state_dict:
|
246 |
+
logger.info("Overwriting " + prefix + "classification_heads." + k)
|
247 |
+
state_dict[prefix + "classification_heads." + k] = v
|
248 |
+
|
249 |
+
for k in list(state_dict.keys()):
|
250 |
+
if k.startswith(prefix + "encoder.lm_head.") or k.startswith(
|
251 |
+
prefix + "encoder.emb_head."
|
252 |
+
):
|
253 |
+
del state_dict[k]
|
254 |
+
|
255 |
+
self.encoder.lm_head = None
|
256 |
+
|
257 |
+
if self.encoder.target_model is None:
|
258 |
+
for k in list(state_dict.keys()):
|
259 |
+
if k.startswith(prefix + "encoder.target_model."):
|
260 |
+
del state_dict[k]
|
261 |
+
|
262 |
+
if (self.encoder.ema is None) and (prefix + "encoder._ema" in state_dict):
|
263 |
+
del state_dict[prefix + "encoder._ema"]
|
264 |
+
|
265 |
+
def remove_pretraining_modules(self, last_layer=None):
|
266 |
+
self.encoder.lm_head = None
|
267 |
+
self.encoder.regression_head = None
|
268 |
+
self.encoder.ema = None
|
269 |
+
self.classification_heads = None
|
270 |
+
|
271 |
+
if last_layer is not None:
|
272 |
+
self.encoder.sentence_encoder.layers = nn.ModuleList(
|
273 |
+
l
|
274 |
+
for i, l in enumerate(self.encoder.sentence_encoder.layers)
|
275 |
+
if i <= last_layer
|
276 |
+
)
|
277 |
+
self.encoder.sentence_encoder.layer_norm = None
|
278 |
+
|
279 |
+
|
280 |
+
class Data2VecTextEncoder(FairseqEncoder):
|
281 |
+
def __init__(self, cfg: Data2VecTextConfig, dictionary, task_data):
|
282 |
+
super().__init__(dictionary)
|
283 |
+
|
284 |
+
self.cfg = cfg
|
285 |
+
|
286 |
+
embed_tokens = self.build_embedding(
|
287 |
+
len(dictionary), cfg.transformer.encoder.embed_dim, dictionary.pad()
|
288 |
+
)
|
289 |
+
|
290 |
+
self.sentence_encoder = self.build_encoder(cfg, dictionary, embed_tokens)
|
291 |
+
self.mask_idx = dictionary.index("<mask>")
|
292 |
+
assert self.mask_idx != dictionary.unk(), dictionary.symbols
|
293 |
+
|
294 |
+
self.ema = None
|
295 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
296 |
+
self.loss_scale = cfg.loss_scale
|
297 |
+
|
298 |
+
assert self.cfg.head_layers >= 1
|
299 |
+
|
300 |
+
embed_dim = cfg.transformer.encoder.embed_dim
|
301 |
+
curr_dim = embed_dim
|
302 |
+
projs = []
|
303 |
+
for i in range(self.cfg.head_layers - 1):
|
304 |
+
next_dim = embed_dim * 2 if i == 0 else curr_dim
|
305 |
+
projs.append(nn.Linear(curr_dim, next_dim))
|
306 |
+
projs.append(nn.GELU())
|
307 |
+
curr_dim = next_dim
|
308 |
+
|
309 |
+
projs.append(nn.Linear(curr_dim, embed_dim))
|
310 |
+
self.regression_head = nn.Sequential(*projs)
|
311 |
+
|
312 |
+
self.num_updates = 0
|
313 |
+
|
314 |
+
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
|
315 |
+
return nn.Embedding(vocab_size, embedding_dim, padding_idx)
|
316 |
+
|
317 |
+
def build_encoder(self, cfg, dictionary, embed_tokens):
|
318 |
+
encoder = TransformerEncoder(cfg.transformer, dictionary, embed_tokens, return_fc=True)
|
319 |
+
encoder.apply(init_bert_params)
|
320 |
+
return encoder
|
321 |
+
|
322 |
+
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
|
323 |
+
return RobertaLMHead(embed_dim, output_dim, activation_fn, weight)
|
324 |
+
|
325 |
+
def make_ema_teacher(self):
|
326 |
+
ema_config = EMAModuleConfig(
|
327 |
+
ema_decay=self.cfg.ema_decay,
|
328 |
+
ema_fp32=True,
|
329 |
+
)
|
330 |
+
skip_keys = set()
|
331 |
+
if self.cfg.ema_transformer_layers_only:
|
332 |
+
for k, _ in self.sentence_encoder.embed_positions.named_parameters():
|
333 |
+
skip_keys.add(f"embed_tokens.{k}")
|
334 |
+
for k, _ in self.sentence_encoder.embed_positions.named_parameters():
|
335 |
+
skip_keys.add(f"embed_positions.{k}")
|
336 |
+
if self.sentence_encoder.layernorm_embedding is not None:
|
337 |
+
for (
|
338 |
+
k,
|
339 |
+
_,
|
340 |
+
) in self.sentence_encoder.layernorm_embedding.named_parameters():
|
341 |
+
skip_keys.add(f"layernorm_embedding.{k}")
|
342 |
+
if self.sentence_encoder.layer_norm is not None:
|
343 |
+
for k, _ in self.sentence_encoder.layer_norm.named_parameters():
|
344 |
+
skip_keys.add(f"layernorm_embedding.{k}")
|
345 |
+
|
346 |
+
self.ema = EMAModule(
|
347 |
+
self.sentence_encoder,
|
348 |
+
ema_config,
|
349 |
+
skip_keys=skip_keys,
|
350 |
+
)
|
351 |
+
|
352 |
+
def set_num_updates(self, num_updates):
|
353 |
+
super().set_num_updates(num_updates)
|
354 |
+
|
355 |
+
if self.ema is None and self.regression_head is not None:
|
356 |
+
logger.info(f"making ema teacher")
|
357 |
+
self.make_ema_teacher()
|
358 |
+
elif self.training and self.ema is not None:
|
359 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
360 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
361 |
+
decay = self.cfg.ema_end_decay
|
362 |
+
else:
|
363 |
+
decay = get_annealed_rate(
|
364 |
+
self.cfg.ema_decay,
|
365 |
+
self.cfg.ema_end_decay,
|
366 |
+
num_updates,
|
367 |
+
self.cfg.ema_anneal_end_step,
|
368 |
+
)
|
369 |
+
self.ema.set_decay(decay)
|
370 |
+
if self.ema.get_decay() < 1:
|
371 |
+
self.ema.step(self.sentence_encoder)
|
372 |
+
|
373 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
374 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
375 |
+
if self.ema is not None:
|
376 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
377 |
+
return state
|
378 |
+
|
379 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
380 |
+
if self.ema is not None:
|
381 |
+
k = prefix + "_ema"
|
382 |
+
assert k in state_dict
|
383 |
+
self.ema.restore(state_dict[k], True)
|
384 |
+
del state_dict[k]
|
385 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
386 |
+
|
387 |
+
def forward(
|
388 |
+
self,
|
389 |
+
src_tokens,
|
390 |
+
target_tokens=None,
|
391 |
+
features_only=False,
|
392 |
+
return_all_hiddens=False,
|
393 |
+
masked_tokens=None,
|
394 |
+
**unused,
|
395 |
+
):
|
396 |
+
"""
|
397 |
+
Args:
|
398 |
+
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
|
399 |
+
features_only (bool, optional): skip LM head and just return
|
400 |
+
features. If True, the output will be of shape
|
401 |
+
`(batch, src_len, embed_dim)`.
|
402 |
+
return_all_hiddens (bool, optional): also return all of the
|
403 |
+
intermediate hidden states (default: False).
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
tuple:
|
407 |
+
- the LM output of shape `(batch, src_len, vocab)`
|
408 |
+
- a dictionary of additional data, where 'inner_states'
|
409 |
+
is a list of hidden states. Note that the hidden
|
410 |
+
states have shape `(src_len, batch, vocab)`.
|
411 |
+
"""
|
412 |
+
|
413 |
+
x, extra = self.extract_features(
|
414 |
+
src_tokens, return_all_hiddens=return_all_hiddens
|
415 |
+
)
|
416 |
+
|
417 |
+
if features_only:
|
418 |
+
return x, extra
|
419 |
+
|
420 |
+
assert target_tokens is not None
|
421 |
+
|
422 |
+
with torch.no_grad():
|
423 |
+
# use EMA parameter as the teacher
|
424 |
+
self.ema.model.eval()
|
425 |
+
|
426 |
+
encoder_out = self.ema.model(
|
427 |
+
target_tokens,
|
428 |
+
return_all_hiddens=True,
|
429 |
+
)
|
430 |
+
y = encoder_out["fc_results"]
|
431 |
+
|
432 |
+
y = y[-self.average_top_k_layers :]
|
433 |
+
|
434 |
+
permuted = False
|
435 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
436 |
+
y = [tl.permute(1, 2, 0) for tl in y] # TBC -> BCT
|
437 |
+
permuted = True
|
438 |
+
|
439 |
+
if self.cfg.batch_norm_target_layer:
|
440 |
+
y = [
|
441 |
+
F.batch_norm(
|
442 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
443 |
+
)
|
444 |
+
for tl in y
|
445 |
+
]
|
446 |
+
|
447 |
+
if self.cfg.instance_norm_target_layer:
|
448 |
+
y = [F.instance_norm(tl.float()) for tl in y]
|
449 |
+
|
450 |
+
if permuted:
|
451 |
+
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
|
452 |
+
|
453 |
+
if self.cfg.layer_norm_target_layer:
|
454 |
+
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
|
455 |
+
|
456 |
+
y = sum(y) / len(y)
|
457 |
+
|
458 |
+
if not permuted:
|
459 |
+
y = y.transpose(0, 1)
|
460 |
+
|
461 |
+
if self.cfg.layer_norm_targets:
|
462 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
463 |
+
|
464 |
+
if self.cfg.instance_norm_targets:
|
465 |
+
y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2)
|
466 |
+
|
467 |
+
masked_indices = src_tokens.eq(self.mask_idx)
|
468 |
+
|
469 |
+
x = x[masked_indices]
|
470 |
+
y = y[masked_indices]
|
471 |
+
|
472 |
+
x = self.regression_head(x)
|
473 |
+
|
474 |
+
sz = x.size(-1)
|
475 |
+
if self.cfg.loss_beta == 0:
|
476 |
+
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
477 |
+
else:
|
478 |
+
loss = F.smooth_l1_loss(
|
479 |
+
x.float(), y.float(), reduction="none", beta=self.cfg.loss_beta
|
480 |
+
).sum(dim=-1)
|
481 |
+
|
482 |
+
result = {
|
483 |
+
"losses": {
|
484 |
+
"main": loss.sum() / math.sqrt(sz)
|
485 |
+
if self.loss_scale <= 0
|
486 |
+
else loss.sum() * self.loss_scale,
|
487 |
+
},
|
488 |
+
"sample_size": loss.numel(),
|
489 |
+
}
|
490 |
+
|
491 |
+
# logging other values
|
492 |
+
other_logs = {
|
493 |
+
"ema_decay": self.ema.get_decay() * 1000
|
494 |
+
}
|
495 |
+
result["logs"] = other_logs
|
496 |
+
return result
|
497 |
+
|
498 |
+
def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
|
499 |
+
encoder_out = self.sentence_encoder(
|
500 |
+
src_tokens,
|
501 |
+
return_all_hiddens=return_all_hiddens,
|
502 |
+
token_embeddings=kwargs.get("token_embeddings", None),
|
503 |
+
)
|
504 |
+
# T x B x C -> B x T x C
|
505 |
+
features = encoder_out["encoder_out"][0].transpose(0, 1)
|
506 |
+
inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
|
507 |
+
return features, {
|
508 |
+
"inner_states": inner_states,
|
509 |
+
"encoder_embedding": encoder_out["encoder_embedding"][0],
|
510 |
+
}
|
511 |
+
|
512 |
+
def output_layer(self, features, masked_tokens=None, **unused):
|
513 |
+
return self.lm_head(features, masked_tokens)
|
514 |
+
|
515 |
+
def max_positions(self):
|
516 |
+
"""Maximum output length supported by the encoder."""
|
517 |
+
return self.cfg.max_positions
|
fairseq/examples/data2vec/models/data2vec_text_classification.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# The code in this file is adapted from the BeiT implementation which can be found here:
|
7 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import Any
|
13 |
+
|
14 |
+
from omegaconf import II, MISSING
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
|
20 |
+
from fairseq import checkpoint_utils, tasks
|
21 |
+
|
22 |
+
from fairseq.dataclass import FairseqDataclass
|
23 |
+
from fairseq.models import BaseFairseqModel, register_model
|
24 |
+
from fairseq.models.roberta.model import RobertaClassificationHead
|
25 |
+
|
26 |
+
from examples.data2vec.data.modality import Modality
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Data2VecTextClassificationConfig(FairseqDataclass):
|
34 |
+
pooler_dropout: float = 0.0
|
35 |
+
pooler_activation_fn: str = "tanh"
|
36 |
+
quant_noise_pq: int = 0
|
37 |
+
quant_noise_pq_block_size: int = 8
|
38 |
+
spectral_norm_classification_head: bool = False
|
39 |
+
|
40 |
+
model_path: str = MISSING
|
41 |
+
no_pretrained_weights: bool = False
|
42 |
+
|
43 |
+
pretrained_model_args: Any = None
|
44 |
+
|
45 |
+
|
46 |
+
@register_model(
|
47 |
+
"data2vec_text_classification", dataclass=Data2VecTextClassificationConfig
|
48 |
+
)
|
49 |
+
class Data2VecTextClassificationModel(BaseFairseqModel):
|
50 |
+
def __init__(self, cfg: Data2VecTextClassificationConfig):
|
51 |
+
super().__init__()
|
52 |
+
self.cfg = cfg
|
53 |
+
|
54 |
+
if cfg.pretrained_model_args is None:
|
55 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
56 |
+
pretrained_args = state.get("cfg", None)
|
57 |
+
pretrained_args.criterion = None
|
58 |
+
pretrained_args.lr_scheduler = None
|
59 |
+
cfg.pretrained_model_args = pretrained_args
|
60 |
+
|
61 |
+
logger.info(pretrained_args)
|
62 |
+
else:
|
63 |
+
state = None
|
64 |
+
pretrained_args = cfg.pretrained_model_args
|
65 |
+
|
66 |
+
task = tasks.setup_task(pretrained_args.task)
|
67 |
+
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
68 |
+
|
69 |
+
model.remove_pretraining_modules()
|
70 |
+
|
71 |
+
self.model = model
|
72 |
+
|
73 |
+
if state is not None and not cfg.no_pretrained_weights:
|
74 |
+
self.load_model_weights(state, model, cfg)
|
75 |
+
|
76 |
+
self.classification_heads = nn.ModuleDict()
|
77 |
+
|
78 |
+
|
79 |
+
def load_model_weights(self, state, model, cfg):
|
80 |
+
for k in list(state["model"].keys()):
|
81 |
+
if (
|
82 |
+
k.startswith("shared_decoder") or
|
83 |
+
k.startswith("_ema") or
|
84 |
+
"decoder" in k
|
85 |
+
):
|
86 |
+
logger.info(f"Deleting {k} from checkpoint")
|
87 |
+
del state["model"][k]
|
88 |
+
model.load_state_dict(state["model"], strict=True)
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def build_model(cls, cfg: Data2VecTextClassificationConfig, task=None):
|
92 |
+
"""Build a new model instance."""
|
93 |
+
|
94 |
+
return cls(cfg)
|
95 |
+
|
96 |
+
def register_classification_head(
|
97 |
+
self, name, num_classes=None, inner_dim=None, **kwargs
|
98 |
+
):
|
99 |
+
"""Register a classification head."""
|
100 |
+
if name in self.classification_heads:
|
101 |
+
prev_num_classes = self.classification_heads[name].out_proj.out_features
|
102 |
+
prev_inner_dim = self.classification_heads[name].dense.out_features
|
103 |
+
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
|
104 |
+
logger.warning(
|
105 |
+
're-registering head "{}" with num_classes {} (prev: {}) '
|
106 |
+
"and inner_dim {} (prev: {})".format(
|
107 |
+
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
|
108 |
+
)
|
109 |
+
)
|
110 |
+
embed_dim = self.cfg.pretrained_model_args.model.embed_dim
|
111 |
+
self.classification_heads[name] = RobertaClassificationHead(
|
112 |
+
input_dim=embed_dim,
|
113 |
+
inner_dim=inner_dim or embed_dim,
|
114 |
+
num_classes=num_classes,
|
115 |
+
activation_fn=self.cfg.pooler_activation_fn,
|
116 |
+
pooler_dropout=self.cfg.pooler_dropout,
|
117 |
+
q_noise=self.cfg.quant_noise_pq,
|
118 |
+
qn_block_size=self.cfg.quant_noise_pq_block_size,
|
119 |
+
do_spectral_norm=self.cfg.spectral_norm_classification_head,
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(
|
123 |
+
self,
|
124 |
+
source,
|
125 |
+
id,
|
126 |
+
padding_mask,
|
127 |
+
features_only=True,
|
128 |
+
remove_extra_tokens=True,
|
129 |
+
classification_head_name=None,
|
130 |
+
):
|
131 |
+
encoder_out = self.model(
|
132 |
+
source,
|
133 |
+
id=id,
|
134 |
+
mode=Modality.TEXT,
|
135 |
+
padding_mask=padding_mask,
|
136 |
+
mask=False,
|
137 |
+
features_only=features_only,
|
138 |
+
remove_extra_tokens=remove_extra_tokens
|
139 |
+
)
|
140 |
+
logits = self.classification_heads[classification_head_name](encoder_out["x"])
|
141 |
+
return logits, encoder_out
|
fairseq/examples/data2vec/models/data2vec_vision.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# The code in this file is adapted from the BeiT implementation which can be found here:
|
7 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
8 |
+
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
|
14 |
+
from dataclasses import dataclass, field
|
15 |
+
from typing import Optional
|
16 |
+
|
17 |
+
from omegaconf import II
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.distributed as dist
|
23 |
+
|
24 |
+
from fairseq.modules import EMAModule, EMAModuleConfig
|
25 |
+
from fairseq.dataclass import FairseqDataclass
|
26 |
+
from fairseq.models import BaseFairseqModel, register_model
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Data2VecVisionConfig(FairseqDataclass):
|
34 |
+
layer_scale_init_value: float = field(
|
35 |
+
default=1e-4, metadata={"help": "rescale layer outputs, 0 to disable"}
|
36 |
+
)
|
37 |
+
num_mask_patches: int = field(
|
38 |
+
default=75,
|
39 |
+
metadata={"help": "number of the visual tokens/patches need be masked"},
|
40 |
+
)
|
41 |
+
min_mask_patches_per_block: int = 16
|
42 |
+
max_mask_patches_per_block: int = 196
|
43 |
+
image_size: int = 224
|
44 |
+
patch_size: int = 16
|
45 |
+
in_channels: int = 3
|
46 |
+
|
47 |
+
shared_rel_pos_bias: bool = True
|
48 |
+
|
49 |
+
drop_path: float = 0.1
|
50 |
+
attention_dropout: float = 0.0
|
51 |
+
|
52 |
+
depth: int = 12
|
53 |
+
embed_dim: int = 768
|
54 |
+
num_heads: int = 12
|
55 |
+
mlp_ratio: int = 4
|
56 |
+
|
57 |
+
loss_beta: float = field(
|
58 |
+
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
59 |
+
)
|
60 |
+
loss_scale: Optional[float] = field(
|
61 |
+
default=None,
|
62 |
+
metadata={
|
63 |
+
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
64 |
+
},
|
65 |
+
)
|
66 |
+
average_top_k_layers: int = field(
|
67 |
+
default=8, metadata={"help": "how many layers to average"}
|
68 |
+
)
|
69 |
+
|
70 |
+
end_of_block_targets: bool = True
|
71 |
+
layer_norm_target_layer: bool = False
|
72 |
+
instance_norm_target_layer: bool = False
|
73 |
+
batch_norm_target_layer: bool = False
|
74 |
+
instance_norm_targets: bool = False
|
75 |
+
layer_norm_targets: bool = False
|
76 |
+
|
77 |
+
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
78 |
+
ema_end_decay: float = field(
|
79 |
+
default=0.9999, metadata={"help": "final ema decay rate"}
|
80 |
+
)
|
81 |
+
|
82 |
+
# when to finish annealing ema decay rate
|
83 |
+
ema_anneal_end_step: int = II("optimization.max_update")
|
84 |
+
|
85 |
+
ema_transformer_only: bool = field(
|
86 |
+
default=True,
|
87 |
+
metadata={"help": "whether to momentum update only the transformer layers"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
92 |
+
r = end - start
|
93 |
+
pct_remaining = 1 - curr_step / total_steps
|
94 |
+
return end - r * pct_remaining
|
95 |
+
|
96 |
+
|
97 |
+
@register_model("data2vec_vision", dataclass=Data2VecVisionConfig)
|
98 |
+
class Data2VecVisionModel(BaseFairseqModel):
|
99 |
+
def __init__(self, cfg: Data2VecVisionConfig):
|
100 |
+
super().__init__()
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
self.ema = None
|
104 |
+
|
105 |
+
self.average_top_k_layers = cfg.average_top_k_layers
|
106 |
+
self.loss_beta = cfg.loss_beta
|
107 |
+
self.loss_scale = (
|
108 |
+
cfg.loss_scale
|
109 |
+
if cfg.loss_scale is not None
|
110 |
+
else 1 / math.sqrt(cfg.embed_dim)
|
111 |
+
)
|
112 |
+
|
113 |
+
self.patch_embed = PatchEmbed(
|
114 |
+
img_size=cfg.image_size,
|
115 |
+
patch_size=cfg.patch_size,
|
116 |
+
in_chans=cfg.in_channels,
|
117 |
+
embed_dim=cfg.embed_dim,
|
118 |
+
)
|
119 |
+
|
120 |
+
patch_size = self.patch_embed.patch_size
|
121 |
+
self.window_size = (
|
122 |
+
cfg.image_size // patch_size[0],
|
123 |
+
cfg.image_size // patch_size[1],
|
124 |
+
)
|
125 |
+
|
126 |
+
self.cls_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
|
127 |
+
self.mask_emb = nn.Parameter(torch.FloatTensor(1, 1, cfg.embed_dim))
|
128 |
+
|
129 |
+
nn.init.trunc_normal_(self.cls_emb, 0.02)
|
130 |
+
nn.init.trunc_normal_(self.mask_emb, 0.02)
|
131 |
+
|
132 |
+
self.encoder = TransformerEncoder(cfg, self.patch_embed.patch_shape)
|
133 |
+
|
134 |
+
self.final_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim)
|
135 |
+
self.num_updates = 0
|
136 |
+
|
137 |
+
def make_ema_teacher(self):
|
138 |
+
ema_config = EMAModuleConfig(
|
139 |
+
ema_decay=self.cfg.ema_decay,
|
140 |
+
ema_fp32=True,
|
141 |
+
)
|
142 |
+
self.ema = EMAModule(
|
143 |
+
self.encoder if self.cfg.ema_transformer_only else self,
|
144 |
+
ema_config,
|
145 |
+
)
|
146 |
+
|
147 |
+
def set_num_updates(self, num_updates):
|
148 |
+
super().set_num_updates(num_updates)
|
149 |
+
|
150 |
+
if self.ema is None and self.final_proj is not None:
|
151 |
+
logger.info(f"making ema teacher")
|
152 |
+
self.make_ema_teacher()
|
153 |
+
elif self.training and self.ema is not None:
|
154 |
+
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
155 |
+
if num_updates >= self.cfg.ema_anneal_end_step:
|
156 |
+
decay = self.cfg.ema_end_decay
|
157 |
+
else:
|
158 |
+
decay = get_annealed_rate(
|
159 |
+
self.cfg.ema_decay,
|
160 |
+
self.cfg.ema_end_decay,
|
161 |
+
num_updates,
|
162 |
+
self.cfg.ema_anneal_end_step,
|
163 |
+
)
|
164 |
+
self.ema.set_decay(decay)
|
165 |
+
if self.ema.get_decay() < 1:
|
166 |
+
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
167 |
+
|
168 |
+
self.num_updates = num_updates
|
169 |
+
|
170 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
171 |
+
state = super().state_dict(destination, prefix, keep_vars)
|
172 |
+
|
173 |
+
if self.ema is not None:
|
174 |
+
state[prefix + "_ema"] = self.ema.fp32_params
|
175 |
+
|
176 |
+
return state
|
177 |
+
|
178 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
179 |
+
if self.ema is not None:
|
180 |
+
k = prefix + "_ema"
|
181 |
+
assert k in state_dict
|
182 |
+
self.ema.restore(state_dict[k], True)
|
183 |
+
del state_dict[k]
|
184 |
+
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def build_model(cls, cfg: Data2VecVisionConfig, task=None):
|
188 |
+
"""Build a new model instance."""
|
189 |
+
|
190 |
+
return cls(cfg)
|
191 |
+
|
192 |
+
def make_mask(self, bsz, num_masks, min_masks, max_masks):
|
193 |
+
height, width = self.window_size
|
194 |
+
|
195 |
+
masks = np.zeros(shape=(bsz, height, width), dtype=np.int)
|
196 |
+
|
197 |
+
for i in range(bsz):
|
198 |
+
mask = masks[i]
|
199 |
+
mask_count = 0
|
200 |
+
|
201 |
+
min_aspect = 0.3
|
202 |
+
max_aspect = 1 / min_aspect
|
203 |
+
log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
204 |
+
|
205 |
+
def _mask(mask, max_mask_patches):
|
206 |
+
delta = 0
|
207 |
+
for attempt in range(10):
|
208 |
+
target_area = random.uniform(min_masks, max_mask_patches)
|
209 |
+
aspect_ratio = math.exp(random.uniform(*log_aspect_ratio))
|
210 |
+
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
211 |
+
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
212 |
+
if w < width and h < height:
|
213 |
+
top = random.randint(0, height - h)
|
214 |
+
left = random.randint(0, width - w)
|
215 |
+
|
216 |
+
num_masked = mask[top : top + h, left : left + w].sum()
|
217 |
+
# Overlap
|
218 |
+
if 0 < h * w - num_masked <= max_mask_patches:
|
219 |
+
for i in range(top, top + h):
|
220 |
+
for j in range(left, left + w):
|
221 |
+
if mask[i, j] == 0:
|
222 |
+
mask[i, j] = 1
|
223 |
+
delta += 1
|
224 |
+
|
225 |
+
if delta > 0:
|
226 |
+
break
|
227 |
+
return delta
|
228 |
+
|
229 |
+
while mask_count < num_masks:
|
230 |
+
max_mask_patches = min(num_masks - mask_count, max_masks)
|
231 |
+
|
232 |
+
delta = _mask(mask, max_mask_patches)
|
233 |
+
if delta == 0:
|
234 |
+
break
|
235 |
+
else:
|
236 |
+
mask_count += delta
|
237 |
+
|
238 |
+
return torch.from_numpy(masks)
|
239 |
+
|
240 |
+
def forward(
|
241 |
+
self,
|
242 |
+
img,
|
243 |
+
mask: bool = True,
|
244 |
+
layer_results: bool = False,
|
245 |
+
):
|
246 |
+
x = self.patch_embed(img)
|
247 |
+
batch_size, seq_len, _ = x.size()
|
248 |
+
|
249 |
+
if mask:
|
250 |
+
mask_indices = self.make_mask(
|
251 |
+
img.size(0),
|
252 |
+
self.cfg.num_mask_patches,
|
253 |
+
self.cfg.min_mask_patches_per_block,
|
254 |
+
self.cfg.max_mask_patches_per_block,
|
255 |
+
)
|
256 |
+
bool_mask = mask_indices.view(mask_indices.size(0), -1).bool()
|
257 |
+
else:
|
258 |
+
mask_indices = bool_mask = None
|
259 |
+
|
260 |
+
cls_tokens = self.cls_emb.expand(batch_size, -1, -1)
|
261 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
262 |
+
|
263 |
+
if self.ema is not None:
|
264 |
+
with torch.no_grad():
|
265 |
+
self.ema.model.eval()
|
266 |
+
|
267 |
+
if self.cfg.ema_transformer_only:
|
268 |
+
y = self.ema.model(
|
269 |
+
x,
|
270 |
+
layer_results="end" if self.cfg.end_of_block_targets else "fc",
|
271 |
+
)
|
272 |
+
else:
|
273 |
+
y = self.ema.model(
|
274 |
+
img,
|
275 |
+
mask=False,
|
276 |
+
layer_results=True,
|
277 |
+
)
|
278 |
+
|
279 |
+
y = y[-self.cfg.average_top_k_layers :]
|
280 |
+
|
281 |
+
permuted = False
|
282 |
+
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
283 |
+
y = [tl.transpose(1, 2) for tl in y] # BTC -> BCT
|
284 |
+
permuted = True
|
285 |
+
|
286 |
+
if self.cfg.batch_norm_target_layer:
|
287 |
+
y = [
|
288 |
+
F.batch_norm(
|
289 |
+
tl.float(), running_mean=None, running_var=None, training=True
|
290 |
+
)
|
291 |
+
for tl in y
|
292 |
+
]
|
293 |
+
|
294 |
+
if self.cfg.instance_norm_target_layer:
|
295 |
+
y = [F.instance_norm(tl.float()) for tl in y]
|
296 |
+
|
297 |
+
if permuted:
|
298 |
+
y = [tl.transpose(1, 2) for tl in y] # BCT -> BTC
|
299 |
+
|
300 |
+
if self.cfg.layer_norm_target_layer:
|
301 |
+
y = [F.layer_norm(tl.float(), tl.shape[-1:]) for tl in y]
|
302 |
+
|
303 |
+
y = sum(y) / len(y)
|
304 |
+
|
305 |
+
if self.cfg.layer_norm_targets:
|
306 |
+
y = F.layer_norm(y.float(), y.shape[-1:])
|
307 |
+
|
308 |
+
if self.cfg.instance_norm_targets:
|
309 |
+
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
310 |
+
|
311 |
+
y = y[bool_mask].float()
|
312 |
+
|
313 |
+
if mask_indices is not None:
|
314 |
+
mask_token = self.mask_emb.expand(batch_size, seq_len, -1)
|
315 |
+
w = mask_indices.view(mask_indices.size(0), -1, 1).type_as(mask_token)
|
316 |
+
x[:, 1:] = x[:, 1:] * (1 - w) + mask_token * w
|
317 |
+
|
318 |
+
if layer_results:
|
319 |
+
enc_layer_results = "end" if self.cfg.end_of_block_targets else "fc"
|
320 |
+
else:
|
321 |
+
enc_layer_results = None
|
322 |
+
|
323 |
+
x = self.encoder(x, layer_results=enc_layer_results)
|
324 |
+
if layer_results or mask_indices is None:
|
325 |
+
return x
|
326 |
+
|
327 |
+
x = x[bool_mask].float()
|
328 |
+
|
329 |
+
if self.loss_beta == 0:
|
330 |
+
loss = F.mse_loss(x, y, reduction="none").sum(dim=-1)
|
331 |
+
else:
|
332 |
+
loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta).sum(
|
333 |
+
dim=-1
|
334 |
+
)
|
335 |
+
|
336 |
+
if self.loss_scale > 0:
|
337 |
+
loss = loss * self.loss_scale
|
338 |
+
|
339 |
+
result = {
|
340 |
+
"losses": {"regression": loss.sum()},
|
341 |
+
"sample_size": loss.numel(),
|
342 |
+
"target_var": self.compute_var(y),
|
343 |
+
"pred_var": self.compute_var(x),
|
344 |
+
"ema_decay": self.ema.get_decay() * 1000,
|
345 |
+
}
|
346 |
+
return result
|
347 |
+
|
348 |
+
@staticmethod
|
349 |
+
def compute_var(y):
|
350 |
+
y = y.view(-1, y.size(-1))
|
351 |
+
if dist.is_initialized():
|
352 |
+
zc = torch.tensor(y.size(0)).cuda()
|
353 |
+
zs = y.sum(dim=0)
|
354 |
+
zss = (y ** 2).sum(dim=0)
|
355 |
+
|
356 |
+
dist.all_reduce(zc)
|
357 |
+
dist.all_reduce(zs)
|
358 |
+
dist.all_reduce(zss)
|
359 |
+
|
360 |
+
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
361 |
+
return torch.sqrt(var + 1e-6).mean()
|
362 |
+
else:
|
363 |
+
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
364 |
+
|
365 |
+
def remove_pretraining_modules(self, last_layer=None):
|
366 |
+
self.final_proj = None
|
367 |
+
self.ema = None
|
368 |
+
self.encoder.norm = nn.Identity()
|
369 |
+
self.mask_emb = None
|
370 |
+
if last_layer is not None:
|
371 |
+
self.encoder.layers = nn.ModuleList(
|
372 |
+
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
373 |
+
)
|
374 |
+
|
375 |
+
|
376 |
+
class PatchEmbed(nn.Module):
|
377 |
+
"""Image to Patch Embedding"""
|
378 |
+
|
379 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
380 |
+
super().__init__()
|
381 |
+
if isinstance(img_size, int):
|
382 |
+
img_size = img_size, img_size
|
383 |
+
if isinstance(patch_size, int):
|
384 |
+
patch_size = patch_size, patch_size
|
385 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
386 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
387 |
+
self.img_size = img_size
|
388 |
+
self.patch_size = patch_size
|
389 |
+
self.num_patches = num_patches
|
390 |
+
|
391 |
+
self.conv = nn.Conv2d(
|
392 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
393 |
+
)
|
394 |
+
|
395 |
+
def forward(self, x):
|
396 |
+
# BCHW -> BTC
|
397 |
+
x = self.conv(x).flatten(2).transpose(1, 2)
|
398 |
+
return x
|
399 |
+
|
400 |
+
|
401 |
+
class Attention(nn.Module):
|
402 |
+
def __init__(
|
403 |
+
self,
|
404 |
+
dim,
|
405 |
+
num_heads=8,
|
406 |
+
qkv_bias=True,
|
407 |
+
attn_drop=0.0,
|
408 |
+
proj_drop=0.0,
|
409 |
+
window_size=None,
|
410 |
+
attn_head_dim=None,
|
411 |
+
):
|
412 |
+
super().__init__()
|
413 |
+
self.num_heads = num_heads
|
414 |
+
head_dim = dim // num_heads
|
415 |
+
if attn_head_dim is not None:
|
416 |
+
head_dim = attn_head_dim
|
417 |
+
all_head_dim = head_dim * self.num_heads
|
418 |
+
self.scale = head_dim ** -0.5
|
419 |
+
|
420 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
421 |
+
if qkv_bias:
|
422 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
423 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
424 |
+
else:
|
425 |
+
self.q_bias = None
|
426 |
+
self.v_bias = None
|
427 |
+
|
428 |
+
if window_size:
|
429 |
+
self.window_size = window_size
|
430 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
431 |
+
2 * window_size[1] - 1
|
432 |
+
) + 3
|
433 |
+
self.relative_position_bias_table = nn.Parameter(
|
434 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
435 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
436 |
+
# cls to token & token 2 cls & cls to cls
|
437 |
+
|
438 |
+
# get pair-wise relative position index for each token inside the window
|
439 |
+
coords_h = torch.arange(window_size[0])
|
440 |
+
coords_w = torch.arange(window_size[1])
|
441 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
442 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
443 |
+
relative_coords = (
|
444 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
445 |
+
) # 2, Wh*Ww, Wh*Ww
|
446 |
+
relative_coords = relative_coords.permute(
|
447 |
+
1, 2, 0
|
448 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
449 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
450 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
451 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
452 |
+
relative_position_index = torch.zeros(
|
453 |
+
size=(window_size[0] * window_size[1] + 1,) * 2,
|
454 |
+
dtype=relative_coords.dtype,
|
455 |
+
)
|
456 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
457 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
458 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
459 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
460 |
+
|
461 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
462 |
+
else:
|
463 |
+
self.window_size = None
|
464 |
+
self.relative_position_bias_table = None
|
465 |
+
self.relative_position_index = None
|
466 |
+
|
467 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
468 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
469 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
470 |
+
|
471 |
+
def forward(self, x, rel_pos_bias=None):
|
472 |
+
B, N, C = x.shape
|
473 |
+
qkv_bias = None
|
474 |
+
if self.q_bias is not None:
|
475 |
+
qkv_bias = torch.cat(
|
476 |
+
(
|
477 |
+
self.q_bias,
|
478 |
+
torch.zeros_like(self.v_bias, requires_grad=False),
|
479 |
+
self.v_bias,
|
480 |
+
)
|
481 |
+
)
|
482 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
483 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
484 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
485 |
+
q, k, v = (
|
486 |
+
qkv[0],
|
487 |
+
qkv[1],
|
488 |
+
qkv[2],
|
489 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
490 |
+
|
491 |
+
q = q * self.scale
|
492 |
+
attn = q @ k.transpose(-2, -1)
|
493 |
+
|
494 |
+
if self.relative_position_bias_table is not None:
|
495 |
+
assert 1==2
|
496 |
+
relative_position_bias = self.relative_position_bias_table[
|
497 |
+
self.relative_position_index.view(-1)
|
498 |
+
].view(
|
499 |
+
self.window_size[0] * self.window_size[1] + 1,
|
500 |
+
self.window_size[0] * self.window_size[1] + 1,
|
501 |
+
-1,
|
502 |
+
) # Wh*Ww,Wh*Ww,nH
|
503 |
+
relative_position_bias = relative_position_bias.permute(
|
504 |
+
2, 0, 1
|
505 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
506 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
507 |
+
print("attn.size() :", attn.size())
|
508 |
+
print("rel_pos_bias.size() :", rel_pos_bias.size())
|
509 |
+
if rel_pos_bias is not None:
|
510 |
+
attn = attn + rel_pos_bias
|
511 |
+
attn = attn.softmax(dim=-1)
|
512 |
+
attn = self.attn_drop(attn)
|
513 |
+
|
514 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
515 |
+
x = self.proj(x)
|
516 |
+
x = self.proj_drop(x)
|
517 |
+
return x
|
518 |
+
|
519 |
+
|
520 |
+
class RelativePositionBias(nn.Module):
|
521 |
+
def __init__(self, window_size, num_heads):
|
522 |
+
super().__init__()
|
523 |
+
self.window_size = window_size
|
524 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
525 |
+
2 * window_size[1] - 1
|
526 |
+
) + 3
|
527 |
+
self.relative_position_bias_table = nn.Parameter(
|
528 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
529 |
+
)
|
530 |
+
|
531 |
+
# get pair-wise relative position index for each token inside the window
|
532 |
+
coords_h = torch.arange(window_size[0])
|
533 |
+
coords_w = torch.arange(window_size[1])
|
534 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
535 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
536 |
+
relative_coords = (
|
537 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
538 |
+
) # 2, Wh*Ww, Wh*Ww
|
539 |
+
relative_coords = relative_coords.permute(
|
540 |
+
1, 2, 0
|
541 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
542 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
543 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
544 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
545 |
+
relative_position_index = torch.zeros(
|
546 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
547 |
+
)
|
548 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
549 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
550 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
551 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
552 |
+
|
553 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
554 |
+
|
555 |
+
def forward(self):
|
556 |
+
relative_position_bias = self.relative_position_bias_table[
|
557 |
+
self.relative_position_index.view(-1)
|
558 |
+
].view(
|
559 |
+
self.window_size[0] * self.window_size[1] + 1,
|
560 |
+
self.window_size[0] * self.window_size[1] + 1,
|
561 |
+
-1,
|
562 |
+
) # Wh*Ww,Wh*Ww,nH
|
563 |
+
print("self.window_size :", self.window_size)
|
564 |
+
print("self.num_relative_distance :", self.num_relative_distance)
|
565 |
+
print("self.relative_position_index :", self.relative_position_index.size(), self.relative_position_index)
|
566 |
+
print("relative_position_bias.size(), relative_position_bias :",relative_position_bias.size(), relative_position_bias)
|
567 |
+
print("self.relative_position_bias_table.size(), self.relative_position_bias_table :",self.relative_position_bias_table.size(), self.relative_position_bias_table)
|
568 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
569 |
+
|
570 |
+
|
571 |
+
class DropPath(nn.Module):
|
572 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
573 |
+
|
574 |
+
def __init__(self, drop_prob=None):
|
575 |
+
super(DropPath, self).__init__()
|
576 |
+
self.drop_prob = drop_prob
|
577 |
+
|
578 |
+
def forward(self, x):
|
579 |
+
if self.drop_prob == 0.0 or not self.training:
|
580 |
+
return x
|
581 |
+
keep_prob = 1 - self.drop_prob
|
582 |
+
shape = (x.shape[0],) + (1,) * (
|
583 |
+
x.ndim - 1
|
584 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
585 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
586 |
+
random_tensor.floor_()
|
587 |
+
output = x.div(keep_prob) * random_tensor
|
588 |
+
return output
|
589 |
+
|
590 |
+
def extra_repr(self) -> str:
|
591 |
+
return "p={}".format(self.drop_prob)
|
592 |
+
|
593 |
+
|
594 |
+
class Block(nn.Module):
|
595 |
+
def __init__(
|
596 |
+
self,
|
597 |
+
dim,
|
598 |
+
num_heads,
|
599 |
+
mlp_ratio=4.0,
|
600 |
+
drop=0.0,
|
601 |
+
attn_drop=0.0,
|
602 |
+
drop_path=0.0,
|
603 |
+
init_values=None,
|
604 |
+
window_size=None,
|
605 |
+
):
|
606 |
+
super().__init__()
|
607 |
+
|
608 |
+
self.norm1 = nn.LayerNorm(dim)
|
609 |
+
self.attn = Attention(
|
610 |
+
dim,
|
611 |
+
num_heads=num_heads,
|
612 |
+
attn_drop=attn_drop,
|
613 |
+
proj_drop=drop,
|
614 |
+
window_size=window_size,
|
615 |
+
)
|
616 |
+
|
617 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
618 |
+
self.norm2 = nn.LayerNorm(dim)
|
619 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
620 |
+
|
621 |
+
self.mlp = nn.Sequential(
|
622 |
+
nn.Linear(dim, mlp_hidden_dim),
|
623 |
+
nn.GELU(),
|
624 |
+
nn.Linear(mlp_hidden_dim, dim),
|
625 |
+
nn.Dropout(drop),
|
626 |
+
)
|
627 |
+
|
628 |
+
if init_values > 0:
|
629 |
+
self.gamma_1 = nn.Parameter(
|
630 |
+
init_values * torch.ones((dim)), requires_grad=True
|
631 |
+
)
|
632 |
+
self.gamma_2 = nn.Parameter(
|
633 |
+
init_values * torch.ones((dim)), requires_grad=True
|
634 |
+
)
|
635 |
+
else:
|
636 |
+
self.gamma_1, self.gamma_2 = None, None
|
637 |
+
|
638 |
+
def forward(self, x, rel_pos_bias=None):
|
639 |
+
print("inside block :", x.size())
|
640 |
+
if self.gamma_1 is None:
|
641 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
|
642 |
+
fc_feature = self.drop_path(self.mlp(self.norm2(x)))
|
643 |
+
x = x + fc_feature
|
644 |
+
else:
|
645 |
+
x = x + self.drop_path(
|
646 |
+
self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
647 |
+
)
|
648 |
+
fc_feature = self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
649 |
+
x = x + fc_feature
|
650 |
+
return x, fc_feature
|
651 |
+
|
652 |
+
|
653 |
+
class TransformerEncoder(nn.Module):
|
654 |
+
def __init__(self, cfg: Data2VecVisionConfig, patch_shape):
|
655 |
+
super().__init__()
|
656 |
+
|
657 |
+
self.rel_pos_bias = None
|
658 |
+
if cfg.shared_rel_pos_bias:
|
659 |
+
self.rel_pos_bias = RelativePositionBias(
|
660 |
+
window_size=patch_shape, num_heads=cfg.num_heads
|
661 |
+
)
|
662 |
+
|
663 |
+
dpr = [
|
664 |
+
x.item() for x in torch.linspace(0, cfg.drop_path, cfg.depth)
|
665 |
+
] # stochastic depth decay rule
|
666 |
+
|
667 |
+
print("TransformerEncoder > patch_shape :", patch_shape)
|
668 |
+
self.blocks = nn.ModuleList(
|
669 |
+
Block(
|
670 |
+
dim=cfg.embed_dim,
|
671 |
+
num_heads=cfg.num_heads,
|
672 |
+
attn_drop=cfg.attention_dropout,
|
673 |
+
drop_path=dpr[i],
|
674 |
+
init_values=cfg.layer_scale_init_value,
|
675 |
+
window_size=patch_shape if not cfg.shared_rel_pos_bias else None,
|
676 |
+
)
|
677 |
+
for i in range(cfg.depth)
|
678 |
+
)
|
679 |
+
|
680 |
+
self.norm = nn.LayerNorm(cfg.embed_dim)
|
681 |
+
|
682 |
+
self.apply(self.init_weights)
|
683 |
+
self.fix_init_weight()
|
684 |
+
|
685 |
+
def init_weights(self, m):
|
686 |
+
std = 0.02
|
687 |
+
if isinstance(m, nn.Linear):
|
688 |
+
nn.init.trunc_normal_(m.weight, std=std)
|
689 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
690 |
+
nn.init.constant_(m.bias, 0)
|
691 |
+
elif isinstance(m, nn.LayerNorm):
|
692 |
+
nn.init.constant_(m.bias, 0)
|
693 |
+
nn.init.constant_(m.weight, 1.0)
|
694 |
+
elif isinstance(m, nn.Conv2d):
|
695 |
+
nn.init.trunc_normal_(m.weight, std=std)
|
696 |
+
if m.bias is not None:
|
697 |
+
nn.init.constant_(m.bias, 0)
|
698 |
+
|
699 |
+
def fix_init_weight(self):
|
700 |
+
def rescale(param, layer_id):
|
701 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
702 |
+
|
703 |
+
for layer_id, layer in enumerate(self.blocks):
|
704 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
705 |
+
rescale(layer.mlp[2].weight.data, layer_id + 1)
|
706 |
+
|
707 |
+
def extract_features(self, x, layer_results):
|
708 |
+
|
709 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
710 |
+
|
711 |
+
z = []
|
712 |
+
for i, blk in enumerate(self.blocks):
|
713 |
+
x, fc_feature = blk(x, rel_pos_bias=rel_pos_bias)
|
714 |
+
if layer_results == "end":
|
715 |
+
z.append(x)
|
716 |
+
elif layer_results == "fc":
|
717 |
+
z.append(fc_feature)
|
718 |
+
|
719 |
+
return z if layer_results else self.norm(x)
|
720 |
+
|
721 |
+
def forward(self, x, layer_results=None):
|
722 |
+
x = self.extract_features(x, layer_results=layer_results)
|
723 |
+
if layer_results:
|
724 |
+
return [z[:, 1:] for z in x]
|
725 |
+
|
726 |
+
x = x[:, 1:]
|
727 |
+
return x
|
fairseq/examples/data2vec/models/mae.py
ADDED
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# The code in this file is adapted from the BeiT implementation which can be found here:
|
7 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
8 |
+
|
9 |
+
import logging
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from fairseq.dataclass import FairseqDataclass
|
21 |
+
from fairseq.models import BaseFairseqModel, register_model
|
22 |
+
from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer
|
23 |
+
|
24 |
+
try:
|
25 |
+
from apex.normalization import FusedLayerNorm
|
26 |
+
except:
|
27 |
+
FusedLayerNorm = nn.LayerNorm
|
28 |
+
|
29 |
+
import torch.nn.functional as F
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class MaeConfig(FairseqDataclass):
|
37 |
+
input_size: int = 224
|
38 |
+
in_chans: int = 3
|
39 |
+
patch_size: int = 16
|
40 |
+
embed_dim: int = 768
|
41 |
+
depth: int = 12
|
42 |
+
num_heads: int = 12
|
43 |
+
decoder_embed_dim: int = 512
|
44 |
+
decoder_depth: int = 8
|
45 |
+
decoder_num_heads: int = 16
|
46 |
+
mlp_ratio: int = 4
|
47 |
+
norm_eps: float = 1e-6
|
48 |
+
|
49 |
+
drop_path_rate: float = 0.0
|
50 |
+
|
51 |
+
mask_ratio: float = 0.75
|
52 |
+
norm_pix_loss: bool = True
|
53 |
+
|
54 |
+
w2v_block: bool = False
|
55 |
+
alt_block: bool = False
|
56 |
+
alt_block2: bool = False
|
57 |
+
alt_attention: bool = False
|
58 |
+
block_dropout: float = 0
|
59 |
+
attention_dropout: float = 0
|
60 |
+
activation_dropout: float = 0
|
61 |
+
layer_norm_first: bool = False
|
62 |
+
|
63 |
+
fused_ln: bool = True
|
64 |
+
end_of_block_targets: bool = True
|
65 |
+
|
66 |
+
no_decoder_embed: bool = False
|
67 |
+
no_decoder_pos_embed: bool = False
|
68 |
+
mask_noise_std: float = 0
|
69 |
+
|
70 |
+
single_qkv: bool = False
|
71 |
+
use_rel_pos_bias: bool = False
|
72 |
+
no_cls: bool = False
|
73 |
+
|
74 |
+
|
75 |
+
def modify_relative_position_bias(orig_bias, bsz, mask):
|
76 |
+
if mask is None:
|
77 |
+
return orig_bias.unsqueeze(0).repeat(
|
78 |
+
bsz, 1, 1, 1
|
79 |
+
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
|
80 |
+
heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token
|
81 |
+
mask_for_rel_pos_bias = torch.cat(
|
82 |
+
(torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1
|
83 |
+
).bool() # bsz x seqlen (add CLS token)
|
84 |
+
unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias
|
85 |
+
unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat(
|
86 |
+
1, heads, 1
|
87 |
+
) # bsz x seq_len => bsz x heads x seq_len
|
88 |
+
b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat(
|
89 |
+
bsz, 1, 1, 1
|
90 |
+
) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
|
91 |
+
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
|
92 |
+
unmasked_for_rel_pos_bias.unsqueeze(-1)
|
93 |
+
)
|
94 |
+
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len)
|
95 |
+
new_len = b_t_t_rel_pos_bias.size(-2)
|
96 |
+
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select(
|
97 |
+
unmasked_for_rel_pos_bias.unsqueeze(-2)
|
98 |
+
)
|
99 |
+
b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len)
|
100 |
+
return b_t_t_rel_pos_bias
|
101 |
+
|
102 |
+
|
103 |
+
class AltBlock(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
dim,
|
107 |
+
num_heads,
|
108 |
+
mlp_ratio=4.0,
|
109 |
+
qkv_bias=False,
|
110 |
+
qk_scale=None,
|
111 |
+
drop=0.0,
|
112 |
+
attn_drop=0.0,
|
113 |
+
drop_path=0.0,
|
114 |
+
act_layer=nn.GELU,
|
115 |
+
norm_layer=nn.LayerNorm,
|
116 |
+
layer_norm_first=True,
|
117 |
+
ffn_targets=False,
|
118 |
+
use_rel_pos_bias=False,
|
119 |
+
window_size=None,
|
120 |
+
alt_attention=False,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
|
124 |
+
self.layer_norm_first = layer_norm_first
|
125 |
+
self.ffn_targets = ffn_targets
|
126 |
+
|
127 |
+
from timm.models.vision_transformer import Attention, DropPath, Mlp
|
128 |
+
|
129 |
+
self.norm1 = norm_layer(dim)
|
130 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
131 |
+
if use_rel_pos_bias:
|
132 |
+
self.attn = AltAttention(
|
133 |
+
dim,
|
134 |
+
num_heads=num_heads,
|
135 |
+
qkv_bias=qkv_bias,
|
136 |
+
qk_scale=qk_scale,
|
137 |
+
attn_drop=attn_drop,
|
138 |
+
proj_drop=drop,
|
139 |
+
window_size=window_size,
|
140 |
+
)
|
141 |
+
else:
|
142 |
+
if alt_attention:
|
143 |
+
from .multi.modules import AltAttention as AltAttention2
|
144 |
+
self.attn = AltAttention2(
|
145 |
+
dim,
|
146 |
+
num_heads=num_heads,
|
147 |
+
qkv_bias=qkv_bias,
|
148 |
+
qk_scale=qk_scale,
|
149 |
+
attn_drop=attn_drop,
|
150 |
+
proj_drop=drop,
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
self.attn = Attention(
|
154 |
+
dim,
|
155 |
+
num_heads=num_heads,
|
156 |
+
qkv_bias=qkv_bias,
|
157 |
+
qk_scale=qk_scale,
|
158 |
+
attn_drop=attn_drop,
|
159 |
+
proj_drop=drop,
|
160 |
+
)
|
161 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
162 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
163 |
+
self.norm2 = norm_layer(dim)
|
164 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
165 |
+
self.mlp = Mlp(
|
166 |
+
in_features=dim,
|
167 |
+
hidden_features=mlp_hidden_dim,
|
168 |
+
act_layer=act_layer,
|
169 |
+
drop=drop,
|
170 |
+
)
|
171 |
+
|
172 |
+
def forward(self, x, rel_pos_bias=None, pos_mask=None):
|
173 |
+
if self.layer_norm_first:
|
174 |
+
if self.use_rel_pos_bias:
|
175 |
+
x = x + self.drop_path(
|
176 |
+
self.attn(
|
177 |
+
self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask
|
178 |
+
)
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
182 |
+
t = self.mlp(self.norm2(x))
|
183 |
+
x = x + self.drop_path(t)
|
184 |
+
if not self.ffn_targets:
|
185 |
+
t = x
|
186 |
+
return x, t
|
187 |
+
else:
|
188 |
+
if self.use_rel_pos_bias:
|
189 |
+
x = x + self.drop_path(
|
190 |
+
self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask)
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
x = x + self.drop_path(self.attn(x))
|
194 |
+
r = x = self.norm1(x)
|
195 |
+
x = self.mlp(x)
|
196 |
+
t = x
|
197 |
+
x = self.norm2(r + self.drop_path(x))
|
198 |
+
if not self.ffn_targets:
|
199 |
+
t = x
|
200 |
+
return x, t
|
201 |
+
|
202 |
+
|
203 |
+
class AltAttention(nn.Module):
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
dim,
|
207 |
+
num_heads=8,
|
208 |
+
qkv_bias=True,
|
209 |
+
qk_scale=None,
|
210 |
+
attn_drop=0.0,
|
211 |
+
proj_drop=0.0,
|
212 |
+
window_size=None,
|
213 |
+
attn_head_dim=None,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.num_heads = num_heads
|
217 |
+
head_dim = dim // num_heads
|
218 |
+
if attn_head_dim is not None:
|
219 |
+
head_dim = attn_head_dim
|
220 |
+
all_head_dim = head_dim * self.num_heads
|
221 |
+
self.scale = qk_scale or head_dim ** -0.5
|
222 |
+
|
223 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
224 |
+
if qkv_bias:
|
225 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
226 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
227 |
+
else:
|
228 |
+
self.q_bias = None
|
229 |
+
self.v_bias = None
|
230 |
+
|
231 |
+
if window_size:
|
232 |
+
self.window_size = window_size
|
233 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
234 |
+
2 * window_size[1] - 1
|
235 |
+
) + 3
|
236 |
+
self.relative_position_bias_table = nn.Parameter(
|
237 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
238 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
239 |
+
# cls to token & token 2 cls & cls to cls
|
240 |
+
|
241 |
+
# get pair-wise relative position index for each token inside the window
|
242 |
+
coords_h = torch.arange(window_size[0])
|
243 |
+
coords_w = torch.arange(window_size[1])
|
244 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
245 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
246 |
+
relative_coords = (
|
247 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
248 |
+
) # 2, Wh*Ww, Wh*Ww
|
249 |
+
relative_coords = relative_coords.permute(
|
250 |
+
1, 2, 0
|
251 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
252 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
253 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
254 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
255 |
+
relative_position_index = torch.zeros(
|
256 |
+
size=(window_size[0] * window_size[1] + 1,) * 2,
|
257 |
+
dtype=relative_coords.dtype,
|
258 |
+
)
|
259 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
260 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
261 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
262 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
263 |
+
|
264 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
265 |
+
else:
|
266 |
+
self.window_size = None
|
267 |
+
self.relative_position_bias_table = None
|
268 |
+
self.relative_position_index = None
|
269 |
+
|
270 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
271 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
272 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
273 |
+
|
274 |
+
def forward(self, x, rel_pos_bias=None, pos_mask=None):
|
275 |
+
B, N, C = x.shape
|
276 |
+
qkv_bias = None
|
277 |
+
if self.q_bias is not None:
|
278 |
+
qkv_bias = torch.cat(
|
279 |
+
(
|
280 |
+
self.q_bias,
|
281 |
+
torch.zeros_like(self.v_bias, requires_grad=False),
|
282 |
+
self.v_bias,
|
283 |
+
)
|
284 |
+
)
|
285 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
286 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
287 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
288 |
+
q, k, v = (
|
289 |
+
qkv[0],
|
290 |
+
qkv[1],
|
291 |
+
qkv[2],
|
292 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
293 |
+
|
294 |
+
q = q * self.scale
|
295 |
+
attn = q @ k.transpose(-2, -1)
|
296 |
+
|
297 |
+
if self.relative_position_bias_table is not None:
|
298 |
+
relative_position_bias = self.relative_position_bias_table[
|
299 |
+
self.relative_position_index.view(-1)
|
300 |
+
].view(
|
301 |
+
self.window_size[0] * self.window_size[1] + 1,
|
302 |
+
self.window_size[0] * self.window_size[1] + 1,
|
303 |
+
-1,
|
304 |
+
) # Wh*Ww,Wh*Ww,nH
|
305 |
+
relative_position_bias = relative_position_bias.permute(
|
306 |
+
2, 0, 1
|
307 |
+
).contiguous() # nH, Wh*Ww, Wh*Ww
|
308 |
+
attn = attn + modify_relative_position_bias(
|
309 |
+
relative_position_bias, x.size(0), pos_mask
|
310 |
+
)
|
311 |
+
|
312 |
+
if rel_pos_bias is not None:
|
313 |
+
attn = attn + rel_pos_bias
|
314 |
+
|
315 |
+
attn = attn.softmax(dim=-1)
|
316 |
+
attn = self.attn_drop(attn)
|
317 |
+
|
318 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
319 |
+
x = self.proj(x)
|
320 |
+
x = self.proj_drop(x)
|
321 |
+
return x
|
322 |
+
|
323 |
+
|
324 |
+
class RelativePositionBias(nn.Module):
|
325 |
+
def __init__(self, window_size, num_heads):
|
326 |
+
super().__init__()
|
327 |
+
self.window_size = window_size
|
328 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
329 |
+
2 * window_size[1] - 1
|
330 |
+
) + 3
|
331 |
+
self.relative_position_bias_table = nn.Parameter(
|
332 |
+
torch.zeros(self.num_relative_distance, num_heads)
|
333 |
+
)
|
334 |
+
|
335 |
+
# get pair-wise relative position index for each token inside the window
|
336 |
+
coords_h = torch.arange(window_size[0])
|
337 |
+
coords_w = torch.arange(window_size[1])
|
338 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
339 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
340 |
+
relative_coords = (
|
341 |
+
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
342 |
+
) # 2, Wh*Ww, Wh*Ww
|
343 |
+
relative_coords = relative_coords.permute(
|
344 |
+
1, 2, 0
|
345 |
+
).contiguous() # Wh*Ww, Wh*Ww, 2
|
346 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
347 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
348 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
349 |
+
relative_position_index = torch.zeros(
|
350 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
351 |
+
)
|
352 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
353 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
354 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
355 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
356 |
+
|
357 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
358 |
+
|
359 |
+
def forward(self):
|
360 |
+
relative_position_bias = self.relative_position_bias_table[
|
361 |
+
self.relative_position_index.view(-1)
|
362 |
+
].view(
|
363 |
+
self.window_size[0] * self.window_size[1] + 1,
|
364 |
+
self.window_size[0] * self.window_size[1] + 1,
|
365 |
+
-1,
|
366 |
+
) # Wh*Ww,Wh*Ww,nH
|
367 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
368 |
+
|
369 |
+
|
370 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
371 |
+
"""
|
372 |
+
grid_size: int of the grid height and width
|
373 |
+
return:
|
374 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
375 |
+
"""
|
376 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
377 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
378 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
379 |
+
grid = np.stack(grid, axis=0)
|
380 |
+
|
381 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
382 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
383 |
+
if cls_token:
|
384 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
385 |
+
return pos_embed
|
386 |
+
|
387 |
+
|
388 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
389 |
+
assert embed_dim % 2 == 0
|
390 |
+
|
391 |
+
# use half of dimensions to encode grid_h
|
392 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
393 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
394 |
+
|
395 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
396 |
+
return emb
|
397 |
+
|
398 |
+
|
399 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
400 |
+
"""
|
401 |
+
embed_dim: output dimension for each position
|
402 |
+
pos: a list of positions to be encoded: size (M,)
|
403 |
+
out: (M, D)
|
404 |
+
"""
|
405 |
+
assert embed_dim % 2 == 0
|
406 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
407 |
+
omega /= embed_dim / 2.0
|
408 |
+
omega = 1.0 / 10000 ** omega # (D/2,)
|
409 |
+
|
410 |
+
pos = pos.reshape(-1) # (M,)
|
411 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
412 |
+
|
413 |
+
emb_sin = np.sin(out) # (M, D/2)
|
414 |
+
emb_cos = np.cos(out) # (M, D/2)
|
415 |
+
|
416 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
417 |
+
return emb
|
418 |
+
|
419 |
+
|
420 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
421 |
+
if "pos_embed" in checkpoint_model:
|
422 |
+
pos_embed_checkpoint = checkpoint_model["pos_embed"]
|
423 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
424 |
+
num_patches = model.patch_embed.num_patches
|
425 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
426 |
+
# height (== width) for the checkpoint position embedding
|
427 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
428 |
+
# height (== width) for the new position embedding
|
429 |
+
new_size = int(num_patches ** 0.5)
|
430 |
+
# class_token and dist_token are kept unchanged
|
431 |
+
if orig_size != new_size:
|
432 |
+
print(
|
433 |
+
"Position interpolate from %dx%d to %dx%d"
|
434 |
+
% (orig_size, orig_size, new_size, new_size)
|
435 |
+
)
|
436 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
437 |
+
# only the position tokens are interpolated
|
438 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
439 |
+
pos_tokens = pos_tokens.reshape(
|
440 |
+
-1, orig_size, orig_size, embedding_size
|
441 |
+
).permute(0, 3, 1, 2)
|
442 |
+
pos_tokens = torch.nn.functional.interpolate(
|
443 |
+
pos_tokens,
|
444 |
+
size=(new_size, new_size),
|
445 |
+
mode="bicubic",
|
446 |
+
align_corners=False,
|
447 |
+
)
|
448 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
449 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
450 |
+
checkpoint_model["pos_embed"] = new_pos_embed
|
451 |
+
|
452 |
+
|
453 |
+
@register_model("mae", dataclass=MaeConfig)
|
454 |
+
class MaeModel(BaseFairseqModel):
|
455 |
+
def __init__(self, cfg: MaeConfig):
|
456 |
+
super().__init__()
|
457 |
+
self.cfg = cfg
|
458 |
+
|
459 |
+
self.mask_ratio = cfg.mask_ratio
|
460 |
+
|
461 |
+
# --------------------------------------------------------------------------
|
462 |
+
# MAE encoder specifics
|
463 |
+
self.patch_embed = PatchEmbed(
|
464 |
+
cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim
|
465 |
+
)
|
466 |
+
num_patches = self.patch_embed.num_patches
|
467 |
+
|
468 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None
|
469 |
+
self.pos_embed = nn.Parameter(
|
470 |
+
torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False
|
471 |
+
) # fixed sin-cos embedding
|
472 |
+
|
473 |
+
norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps)
|
474 |
+
|
475 |
+
dpr = [
|
476 |
+
x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)
|
477 |
+
] # stochastic depth decay rule
|
478 |
+
|
479 |
+
def make_block(drop_path):
|
480 |
+
if cfg.w2v_block:
|
481 |
+
return TransformerSentenceEncoderLayer(
|
482 |
+
embedding_dim=cfg.embed_dim,
|
483 |
+
ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio,
|
484 |
+
num_attention_heads=cfg.num_heads,
|
485 |
+
dropout=cfg.block_dropout,
|
486 |
+
attention_dropout=cfg.attention_dropout,
|
487 |
+
activation_dropout=cfg.activation_dropout,
|
488 |
+
activation_fn="gelu",
|
489 |
+
layer_norm_first=cfg.layer_norm_first,
|
490 |
+
drop_path=drop_path,
|
491 |
+
norm_eps=1e-6,
|
492 |
+
single_qkv=cfg.single_qkv,
|
493 |
+
fused_ln=cfg.fused_ln,
|
494 |
+
)
|
495 |
+
elif cfg.alt_block:
|
496 |
+
window_size = (
|
497 |
+
cfg.input_size // self.patch_embed.patch_size[0],
|
498 |
+
cfg.input_size // self.patch_embed.patch_size[1],
|
499 |
+
)
|
500 |
+
return AltBlock(
|
501 |
+
cfg.embed_dim,
|
502 |
+
cfg.num_heads,
|
503 |
+
cfg.mlp_ratio,
|
504 |
+
qkv_bias=True,
|
505 |
+
qk_scale=None,
|
506 |
+
norm_layer=norm_layer,
|
507 |
+
drop_path=drop_path,
|
508 |
+
layer_norm_first=cfg.layer_norm_first,
|
509 |
+
ffn_targets=not cfg.end_of_block_targets,
|
510 |
+
use_rel_pos_bias=cfg.use_rel_pos_bias,
|
511 |
+
window_size=window_size
|
512 |
+
if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias)
|
513 |
+
else None,
|
514 |
+
alt_attention=cfg.alt_attention,
|
515 |
+
)
|
516 |
+
elif cfg.alt_block2:
|
517 |
+
from .multi.modules import AltBlock as AltBlock2
|
518 |
+
return AltBlock2(
|
519 |
+
cfg.embed_dim,
|
520 |
+
cfg.num_heads,
|
521 |
+
cfg.mlp_ratio,
|
522 |
+
qkv_bias=True,
|
523 |
+
qk_scale=None,
|
524 |
+
norm_layer=norm_layer,
|
525 |
+
drop_path=drop_path,
|
526 |
+
layer_norm_first=cfg.layer_norm_first,
|
527 |
+
ffn_targets=not cfg.end_of_block_targets,
|
528 |
+
)
|
529 |
+
else:
|
530 |
+
return Block(
|
531 |
+
cfg.embed_dim,
|
532 |
+
cfg.num_heads,
|
533 |
+
cfg.mlp_ratio,
|
534 |
+
qkv_bias=True,
|
535 |
+
qk_scale=None,
|
536 |
+
norm_layer=norm_layer,
|
537 |
+
drop_path=drop_path,
|
538 |
+
)
|
539 |
+
|
540 |
+
self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)])
|
541 |
+
self.norm = norm_layer(cfg.embed_dim)
|
542 |
+
# --------------------------------------------------------------------------
|
543 |
+
|
544 |
+
# --------------------------------------------------------------------------
|
545 |
+
# MAE decoder specifics
|
546 |
+
self.decoder_embed = (
|
547 |
+
nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True)
|
548 |
+
if not cfg.no_decoder_embed
|
549 |
+
else None
|
550 |
+
)
|
551 |
+
|
552 |
+
self.mask_token = (
|
553 |
+
nn.Parameter(
|
554 |
+
torch.zeros(
|
555 |
+
1,
|
556 |
+
1,
|
557 |
+
cfg.decoder_embed_dim
|
558 |
+
if not cfg.no_decoder_embed
|
559 |
+
else cfg.embed_dim,
|
560 |
+
)
|
561 |
+
)
|
562 |
+
if cfg.mask_noise_std <= 0
|
563 |
+
else None
|
564 |
+
)
|
565 |
+
|
566 |
+
self.decoder_pos_embed = (
|
567 |
+
nn.Parameter(
|
568 |
+
torch.zeros(
|
569 |
+
1,
|
570 |
+
num_patches + 1,
|
571 |
+
cfg.decoder_embed_dim
|
572 |
+
if not cfg.no_decoder_embed
|
573 |
+
else cfg.embed_dim,
|
574 |
+
),
|
575 |
+
requires_grad=False,
|
576 |
+
)
|
577 |
+
if not cfg.no_decoder_pos_embed
|
578 |
+
else None
|
579 |
+
)
|
580 |
+
|
581 |
+
self.decoder_blocks = nn.ModuleList(
|
582 |
+
[
|
583 |
+
Block(
|
584 |
+
cfg.decoder_embed_dim,
|
585 |
+
cfg.decoder_num_heads,
|
586 |
+
cfg.mlp_ratio,
|
587 |
+
qkv_bias=True,
|
588 |
+
qk_scale=None,
|
589 |
+
norm_layer=norm_layer,
|
590 |
+
)
|
591 |
+
for _ in range(cfg.decoder_depth)
|
592 |
+
]
|
593 |
+
)
|
594 |
+
|
595 |
+
self.decoder_norm = norm_layer(cfg.decoder_embed_dim)
|
596 |
+
self.decoder_pred = nn.Linear(
|
597 |
+
cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True
|
598 |
+
) # decoder to patch
|
599 |
+
# --------------------------------------------------------------------------
|
600 |
+
|
601 |
+
self.norm_pix_loss = cfg.norm_pix_loss
|
602 |
+
|
603 |
+
self.initialize_weights()
|
604 |
+
|
605 |
+
for pn, p in self.named_parameters():
|
606 |
+
if len(p.shape) == 1 or pn.endswith(".bias"):
|
607 |
+
p.param_group = "no_decay"
|
608 |
+
else:
|
609 |
+
p.param_group = "with_decay"
|
610 |
+
|
611 |
+
def initialize_weights(self):
|
612 |
+
# initialization
|
613 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
614 |
+
pos_embed = get_2d_sincos_pos_embed(
|
615 |
+
self.pos_embed.shape[-1],
|
616 |
+
int(self.patch_embed.num_patches ** 0.5),
|
617 |
+
cls_token=not self.cfg.no_cls,
|
618 |
+
)
|
619 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
620 |
+
|
621 |
+
if self.decoder_pos_embed is not None:
|
622 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(
|
623 |
+
self.decoder_pos_embed.shape[-1],
|
624 |
+
int(self.patch_embed.num_patches ** 0.5),
|
625 |
+
cls_token=not self.cfg.no_cls,
|
626 |
+
)
|
627 |
+
self.decoder_pos_embed.data.copy_(
|
628 |
+
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
|
629 |
+
)
|
630 |
+
|
631 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
632 |
+
w = self.patch_embed.proj.weight.data
|
633 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
634 |
+
|
635 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
636 |
+
if self.cls_token is not None:
|
637 |
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
638 |
+
|
639 |
+
if self.mask_token is not None:
|
640 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
641 |
+
|
642 |
+
# initialize nn.Linear and nn.LayerNorm
|
643 |
+
self.apply(self._init_weights)
|
644 |
+
|
645 |
+
def _init_weights(self, m):
|
646 |
+
if isinstance(m, nn.Linear):
|
647 |
+
# we use xavier_uniform following official JAX ViT:
|
648 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
649 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
650 |
+
nn.init.constant_(m.bias, 0)
|
651 |
+
elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm):
|
652 |
+
nn.init.constant_(m.bias, 0)
|
653 |
+
nn.init.constant_(m.weight, 1.0)
|
654 |
+
|
655 |
+
def patchify(self, imgs):
|
656 |
+
"""
|
657 |
+
imgs: (N, 3, H, W)
|
658 |
+
x: (N, L, patch_size**2 *3)
|
659 |
+
"""
|
660 |
+
p = self.patch_embed.patch_size[0]
|
661 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
662 |
+
|
663 |
+
h = w = imgs.shape[2] // p
|
664 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
665 |
+
x = torch.einsum("nchpwq->nhwpqc", x)
|
666 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
667 |
+
return x
|
668 |
+
|
669 |
+
def unpatchify(self, x):
|
670 |
+
"""
|
671 |
+
x: (N, L, patch_size**2 *3)
|
672 |
+
imgs: (N, 3, H, W)
|
673 |
+
"""
|
674 |
+
p = self.patch_embed.patch_size[0]
|
675 |
+
h = w = int(x.shape[1] ** 0.5)
|
676 |
+
assert h * w == x.shape[1]
|
677 |
+
|
678 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
679 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
680 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
681 |
+
return imgs
|
682 |
+
|
683 |
+
def random_masking(self, x, mask_ratio):
|
684 |
+
"""
|
685 |
+
Perform per-sample random masking by per-sample shuffling.
|
686 |
+
Per-sample shuffling is done by argsort random noise.
|
687 |
+
x: [N, L, D], sequence
|
688 |
+
"""
|
689 |
+
N, L, D = x.shape # batch, length, dim
|
690 |
+
len_keep = int(L * (1 - mask_ratio))
|
691 |
+
|
692 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
693 |
+
|
694 |
+
# sort noise for each sample
|
695 |
+
ids_shuffle = torch.argsort(
|
696 |
+
noise, dim=1
|
697 |
+
) # ascend: small is keep, large is remove
|
698 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
699 |
+
|
700 |
+
# keep the first subset
|
701 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
702 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
703 |
+
|
704 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
705 |
+
mask = torch.ones([N, L], device=x.device)
|
706 |
+
mask[:, :len_keep] = 0
|
707 |
+
# unshuffle to get the binary mask
|
708 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
709 |
+
|
710 |
+
return x_masked, mask, ids_restore # x_masked is actually unmasked x
|
711 |
+
|
712 |
+
@classmethod
|
713 |
+
def build_model(cls, cfg: MaeConfig, task=None):
|
714 |
+
"""Build a new model instance."""
|
715 |
+
|
716 |
+
return cls(cfg)
|
717 |
+
|
718 |
+
def forward_encoder(self, x, mask_ratio):
|
719 |
+
# embed patches
|
720 |
+
x = self.patch_embed(x)
|
721 |
+
|
722 |
+
# add pos embed w/o cls token
|
723 |
+
# if self.cls_token is not None:
|
724 |
+
# x = x + self.pos_embed
|
725 |
+
# else:
|
726 |
+
x = x + self.pos_embed[:, 1:, :]
|
727 |
+
|
728 |
+
# masking: length -> length * mask_ratio
|
729 |
+
if mask_ratio > 0:
|
730 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
731 |
+
else:
|
732 |
+
mask = ids_restore = None
|
733 |
+
|
734 |
+
# append cls token
|
735 |
+
if self.cls_token is not None:
|
736 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
737 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
738 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
739 |
+
|
740 |
+
# apply Transformer blocks
|
741 |
+
for blk in self.blocks:
|
742 |
+
x = blk(x)
|
743 |
+
|
744 |
+
if self.norm is not None:
|
745 |
+
x = self.norm(x)
|
746 |
+
|
747 |
+
return x, mask, ids_restore
|
748 |
+
|
749 |
+
def forward_decoder(self, x, ids_restore):
|
750 |
+
# embed tokens
|
751 |
+
x = self.decoder_embed(x)
|
752 |
+
|
753 |
+
# append mask tokens to sequence
|
754 |
+
mask_tokens = self.mask_token.repeat(
|
755 |
+
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
|
756 |
+
)
|
757 |
+
if self.cls_token is not None:
|
758 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
759 |
+
else:
|
760 |
+
x_ = torch.cat([x, mask_tokens], dim=1) # no cls token
|
761 |
+
|
762 |
+
x_ = torch.gather(
|
763 |
+
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
764 |
+
) # unshuffle
|
765 |
+
|
766 |
+
if self.cls_token is not None:
|
767 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
768 |
+
|
769 |
+
# add pos embed
|
770 |
+
x = x + self.decoder_pos_embed
|
771 |
+
|
772 |
+
# apply Transformer blocks
|
773 |
+
for blk in self.decoder_blocks:
|
774 |
+
x = blk(x)
|
775 |
+
x = self.decoder_norm(x)
|
776 |
+
|
777 |
+
# predictor projection
|
778 |
+
x = self.decoder_pred(x)
|
779 |
+
|
780 |
+
if self.cls_token is not None:
|
781 |
+
# remove cls token
|
782 |
+
x = x[:, 1:, :]
|
783 |
+
|
784 |
+
return x
|
785 |
+
|
786 |
+
def forward_loss(self, imgs, pred, mask):
|
787 |
+
"""
|
788 |
+
imgs: [N, 3, H, W]
|
789 |
+
pred: [N, L, p*p*3]
|
790 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
791 |
+
"""
|
792 |
+
target = self.patchify(imgs)
|
793 |
+
if self.norm_pix_loss:
|
794 |
+
mean = target.mean(dim=-1, keepdim=True)
|
795 |
+
var = target.var(dim=-1, keepdim=True)
|
796 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
797 |
+
|
798 |
+
loss = (pred - target) ** 2
|
799 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
800 |
+
|
801 |
+
loss = (loss * mask).sum()
|
802 |
+
return loss, mask.sum()
|
803 |
+
|
804 |
+
def forward(self, imgs, predictions_only=False):
|
805 |
+
latent, mask, ids_restore = self.forward_encoder(
|
806 |
+
imgs, self.mask_ratio if not predictions_only else 0
|
807 |
+
)
|
808 |
+
|
809 |
+
if predictions_only:
|
810 |
+
return latent
|
811 |
+
|
812 |
+
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
813 |
+
loss, sample_size = self.forward_loss(imgs, pred, mask)
|
814 |
+
|
815 |
+
result = {
|
816 |
+
"losses": {"regression": loss},
|
817 |
+
"sample_size": sample_size,
|
818 |
+
}
|
819 |
+
return result
|
820 |
+
|
821 |
+
def remove_pretraining_modules(self):
|
822 |
+
self.decoder_embed = None
|
823 |
+
self.decoder_blocks = None
|
824 |
+
self.decoder_norm = None
|
825 |
+
self.decoder_pos_embed = None
|
826 |
+
self.decoder_pred = None
|
827 |
+
self.mask_token = None
|
828 |
+
if self.cfg.layer_norm_first:
|
829 |
+
self.norm = None
|
fairseq/examples/data2vec/models/mae_image_classification.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# The code in this file is adapted from the BeiT implementation which can be found here:
|
7 |
+
# https://github.com/microsoft/unilm/tree/master/beit
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from enum import Enum, auto
|
13 |
+
from typing import Any, Optional
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from omegaconf import II, MISSING
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
from fairseq import checkpoint_utils, tasks
|
23 |
+
from omegaconf import open_dict
|
24 |
+
|
25 |
+
from fairseq.dataclass import FairseqDataclass
|
26 |
+
from fairseq.models import BaseFairseqModel, register_model
|
27 |
+
from .mae import interpolate_pos_embed
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class PredictionMode(Enum):
|
34 |
+
MEAN_POOLING = auto()
|
35 |
+
CLS_TOKEN = auto()
|
36 |
+
LIN_SOFTMAX = auto()
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class MaeImageClassificationConfig(FairseqDataclass):
|
41 |
+
model_path: str = MISSING
|
42 |
+
no_pretrained_weights: bool = False
|
43 |
+
linear_classifier: bool = False
|
44 |
+
num_classes: int = 1000
|
45 |
+
mixup: float = 0.8
|
46 |
+
cutmix: float = 1.0
|
47 |
+
label_smoothing: float = 0.1
|
48 |
+
|
49 |
+
drop_path_rate: float = 0.1
|
50 |
+
layer_decay: float = 0.65
|
51 |
+
|
52 |
+
mixup_prob: float = 1.0
|
53 |
+
mixup_switch_prob: float = 0.5
|
54 |
+
mixup_mode: str = "batch"
|
55 |
+
|
56 |
+
pretrained_model_args: Any = None
|
57 |
+
data: str = II("task.data")
|
58 |
+
|
59 |
+
norm_eps: Optional[float] = None
|
60 |
+
|
61 |
+
remove_alibi: bool = False
|
62 |
+
|
63 |
+
# regularization overwrites
|
64 |
+
encoder_dropout: float = 0
|
65 |
+
post_mlp_drop: float = 0
|
66 |
+
attention_dropout: float = 0
|
67 |
+
activation_dropout: float = 0.0
|
68 |
+
dropout_input: float = 0.0
|
69 |
+
layerdrop: float = 0.0
|
70 |
+
|
71 |
+
prenet_layerdrop: float = 0
|
72 |
+
prenet_dropout: float = 0
|
73 |
+
|
74 |
+
use_fc_norm: bool = True
|
75 |
+
prediction_mode: PredictionMode = PredictionMode.MEAN_POOLING
|
76 |
+
|
77 |
+
no_decay_blocks: bool = True
|
78 |
+
|
79 |
+
|
80 |
+
def get_layer_id_for_vit(name, num_layers):
|
81 |
+
"""
|
82 |
+
Assign a parameter with its layer id
|
83 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
84 |
+
"""
|
85 |
+
if name in ["cls_token", "pos_embed"]:
|
86 |
+
return 0
|
87 |
+
elif name.startswith("patch_embed"):
|
88 |
+
return 0
|
89 |
+
elif name.startswith("rel_pos_bias"):
|
90 |
+
return num_layers - 1
|
91 |
+
elif name.startswith("blocks"):
|
92 |
+
return int(name.split(".")[1]) + 1
|
93 |
+
else:
|
94 |
+
return num_layers
|
95 |
+
|
96 |
+
|
97 |
+
@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig)
|
98 |
+
class MaeImageClassificationModel(BaseFairseqModel):
|
99 |
+
def __init__(self, cfg: MaeImageClassificationConfig):
|
100 |
+
super().__init__()
|
101 |
+
self.cfg = cfg
|
102 |
+
|
103 |
+
if cfg.pretrained_model_args is None:
|
104 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {})
|
105 |
+
pretrained_args = state.get("cfg", None)
|
106 |
+
|
107 |
+
pretrained_args.criterion = None
|
108 |
+
pretrained_args.lr_scheduler = None
|
109 |
+
|
110 |
+
logger.info(pretrained_args.model)
|
111 |
+
|
112 |
+
with open_dict(pretrained_args.model):
|
113 |
+
pretrained_args.model.drop_path_rate = cfg.drop_path_rate
|
114 |
+
if cfg.norm_eps is not None:
|
115 |
+
pretrained_args.model.norm_eps = cfg.norm_eps
|
116 |
+
|
117 |
+
cfg.pretrained_model_args = pretrained_args
|
118 |
+
|
119 |
+
logger.info(pretrained_args)
|
120 |
+
else:
|
121 |
+
state = None
|
122 |
+
pretrained_args = cfg.pretrained_model_args
|
123 |
+
|
124 |
+
if "data" in pretrained_args.task:
|
125 |
+
pretrained_args.task.data = cfg.data
|
126 |
+
elif "image" in pretrained_args.task:
|
127 |
+
pretrained_args.task.image.data = cfg.data
|
128 |
+
|
129 |
+
if "modalities" in pretrained_args.model:
|
130 |
+
prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"]
|
131 |
+
model_blocks = pretrained_args.model["depth"]
|
132 |
+
with open_dict(pretrained_args):
|
133 |
+
dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist()
|
134 |
+
pretrained_args.model["modalities"]["image"][
|
135 |
+
"start_drop_path_rate"
|
136 |
+
] = dpr[0]
|
137 |
+
pretrained_args.model["modalities"]["image"][
|
138 |
+
"end_drop_path_rate"
|
139 |
+
] = max(0, dpr[prenet_blocks - 1])
|
140 |
+
pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks]
|
141 |
+
pretrained_args.model["end_drop_path_rate"] = dpr[-1]
|
142 |
+
|
143 |
+
if "mae_masking" in pretrained_args.model["modalities"]["image"]:
|
144 |
+
del pretrained_args.model["modalities"]["image"]["mae_masking"]
|
145 |
+
|
146 |
+
if cfg.remove_alibi:
|
147 |
+
pretrained_args.model["modalities"]["image"][
|
148 |
+
"use_alibi_encoder"
|
149 |
+
] = False
|
150 |
+
if (
|
151 |
+
state is not None
|
152 |
+
and "modality_encoders.IMAGE.alibi_bias" in state["model"]
|
153 |
+
):
|
154 |
+
del state["model"]["modality_encoders.IMAGE.alibi_bias"]
|
155 |
+
|
156 |
+
pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout
|
157 |
+
pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop
|
158 |
+
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
|
159 |
+
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
|
160 |
+
pretrained_args.model["dropout_input"] = cfg.dropout_input
|
161 |
+
pretrained_args.model["layerdrop"] = cfg.layerdrop
|
162 |
+
|
163 |
+
pretrained_args.model["modalities"]["image"][
|
164 |
+
"prenet_layerdrop"
|
165 |
+
] = cfg.prenet_layerdrop
|
166 |
+
pretrained_args.model["modalities"]["image"][
|
167 |
+
"prenet_dropout"
|
168 |
+
] = cfg.prenet_dropout
|
169 |
+
else:
|
170 |
+
# not d2v multi
|
171 |
+
with open_dict(pretrained_args):
|
172 |
+
pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate
|
173 |
+
pretrained_args.model["block_dropout"] = cfg.encoder_dropout
|
174 |
+
pretrained_args.model["attention_dropout"] = cfg.attention_dropout
|
175 |
+
pretrained_args.model["activation_dropout"] = cfg.activation_dropout
|
176 |
+
|
177 |
+
task = tasks.setup_task(pretrained_args.task)
|
178 |
+
model = task.build_model(pretrained_args.model, from_checkpoint=True)
|
179 |
+
|
180 |
+
self.d2v_multi = "data2vec_multi" in pretrained_args.model._name
|
181 |
+
self.linear_classifier = cfg.linear_classifier
|
182 |
+
|
183 |
+
self.model = model
|
184 |
+
|
185 |
+
if state is not None and not cfg.no_pretrained_weights:
|
186 |
+
interpolate_pos_embed(model, state)
|
187 |
+
|
188 |
+
if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]:
|
189 |
+
state["model"][
|
190 |
+
"modality_encoders.IMAGE.positional_encoder.positions"
|
191 |
+
] = state["model"][
|
192 |
+
"modality_encoders.IMAGE.positional_encoder.pos_embed"
|
193 |
+
]
|
194 |
+
del state["model"][
|
195 |
+
"modality_encoders.IMAGE.positional_encoder.pos_embed"
|
196 |
+
]
|
197 |
+
if "modality_encoders.IMAGE.encoder_mask" in state["model"]:
|
198 |
+
del state["model"]["modality_encoders.IMAGE.encoder_mask"]
|
199 |
+
|
200 |
+
model.load_state_dict(state["model"], strict=True)
|
201 |
+
|
202 |
+
if self.d2v_multi:
|
203 |
+
model.remove_pretraining_modules(modality="image")
|
204 |
+
else:
|
205 |
+
model.remove_pretraining_modules()
|
206 |
+
|
207 |
+
if self.linear_classifier:
|
208 |
+
model.requires_grad_(False)
|
209 |
+
|
210 |
+
self.fc_norm = None
|
211 |
+
if self.cfg.use_fc_norm:
|
212 |
+
self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6)
|
213 |
+
nn.init.constant_(self.fc_norm.bias, 0)
|
214 |
+
nn.init.constant_(self.fc_norm.weight, 1.0)
|
215 |
+
|
216 |
+
self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes)
|
217 |
+
|
218 |
+
nn.init.trunc_normal_(self.head.weight, std=0.02)
|
219 |
+
nn.init.constant_(self.head.bias, 0)
|
220 |
+
|
221 |
+
self.mixup_fn = None
|
222 |
+
|
223 |
+
if cfg.mixup > 0 or cfg.cutmix > 0:
|
224 |
+
from timm.data import Mixup
|
225 |
+
|
226 |
+
self.mixup_fn = Mixup(
|
227 |
+
mixup_alpha=cfg.mixup,
|
228 |
+
cutmix_alpha=cfg.cutmix,
|
229 |
+
cutmix_minmax=None,
|
230 |
+
prob=cfg.mixup_prob,
|
231 |
+
switch_prob=cfg.mixup_switch_prob,
|
232 |
+
mode=cfg.mixup_mode,
|
233 |
+
label_smoothing=cfg.label_smoothing,
|
234 |
+
num_classes=cfg.num_classes,
|
235 |
+
)
|
236 |
+
|
237 |
+
if self.model.norm is not None:
|
238 |
+
for pn, p in self.model.norm.named_parameters():
|
239 |
+
if len(p.shape) == 1 or pn.endswith(".bias"):
|
240 |
+
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
241 |
+
|
242 |
+
if self.fc_norm is not None:
|
243 |
+
for pn, p in self.fc_norm.named_parameters():
|
244 |
+
if len(p.shape) == 1 or pn.endswith(".bias"):
|
245 |
+
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
246 |
+
|
247 |
+
for pn, p in self.head.named_parameters():
|
248 |
+
if len(p.shape) == 1 or pn.endswith(".bias"):
|
249 |
+
p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}}
|
250 |
+
|
251 |
+
if self.d2v_multi:
|
252 |
+
mod_encs = list(model.modality_encoders.values())
|
253 |
+
assert len(mod_encs) == 1, len(mod_encs)
|
254 |
+
blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks)
|
255 |
+
else:
|
256 |
+
blocks = model.blocks
|
257 |
+
|
258 |
+
num_layers = len(blocks) + 1
|
259 |
+
layer_scales = list(
|
260 |
+
cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1)
|
261 |
+
)
|
262 |
+
|
263 |
+
if self.d2v_multi:
|
264 |
+
for n, p in self.model.named_parameters():
|
265 |
+
optimizer_override_dict = {}
|
266 |
+
|
267 |
+
if len(p.shape) == 1 or n.endswith(".bias"):
|
268 |
+
optimizer_override_dict["weight_decay_scale"] = 0
|
269 |
+
|
270 |
+
p.optim_overrides = {"optimizer": optimizer_override_dict}
|
271 |
+
|
272 |
+
if cfg.layer_decay > 0:
|
273 |
+
for i, b in enumerate(blocks):
|
274 |
+
lid = i + 1
|
275 |
+
if layer_scales[lid] == 1.0:
|
276 |
+
continue
|
277 |
+
|
278 |
+
for n, p in b.named_parameters():
|
279 |
+
optim_override = getattr(p, "optim_overrides", {})
|
280 |
+
if "optimizer" not in optim_override:
|
281 |
+
optim_override["optimizer"] = {}
|
282 |
+
|
283 |
+
if cfg.no_decay_blocks:
|
284 |
+
optim_override["optimizer"]["lr_scale"] = layer_scales[lid]
|
285 |
+
p.optim_overrides = optim_override
|
286 |
+
else:
|
287 |
+
optim_override["optimizer"] = {
|
288 |
+
"lr_scale": layer_scales[lid]
|
289 |
+
}
|
290 |
+
p.optim_overrides = optim_override
|
291 |
+
|
292 |
+
else:
|
293 |
+
for n, p in self.model.named_parameters():
|
294 |
+
optimizer_override_dict = {}
|
295 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
296 |
+
|
297 |
+
if len(p.shape) == 1 or n.endswith(".bias"):
|
298 |
+
optimizer_override_dict["weight_decay_scale"] = 0
|
299 |
+
|
300 |
+
if cfg.layer_decay > 0:
|
301 |
+
optimizer_override_dict["lr_scale"] = layer_scales[layer_id]
|
302 |
+
p.optim_overrides = {"optimizer": optimizer_override_dict}
|
303 |
+
|
304 |
+
@classmethod
|
305 |
+
def build_model(cls, cfg: MaeImageClassificationConfig, task=None):
|
306 |
+
"""Build a new model instance."""
|
307 |
+
|
308 |
+
return cls(cfg)
|
309 |
+
|
310 |
+
def forward(
|
311 |
+
self,
|
312 |
+
imgs,
|
313 |
+
labels=None,
|
314 |
+
):
|
315 |
+
if self.training and self.mixup_fn is not None and labels is not None:
|
316 |
+
imgs, labels = self.mixup_fn(imgs, labels)
|
317 |
+
|
318 |
+
if self.linear_classifier:
|
319 |
+
with torch.no_grad():
|
320 |
+
x = self.model_forward(imgs)
|
321 |
+
else:
|
322 |
+
x = self.model_forward(imgs)
|
323 |
+
|
324 |
+
if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING:
|
325 |
+
x = x.mean(dim=1)
|
326 |
+
elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
|
327 |
+
x = x[:, 0]
|
328 |
+
elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX:
|
329 |
+
dtype = x.dtype
|
330 |
+
x = F.logsigmoid(x.float())
|
331 |
+
x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1)
|
332 |
+
x = x.clamp(max=0)
|
333 |
+
x = x - torch.log(-(torch.expm1(x)))
|
334 |
+
x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0)
|
335 |
+
x = x.to(dtype=dtype)
|
336 |
+
else:
|
337 |
+
raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}")
|
338 |
+
|
339 |
+
if self.fc_norm is not None:
|
340 |
+
x = self.fc_norm(x)
|
341 |
+
|
342 |
+
x = self.head(x)
|
343 |
+
|
344 |
+
if labels is None:
|
345 |
+
return x
|
346 |
+
|
347 |
+
if self.training and self.mixup_fn is not None:
|
348 |
+
loss = -labels * F.log_softmax(x.float(), dim=-1)
|
349 |
+
else:
|
350 |
+
loss = F.cross_entropy(
|
351 |
+
x.float(),
|
352 |
+
labels,
|
353 |
+
label_smoothing=self.cfg.label_smoothing if self.training else 0,
|
354 |
+
reduction="none",
|
355 |
+
)
|
356 |
+
|
357 |
+
result = {
|
358 |
+
"losses": {"regression": loss},
|
359 |
+
"sample_size": imgs.size(0),
|
360 |
+
}
|
361 |
+
|
362 |
+
if not self.training:
|
363 |
+
with torch.no_grad():
|
364 |
+
pred = x.argmax(-1)
|
365 |
+
correct = (pred == labels).sum()
|
366 |
+
result["correct"] = correct
|
367 |
+
|
368 |
+
return result
|
369 |
+
|
370 |
+
def model_forward(self, imgs):
|
371 |
+
if self.d2v_multi:
|
372 |
+
x = self.model.extract_features(
|
373 |
+
imgs,
|
374 |
+
mode="IMAGE",
|
375 |
+
mask=False,
|
376 |
+
remove_extra_tokens=(
|
377 |
+
self.cfg.prediction_mode != PredictionMode.CLS_TOKEN
|
378 |
+
),
|
379 |
+
)["x"]
|
380 |
+
else:
|
381 |
+
x = self.model(imgs, predictions_only=True)
|
382 |
+
if (
|
383 |
+
"no_cls" not in self.model.cfg or not self.model.cfg.no_cls
|
384 |
+
) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN:
|
385 |
+
x = x[:, 1:]
|
386 |
+
return x
|
fairseq/examples/data2vec/models/modalities/__init__.py
ADDED
File without changes
|
fairseq/examples/data2vec/models/modalities/audio.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import Callable, Dict, Optional
|
12 |
+
from fairseq.models.wav2vec import ConvFeatureExtractionModel
|
13 |
+
from fairseq.modules import (
|
14 |
+
LayerNorm,
|
15 |
+
SamePad,
|
16 |
+
TransposeLast,
|
17 |
+
)
|
18 |
+
from fairseq.tasks import FairseqTask
|
19 |
+
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
|
20 |
+
from .modules import BlockEncoder, Decoder1d
|
21 |
+
from examples.data2vec.data.modality import Modality
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class D2vAudioConfig(D2vModalityConfig):
|
26 |
+
type: Modality = Modality.AUDIO
|
27 |
+
extractor_mode: str = "layer_norm"
|
28 |
+
feature_encoder_spec: str = field(
|
29 |
+
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
|
30 |
+
metadata={
|
31 |
+
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
|
32 |
+
"[(dim, kernel_size, stride), ...]"
|
33 |
+
},
|
34 |
+
)
|
35 |
+
conv_pos_width: int = field(
|
36 |
+
default=95,
|
37 |
+
metadata={"help": "number of filters for convolutional positional embeddings"},
|
38 |
+
)
|
39 |
+
conv_pos_groups: int = field(
|
40 |
+
default=16,
|
41 |
+
metadata={"help": "number of groups for convolutional positional embedding"},
|
42 |
+
)
|
43 |
+
conv_pos_depth: int = field(
|
44 |
+
default=5,
|
45 |
+
metadata={"help": "depth of positional encoder network"},
|
46 |
+
)
|
47 |
+
conv_pos_pre_ln: bool = False
|
48 |
+
|
49 |
+
|
50 |
+
class AudioEncoder(ModalitySpecificEncoder):
|
51 |
+
|
52 |
+
modality_cfg: D2vAudioConfig
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
modality_cfg: D2vAudioConfig,
|
57 |
+
embed_dim: int,
|
58 |
+
make_block: Callable[[float], nn.ModuleList],
|
59 |
+
norm_layer: Callable[[int], nn.LayerNorm],
|
60 |
+
layer_norm_first: bool,
|
61 |
+
alibi_biases: Dict,
|
62 |
+
task: Optional[FairseqTask],
|
63 |
+
):
|
64 |
+
|
65 |
+
self.feature_enc_layers = eval(modality_cfg.feature_encoder_spec)
|
66 |
+
feature_embed_dim = self.feature_enc_layers[-1][0]
|
67 |
+
|
68 |
+
local_encoder = ConvFeatureExtractionModel(
|
69 |
+
conv_layers=self.feature_enc_layers,
|
70 |
+
dropout=0.0,
|
71 |
+
mode=modality_cfg.extractor_mode,
|
72 |
+
conv_bias=False,
|
73 |
+
)
|
74 |
+
|
75 |
+
project_features = nn.Sequential(
|
76 |
+
TransposeLast(),
|
77 |
+
nn.LayerNorm(feature_embed_dim),
|
78 |
+
nn.Linear(feature_embed_dim, embed_dim),
|
79 |
+
)
|
80 |
+
|
81 |
+
num_pos_layers = modality_cfg.conv_pos_depth
|
82 |
+
k = max(3, modality_cfg.conv_pos_width // num_pos_layers)
|
83 |
+
|
84 |
+
positional_encoder = nn.Sequential(
|
85 |
+
TransposeLast(),
|
86 |
+
*[
|
87 |
+
nn.Sequential(
|
88 |
+
nn.Conv1d(
|
89 |
+
embed_dim,
|
90 |
+
embed_dim,
|
91 |
+
kernel_size=k,
|
92 |
+
padding=k // 2,
|
93 |
+
groups=modality_cfg.conv_pos_groups,
|
94 |
+
),
|
95 |
+
SamePad(k),
|
96 |
+
TransposeLast(),
|
97 |
+
LayerNorm(embed_dim, elementwise_affine=False),
|
98 |
+
TransposeLast(),
|
99 |
+
nn.GELU(),
|
100 |
+
)
|
101 |
+
for _ in range(num_pos_layers)
|
102 |
+
],
|
103 |
+
TransposeLast(),
|
104 |
+
)
|
105 |
+
|
106 |
+
if modality_cfg.conv_pos_pre_ln:
|
107 |
+
positional_encoder = nn.Sequential(LayerNorm(embed_dim), positional_encoder)
|
108 |
+
|
109 |
+
dpr = np.linspace(
|
110 |
+
modality_cfg.start_drop_path_rate,
|
111 |
+
modality_cfg.end_drop_path_rate,
|
112 |
+
modality_cfg.prenet_depth,
|
113 |
+
)
|
114 |
+
context_encoder = BlockEncoder(
|
115 |
+
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
|
116 |
+
norm_layer(embed_dim) if not layer_norm_first else None,
|
117 |
+
layer_norm_first,
|
118 |
+
modality_cfg.prenet_layerdrop,
|
119 |
+
modality_cfg.prenet_dropout,
|
120 |
+
)
|
121 |
+
|
122 |
+
decoder = (
|
123 |
+
Decoder1d(modality_cfg.decoder, embed_dim)
|
124 |
+
if modality_cfg.decoder is not None
|
125 |
+
else None
|
126 |
+
)
|
127 |
+
|
128 |
+
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
|
129 |
+
|
130 |
+
super().__init__(
|
131 |
+
modality_cfg=modality_cfg,
|
132 |
+
embed_dim=embed_dim,
|
133 |
+
local_encoder=local_encoder,
|
134 |
+
project_features=project_features,
|
135 |
+
fixed_positional_encoder=None,
|
136 |
+
relative_positional_encoder=positional_encoder,
|
137 |
+
context_encoder=context_encoder,
|
138 |
+
decoder=decoder,
|
139 |
+
get_alibi_bias=alibi_bias_fn,
|
140 |
+
)
|
141 |
+
|
142 |
+
def convert_padding_mask(self, x, padding_mask):
|
143 |
+
def get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
|
144 |
+
"""
|
145 |
+
Computes the output length of the convolutional layers
|
146 |
+
"""
|
147 |
+
|
148 |
+
def _conv_out_length(input_length, kernel_size, stride):
|
149 |
+
return torch.floor((input_length - kernel_size) / stride + 1)
|
150 |
+
|
151 |
+
for i in range(len(self.feature_enc_layers)):
|
152 |
+
input_lengths = _conv_out_length(
|
153 |
+
input_lengths,
|
154 |
+
self.feature_enc_layers[i][1],
|
155 |
+
self.feature_enc_layers[i][2],
|
156 |
+
)
|
157 |
+
|
158 |
+
return input_lengths.to(torch.long)
|
159 |
+
|
160 |
+
if padding_mask is not None:
|
161 |
+
input_lengths = (1 - padding_mask.long()).sum(-1)
|
162 |
+
# apply conv formula to get real output_lengths
|
163 |
+
output_lengths = get_feat_extract_output_lengths(input_lengths)
|
164 |
+
|
165 |
+
if padding_mask.any():
|
166 |
+
padding_mask = torch.zeros(x.shape[:2], dtype=x.dtype, device=x.device)
|
167 |
+
|
168 |
+
# these two operations makes sure that all values
|
169 |
+
# before the output lengths indices are attended to
|
170 |
+
padding_mask[
|
171 |
+
(
|
172 |
+
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
173 |
+
output_lengths - 1,
|
174 |
+
)
|
175 |
+
] = 1
|
176 |
+
padding_mask = (
|
177 |
+
1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])
|
178 |
+
).bool()
|
179 |
+
else:
|
180 |
+
padding_mask = torch.zeros(
|
181 |
+
x.shape[:2], dtype=torch.bool, device=x.device
|
182 |
+
)
|
183 |
+
|
184 |
+
return padding_mask
|
185 |
+
|
186 |
+
def reset_parameters(self):
|
187 |
+
super().reset_parameters()
|
188 |
+
for mod in self.project_features.children():
|
189 |
+
if isinstance(mod, nn.Linear):
|
190 |
+
mod.reset_parameters()
|
191 |
+
if self.decoder is not None:
|
192 |
+
self.decoder.reset_parameters()
|
fairseq/examples/data2vec/models/modalities/base.py
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from collections import namedtuple
|
13 |
+
from dataclasses import dataclass
|
14 |
+
from functools import partial
|
15 |
+
from omegaconf import MISSING, II
|
16 |
+
from typing import Optional, Callable
|
17 |
+
from fairseq.data.data_utils import compute_mask_indices
|
18 |
+
from fairseq.modules import GradMultiply
|
19 |
+
from fairseq.utils import index_put
|
20 |
+
from examples.data2vec.data.modality import Modality
|
21 |
+
from .modules import D2vDecoderConfig
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class D2vModalityConfig:
|
28 |
+
type: Modality = MISSING
|
29 |
+
prenet_depth: int = 4
|
30 |
+
prenet_layerdrop: float = 0
|
31 |
+
prenet_dropout: float = 0
|
32 |
+
start_drop_path_rate: float = 0
|
33 |
+
end_drop_path_rate: float = 0
|
34 |
+
|
35 |
+
num_extra_tokens: int = 0
|
36 |
+
init_extra_token_zero: bool = True
|
37 |
+
|
38 |
+
mask_noise_std: float = 0.01
|
39 |
+
mask_prob_min: Optional[float] = None
|
40 |
+
mask_prob: float = 0.7
|
41 |
+
inverse_mask: bool = False
|
42 |
+
mask_prob_adjust: float = 0
|
43 |
+
keep_masked_pct: float = 0
|
44 |
+
|
45 |
+
mask_length: int = 5
|
46 |
+
add_masks: bool = False
|
47 |
+
remove_masks: bool = False
|
48 |
+
mask_dropout: float = 0.0
|
49 |
+
encoder_zero_mask: bool = True
|
50 |
+
|
51 |
+
mask_channel_prob: float = 0.0
|
52 |
+
mask_channel_length: int = 64
|
53 |
+
|
54 |
+
ema_local_encoder: bool = False # used in data2vec_multi
|
55 |
+
local_grad_mult: float = 1.0
|
56 |
+
|
57 |
+
use_alibi_encoder: bool = False
|
58 |
+
alibi_scale: float = 1.0
|
59 |
+
learned_alibi: bool = False
|
60 |
+
alibi_max_pos: Optional[int] = None
|
61 |
+
learned_alibi_scale: bool = False
|
62 |
+
learned_alibi_scale_per_head: bool = False
|
63 |
+
learned_alibi_scale_per_layer: bool = False
|
64 |
+
|
65 |
+
num_alibi_heads: int = II("model.num_heads")
|
66 |
+
model_depth: int = II("model.depth")
|
67 |
+
|
68 |
+
decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig()
|
69 |
+
|
70 |
+
|
71 |
+
MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"])
|
72 |
+
MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"])
|
73 |
+
|
74 |
+
|
75 |
+
class ModalitySpecificEncoder(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
modality_cfg: D2vModalityConfig,
|
79 |
+
embed_dim: int,
|
80 |
+
local_encoder: nn.Module,
|
81 |
+
project_features: nn.Module,
|
82 |
+
fixed_positional_encoder: Optional[nn.Module],
|
83 |
+
relative_positional_encoder: Optional[nn.Module],
|
84 |
+
context_encoder: nn.Module,
|
85 |
+
decoder: nn.Module,
|
86 |
+
get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]],
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
self.modality_cfg = modality_cfg
|
91 |
+
self.local_encoder = local_encoder
|
92 |
+
self.project_features = project_features
|
93 |
+
self.fixed_positional_encoder = fixed_positional_encoder
|
94 |
+
self.relative_positional_encoder = relative_positional_encoder
|
95 |
+
self.context_encoder = context_encoder
|
96 |
+
|
97 |
+
self.decoder = decoder
|
98 |
+
self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None
|
99 |
+
|
100 |
+
self.local_grad_mult = self.modality_cfg.local_grad_mult
|
101 |
+
|
102 |
+
self.extra_tokens = None
|
103 |
+
if modality_cfg.num_extra_tokens > 0:
|
104 |
+
self.extra_tokens = nn.Parameter(
|
105 |
+
torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim)
|
106 |
+
)
|
107 |
+
if not modality_cfg.init_extra_token_zero:
|
108 |
+
nn.init.normal_(self.extra_tokens)
|
109 |
+
elif self.extra_tokens.size(1) > 1:
|
110 |
+
nn.init.normal_(self.extra_tokens[:, 1:])
|
111 |
+
|
112 |
+
self.alibi_scale = None
|
113 |
+
if self.get_alibi_bias is not None:
|
114 |
+
self.alibi_scale = nn.Parameter(
|
115 |
+
torch.full(
|
116 |
+
(
|
117 |
+
(modality_cfg.prenet_depth + modality_cfg.model_depth)
|
118 |
+
if modality_cfg.learned_alibi_scale_per_layer
|
119 |
+
else 1,
|
120 |
+
1,
|
121 |
+
self.modality_cfg.num_alibi_heads
|
122 |
+
if modality_cfg.learned_alibi_scale_per_head
|
123 |
+
else 1,
|
124 |
+
1,
|
125 |
+
1,
|
126 |
+
),
|
127 |
+
modality_cfg.alibi_scale,
|
128 |
+
dtype=torch.float,
|
129 |
+
),
|
130 |
+
requires_grad=modality_cfg.learned_alibi_scale,
|
131 |
+
)
|
132 |
+
|
133 |
+
if modality_cfg.learned_alibi and self.get_alibi_bias is not None:
|
134 |
+
assert modality_cfg.alibi_max_pos is not None
|
135 |
+
alibi_bias = self.get_alibi_bias(
|
136 |
+
batch_size=1,
|
137 |
+
time_steps=modality_cfg.alibi_max_pos,
|
138 |
+
heads=modality_cfg.num_alibi_heads,
|
139 |
+
scale=1.0,
|
140 |
+
dtype=torch.float,
|
141 |
+
device="cpu",
|
142 |
+
)
|
143 |
+
self.alibi_bias = nn.Parameter(alibi_bias)
|
144 |
+
self.get_alibi_bias = partial(
|
145 |
+
_learned_alibi_bias, alibi_bias=self.alibi_bias
|
146 |
+
)
|
147 |
+
|
148 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
149 |
+
k = f"{name}.alibi_scale"
|
150 |
+
if k in state_dict and state_dict[k].dim() == 4:
|
151 |
+
state_dict[k] = state_dict[k].unsqueeze(0)
|
152 |
+
|
153 |
+
return state_dict
|
154 |
+
|
155 |
+
def convert_padding_mask(self, x, padding_mask):
|
156 |
+
return padding_mask
|
157 |
+
|
158 |
+
def decoder_input(self, x, mask_info: MaskInfo):
|
159 |
+
inp_drop = self.modality_cfg.decoder.input_dropout
|
160 |
+
if inp_drop > 0:
|
161 |
+
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
|
162 |
+
|
163 |
+
num_extra = self.modality_cfg.num_extra_tokens
|
164 |
+
|
165 |
+
if mask_info is not None:
|
166 |
+
num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra
|
167 |
+
|
168 |
+
mask_tokens = x.new_empty(
|
169 |
+
x.size(0),
|
170 |
+
num_masked,
|
171 |
+
x.size(-1),
|
172 |
+
).normal_(0, self.modality_cfg.mask_noise_std)
|
173 |
+
|
174 |
+
x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1)
|
175 |
+
x = torch.gather(x_, dim=1, index=mask_info.ids_restore)
|
176 |
+
|
177 |
+
if self.modality_cfg.decoder.add_positions_masked:
|
178 |
+
assert self.fixed_positional_encoder is not None
|
179 |
+
pos = self.fixed_positional_encoder(x, None)
|
180 |
+
x = x + (pos * mask_info.mask.unsqueeze(-1))
|
181 |
+
else:
|
182 |
+
x = x[:, num_extra:]
|
183 |
+
|
184 |
+
if self.modality_cfg.decoder.add_positions_all:
|
185 |
+
assert self.fixed_positional_encoder is not None
|
186 |
+
x = x + self.fixed_positional_encoder(x, None)
|
187 |
+
|
188 |
+
return x, mask_info
|
189 |
+
|
190 |
+
def local_features(self, features):
|
191 |
+
if self.local_grad_mult > 0:
|
192 |
+
if self.local_grad_mult == 1.0:
|
193 |
+
x = self.local_encoder(features)
|
194 |
+
else:
|
195 |
+
x = GradMultiply.apply(
|
196 |
+
self.local_encoder(features), self.local_grad_mult
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
with torch.no_grad():
|
200 |
+
x = self.local_encoder(features)
|
201 |
+
|
202 |
+
x = self.project_features(x)
|
203 |
+
return x
|
204 |
+
|
205 |
+
def contextualized_features(
|
206 |
+
self,
|
207 |
+
x,
|
208 |
+
padding_mask,
|
209 |
+
mask,
|
210 |
+
remove_masked,
|
211 |
+
clone_batch: int = 1,
|
212 |
+
mask_seeds: Optional[torch.Tensor] = None,
|
213 |
+
precomputed_mask=None,
|
214 |
+
):
|
215 |
+
|
216 |
+
if padding_mask is not None:
|
217 |
+
padding_mask = self.convert_padding_mask(x, padding_mask)
|
218 |
+
|
219 |
+
local_features = x
|
220 |
+
if mask and clone_batch == 1:
|
221 |
+
local_features = local_features.clone()
|
222 |
+
|
223 |
+
orig_B, orig_T, _ = x.shape
|
224 |
+
pre_mask_B = orig_B
|
225 |
+
mask_info = None
|
226 |
+
|
227 |
+
x_pos = None
|
228 |
+
if self.fixed_positional_encoder is not None:
|
229 |
+
x = x + self.fixed_positional_encoder(x, padding_mask)
|
230 |
+
|
231 |
+
if mask:
|
232 |
+
if clone_batch > 1:
|
233 |
+
x = x.repeat_interleave(clone_batch, 0)
|
234 |
+
if mask_seeds is not None:
|
235 |
+
clone_hash = [
|
236 |
+
int(hash((mask_seeds.seed, ind)) % 1e10)
|
237 |
+
for ind in range(clone_batch - 1)
|
238 |
+
]
|
239 |
+
clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1)
|
240 |
+
|
241 |
+
id = mask_seeds.ids
|
242 |
+
id = id.repeat_interleave(clone_batch, 0)
|
243 |
+
id = id.view(-1, clone_batch) + clone_hash.to(id)
|
244 |
+
id = id.view(-1)
|
245 |
+
mask_seeds = MaskSeed(
|
246 |
+
seed=mask_seeds.seed, update=mask_seeds.update, ids=id
|
247 |
+
)
|
248 |
+
if padding_mask is not None:
|
249 |
+
padding_mask = padding_mask.repeat_interleave(clone_batch, 0)
|
250 |
+
|
251 |
+
x, mask_info = self.compute_mask(
|
252 |
+
x,
|
253 |
+
padding_mask,
|
254 |
+
mask_seed=mask_seeds,
|
255 |
+
apply=self.relative_positional_encoder is not None or not remove_masked,
|
256 |
+
precomputed_mask=precomputed_mask,
|
257 |
+
)
|
258 |
+
|
259 |
+
if self.relative_positional_encoder is not None:
|
260 |
+
x_pos = self.relative_positional_encoder(x)
|
261 |
+
|
262 |
+
masked_padding_mask = padding_mask
|
263 |
+
if mask and remove_masked:
|
264 |
+
x = mask_info.x_unmasked
|
265 |
+
if x_pos is not None:
|
266 |
+
x = x + gather_unmasked(x_pos, mask_info)
|
267 |
+
|
268 |
+
if padding_mask is not None and padding_mask.any():
|
269 |
+
masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info)
|
270 |
+
if not masked_padding_mask.any():
|
271 |
+
masked_padding_mask = None
|
272 |
+
else:
|
273 |
+
masked_padding_mask = None
|
274 |
+
|
275 |
+
elif x_pos is not None:
|
276 |
+
x = x + x_pos
|
277 |
+
|
278 |
+
alibi_bias = None
|
279 |
+
alibi_scale = self.alibi_scale
|
280 |
+
|
281 |
+
if self.get_alibi_bias is not None:
|
282 |
+
alibi_bias = self.get_alibi_bias(
|
283 |
+
batch_size=pre_mask_B,
|
284 |
+
time_steps=orig_T,
|
285 |
+
heads=self.modality_cfg.num_alibi_heads,
|
286 |
+
dtype=torch.float32,
|
287 |
+
device=x.device,
|
288 |
+
)
|
289 |
+
|
290 |
+
if alibi_scale is not None:
|
291 |
+
alibi_scale = alibi_scale.clamp_min(0)
|
292 |
+
if alibi_scale.size(0) == 1:
|
293 |
+
alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias)
|
294 |
+
alibi_scale = None
|
295 |
+
|
296 |
+
if clone_batch > 1:
|
297 |
+
alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0)
|
298 |
+
|
299 |
+
if mask_info is not None and remove_masked:
|
300 |
+
alibi_bias = masked_alibi(alibi_bias, mask_info)
|
301 |
+
|
302 |
+
if self.extra_tokens is not None:
|
303 |
+
num = self.extra_tokens.size(1)
|
304 |
+
x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1)
|
305 |
+
if masked_padding_mask is not None:
|
306 |
+
# B x T
|
307 |
+
masked_padding_mask = F.pad(masked_padding_mask, (num, 0))
|
308 |
+
if alibi_bias is not None:
|
309 |
+
# B x H x T x T
|
310 |
+
alibi_bias = F.pad(alibi_bias, (num, 0, num, 0))
|
311 |
+
|
312 |
+
x = self.context_encoder(
|
313 |
+
x,
|
314 |
+
masked_padding_mask,
|
315 |
+
alibi_bias,
|
316 |
+
alibi_scale[: self.modality_cfg.prenet_depth]
|
317 |
+
if alibi_scale is not None
|
318 |
+
else None,
|
319 |
+
)
|
320 |
+
|
321 |
+
return {
|
322 |
+
"x": x,
|
323 |
+
"local_features": local_features,
|
324 |
+
"padding_mask": masked_padding_mask,
|
325 |
+
"alibi_bias": alibi_bias,
|
326 |
+
"alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :]
|
327 |
+
if alibi_scale is not None and alibi_scale.size(0) > 1
|
328 |
+
else alibi_scale,
|
329 |
+
"encoder_mask": mask_info,
|
330 |
+
}
|
331 |
+
|
332 |
+
def forward(
|
333 |
+
self,
|
334 |
+
features,
|
335 |
+
padding_mask,
|
336 |
+
mask: bool,
|
337 |
+
remove_masked: bool,
|
338 |
+
clone_batch: int = 1,
|
339 |
+
mask_seeds: Optional[torch.Tensor] = None,
|
340 |
+
precomputed_mask=None,
|
341 |
+
):
|
342 |
+
x = self.local_features(features)
|
343 |
+
return self.contextualized_features(
|
344 |
+
x,
|
345 |
+
padding_mask,
|
346 |
+
mask,
|
347 |
+
remove_masked,
|
348 |
+
clone_batch,
|
349 |
+
mask_seeds,
|
350 |
+
precomputed_mask,
|
351 |
+
)
|
352 |
+
|
353 |
+
def reset_parameters(self):
|
354 |
+
pass
|
355 |
+
|
356 |
+
def compute_mask(
|
357 |
+
self,
|
358 |
+
x,
|
359 |
+
padding_mask,
|
360 |
+
mask_seed: Optional[MaskSeed],
|
361 |
+
apply,
|
362 |
+
precomputed_mask,
|
363 |
+
):
|
364 |
+
if precomputed_mask is not None:
|
365 |
+
mask = precomputed_mask
|
366 |
+
mask_info = self.make_maskinfo(x, mask)
|
367 |
+
else:
|
368 |
+
B, T, C = x.shape
|
369 |
+
cfg = self.modality_cfg
|
370 |
+
|
371 |
+
mask_prob = cfg.mask_prob
|
372 |
+
|
373 |
+
if (
|
374 |
+
cfg.mask_prob_min is not None
|
375 |
+
and cfg.mask_prob_min >= 0
|
376 |
+
and cfg.mask_prob_min < mask_prob
|
377 |
+
):
|
378 |
+
mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob)
|
379 |
+
|
380 |
+
if mask_prob > 0:
|
381 |
+
if cfg.mask_length == 1:
|
382 |
+
mask_info = random_masking(x, mask_prob, mask_seed)
|
383 |
+
else:
|
384 |
+
if self.modality_cfg.inverse_mask:
|
385 |
+
mask_prob = 1 - mask_prob
|
386 |
+
|
387 |
+
mask = compute_mask_indices(
|
388 |
+
(B, T),
|
389 |
+
padding_mask,
|
390 |
+
mask_prob,
|
391 |
+
cfg.mask_length,
|
392 |
+
min_masks=1,
|
393 |
+
require_same_masks=True,
|
394 |
+
mask_dropout=cfg.mask_dropout,
|
395 |
+
add_masks=cfg.add_masks,
|
396 |
+
seed=mask_seed.seed if mask_seed is not None else None,
|
397 |
+
epoch=mask_seed.update if mask_seed is not None else None,
|
398 |
+
indices=mask_seed.ids if mask_seed is not None else None,
|
399 |
+
)
|
400 |
+
|
401 |
+
mask = torch.from_numpy(mask).to(device=x.device)
|
402 |
+
if self.modality_cfg.inverse_mask:
|
403 |
+
mask = 1 - mask
|
404 |
+
mask_info = self.make_maskinfo(x, mask)
|
405 |
+
else:
|
406 |
+
mask_info = None
|
407 |
+
|
408 |
+
if apply:
|
409 |
+
x = self.apply_mask(x, mask_info)
|
410 |
+
|
411 |
+
return x, mask_info
|
412 |
+
|
413 |
+
def make_maskinfo(self, x, mask, shape=None):
|
414 |
+
if shape is None:
|
415 |
+
B, T, D = x.shape
|
416 |
+
else:
|
417 |
+
B, T, D = shape
|
418 |
+
|
419 |
+
mask = mask.to(torch.uint8)
|
420 |
+
ids_shuffle = mask.argsort(dim=1)
|
421 |
+
ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D)
|
422 |
+
|
423 |
+
len_keep = T - mask[0].sum()
|
424 |
+
if self.modality_cfg.keep_masked_pct > 0:
|
425 |
+
len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct)
|
426 |
+
|
427 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
428 |
+
|
429 |
+
if shape is not None:
|
430 |
+
x_unmasked = None
|
431 |
+
else:
|
432 |
+
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
|
433 |
+
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
|
434 |
+
|
435 |
+
mask_info = MaskInfo(
|
436 |
+
x_unmasked=x_unmasked,
|
437 |
+
mask=mask,
|
438 |
+
ids_restore=ids_restore,
|
439 |
+
ids_keep=ids_keep,
|
440 |
+
)
|
441 |
+
return mask_info
|
442 |
+
|
443 |
+
def apply_mask(self, x, mask_info):
|
444 |
+
cfg = self.modality_cfg
|
445 |
+
B, T, C = x.shape
|
446 |
+
|
447 |
+
if mask_info is not None:
|
448 |
+
mask = mask_info.mask
|
449 |
+
if cfg.encoder_zero_mask:
|
450 |
+
x = x * (1 - mask.type_as(x).unsqueeze(-1))
|
451 |
+
else:
|
452 |
+
num_masks = mask.sum().item()
|
453 |
+
masks = x.new_empty(num_masks, x.size(-1)).normal_(
|
454 |
+
0, cfg.mask_noise_std
|
455 |
+
)
|
456 |
+
x = index_put(x, mask, masks)
|
457 |
+
if cfg.mask_channel_prob > 0:
|
458 |
+
mask_channel = compute_mask_indices(
|
459 |
+
(B, C),
|
460 |
+
None,
|
461 |
+
cfg.mask_channel_prob,
|
462 |
+
cfg.mask_channel_length,
|
463 |
+
)
|
464 |
+
mask_channel = (
|
465 |
+
torch.from_numpy(mask_channel)
|
466 |
+
.to(x.device)
|
467 |
+
.unsqueeze(1)
|
468 |
+
.expand(-1, T, -1)
|
469 |
+
)
|
470 |
+
x = index_put(x, mask_channel, 0)
|
471 |
+
return x
|
472 |
+
|
473 |
+
def remove_pretraining_modules(self, keep_decoder=False):
|
474 |
+
if not keep_decoder:
|
475 |
+
self.decoder = None
|
476 |
+
|
477 |
+
|
478 |
+
def get_annealed_rate(start, end, curr_step, total_steps):
|
479 |
+
if curr_step >= total_steps:
|
480 |
+
return end
|
481 |
+
r = end - start
|
482 |
+
pct_remaining = 1 - curr_step / total_steps
|
483 |
+
return end - r * pct_remaining
|
484 |
+
|
485 |
+
|
486 |
+
# adapted from MAE
|
487 |
+
def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]):
|
488 |
+
N, L, D = x.shape # batch, length, dim
|
489 |
+
len_keep = int(L * (1 - mask_ratio))
|
490 |
+
|
491 |
+
generator = None
|
492 |
+
if mask_seed is not None:
|
493 |
+
seed = int(
|
494 |
+
hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6
|
495 |
+
)
|
496 |
+
generator = torch.Generator(device=x.device)
|
497 |
+
generator.manual_seed(seed)
|
498 |
+
|
499 |
+
noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1]
|
500 |
+
|
501 |
+
# sort noise for each sample
|
502 |
+
ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove
|
503 |
+
ids_restore = ids_shuffle.argsort(dim=1)
|
504 |
+
|
505 |
+
# keep the first subset
|
506 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
507 |
+
ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D)
|
508 |
+
x_unmasked = torch.gather(x, dim=1, index=ids_keep)
|
509 |
+
|
510 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
511 |
+
mask = torch.ones([N, L], dtype=x.dtype, device=x.device)
|
512 |
+
mask[:, :len_keep] = 0
|
513 |
+
# unshuffle to get the binary mask
|
514 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
515 |
+
|
516 |
+
ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D)
|
517 |
+
|
518 |
+
return MaskInfo(
|
519 |
+
x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
|
524 |
+
return torch.gather(
|
525 |
+
x,
|
526 |
+
dim=1,
|
527 |
+
index=mask_info.ids_keep,
|
528 |
+
)
|
529 |
+
|
530 |
+
|
531 |
+
def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor:
|
532 |
+
return torch.gather(
|
533 |
+
x,
|
534 |
+
dim=1,
|
535 |
+
index=mask_info.ids_keep[..., 0], # ignore the feature dimension
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
def get_alibi(
|
540 |
+
max_positions: int,
|
541 |
+
attention_heads: int,
|
542 |
+
dims: int = 1,
|
543 |
+
distance: str = "manhattan",
|
544 |
+
):
|
545 |
+
def get_slopes(n):
|
546 |
+
def get_slopes_power_of_2(n):
|
547 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
548 |
+
ratio = start
|
549 |
+
return [start * ratio**i for i in range(n)]
|
550 |
+
|
551 |
+
# In the paper, we only train models that have 2^a heads for some
|
552 |
+
# a. This function has some good properties that only occur when
|
553 |
+
# the input is a power of 2. To maintain that even when the number
|
554 |
+
# of heads is not a power of 2, we use this workaround.
|
555 |
+
if math.log2(n).is_integer():
|
556 |
+
return get_slopes_power_of_2(n)
|
557 |
+
else:
|
558 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
559 |
+
return (
|
560 |
+
get_slopes_power_of_2(closest_power_of_2)
|
561 |
+
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
562 |
+
)
|
563 |
+
|
564 |
+
maxpos = max_positions
|
565 |
+
attn_heads = attention_heads
|
566 |
+
slopes = torch.Tensor(get_slopes(attn_heads))
|
567 |
+
|
568 |
+
if dims == 1:
|
569 |
+
# prepare alibi position linear bias. Note that wav2vec2 is non
|
570 |
+
# autoregressive model so we want a symmetric mask with 0 on the
|
571 |
+
# diagonal and other wise linear decreasing valuees
|
572 |
+
pos_bias = (
|
573 |
+
torch.abs(
|
574 |
+
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
|
575 |
+
)
|
576 |
+
* -1
|
577 |
+
)
|
578 |
+
elif dims == 2:
|
579 |
+
if distance == "manhattan":
|
580 |
+
df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2)
|
581 |
+
elif distance == "euclidean":
|
582 |
+
df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
|
583 |
+
|
584 |
+
n = math.sqrt(max_positions)
|
585 |
+
assert n.is_integer(), n
|
586 |
+
n = int(n)
|
587 |
+
|
588 |
+
pos_bias = torch.zeros((max_positions, max_positions))
|
589 |
+
|
590 |
+
for i in range(n):
|
591 |
+
for j in range(n):
|
592 |
+
for k in range(n):
|
593 |
+
for l in range(n):
|
594 |
+
new_x = i * n + j
|
595 |
+
new_y = k * n + l
|
596 |
+
pos_bias[new_x, new_y] = -df(i, j, k, l)
|
597 |
+
|
598 |
+
else:
|
599 |
+
raise Exception(f"unsupported number of alibi dims: {dims}")
|
600 |
+
|
601 |
+
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
|
602 |
+
attn_heads, -1, -1
|
603 |
+
)
|
604 |
+
|
605 |
+
return alibi_bias
|
606 |
+
|
607 |
+
|
608 |
+
def get_alibi_bias(
|
609 |
+
alibi_biases,
|
610 |
+
batch_size,
|
611 |
+
time_steps,
|
612 |
+
heads,
|
613 |
+
dtype,
|
614 |
+
device,
|
615 |
+
dims=1,
|
616 |
+
distance="manhattan",
|
617 |
+
):
|
618 |
+
cache_key = f"{dims}_{heads}_{distance}"
|
619 |
+
|
620 |
+
buffered = alibi_biases.get(cache_key, None)
|
621 |
+
|
622 |
+
target_size = heads * batch_size
|
623 |
+
if (
|
624 |
+
buffered is None
|
625 |
+
or buffered.size(0) < target_size
|
626 |
+
or buffered.size(1) < time_steps
|
627 |
+
or buffered.dtype != dtype
|
628 |
+
or buffered.device != device
|
629 |
+
):
|
630 |
+
bt = max(time_steps, buffered.size(1) if buffered is not None else 0)
|
631 |
+
bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads
|
632 |
+
|
633 |
+
buffered = (
|
634 |
+
get_alibi(bt, heads, dims=dims, distance=distance)
|
635 |
+
.to(dtype=dtype, device=device)
|
636 |
+
.repeat(bn, 1, 1)
|
637 |
+
)
|
638 |
+
|
639 |
+
alibi_biases[cache_key] = buffered
|
640 |
+
|
641 |
+
b = buffered[:target_size, :time_steps, :time_steps]
|
642 |
+
b = b.view(batch_size, heads, time_steps, time_steps)
|
643 |
+
return b
|
644 |
+
|
645 |
+
|
646 |
+
def _learned_alibi_bias(
|
647 |
+
alibi_bias,
|
648 |
+
batch_size,
|
649 |
+
time_steps,
|
650 |
+
heads,
|
651 |
+
scale,
|
652 |
+
dtype,
|
653 |
+
device,
|
654 |
+
):
|
655 |
+
assert alibi_bias.size(1) == heads, alibi_bias.shape
|
656 |
+
assert alibi_bias.dtype == dtype, alibi_bias.dtype
|
657 |
+
assert alibi_bias.device == device, alibi_bias.device
|
658 |
+
|
659 |
+
if alibi_bias.size(-1) < time_steps:
|
660 |
+
psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2)
|
661 |
+
alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate")
|
662 |
+
|
663 |
+
alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale
|
664 |
+
return alibi_bias[..., :time_steps, :time_steps]
|
665 |
+
|
666 |
+
|
667 |
+
def masked_alibi(alibi_bias, mask_info):
|
668 |
+
H = alibi_bias.size(1)
|
669 |
+
|
670 |
+
orig_bias = alibi_bias
|
671 |
+
|
672 |
+
index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1)
|
673 |
+
alibi_bias = torch.gather(
|
674 |
+
orig_bias,
|
675 |
+
dim=-2,
|
676 |
+
index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)),
|
677 |
+
)
|
678 |
+
alibi_bias = torch.gather(
|
679 |
+
alibi_bias,
|
680 |
+
dim=-1,
|
681 |
+
index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1),
|
682 |
+
)
|
683 |
+
|
684 |
+
return alibi_bias
|
fairseq/examples/data2vec/models/modalities/images.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from functools import partial
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from typing import Callable, Dict, Optional
|
13 |
+
from timm.models.layers import to_2tuple
|
14 |
+
from fairseq.tasks import FairseqTask
|
15 |
+
from examples.data2vec.models.mae import get_2d_sincos_pos_embed, PatchEmbed
|
16 |
+
from .base import (
|
17 |
+
D2vModalityConfig,
|
18 |
+
ModalitySpecificEncoder,
|
19 |
+
get_alibi_bias,
|
20 |
+
MaskSeed,
|
21 |
+
)
|
22 |
+
from .modules import (
|
23 |
+
BlockEncoder,
|
24 |
+
Decoder2d,
|
25 |
+
FixedPositionalEncoder,
|
26 |
+
TransformerDecoder,
|
27 |
+
EncDecTransformerDecoder,
|
28 |
+
)
|
29 |
+
from examples.data2vec.data.modality import Modality
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class D2vImageConfig(D2vModalityConfig):
|
34 |
+
type: Modality = Modality.IMAGE
|
35 |
+
|
36 |
+
input_size: int = 224
|
37 |
+
in_chans: int = 3
|
38 |
+
patch_size: int = 16
|
39 |
+
embed_dim: int = 768
|
40 |
+
|
41 |
+
alibi_dims: int = 2
|
42 |
+
alibi_distance: str = "manhattan"
|
43 |
+
|
44 |
+
fixed_positions: bool = True
|
45 |
+
|
46 |
+
transformer_decoder: bool = False
|
47 |
+
enc_dec_transformer: bool = False
|
48 |
+
|
49 |
+
|
50 |
+
class ImageEncoder(ModalitySpecificEncoder):
|
51 |
+
|
52 |
+
modality_cfg: D2vImageConfig
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
modality_cfg: D2vImageConfig,
|
57 |
+
embed_dim: int,
|
58 |
+
make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList],
|
59 |
+
norm_layer: Callable[[int], nn.LayerNorm],
|
60 |
+
layer_norm_first: bool,
|
61 |
+
alibi_biases: Dict,
|
62 |
+
task: Optional[FairseqTask],
|
63 |
+
):
|
64 |
+
|
65 |
+
img_size = to_2tuple(modality_cfg.input_size)
|
66 |
+
patch_size = to_2tuple(modality_cfg.patch_size)
|
67 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
68 |
+
|
69 |
+
local_encoder = PatchEmbed(
|
70 |
+
modality_cfg.input_size,
|
71 |
+
modality_cfg.patch_size,
|
72 |
+
modality_cfg.in_chans,
|
73 |
+
modality_cfg.embed_dim,
|
74 |
+
)
|
75 |
+
|
76 |
+
w = local_encoder.proj.weight.data
|
77 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
78 |
+
|
79 |
+
if modality_cfg.embed_dim != embed_dim:
|
80 |
+
local_encoder = nn.Sequential(
|
81 |
+
local_encoder,
|
82 |
+
nn.Linear(modality_cfg.embed_dim, embed_dim),
|
83 |
+
)
|
84 |
+
|
85 |
+
project_features = nn.Identity()
|
86 |
+
|
87 |
+
pos_embed = nn.Parameter(
|
88 |
+
torch.zeros(1, num_patches, embed_dim), requires_grad=False
|
89 |
+
)
|
90 |
+
|
91 |
+
side_n = int(num_patches ** 0.5)
|
92 |
+
|
93 |
+
emb = get_2d_sincos_pos_embed(
|
94 |
+
pos_embed.shape[-1],
|
95 |
+
side_n,
|
96 |
+
cls_token=False,
|
97 |
+
)
|
98 |
+
pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
|
99 |
+
fixed_positional_encoder = (
|
100 |
+
FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None
|
101 |
+
)
|
102 |
+
|
103 |
+
dpr = np.linspace(
|
104 |
+
modality_cfg.start_drop_path_rate,
|
105 |
+
modality_cfg.end_drop_path_rate,
|
106 |
+
modality_cfg.prenet_depth,
|
107 |
+
)
|
108 |
+
|
109 |
+
context_encoder = BlockEncoder(
|
110 |
+
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
|
111 |
+
norm_layer(embed_dim) if not layer_norm_first else None,
|
112 |
+
layer_norm_first,
|
113 |
+
modality_cfg.prenet_layerdrop,
|
114 |
+
modality_cfg.prenet_dropout,
|
115 |
+
)
|
116 |
+
|
117 |
+
if modality_cfg.transformer_decoder:
|
118 |
+
if modality_cfg.enc_dec_transformer:
|
119 |
+
decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim)
|
120 |
+
else:
|
121 |
+
dec_enc = BlockEncoder(
|
122 |
+
nn.ModuleList(
|
123 |
+
make_block(0, modality_cfg.decoder.decoder_dim, 8)
|
124 |
+
for _ in range(modality_cfg.decoder.decoder_layers)
|
125 |
+
),
|
126 |
+
None,
|
127 |
+
layer_norm_first,
|
128 |
+
0,
|
129 |
+
0,
|
130 |
+
)
|
131 |
+
decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc)
|
132 |
+
else:
|
133 |
+
decoder = (
|
134 |
+
Decoder2d(modality_cfg.decoder, embed_dim, side_n, side_n)
|
135 |
+
if modality_cfg.decoder is not None
|
136 |
+
else None
|
137 |
+
)
|
138 |
+
|
139 |
+
alibi_bias_fn = partial(
|
140 |
+
get_alibi_bias,
|
141 |
+
alibi_biases=alibi_biases,
|
142 |
+
heads=modality_cfg.num_alibi_heads,
|
143 |
+
dims=modality_cfg.alibi_dims,
|
144 |
+
distance=modality_cfg.alibi_distance,
|
145 |
+
)
|
146 |
+
|
147 |
+
super().__init__(
|
148 |
+
modality_cfg=modality_cfg,
|
149 |
+
embed_dim=embed_dim,
|
150 |
+
local_encoder=local_encoder,
|
151 |
+
project_features=project_features,
|
152 |
+
fixed_positional_encoder=fixed_positional_encoder,
|
153 |
+
relative_positional_encoder=None,
|
154 |
+
context_encoder=context_encoder,
|
155 |
+
decoder=decoder,
|
156 |
+
get_alibi_bias=alibi_bias_fn,
|
157 |
+
)
|
158 |
+
|
159 |
+
def reset_parameters(self):
|
160 |
+
super().reset_parameters()
|
161 |
+
if self.decoder is not None:
|
162 |
+
self.decoder.reset_parameters()
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def patchify(self, imgs):
|
166 |
+
"""
|
167 |
+
imgs: (N, 3, H, W)
|
168 |
+
x: (N, L, patch_size**2 *3)
|
169 |
+
"""
|
170 |
+
p = self.modality_cfg.patch_size
|
171 |
+
h = w = imgs.shape[2] // p
|
172 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
173 |
+
x = torch.einsum("nchpwq->nhwpqc", x)
|
174 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
175 |
+
|
176 |
+
return x
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def unpatchify(self, x):
|
180 |
+
"""
|
181 |
+
x: (N, L, patch_size**2 *3)
|
182 |
+
imgs: (N, 3, H, W)
|
183 |
+
"""
|
184 |
+
p = self.modality_cfg.patch_size
|
185 |
+
h = w = int(x.shape[1] ** 0.5)
|
186 |
+
assert h * w == x.shape[1]
|
187 |
+
|
188 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
189 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
190 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
191 |
+
return imgs
|
192 |
+
|
193 |
+
def compute_mask(
|
194 |
+
self,
|
195 |
+
x,
|
196 |
+
padding_mask,
|
197 |
+
mask_seed: Optional[MaskSeed],
|
198 |
+
apply,
|
199 |
+
shape=None,
|
200 |
+
precomputed_mask=None,
|
201 |
+
):
|
202 |
+
mlen = self.modality_cfg.mask_length
|
203 |
+
if mlen <= 1:
|
204 |
+
return super().compute_mask(
|
205 |
+
x, padding_mask, mask_seed, apply, precomputed_mask
|
206 |
+
)
|
207 |
+
|
208 |
+
if precomputed_mask is not None:
|
209 |
+
mask = precomputed_mask
|
210 |
+
else:
|
211 |
+
from fairseq.data.data_utils import compute_block_mask_2d
|
212 |
+
|
213 |
+
if shape is not None:
|
214 |
+
B, L, D = shape
|
215 |
+
else:
|
216 |
+
B, L, D = x.shape
|
217 |
+
|
218 |
+
mask = compute_block_mask_2d(
|
219 |
+
shape=(B, L),
|
220 |
+
mask_prob=self.modality_cfg.mask_prob,
|
221 |
+
mask_length=self.modality_cfg.mask_length,
|
222 |
+
mask_prob_adjust=self.modality_cfg.mask_prob_adjust,
|
223 |
+
inverse_mask=self.modality_cfg.inverse_mask,
|
224 |
+
require_same_masks=True,
|
225 |
+
mask_dropout=self.modality_cfg.mask_dropout,
|
226 |
+
)
|
227 |
+
|
228 |
+
mask_info = self.make_maskinfo(x, mask, shape)
|
229 |
+
if apply:
|
230 |
+
x = self.apply_mask(x, mask_info)
|
231 |
+
|
232 |
+
return x, mask_info
|
233 |
+
|
234 |
+
def decoder_input(self, x, mask_info):
|
235 |
+
if (
|
236 |
+
not self.modality_cfg.transformer_decoder
|
237 |
+
or not self.modality_cfg.enc_dec_transformer
|
238 |
+
):
|
239 |
+
return super().decoder_input(x, mask_info)
|
240 |
+
|
241 |
+
inp_drop = self.modality_cfg.decoder.input_dropout
|
242 |
+
if inp_drop > 0:
|
243 |
+
x = F.dropout(x, inp_drop, training=self.training, inplace=True)
|
244 |
+
|
245 |
+
kv = x[:, self.modality_cfg.num_extra_tokens :]
|
246 |
+
|
247 |
+
assert self.fixed_positional_encoder is not None
|
248 |
+
pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1)
|
249 |
+
|
250 |
+
mask = mask_info.mask.bool()
|
251 |
+
if self.modality_cfg.decoder.add_positions_all:
|
252 |
+
kv = kv + pos[~mask].view(kv.shape)
|
253 |
+
|
254 |
+
q = pos[mask].view(x.size(0), -1, x.size(-1))
|
255 |
+
|
256 |
+
return q, kv
|
fairseq/examples/data2vec/models/modalities/modules.py
ADDED
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from fairseq.modules import (
|
12 |
+
LayerNorm,
|
13 |
+
SamePad,
|
14 |
+
SamePad2d,
|
15 |
+
TransposeLast,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class D2vDecoderConfig:
|
21 |
+
decoder_dim: int = 384
|
22 |
+
decoder_groups: int = 16
|
23 |
+
decoder_kernel: int = 5
|
24 |
+
decoder_layers: int = 5
|
25 |
+
input_dropout: float = 0.1
|
26 |
+
|
27 |
+
add_positions_masked: bool = False
|
28 |
+
add_positions_all: bool = False
|
29 |
+
|
30 |
+
decoder_residual: bool = True
|
31 |
+
projection_layers: int = 1
|
32 |
+
projection_ratio: float = 2.0
|
33 |
+
|
34 |
+
|
35 |
+
class FixedPositionalEncoder(nn.Module):
|
36 |
+
def __init__(self, pos_embed):
|
37 |
+
super().__init__()
|
38 |
+
self.positions = pos_embed
|
39 |
+
|
40 |
+
def forward(self, x, padding_mask):
|
41 |
+
return self.positions
|
42 |
+
|
43 |
+
|
44 |
+
class TextFeatPositionalEncoder(nn.Module):
|
45 |
+
"""
|
46 |
+
Original encoder expects (B, T) long input. This module wraps it to take
|
47 |
+
local_encoder output which are (B, T, D) float tensors
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self, pos_encoder):
|
51 |
+
super().__init__()
|
52 |
+
self.pos_encoder = pos_encoder
|
53 |
+
|
54 |
+
def forward(self, x, padding_mask):
|
55 |
+
# assume padded token embeddings are 0s
|
56 |
+
# TODO: consider using padding_mask as input
|
57 |
+
return self.pos_encoder(x[..., 0])
|
58 |
+
|
59 |
+
|
60 |
+
class BlockEncoder(nn.Module):
|
61 |
+
def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout):
|
62 |
+
super().__init__()
|
63 |
+
self.blocks = blocks
|
64 |
+
self.norm = norm_layer
|
65 |
+
self.layer_norm_first = layer_norm_first
|
66 |
+
self.layerdrop = layerdrop
|
67 |
+
self.dropout = nn.Dropout(dropout, inplace=True)
|
68 |
+
|
69 |
+
def forward(self, x, padding_mask, alibi_bias, alibi_scale):
|
70 |
+
if self.norm is not None and not self.layer_norm_first:
|
71 |
+
x = self.norm(x)
|
72 |
+
|
73 |
+
x = self.dropout(x)
|
74 |
+
|
75 |
+
for i, blk in enumerate(self.blocks):
|
76 |
+
if (
|
77 |
+
not self.training
|
78 |
+
or self.layerdrop == 0
|
79 |
+
or (np.random.random() > self.layerdrop)
|
80 |
+
):
|
81 |
+
ab = alibi_bias
|
82 |
+
if ab is not None and alibi_scale is not None:
|
83 |
+
scale = (
|
84 |
+
alibi_scale[i]
|
85 |
+
if alibi_scale.size(0) > 1
|
86 |
+
else alibi_scale.squeeze(0)
|
87 |
+
)
|
88 |
+
ab = ab * scale.type_as(ab)
|
89 |
+
x, _ = blk(x, padding_mask, ab)
|
90 |
+
|
91 |
+
if self.norm is not None and self.layer_norm_first:
|
92 |
+
x = self.norm(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class DecoderBase(nn.Module):
|
98 |
+
decoder_cfg: D2vDecoderConfig
|
99 |
+
|
100 |
+
def __init__(self, cfg: D2vDecoderConfig):
|
101 |
+
super().__init__()
|
102 |
+
|
103 |
+
self.decoder_cfg = cfg
|
104 |
+
|
105 |
+
def reset_parameters(self):
|
106 |
+
for mod in self.proj.modules():
|
107 |
+
if isinstance(mod, nn.Linear):
|
108 |
+
mod.reset_parameters()
|
109 |
+
|
110 |
+
def add_residual(self, x, residual, i, mask_info):
|
111 |
+
if (
|
112 |
+
residual is None
|
113 |
+
or not self.decoder_cfg.decoder_residual
|
114 |
+
or residual.size(1) != x.size(1)
|
115 |
+
):
|
116 |
+
return x
|
117 |
+
|
118 |
+
ret = x + residual
|
119 |
+
|
120 |
+
return ret
|
121 |
+
|
122 |
+
|
123 |
+
class Decoder1d(DecoderBase):
|
124 |
+
def __init__(self, cfg: D2vDecoderConfig, input_dim):
|
125 |
+
super().__init__(cfg)
|
126 |
+
|
127 |
+
def make_block(in_dim):
|
128 |
+
block = [
|
129 |
+
nn.Conv1d(
|
130 |
+
in_dim,
|
131 |
+
cfg.decoder_dim,
|
132 |
+
kernel_size=cfg.decoder_kernel,
|
133 |
+
padding=cfg.decoder_kernel // 2,
|
134 |
+
groups=cfg.decoder_groups,
|
135 |
+
),
|
136 |
+
SamePad(cfg.decoder_kernel),
|
137 |
+
TransposeLast(),
|
138 |
+
LayerNorm(cfg.decoder_dim, elementwise_affine=False),
|
139 |
+
TransposeLast(),
|
140 |
+
nn.GELU(),
|
141 |
+
]
|
142 |
+
|
143 |
+
return nn.Sequential(*block)
|
144 |
+
|
145 |
+
self.blocks = nn.Sequential(
|
146 |
+
*[
|
147 |
+
make_block(input_dim if i == 0 else cfg.decoder_dim)
|
148 |
+
for i in range(cfg.decoder_layers)
|
149 |
+
]
|
150 |
+
)
|
151 |
+
|
152 |
+
projs = []
|
153 |
+
curr_dim = cfg.decoder_dim
|
154 |
+
for i in range(cfg.projection_layers - 1):
|
155 |
+
next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim
|
156 |
+
projs.append(nn.Linear(curr_dim, next_dim))
|
157 |
+
projs.append(nn.GELU())
|
158 |
+
curr_dim = next_dim
|
159 |
+
projs.append(nn.Linear(curr_dim, input_dim))
|
160 |
+
if len(projs) == 1:
|
161 |
+
self.proj = projs[0]
|
162 |
+
else:
|
163 |
+
self.proj = nn.Sequential(*projs)
|
164 |
+
|
165 |
+
def forward(self, x, mask_info):
|
166 |
+
|
167 |
+
x = x.transpose(1, 2)
|
168 |
+
|
169 |
+
residual = x
|
170 |
+
|
171 |
+
for i, layer in enumerate(self.blocks):
|
172 |
+
x = layer(x)
|
173 |
+
x = self.add_residual(x, residual, i, mask_info)
|
174 |
+
residual = x
|
175 |
+
|
176 |
+
x = x.transpose(1, 2)
|
177 |
+
x = self.proj(x)
|
178 |
+
return x
|
179 |
+
|
180 |
+
|
181 |
+
class Decoder2d(DecoderBase):
|
182 |
+
def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size):
|
183 |
+
super().__init__(cfg)
|
184 |
+
|
185 |
+
self.h_size = h_size
|
186 |
+
self.w_size = w_size
|
187 |
+
|
188 |
+
def make_block(in_dim):
|
189 |
+
block = [
|
190 |
+
nn.Conv2d(
|
191 |
+
in_dim,
|
192 |
+
cfg.decoder_dim,
|
193 |
+
kernel_size=cfg.decoder_kernel,
|
194 |
+
padding=cfg.decoder_kernel // 2,
|
195 |
+
groups=cfg.decoder_groups,
|
196 |
+
),
|
197 |
+
SamePad2d(cfg.decoder_kernel),
|
198 |
+
TransposeLast(tranpose_dim=-3),
|
199 |
+
LayerNorm(cfg.decoder_dim, elementwise_affine=False),
|
200 |
+
TransposeLast(tranpose_dim=-3),
|
201 |
+
nn.GELU(),
|
202 |
+
]
|
203 |
+
|
204 |
+
return nn.Sequential(*block)
|
205 |
+
|
206 |
+
self.blocks = nn.Sequential(
|
207 |
+
*[
|
208 |
+
make_block(input_dim if i == 0 else cfg.decoder_dim)
|
209 |
+
for i in range(cfg.decoder_layers)
|
210 |
+
]
|
211 |
+
)
|
212 |
+
|
213 |
+
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
|
214 |
+
|
215 |
+
def forward(self, x, mask_info):
|
216 |
+
B, T, C = x.shape
|
217 |
+
|
218 |
+
x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size)
|
219 |
+
|
220 |
+
residual = x
|
221 |
+
|
222 |
+
for i, layer in enumerate(self.blocks):
|
223 |
+
x = layer(x)
|
224 |
+
x = self.add_residual(x, residual, i, mask_info)
|
225 |
+
residual = x
|
226 |
+
|
227 |
+
x = x.reshape(B, -1, T).transpose(1, 2)
|
228 |
+
x = self.proj(x)
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
class TransformerDecoder(nn.Module):
|
233 |
+
decoder_cfg: D2vDecoderConfig
|
234 |
+
|
235 |
+
def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
self.decoder_cfg = cfg
|
239 |
+
|
240 |
+
self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
|
241 |
+
|
242 |
+
self.encoder = encoder
|
243 |
+
|
244 |
+
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
|
245 |
+
|
246 |
+
def reset_parameters(self):
|
247 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
248 |
+
|
249 |
+
self.apply(init_bert_params)
|
250 |
+
|
251 |
+
def forward(self, x, mask_info):
|
252 |
+
x = self.input_proj(x)
|
253 |
+
x = self.encoder(x, None, None, 1)
|
254 |
+
x = self.proj(x)
|
255 |
+
return x
|
256 |
+
|
257 |
+
|
258 |
+
class AltBlock(nn.Module):
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
dim,
|
262 |
+
num_heads,
|
263 |
+
mlp_ratio=4.0,
|
264 |
+
qkv_bias=False,
|
265 |
+
qk_scale=None,
|
266 |
+
drop=0.0,
|
267 |
+
attn_drop=0.0,
|
268 |
+
mlp_drop=0.0,
|
269 |
+
post_mlp_drop=0.0,
|
270 |
+
drop_path=0.0,
|
271 |
+
act_layer=nn.GELU,
|
272 |
+
norm_layer=nn.LayerNorm,
|
273 |
+
layer_norm_first=True,
|
274 |
+
ffn_targets=False,
|
275 |
+
cosine_attention=False,
|
276 |
+
):
|
277 |
+
super().__init__()
|
278 |
+
|
279 |
+
self.layer_norm_first = layer_norm_first
|
280 |
+
self.ffn_targets = ffn_targets
|
281 |
+
|
282 |
+
from timm.models.vision_transformer import DropPath, Mlp
|
283 |
+
|
284 |
+
self.norm1 = norm_layer(dim)
|
285 |
+
self.attn = AltAttention(
|
286 |
+
dim,
|
287 |
+
num_heads=num_heads,
|
288 |
+
qkv_bias=qkv_bias,
|
289 |
+
qk_scale=qk_scale,
|
290 |
+
attn_drop=attn_drop,
|
291 |
+
proj_drop=drop,
|
292 |
+
cosine_attention=cosine_attention,
|
293 |
+
)
|
294 |
+
|
295 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
296 |
+
self.norm2 = norm_layer(dim)
|
297 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
298 |
+
self.mlp = Mlp(
|
299 |
+
in_features=dim,
|
300 |
+
hidden_features=mlp_hidden_dim,
|
301 |
+
act_layer=act_layer,
|
302 |
+
drop=mlp_drop,
|
303 |
+
)
|
304 |
+
self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
|
305 |
+
|
306 |
+
def forward(self, x, padding_mask=None, alibi_bias=None):
|
307 |
+
if self.layer_norm_first:
|
308 |
+
x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
|
309 |
+
r = x = self.mlp(self.norm2(x))
|
310 |
+
t = x
|
311 |
+
x = r + self.drop_path(self.post_mlp_dropout(x))
|
312 |
+
if not self.ffn_targets:
|
313 |
+
t = x
|
314 |
+
else:
|
315 |
+
x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
|
316 |
+
r = x = self.norm1(x)
|
317 |
+
x = self.mlp(x)
|
318 |
+
t = x
|
319 |
+
x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
|
320 |
+
if not self.ffn_targets:
|
321 |
+
t = x
|
322 |
+
|
323 |
+
return x, t
|
324 |
+
|
325 |
+
|
326 |
+
class AltAttention(nn.Module):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
dim,
|
330 |
+
num_heads=8,
|
331 |
+
qkv_bias=False,
|
332 |
+
qk_scale=None,
|
333 |
+
attn_drop=0.0,
|
334 |
+
proj_drop=0.0,
|
335 |
+
cosine_attention=False,
|
336 |
+
):
|
337 |
+
super().__init__()
|
338 |
+
self.num_heads = num_heads
|
339 |
+
head_dim = dim // num_heads
|
340 |
+
self.scale = qk_scale or head_dim ** -0.5
|
341 |
+
|
342 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
343 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
344 |
+
self.proj = nn.Linear(dim, dim)
|
345 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
346 |
+
|
347 |
+
self.cosine_attention = cosine_attention
|
348 |
+
|
349 |
+
if cosine_attention:
|
350 |
+
self.logit_scale = nn.Parameter(
|
351 |
+
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
|
352 |
+
)
|
353 |
+
|
354 |
+
def forward(self, x, padding_mask=None, alibi_bias=None):
|
355 |
+
B, N, C = x.shape
|
356 |
+
qkv = (
|
357 |
+
self.qkv(x)
|
358 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
359 |
+
.permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
|
360 |
+
)
|
361 |
+
q, k, v = (
|
362 |
+
qkv[0],
|
363 |
+
qkv[1],
|
364 |
+
qkv[2],
|
365 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
366 |
+
|
367 |
+
dtype = q.dtype
|
368 |
+
|
369 |
+
if self.cosine_attention:
|
370 |
+
# cosine attention
|
371 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
372 |
+
logit_scale = torch.clamp(
|
373 |
+
self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
|
374 |
+
).exp()
|
375 |
+
attn = attn * logit_scale
|
376 |
+
else:
|
377 |
+
q = q * self.scale
|
378 |
+
attn = q @ k.transpose(-2, -1)
|
379 |
+
|
380 |
+
if alibi_bias is not None:
|
381 |
+
attn = attn.type_as(alibi_bias)
|
382 |
+
attn[:, : alibi_bias.size(1)] += alibi_bias
|
383 |
+
|
384 |
+
if padding_mask is not None and padding_mask.any():
|
385 |
+
attn = attn.masked_fill(
|
386 |
+
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
387 |
+
float("-inf"),
|
388 |
+
)
|
389 |
+
|
390 |
+
attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
|
391 |
+
attn = self.attn_drop(attn)
|
392 |
+
x = (attn @ v).transpose(1, 2) #
|
393 |
+
x = x.reshape(B, N, C)
|
394 |
+
x = self.proj(x)
|
395 |
+
x = self.proj_drop(x)
|
396 |
+
return x
|
397 |
+
|
398 |
+
|
399 |
+
class EncDecAttention(nn.Module):
|
400 |
+
def __init__(
|
401 |
+
self,
|
402 |
+
q_dim,
|
403 |
+
kv_dim,
|
404 |
+
num_heads=8,
|
405 |
+
qkv_bias=False,
|
406 |
+
qk_scale=None,
|
407 |
+
attn_drop=0.0,
|
408 |
+
proj_drop=0.0,
|
409 |
+
cosine_attention=False,
|
410 |
+
):
|
411 |
+
super().__init__()
|
412 |
+
self.num_heads = num_heads
|
413 |
+
head_dim = q_dim // num_heads
|
414 |
+
self.scale = qk_scale or head_dim ** -0.5
|
415 |
+
|
416 |
+
self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias)
|
417 |
+
self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias)
|
418 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
419 |
+
self.proj = nn.Linear(q_dim, q_dim)
|
420 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
421 |
+
|
422 |
+
self.cosine_attention = cosine_attention
|
423 |
+
|
424 |
+
if cosine_attention:
|
425 |
+
self.logit_scale = nn.Parameter(
|
426 |
+
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
|
427 |
+
)
|
428 |
+
|
429 |
+
def forward(self, q, kv, padding_mask=None, alibi_bias=None):
|
430 |
+
B, N, C = q.shape
|
431 |
+
|
432 |
+
q = (
|
433 |
+
self.q_proj(q)
|
434 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
435 |
+
.permute(0, 2, 1, 3)
|
436 |
+
) # B x H x L x D
|
437 |
+
kv = (
|
438 |
+
self.kv_proj(kv)
|
439 |
+
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
|
440 |
+
.permute(2, 0, 3, 1, 4)
|
441 |
+
) # kv x B x H x L x D
|
442 |
+
k, v = (
|
443 |
+
kv[0],
|
444 |
+
kv[1],
|
445 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
446 |
+
|
447 |
+
dtype = q.dtype
|
448 |
+
|
449 |
+
if self.cosine_attention:
|
450 |
+
# cosine attention
|
451 |
+
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
452 |
+
logit_scale = torch.clamp(
|
453 |
+
self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
|
454 |
+
).exp()
|
455 |
+
attn = attn * logit_scale
|
456 |
+
else:
|
457 |
+
q = q * self.scale
|
458 |
+
attn = q @ k.transpose(-2, -1)
|
459 |
+
|
460 |
+
if alibi_bias is not None:
|
461 |
+
attn = attn.type_as(alibi_bias)
|
462 |
+
attn[:, : alibi_bias.size(1)] += alibi_bias
|
463 |
+
|
464 |
+
if padding_mask is not None and padding_mask.any():
|
465 |
+
attn = attn.masked_fill(
|
466 |
+
padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
467 |
+
float("-inf"),
|
468 |
+
)
|
469 |
+
|
470 |
+
attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
|
471 |
+
attn = self.attn_drop(attn)
|
472 |
+
x = (attn @ v).transpose(1, 2) #
|
473 |
+
x = x.reshape(B, N, C)
|
474 |
+
x = self.proj(x)
|
475 |
+
x = self.proj_drop(x)
|
476 |
+
return x
|
477 |
+
|
478 |
+
|
479 |
+
class EncDecBlock(nn.Module):
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
q_dim,
|
483 |
+
kv_dim,
|
484 |
+
num_heads,
|
485 |
+
mlp_ratio=4.0,
|
486 |
+
qkv_bias=False,
|
487 |
+
qk_scale=None,
|
488 |
+
drop=0.0,
|
489 |
+
attn_drop=0.0,
|
490 |
+
mlp_drop=0.0,
|
491 |
+
post_mlp_drop=0.0,
|
492 |
+
drop_path=0.0,
|
493 |
+
act_layer=nn.GELU,
|
494 |
+
norm_layer=nn.LayerNorm,
|
495 |
+
layer_norm_first=True,
|
496 |
+
cosine_attention=False,
|
497 |
+
first_residual=True,
|
498 |
+
):
|
499 |
+
super().__init__()
|
500 |
+
|
501 |
+
self.layer_norm_first = layer_norm_first
|
502 |
+
|
503 |
+
from timm.models.vision_transformer import DropPath, Mlp
|
504 |
+
|
505 |
+
self.norm1 = norm_layer(q_dim)
|
506 |
+
self.attn = EncDecAttention(
|
507 |
+
q_dim,
|
508 |
+
kv_dim,
|
509 |
+
num_heads=num_heads,
|
510 |
+
qkv_bias=qkv_bias,
|
511 |
+
qk_scale=qk_scale,
|
512 |
+
attn_drop=attn_drop,
|
513 |
+
proj_drop=drop,
|
514 |
+
cosine_attention=cosine_attention,
|
515 |
+
)
|
516 |
+
|
517 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
518 |
+
self.norm2 = norm_layer(q_dim)
|
519 |
+
mlp_hidden_dim = int(q_dim * mlp_ratio)
|
520 |
+
self.mlp = Mlp(
|
521 |
+
in_features=q_dim,
|
522 |
+
hidden_features=mlp_hidden_dim,
|
523 |
+
act_layer=act_layer,
|
524 |
+
drop=mlp_drop,
|
525 |
+
)
|
526 |
+
self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
|
527 |
+
self.first_residual = first_residual
|
528 |
+
|
529 |
+
def forward(self, q, kv, padding_mask=None, alibi_bias=None):
|
530 |
+
r = q if self.first_residual else 0
|
531 |
+
if self.layer_norm_first:
|
532 |
+
x = r + self.drop_path(
|
533 |
+
self.attn(self.norm1(q), kv, padding_mask, alibi_bias)
|
534 |
+
)
|
535 |
+
r = x = self.mlp(self.norm2(x))
|
536 |
+
x = r + self.drop_path(self.post_mlp_dropout(x))
|
537 |
+
else:
|
538 |
+
x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias))
|
539 |
+
r = x = self.norm1(x)
|
540 |
+
x = self.mlp(x)
|
541 |
+
x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
|
542 |
+
|
543 |
+
return x
|
544 |
+
|
545 |
+
|
546 |
+
class EncDecTransformerDecoder(nn.Module):
|
547 |
+
def __init__(self, cfg: D2vDecoderConfig, input_dim):
|
548 |
+
super().__init__()
|
549 |
+
|
550 |
+
self.input_proj = nn.Linear(input_dim, cfg.decoder_dim)
|
551 |
+
|
552 |
+
self.blocks = nn.Sequential(
|
553 |
+
*[
|
554 |
+
EncDecBlock(
|
555 |
+
q_dim=cfg.decoder_dim,
|
556 |
+
kv_dim=input_dim,
|
557 |
+
num_heads=8,
|
558 |
+
mlp_ratio=4.0,
|
559 |
+
qkv_bias=True,
|
560 |
+
qk_scale=None,
|
561 |
+
drop=0.0,
|
562 |
+
attn_drop=0.0,
|
563 |
+
mlp_drop=0.0,
|
564 |
+
post_mlp_drop=0.0,
|
565 |
+
drop_path=0.0,
|
566 |
+
act_layer=nn.GELU,
|
567 |
+
norm_layer=nn.LayerNorm,
|
568 |
+
layer_norm_first=False,
|
569 |
+
cosine_attention=False,
|
570 |
+
first_residual=i > 0,
|
571 |
+
)
|
572 |
+
for i in range(cfg.decoder_layers)
|
573 |
+
]
|
574 |
+
)
|
575 |
+
|
576 |
+
self.proj = nn.Linear(cfg.decoder_dim, input_dim)
|
577 |
+
|
578 |
+
def reset_parameters(self):
|
579 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
580 |
+
|
581 |
+
self.apply(init_bert_params)
|
582 |
+
|
583 |
+
def forward(self, x, kv):
|
584 |
+
x = self.input_proj(x)
|
585 |
+
for i, layer in enumerate(self.blocks):
|
586 |
+
x = layer(x, kv)
|
587 |
+
|
588 |
+
x = self.proj(x)
|
589 |
+
return x
|
fairseq/examples/data2vec/models/modalities/text.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import partial
|
9 |
+
from typing import Callable, Dict, Optional
|
10 |
+
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import numpy as np
|
14 |
+
from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
|
15 |
+
from fairseq.tasks import FairseqTask
|
16 |
+
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
|
17 |
+
from .modules import BlockEncoder, Decoder1d
|
18 |
+
from examples.data2vec.data.modality import Modality
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class D2vTextConfig(D2vModalityConfig):
|
23 |
+
type: Modality = Modality.TEXT
|
24 |
+
max_source_positions: int = 512
|
25 |
+
learned_pos: bool = True
|
26 |
+
dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
|
27 |
+
|
28 |
+
no_scale_embedding: bool = True
|
29 |
+
layernorm_embedding: bool = True
|
30 |
+
no_token_positional_embeddings: bool = False
|
31 |
+
|
32 |
+
|
33 |
+
class TextEncoder(ModalitySpecificEncoder):
|
34 |
+
|
35 |
+
modality_cfg: D2vTextConfig
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
modality_cfg: D2vTextConfig,
|
40 |
+
embed_dim: int,
|
41 |
+
make_block: Callable[[float], nn.ModuleList],
|
42 |
+
norm_layer: Callable[[int], nn.LayerNorm],
|
43 |
+
layer_norm_first: bool,
|
44 |
+
alibi_biases: Dict,
|
45 |
+
task: Optional[FairseqTask],
|
46 |
+
):
|
47 |
+
self.pad_idx = task.source_dictionary.pad()
|
48 |
+
self.vocab_size = len(task.source_dictionary)
|
49 |
+
|
50 |
+
local_encoder = TextLocalEncoder(
|
51 |
+
vocab_size=self.vocab_size,
|
52 |
+
embed_dim=embed_dim,
|
53 |
+
max_source_positions=modality_cfg.max_source_positions,
|
54 |
+
pad_idx=self.pad_idx,
|
55 |
+
no_scale_embedding=modality_cfg.no_scale_embedding,
|
56 |
+
layernorm_embedding=modality_cfg.layernorm_embedding,
|
57 |
+
dropout=modality_cfg.dropout,
|
58 |
+
no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
|
59 |
+
learned_pos=modality_cfg.learned_pos,
|
60 |
+
)
|
61 |
+
dpr = np.linspace(
|
62 |
+
modality_cfg.start_drop_path_rate,
|
63 |
+
modality_cfg.end_drop_path_rate,
|
64 |
+
modality_cfg.prenet_depth,
|
65 |
+
)
|
66 |
+
context_encoder = BlockEncoder(
|
67 |
+
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
|
68 |
+
norm_layer(embed_dim)
|
69 |
+
if not layer_norm_first and modality_cfg.prenet_depth > 0
|
70 |
+
else None,
|
71 |
+
layer_norm_first,
|
72 |
+
modality_cfg.prenet_layerdrop,
|
73 |
+
modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
|
74 |
+
)
|
75 |
+
decoder = (
|
76 |
+
Decoder1d(modality_cfg.decoder, embed_dim)
|
77 |
+
if modality_cfg.decoder is not None
|
78 |
+
else None
|
79 |
+
)
|
80 |
+
|
81 |
+
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
|
82 |
+
|
83 |
+
super().__init__(
|
84 |
+
modality_cfg=modality_cfg,
|
85 |
+
embed_dim=embed_dim,
|
86 |
+
local_encoder=local_encoder,
|
87 |
+
project_features=nn.Identity(),
|
88 |
+
fixed_positional_encoder=None,
|
89 |
+
relative_positional_encoder=None,
|
90 |
+
context_encoder=context_encoder,
|
91 |
+
decoder=decoder,
|
92 |
+
get_alibi_bias=alibi_bias_fn,
|
93 |
+
)
|
94 |
+
|
95 |
+
def reset_parameters(self):
|
96 |
+
super().reset_parameters()
|
97 |
+
|
98 |
+
def convert_padding_mask(self, x, padding_mask):
|
99 |
+
if padding_mask is None or padding_mask.size(1) == x.size(1):
|
100 |
+
return padding_mask
|
101 |
+
|
102 |
+
diff = self.downsample - padding_mask.size(1) % self.downsample
|
103 |
+
if 0 < diff < self.downsample:
|
104 |
+
padding_mask = F.pad(padding_mask, (0, diff), value=True)
|
105 |
+
|
106 |
+
padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
|
107 |
+
padding_mask = padding_mask.all(-1)
|
108 |
+
if padding_mask.size(1) > x.size(1):
|
109 |
+
padding_mask = padding_mask[:, : x.size(1)]
|
110 |
+
|
111 |
+
assert x.size(1) == padding_mask.size(
|
112 |
+
1
|
113 |
+
), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
|
114 |
+
|
115 |
+
return padding_mask
|
116 |
+
|
117 |
+
|
118 |
+
class TextLocalEncoder(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
vocab_size,
|
122 |
+
embed_dim,
|
123 |
+
max_source_positions,
|
124 |
+
pad_idx,
|
125 |
+
no_scale_embedding,
|
126 |
+
layernorm_embedding,
|
127 |
+
dropout,
|
128 |
+
no_token_positional_embeddings,
|
129 |
+
learned_pos,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
self.pad_idx = pad_idx
|
133 |
+
self.dropout_module = FairseqDropout(dropout)
|
134 |
+
|
135 |
+
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
|
136 |
+
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
|
137 |
+
self.embed_positions = (
|
138 |
+
PositionalEmbedding(
|
139 |
+
max_source_positions,
|
140 |
+
embed_dim,
|
141 |
+
pad_idx,
|
142 |
+
learned=learned_pos,
|
143 |
+
)
|
144 |
+
if not no_token_positional_embeddings
|
145 |
+
else None
|
146 |
+
)
|
147 |
+
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
|
148 |
+
|
149 |
+
self.layernorm_embedding = None
|
150 |
+
if layernorm_embedding:
|
151 |
+
self.layernorm_embedding = LayerNorm(embed_dim)
|
152 |
+
|
153 |
+
def forward(self, src_tokens):
|
154 |
+
x = self.embed_scale * self.embed_tokens(src_tokens)
|
155 |
+
if self.embed_positions is not None:
|
156 |
+
x = x + self.embed_positions(src_tokens)
|
157 |
+
|
158 |
+
if self.layernorm_embedding is not None:
|
159 |
+
x = self.layernorm_embedding(x)
|
160 |
+
x = self.dropout_module(x)
|
161 |
+
return x
|
fairseq/examples/data2vec/models/utils.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def get_alibi(
|
5 |
+
max_positions: int,
|
6 |
+
attention_heads: int,
|
7 |
+
):
|
8 |
+
def get_slopes(n):
|
9 |
+
def get_slopes_power_of_2(n):
|
10 |
+
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
11 |
+
ratio = start
|
12 |
+
return [start * ratio ** i for i in range(n)]
|
13 |
+
|
14 |
+
# In the paper, we only train models that have 2^a heads for some
|
15 |
+
# a. This function has some good properties that only occur when
|
16 |
+
# the input is a power of 2. To maintain that even when the number
|
17 |
+
# of heads is not a power of 2, we use this workaround.
|
18 |
+
if math.log2(n).is_integer():
|
19 |
+
return get_slopes_power_of_2(n)
|
20 |
+
else:
|
21 |
+
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
22 |
+
return (
|
23 |
+
get_slopes_power_of_2(closest_power_of_2)
|
24 |
+
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
25 |
+
)
|
26 |
+
|
27 |
+
maxpos = max_positions
|
28 |
+
attn_heads = attention_heads
|
29 |
+
slopes = torch.Tensor(get_slopes(attn_heads))
|
30 |
+
# prepare alibi position linear bias. Note that wav2vec2 is non
|
31 |
+
# autoregressive model so we want a symmetric mask with 0 on the
|
32 |
+
# diagonal and other wise linear decreasing valuees
|
33 |
+
pos_bias = (
|
34 |
+
torch.abs(
|
35 |
+
torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)
|
36 |
+
)
|
37 |
+
* -1
|
38 |
+
)
|
39 |
+
alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(
|
40 |
+
attn_heads, -1, -1
|
41 |
+
)
|
42 |
+
return alibi_bias
|
43 |
+
|
44 |
+
def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T):
|
45 |
+
alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T)
|
46 |
+
H = alibi_bias.size(1)
|
47 |
+
alibi_mask = mask_indices.unsqueeze(1)
|
48 |
+
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1))
|
49 |
+
alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T)
|
50 |
+
M = alibi_bias.size(-2)
|
51 |
+
alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2))
|
52 |
+
alibi_bias = alibi_bias.view(-1, M, M)
|
53 |
+
return alibi_bias
|
54 |
+
|
55 |
+
|
fairseq/examples/data2vec/scripts/convert_audioset_labels.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
def get_parser():
|
12 |
+
parser = argparse.ArgumentParser(description="convert audioset labels")
|
13 |
+
# fmt: off
|
14 |
+
parser.add_argument('in_file', help='audioset csv file to convert')
|
15 |
+
parser.add_argument('--manifest', required=True, metavar='PATH', help='wav2vec-like manifest')
|
16 |
+
parser.add_argument('--descriptors', required=True, metavar='PATH', help='path to label descriptor file')
|
17 |
+
parser.add_argument('--output', required=True, metavar='PATH', help='where to output converted labels')
|
18 |
+
# fmt: on
|
19 |
+
|
20 |
+
return parser
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
parser = get_parser()
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
label_descriptors = {}
|
28 |
+
with open(args.descriptors, "r") as ldf:
|
29 |
+
next(ldf)
|
30 |
+
for line in ldf:
|
31 |
+
if line.strip() == "":
|
32 |
+
continue
|
33 |
+
|
34 |
+
items = line.split(",")
|
35 |
+
assert len(items) > 2, line
|
36 |
+
idx = items[0]
|
37 |
+
lbl = items[1]
|
38 |
+
assert lbl not in label_descriptors, lbl
|
39 |
+
label_descriptors[lbl] = idx
|
40 |
+
|
41 |
+
labels = {}
|
42 |
+
with open(args.in_file, "r") as ifd:
|
43 |
+
for line in ifd:
|
44 |
+
if line.lstrip().startswith("#"):
|
45 |
+
continue
|
46 |
+
items = line.rstrip().split(",")
|
47 |
+
id = items[0].strip()
|
48 |
+
start = items[1].strip()
|
49 |
+
end = items[2].strip()
|
50 |
+
lbls = [label_descriptors[it.strip(' "')] for it in items[3:]]
|
51 |
+
labels[id] = [start, end, ",".join(lbls)]
|
52 |
+
|
53 |
+
with open(args.manifest, "r") as mf, open(args.output, "w") as of:
|
54 |
+
next(mf)
|
55 |
+
for line in mf:
|
56 |
+
path, _ = line.split("\t")
|
57 |
+
id = os.path.splitext(os.path.basename(path))[0]
|
58 |
+
lbl = labels[id]
|
59 |
+
print("\t".join(lbl), file=of)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
main()
|
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
job_id="$1"
|
6 |
+
task_id="$2"
|
7 |
+
dir="$3"
|
8 |
+
|
9 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
10 |
+
|
11 |
+
mkdir -p "$dir/log"
|
12 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
13 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
14 |
+
sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
|
15 |
+
sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
|
16 |
+
|
17 |
+
sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
|
18 |
+
|
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_aws_local_lr_nodep.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
dir="$1"
|
6 |
+
|
7 |
+
echo "dir: $dir"
|
8 |
+
|
9 |
+
mkdir -p "$dir/log"
|
10 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
11 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
12 |
+
sbatch_args="$sbatch_args -o $dir/log/decode_sweep_%A.out"
|
13 |
+
sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
|
14 |
+
|
15 |
+
sbatch $sbatch_args examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh $dir
|
16 |
+
|
fairseq/examples/data2vec/scripts/multi/finetune_all_fair_local_lr.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
dir="$1"
|
4 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
5 |
+
|
6 |
+
echo "dir: $dir"
|
7 |
+
|
8 |
+
declare -A tasks
|
9 |
+
tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
|
10 |
+
tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
|
11 |
+
tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
|
12 |
+
tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
|
13 |
+
tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
|
14 |
+
tasks[mnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MNLI-bin"
|
15 |
+
tasks[qqp]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QQP-bin"
|
16 |
+
tasks[sts_b]="/fsx-wav2vec/abaevski/data/nlp/GLUE/STS-B-bin"
|
17 |
+
|
18 |
+
lrs=(5e-6 8e-6 1e-5 2e-5)
|
19 |
+
|
20 |
+
for task data_path in ${(kv)tasks}; do
|
21 |
+
for lr in $lrs; do
|
22 |
+
echo $lr $task
|
23 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
|
24 |
+
python fairseq_cli/hydra_train.py -m --config-dir examples/data2vec/config/multi/text_finetuning \
|
25 |
+
--config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
|
26 |
+
model.model_path="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" +model=text_wrap
|
27 |
+
done
|
28 |
+
done
|
fairseq/examples/data2vec/scripts/text/finetune_all_char_fair_aws_local_lr.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
job_id="$1"
|
6 |
+
task_id="$2"
|
7 |
+
dir="$3"
|
8 |
+
|
9 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
10 |
+
|
11 |
+
mkdir -p "$dir/log"
|
12 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
13 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
14 |
+
sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/ft_%A.out"
|
15 |
+
sbatch_args="$sbatch_args -e $dir/log/ft_%A.err"
|
16 |
+
|
17 |
+
sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_char_fair_local_lr.sh $dir
|
fairseq/examples/data2vec/scripts/text/finetune_all_fair.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
job_id=$1
|
4 |
+
task_id=$2
|
5 |
+
dir="$3"
|
6 |
+
cp="$dir/$task_id/checkpoints/checkpoint_last.pt"
|
7 |
+
|
8 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
9 |
+
|
10 |
+
declare -A tasks
|
11 |
+
tasks[cola]="/private/home/jgu/data/GLUE/CoLA-bin"
|
12 |
+
tasks[qnli]="/private/home/jgu/data/GLUE/QNLI-bin"
|
13 |
+
tasks[mrpc]="/private/home/jgu/data/GLUE/MRPC-bin"
|
14 |
+
tasks[rte]="/private/home/jgu/data/GLUE/RTE-bin"
|
15 |
+
tasks[sst_2]="/private/home/jgu/data/GLUE/SST-2-bin"
|
16 |
+
|
17 |
+
for task data_path in ${(kv)tasks}; do
|
18 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
19 |
+
--config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
|
20 |
+
checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
|
21 |
+
done
|
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
job_id=$1
|
4 |
+
task_id=$2
|
5 |
+
dir="$3"
|
6 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
7 |
+
|
8 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
9 |
+
|
10 |
+
declare -A tasks
|
11 |
+
tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
|
12 |
+
tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
|
13 |
+
tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
|
14 |
+
tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
|
15 |
+
tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
|
16 |
+
|
17 |
+
for task data_path in ${(kv)tasks}; do
|
18 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
19 |
+
--config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
|
20 |
+
checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune/$task" &
|
21 |
+
done
|
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_local_lr.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -eu
|
4 |
+
|
5 |
+
job_id="$1"
|
6 |
+
task_id="$2"
|
7 |
+
dir="$3"
|
8 |
+
|
9 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
10 |
+
|
11 |
+
mkdir -p "$dir/log"
|
12 |
+
sbatch_args="-p wav2vec --nodes=1 --ntasks-per-node=1"
|
13 |
+
sbatch_args="$sbatch_args --gpus-per-node=1 --cpus-per-task=8 --mem=0 --time=24:00:00"
|
14 |
+
sbatch_args="$sbatch_args -d afterok:$job_id -o $dir/log/decode_sweep_%A.out"
|
15 |
+
sbatch_args="$sbatch_args -e $dir/log/decode_sweep_%A.err"
|
16 |
+
|
17 |
+
sbatch $sbatch_args examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh $dir
|
fairseq/examples/data2vec/scripts/text/finetune_all_fair_aws_lr.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
job_id=$1
|
4 |
+
task_id=$2
|
5 |
+
dir="$3"
|
6 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
7 |
+
|
8 |
+
echo "job_id: $job_id, task_id: $task_id, dir: $dir"
|
9 |
+
|
10 |
+
declare -A tasks
|
11 |
+
tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
|
12 |
+
tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
|
13 |
+
tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
|
14 |
+
tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
|
15 |
+
tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
|
16 |
+
|
17 |
+
for task data_path in ${(kv)tasks}; do
|
18 |
+
for lr in 5e-6 8e-6 1e-5 2e-5 5e-5 8e-5 1e-4 2e-4; do
|
19 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" nohup python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
20 |
+
--config-name $task hydra/launcher=submitit_slurm +run_config=slurm_1g_aws task.data="$data_path" hydra.launcher.name=finetune_${task}_${PREFIX} \
|
21 |
+
checkpoint.restore_file="$cp" +hydra.launcher.additional_parameters.dependency="afterok:$job_id" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]" &
|
22 |
+
done
|
23 |
+
done
|
fairseq/examples/data2vec/scripts/text/finetune_all_fair_local_lr.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env zsh
|
2 |
+
|
3 |
+
dir="$1"
|
4 |
+
cp="$dir/checkpoints/checkpoint_last.pt"
|
5 |
+
|
6 |
+
echo "dir: $dir"
|
7 |
+
|
8 |
+
declare -A tasks
|
9 |
+
tasks[cola]="/fsx-wav2vec/abaevski/data/nlp/GLUE/CoLA-bin"
|
10 |
+
tasks[qnli]="/fsx-wav2vec/abaevski/data/nlp/GLUE/QNLI-bin"
|
11 |
+
tasks[mrpc]="/fsx-wav2vec/abaevski/data/nlp/GLUE/MRPC-bin"
|
12 |
+
tasks[rte]="/fsx-wav2vec/abaevski/data/nlp/GLUE/RTE-bin"
|
13 |
+
tasks[sst_2]="/fsx-wav2vec/abaevski/data/nlp/GLUE/SST-2-bin"
|
14 |
+
|
15 |
+
lrs=(5e-6 8e-6 1e-5 2e-5)
|
16 |
+
|
17 |
+
for task data_path in ${(kv)tasks}; do
|
18 |
+
for lr in $lrs; do
|
19 |
+
echo $lr $task
|
20 |
+
PYTHONPATH=. PREFIX="${PREFIX}" SUFFIX="" \
|
21 |
+
python fairseq_cli/hydra_train.py -m --config-dir examples/roberta/config/finetuning \
|
22 |
+
--config-name $task +run_config=local task.data="$data_path" common.log_interval=200 dataset.num_workers=1 \
|
23 |
+
checkpoint.restore_file="$cp" hydra.sweep.dir="$dir/finetune_lr/$task/$lr" "optimization.lr=[${lr}]"
|
24 |
+
done
|
25 |
+
done
|