pere commited on
Commit
13e9d03
·
1 Parent(s): 8ac8b8e
Files changed (2) hide show
  1. run_continue_nst.sh +3 -3
  2. run_whisper_finetuning.py +30 -53
run_continue_nst.sh CHANGED
@@ -14,9 +14,9 @@ python run_whisper_finetuning.py \
14
  --do_eval=True \
15
  --audio_column_name="audio" \
16
  --text_column_name="text" \
17
- --per_device_train_batch_size=48 \
18
- --per_device_train_batch_size=48 \
19
- --learning_rate=4e-5 \
20
  --warmup_steps=500 \
21
  --max_steps=10000 \
22
  --gradient_checkpointing=True \
 
14
  --do_eval=True \
15
  --audio_column_name="audio" \
16
  --text_column_name="text" \
17
+ --per_device_train_batch_size=24 \
18
+ --per_device_train_batch_size=24 \
19
+ --learning_rate=2e-5 \
20
  --warmup_steps=500 \
21
  --max_steps=10000 \
22
  --gradient_checkpointing=True \
run_whisper_finetuning.py CHANGED
@@ -345,7 +345,6 @@ def main():
345
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
346
 
347
  # Metrics
348
-
349
  def compute_metrics(pred):
350
  pred_ids = pred.predictions
351
  label_ids = pred.label_ids
@@ -383,30 +382,40 @@ def main():
383
  print("\n* Training arguments")
384
  pprint(vars(training_args), indent=2)
385
 
386
- def rename_column(ds, old_name, new_name):
387
- feats = ds.info.features
388
- ds = ds.rename_column(old_name, new_name)
389
- feats[new_name] = feats.pop(old_name)
390
- ds.info.features = feats
391
- return ds
392
-
393
- def remove_columns(ds, column_name):
394
- feats = ds.info.features
395
- ds = ds.remove_columns(column_name)
396
- feats.pop(column_name)
397
- ds.info.features = feats
398
- return ds
399
 
400
  # Print training arguments
401
  if data_args.print_training_arguments:
402
  print_training_arguments(model_args, data_args, training_args)
403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
  # Load dataset
405
  train_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
406
  split="train", streaming=True, use_auth_token=True)
407
  eval_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
408
  split="test", streaming=True, use_auth_token=True)
409
-
 
 
 
 
 
 
 
410
  # Rename columns
411
  if data_args.audio_column_name != "audio":
412
  train_dataset = train_dataset.rename_column(
@@ -420,42 +429,13 @@ def main():
420
  eval_dataset = eval_dataset.rename_column(
421
  data_args.text_column_name, "sentence")
422
 
423
- # Initialise
424
- feature_extractor = WhisperFeatureExtractor.from_pretrained(
425
- model_args.model_name_or_path)
426
- tokenizer = WhisperTokenizer.from_pretrained(
427
- model_args.model_name_or_path, language=model_args.language, task=model_args.task)
428
- processor = WhisperProcessor.from_pretrained(
429
- model_args.model_name_or_path, language=model_args.language, task=model_args.task)
430
- data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
431
-
432
- # Saving the processor and the tokenizer
433
- processor.save_pretrained(training_args.output_dir)
434
- tokenizer.save_pretrained(training_args.output_dir)
435
-
436
- # Prepare data
437
- # TODO The casting of the not working on the NPSC in 48K. It seems to be working for Common Voice
438
- # The issue is that the dataset features returns None. But for me thay seem to have been set correctly
439
- # In our case this is not needed, since the datasets already is available as 16K. But it would be great to solve this bug
440
- # train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
441
- # eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=16000))
442
-
443
- # Remove non needed columns
444
- #column_names=[x for x in train_dataset.info.features]
445
-
446
- # for c in column_names:
447
- # if c not in ["audio", "text"]:
448
- # train_dataset = remove_columns(train_dataset, c)
449
- # eval_dataset = remove_columns(eval_dataset, c)
450
-
451
- # TODO I would really like to remove the non needed columns here. At least this cleans up the output.
452
- # I am unable to figure out how to do this Streaming mode. Can not find a way to list columns.
453
- # train_data = train_data.map(prepare_dataset, remove_columns=train_data.column_names, num_proc=1)
454
-
455
- train_dataset = train_dataset.map(prepare_dataset)
456
- eval_dataset = eval_dataset.map(prepare_dataset)
457
 
458
- # Metrics
459
  metric = evaluate.load("wer")
460
 
461
  # Detecting last checkpoint.
@@ -547,13 +527,10 @@ def main():
547
  return train_result
548
 
549
  # XLA hook
550
-
551
-
552
  def _mp_fn(index):
553
  # For xla_spawn (TPUs)
554
  print("The XLA is initiated")
555
  main()
556
 
557
-
558
  if __name__ == "__main__":
559
  main()
 
345
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
346
 
347
  # Metrics
 
348
  def compute_metrics(pred):
349
  pred_ids = pred.predictions
350
  label_ids = pred.label_ids
 
382
  print("\n* Training arguments")
383
  pprint(vars(training_args), indent=2)
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
  # Print training arguments
387
  if data_args.print_training_arguments:
388
  print_training_arguments(model_args, data_args, training_args)
389
 
390
+
391
+
392
+ # Initialise the model
393
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(
394
+ model_args.model_name_or_path)
395
+ tokenizer = WhisperTokenizer.from_pretrained(
396
+ model_args.model_name_or_path, language=model_args.language, task=model_args.task)
397
+ processor = WhisperProcessor.from_pretrained(
398
+ model_args.model_name_or_path, language=model_args.language, task=model_args.task)
399
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
400
+
401
+ # Saving the processor and the tokenizer
402
+ processor.save_pretrained(training_args.output_dir)
403
+ tokenizer.save_pretrained(training_args.output_dir)
404
+
405
+
406
  # Load dataset
407
  train_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
408
  split="train", streaming=True, use_auth_token=True)
409
  eval_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
410
  split="test", streaming=True, use_auth_token=True)
411
+
412
+ # Because a bug in Dataset (https://github.com/huggingface/datasets/issues/3888) we need to read the columns and keep them for later
413
+ column_names=[x for x in train_dataset.info.features]
414
+
415
+ # Make sure everything is in 16K
416
+ train_dataset = train_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000))
417
+ eval_dataset = eval_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000))
418
+
419
  # Rename columns
420
  if data_args.audio_column_name != "audio":
421
  train_dataset = train_dataset.rename_column(
 
429
  eval_dataset = eval_dataset.rename_column(
430
  data_args.text_column_name, "sentence")
431
 
432
+
433
+ # Prepare the dataset
434
+ column_names.extend(['sentence','audio'])
435
+ train_dataset = train_dataset.map(prepare_dataset, remove_columns=column_names)
436
+ eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=column_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
 
438
+ # Define metrics
439
  metric = evaluate.load("wer")
440
 
441
  # Detecting last checkpoint.
 
527
  return train_result
528
 
529
  # XLA hook
 
 
530
  def _mp_fn(index):
531
  # For xla_spawn (TPUs)
532
  print("The XLA is initiated")
533
  main()
534
 
 
535
  if __name__ == "__main__":
536
  main()