test
Browse files- run_continue_nst.sh +3 -3
- run_whisper_finetuning.py +30 -53
run_continue_nst.sh
CHANGED
@@ -14,9 +14,9 @@ python run_whisper_finetuning.py \
|
|
14 |
--do_eval=True \
|
15 |
--audio_column_name="audio" \
|
16 |
--text_column_name="text" \
|
17 |
-
--per_device_train_batch_size=
|
18 |
-
--per_device_train_batch_size=
|
19 |
-
--learning_rate=
|
20 |
--warmup_steps=500 \
|
21 |
--max_steps=10000 \
|
22 |
--gradient_checkpointing=True \
|
|
|
14 |
--do_eval=True \
|
15 |
--audio_column_name="audio" \
|
16 |
--text_column_name="text" \
|
17 |
+
--per_device_train_batch_size=24 \
|
18 |
+
--per_device_train_batch_size=24 \
|
19 |
+
--learning_rate=2e-5 \
|
20 |
--warmup_steps=500 \
|
21 |
--max_steps=10000 \
|
22 |
--gradient_checkpointing=True \
|
run_whisper_finetuning.py
CHANGED
@@ -345,7 +345,6 @@ def main():
|
|
345 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
346 |
|
347 |
# Metrics
|
348 |
-
|
349 |
def compute_metrics(pred):
|
350 |
pred_ids = pred.predictions
|
351 |
label_ids = pred.label_ids
|
@@ -383,30 +382,40 @@ def main():
|
|
383 |
print("\n* Training arguments")
|
384 |
pprint(vars(training_args), indent=2)
|
385 |
|
386 |
-
def rename_column(ds, old_name, new_name):
|
387 |
-
feats = ds.info.features
|
388 |
-
ds = ds.rename_column(old_name, new_name)
|
389 |
-
feats[new_name] = feats.pop(old_name)
|
390 |
-
ds.info.features = feats
|
391 |
-
return ds
|
392 |
-
|
393 |
-
def remove_columns(ds, column_name):
|
394 |
-
feats = ds.info.features
|
395 |
-
ds = ds.remove_columns(column_name)
|
396 |
-
feats.pop(column_name)
|
397 |
-
ds.info.features = feats
|
398 |
-
return ds
|
399 |
|
400 |
# Print training arguments
|
401 |
if data_args.print_training_arguments:
|
402 |
print_training_arguments(model_args, data_args, training_args)
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
# Load dataset
|
405 |
train_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
|
406 |
split="train", streaming=True, use_auth_token=True)
|
407 |
eval_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
|
408 |
split="test", streaming=True, use_auth_token=True)
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
# Rename columns
|
411 |
if data_args.audio_column_name != "audio":
|
412 |
train_dataset = train_dataset.rename_column(
|
@@ -420,42 +429,13 @@ def main():
|
|
420 |
eval_dataset = eval_dataset.rename_column(
|
421 |
data_args.text_column_name, "sentence")
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
processor = WhisperProcessor.from_pretrained(
|
429 |
-
model_args.model_name_or_path, language=model_args.language, task=model_args.task)
|
430 |
-
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
431 |
-
|
432 |
-
# Saving the processor and the tokenizer
|
433 |
-
processor.save_pretrained(training_args.output_dir)
|
434 |
-
tokenizer.save_pretrained(training_args.output_dir)
|
435 |
-
|
436 |
-
# Prepare data
|
437 |
-
# TODO The casting of the not working on the NPSC in 48K. It seems to be working for Common Voice
|
438 |
-
# The issue is that the dataset features returns None. But for me thay seem to have been set correctly
|
439 |
-
# In our case this is not needed, since the datasets already is available as 16K. But it would be great to solve this bug
|
440 |
-
# train_dataset = train_dataset.cast_column("audio", Audio(sampling_rate=16000))
|
441 |
-
# eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=16000))
|
442 |
-
|
443 |
-
# Remove non needed columns
|
444 |
-
#column_names=[x for x in train_dataset.info.features]
|
445 |
-
|
446 |
-
# for c in column_names:
|
447 |
-
# if c not in ["audio", "text"]:
|
448 |
-
# train_dataset = remove_columns(train_dataset, c)
|
449 |
-
# eval_dataset = remove_columns(eval_dataset, c)
|
450 |
-
|
451 |
-
# TODO I would really like to remove the non needed columns here. At least this cleans up the output.
|
452 |
-
# I am unable to figure out how to do this Streaming mode. Can not find a way to list columns.
|
453 |
-
# train_data = train_data.map(prepare_dataset, remove_columns=train_data.column_names, num_proc=1)
|
454 |
-
|
455 |
-
train_dataset = train_dataset.map(prepare_dataset)
|
456 |
-
eval_dataset = eval_dataset.map(prepare_dataset)
|
457 |
|
458 |
-
#
|
459 |
metric = evaluate.load("wer")
|
460 |
|
461 |
# Detecting last checkpoint.
|
@@ -547,13 +527,10 @@ def main():
|
|
547 |
return train_result
|
548 |
|
549 |
# XLA hook
|
550 |
-
|
551 |
-
|
552 |
def _mp_fn(index):
|
553 |
# For xla_spawn (TPUs)
|
554 |
print("The XLA is initiated")
|
555 |
main()
|
556 |
|
557 |
-
|
558 |
if __name__ == "__main__":
|
559 |
main()
|
|
|
345 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
346 |
|
347 |
# Metrics
|
|
|
348 |
def compute_metrics(pred):
|
349 |
pred_ids = pred.predictions
|
350 |
label_ids = pred.label_ids
|
|
|
382 |
print("\n* Training arguments")
|
383 |
pprint(vars(training_args), indent=2)
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
# Print training arguments
|
387 |
if data_args.print_training_arguments:
|
388 |
print_training_arguments(model_args, data_args, training_args)
|
389 |
|
390 |
+
|
391 |
+
|
392 |
+
# Initialise the model
|
393 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(
|
394 |
+
model_args.model_name_or_path)
|
395 |
+
tokenizer = WhisperTokenizer.from_pretrained(
|
396 |
+
model_args.model_name_or_path, language=model_args.language, task=model_args.task)
|
397 |
+
processor = WhisperProcessor.from_pretrained(
|
398 |
+
model_args.model_name_or_path, language=model_args.language, task=model_args.task)
|
399 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
400 |
+
|
401 |
+
# Saving the processor and the tokenizer
|
402 |
+
processor.save_pretrained(training_args.output_dir)
|
403 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
404 |
+
|
405 |
+
|
406 |
# Load dataset
|
407 |
train_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
|
408 |
split="train", streaming=True, use_auth_token=True)
|
409 |
eval_dataset = load_dataset(data_args.dataset_name, data_args.dataset_config_name,
|
410 |
split="test", streaming=True, use_auth_token=True)
|
411 |
+
|
412 |
+
# Because a bug in Dataset (https://github.com/huggingface/datasets/issues/3888) we need to read the columns and keep them for later
|
413 |
+
column_names=[x for x in train_dataset.info.features]
|
414 |
+
|
415 |
+
# Make sure everything is in 16K
|
416 |
+
train_dataset = train_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000))
|
417 |
+
eval_dataset = eval_dataset.cast_column(data_args.audio_column_name, Audio(sampling_rate=16000))
|
418 |
+
|
419 |
# Rename columns
|
420 |
if data_args.audio_column_name != "audio":
|
421 |
train_dataset = train_dataset.rename_column(
|
|
|
429 |
eval_dataset = eval_dataset.rename_column(
|
430 |
data_args.text_column_name, "sentence")
|
431 |
|
432 |
+
|
433 |
+
# Prepare the dataset
|
434 |
+
column_names.extend(['sentence','audio'])
|
435 |
+
train_dataset = train_dataset.map(prepare_dataset, remove_columns=column_names)
|
436 |
+
eval_dataset = eval_dataset.map(prepare_dataset, remove_columns=column_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
|
438 |
+
# Define metrics
|
439 |
metric = evaluate.load("wer")
|
440 |
|
441 |
# Detecting last checkpoint.
|
|
|
527 |
return train_result
|
528 |
|
529 |
# XLA hook
|
|
|
|
|
530 |
def _mp_fn(index):
|
531 |
# For xla_spawn (TPUs)
|
532 |
print("The XLA is initiated")
|
533 |
main()
|
534 |
|
|
|
535 |
if __name__ == "__main__":
|
536 |
main()
|