File size: 5,961 Bytes
72621ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
#!/home/haroon/python_virtual_envs/whisper_fine_tuning/bin/python
from datasets import load_dataset, DatasetDict, Audio
from transformers import (WhisperTokenizer, WhisperFeatureExtractor,
WhisperProcessor, WhisperForConditionalGeneration,
Seq2SeqTrainingArguments, Seq2SeqTrainer)
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="train+validation",
token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0",
"hi",
split="test",
token=True)
print(f'YYY1a {common_voice=}')
common_voice = common_voice.remove_columns([
"accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
print(f'YYY1b {common_voice=}')
print(f'YYY2 {type(common_voice)=}')
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
processor = WhisperProcessor.from_pretrained("openai/whisper-small",
language="Hindi", task="transcribe")
print(common_voice["train"][0])
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
print(common_voice["train"][0])
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
transcription = batch["sentence"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
common_voice = common_voice.map(prepare_dataset,
remove_columns=common_voice.column_names["train"],
num_proc=2)
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
common_voice["train"] = common_voice["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]])\
-> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
metric = evaluate.load("wer")
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.generation_config.language = "hi"
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=8,
gradient_accumulation_steps=8, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=4,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
processor.save_pretrained(training_args.output_dir)
trainer.train()
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # a 'pretty' name for the training dataset
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
trainer.push_to_hub(**kwargs)
|