community-events / whisper-fine-tuning-event /fine-tune-whisper-non-streaming-no_comments.py
showgan's picture
Training in progress, step 1000
72621ec verified
#!/home/haroon/python_virtual_envs/whisper_fine_tuning/bin/python
from datasets import load_dataset, DatasetDict, Audio
from transformers import (WhisperTokenizer, WhisperFeatureExtractor,
WhisperProcessor, WhisperForConditionalGeneration,
Seq2SeqTrainingArguments, Seq2SeqTrainer)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="train+validation",
token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="test",
token=True)
print(f'YYY1a {common_voice=}')
common_voice = common_voice.remove_columns([
"accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
print(f'YYY1b {common_voice=}')
print(f'YYY2 {type(common_voice)=}')
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
print(common_voice["train"][0])
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
print(common_voice["train"][0])
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
transcription = batch["sentence"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
common_voice = common_voice.map(prepare_dataset,
remove_columns=common_voice.column_names["train"],
num_proc=2)
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
common_voice["train"] = common_voice["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]])\
-> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
metric = evaluate.load("wer")
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.generation_config.language = "hi"
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=8,
gradient_accumulation_steps=8, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=4,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
processor.save_pretrained(training_args.output_dir)
trainer.train()
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
trainer.push_to_hub(**kwargs)