from datasets import load_dataset, DatasetDict from transformers import WhisperFeatureExtractor from transformers import WhisperTokenizer from transformers import WhisperProcessor from datasets import Audio from transformers.models.whisper.english_normalizer import BasicTextNormalizer from huggingface_hub import login import argparse my_parser = argparse.ArgumentParser() my_parser.add_argument( "--model_name", "-model_name", type=str, action="store", default="openai/whisper-tiny", ) my_parser.add_argument("--hf_token", "-hf_token", type=str, action="store") my_parser.add_argument( "--dataset_name", "-dataset_name", type=str, action="store", default="google/fleurs" ) my_parser.add_argument("--split", "-split", type=str, action="store", default="test") my_parser.add_argument("--subset", "-subset", type=str, action="store") args = my_parser.parse_args() dataset_name = args.dataset_name model_name = args.model_name subset = args.subset hf_token = args.hf_token login(hf_token) text_column = "sentence" if dataset_name == "google/fleurs": text_column = "transcription" do_lower_case = False do_remove_punctuation = False normalizer = BasicTextNormalizer() processor = WhisperProcessor.from_pretrained( model_name, language="Arabic", task="transcribe" ) dataset = load_dataset(dataset_name, subset, use_auth_token=True) print(dataset) feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) tokenizer = WhisperTokenizer.from_pretrained( model_name, language="Arabic", task="transcribe" ) dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) def prepare_dataset(batch): # load and (possibly) resample audio data to 16kHz audio = batch["audio"] # compute log-Mel input features from input audio array batch["input_features"] = processor.feature_extractor( audio["array"], sampling_rate=audio["sampling_rate"] ).input_features[0] # compute input length of audio sample in seconds batch["input_length"] = len(audio["array"]) / audio["sampling_rate"] # optional pre-processing steps transcription = batch[text_column] if do_lower_case: transcription = transcription.lower() if do_remove_punctuation: transcription = normalizer(transcription).strip() # encode target text to label ids batch["labels"] = processor.tokenizer(transcription).input_ids return batch dataset = dataset.map(prepare_dataset, remove_columns=dataset.column_names["train"]) login(hf_token) print( f"pushing to arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}" ) dataset.push_to_hub( f"arbml/{dataset_name.split('/')[-1]}_preprocessed_{model_name.split('/')[-1]}", private=True, )