pretrained_path: HaNguyen/IWSLT-ast-w2v2-mbart

lang: fr   #for the BLEU score detokenization
target_lang: fr_XX   # for mbart initialization
sample_rate: 16000


# URL for the HuggingFace model we want to load (BASE here)
wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only

# wav2vec 2.0 specific parameters
wav2vec2_frozen: False

# Feature parameters (W2V2 etc)
features_dim: 768 # base wav2vec output dimension, for large replace by 1024

#projection for w2v
enc_dnn_layers: 1
enc_dnn_neurons: 1024

# Transformer
embedding_size: 256
d_model: 1024
activation: !name:torch.nn.GELU

# Outputs
blank_index: 1
label_smoothing: 0.1
pad_index: 1      # pad_index defined by mbart model
bos_index: 250008 # fr_XX bos_index defined by mbart model
eos_index: 2

# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
min_decode_ratio: 0.0
max_decode_ratio: 0.25
valid_beam_size: 5
test_beam_size: 5


############################## models ################################
#wav2vec model
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
  source: !ref <wav2vec2_hub>
  output_norm: True
  freeze: !ref <wav2vec2_frozen>
  save_path: wav2vec2_checkpoint

#linear projection
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
  input_shape: [null, null, 768]
  activation: !ref <activation>
  dnn_blocks: 1
  dnn_neurons: 1024

#mBART
mbart_path: facebook/mbart-large-50-many-to-many-mmt
mbart_frozen: False
mBART: &id004 !new:speechbrain.lobes.models.huggingface_transformers.mbart.mBART
  source: !ref <mbart_path>
  freeze: !ref <mbart_frozen>
  save_path: mbart_checkpoint
  target_lang: !ref <target_lang>

log_softmax: !new:speechbrain.nnet.activations.Softmax
  apply_log: True

seq_lin: !new:torch.nn.Identity

modules:
  wav2vec2: !ref <wav2vec2>
  enc: !ref <enc>
  mBART: !ref <mBART>
model: !new:torch.nn.ModuleList
- [!ref <enc>]

valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
  modules: [!ref <mBART>, null, null]
  vocab_size: 250054
  bos_index: 250008
  eos_index: 2
  min_decode_ratio: 0.0
  max_decode_ratio: 0.25
  beam_size: 5
  using_eos_threshold: True
  length_normalization: True

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
      model: !ref <model>
      wav2vec2: !ref <wav2vec2>
      mBART: !ref <mBART>
    paths:
        wav2vec2: !ref <pretrained_path>/wav2vec2.ckpt
        model: !ref <pretrained_path>/model.ckpt
        mBART: !ref <pretrained_path>/mBART.ckpt