whisper_small_CGN / scripts /run_eval_whisper_streaming_local.py
Jakob Poncelet
First Model Version
60aae99
raw
history blame
4.69 kB
import os
import numpy as np
import re
import argparse
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].replace("CUDA", "")
from transformers import pipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
whisper_norm = BasicTextNormalizer()
def simple_norm(utt):
norm_utt = re.sub(r'[^\w\s]', '', utt) # remove punctualisation
norm_utt = " ".join(norm_utt.split()) # remove whitespaces
norm_utt = norm_utt.lower()
return norm_utt
def data(dataset):
for i, item in enumerate(dataset):
yield {**item["audio"], "reference": item["text"], "utt_id": item["id"]}
def get_ckpt(path, ckpt_id):
if ckpt_id != 0:
model = os.path.join(path, "checkpoint-%i" % ckpt)
else:
dirs = [d for d in os.listdir(path) if d.startswith("checkpoint-")]
ckpts = [int(d.split('-')[-1]) for d in dirs]
last_ckpt = sorted(ckpts)[-1]
model = os.path.join(path, "checkpoint-%s" % last_ckpt)
return model
def main(args):
batch_size = args.batch_size
if args.device == "cpu":
device_id = -1
elif args.device == "gpu":
device_id = 0
else:
raise NotImplementedError("unknown device %s, should be cpu/gpu" % args.device)
model_dir = os.path.join(args.expdir, args.model_size)
#model = os.path.join(get_ckpt(model_dir, args.checkpoint), 'pytorch_model.bin')
#model = get_ckpt(model_dir, args.checkpoint)
model = model_dir
#model = "openai/whisper-tiny"
whisper_asr = pipeline(
"automatic-speech-recognition", model=model, device=device_id
)
whisper_asr.model.config.forced_decoder_ids = (
whisper_asr.tokenizer.get_decoder_prompt_ids(
language=args.language, task="transcribe"
)
)
if args.dataset == 'cgn-dev':
dataset_path = "./cgn-dev/cgn-dev.py"
elif args.dataset == 'subs-annot':
dataset_path = "./subs-annot/subs-annot.py"
else:
raise NotImplementedError('unknown dataset %s' % args.dataset)
cache_dir = "/esat/audioslave/jponcele/hf_cache"
dataset = load_dataset(dataset_path, name="raw", split="test", cache_dir=cache_dir, streaming=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
utterances = []
predictions = []
references = []
# run streamed inference
for out in whisper_asr(data(dataset), batch_size=batch_size):
predictions.append(out["text"])
utterances.append(out["utt_id"][0])
references.append(out["reference"][0])
#break
result_dir = os.path.join(args.expdir, "results", args.dataset)
os.makedirs(result_dir, exist_ok=True)
with open(os.path.join(result_dir, "whisper_%s.txt" % args.model_size), "w") as pd:
for i, utt in enumerate(utterances):
pred = predictions[i]
pd.write(utt + ' ' + pred + '\n')
with open(os.path.join(result_dir, "whisper_%s_normW.txt" % args.model_size), "w") as pd:
for i, utt in enumerate(utterances):
pred = whisper_norm(predictions[i])
pd.write(utt + ' ' + pred + '\n')
with open(os.path.join(result_dir, "whisper_%s_normS.txt" % args.model_size), "w") as pd:
for i, utt in enumerate(utterances):
pred = simple_norm(predictions[i])
pd.write(utt + ' ' + pred + '\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--expdir",
type=str,
default="/esat/audioslave/jponcele/whisper/finetuning_event/CGN",
help="Directory with finetuned models",
)
parser.add_argument(
"--model_size",
type=str,
default="tiny",
help="Model size",
)
parser.add_argument(
"--checkpoint",
type=int,
default=0,
help="Load specific checkpoint. 0 means latest",
)
parser.add_argument(
"--dataset",
type=str,
default="cgn-dev",
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="cpu/gpu",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Number of samples to go through each streamed batch.",
)
parser.add_argument(
"--language",
type=str,
default="dutch",
help="Two letter language code for the transcription language, e.g. use 'en' for English.",
)
args = parser.parse_args()
main(args)