training_fn: wav2vec2.train_huggingface
checkpoint: lightning_logs/version_172/checkpoints/epoch=3-step=4572.ckpt
device: mps
seed: 42
dance_ids: &dance_ids
  - BCH
  - CHA
  - JIV
  - ECS
  - QST
  - RMB
  - SFT
  - SLS
  - SMB
  - SWZ
  - TGO
  - VWZ
  - WCS

data_module:
  batch_size: 64
  num_workers: 7 # Reduced to avoid over opening files
  # data_subset: 0.001
  test_proportion: 0.2

datasets:
  # preprocessing.dataset.BestBallroomDataset:
  #   audio_dir: data/ballroom-songs
  #   class_list: *dance_ids
  #   audio_window_jitter: 0.7

  preprocessing.dataset.Music4DanceDataset:
    song_data_path: data/songs_cleaned.csv
    song_audio_path: data/samples # data/samples
    class_list: *dance_ids
    multi_label: False
    min_votes: 1
    audio_window_jitter: 0.7

model:
  n_channels: 128

feature_extractor:
  mask_count: 0 # Don't mask the data
  snr_mean: 15.0 # Pretty much eliminate the noise
  freq_mask_size: 10
  time_mask_size: 80

trainer:
  log_every_n_steps: 15
  accelerator: gpu
  max_epochs: 50
  min_epochs: 2
  fast_dev_run: False
  # gradient_clip_val: 0.5
  # overfit_batches: 1

training_environment:
  learning_rate: 0.000053
  # loggers:
  #   models.training_environment.SpectrogramLogger:
  #     frequency: 100