asr-inference / whisper.py
wetdog's picture
add attention mask
4e463f4 verified
raw
history blame
1.3 kB
import os
from pyannote.audio import Pipeline
from pydub import AudioSegment
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torchaudio
import torch
device = 0 if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_NAME = "projecte-aina/whisper-large-v3-ca-es-synth-cs"
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype,token=HF_TOKEN).to(device)
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
def generate(audio_path):
input_audio, sample_rate = torchaudio.load(audio_path)
input_audio = torchaudio.transforms.Resample(sample_rate, 16000)(input_audio)
input_speech = input_audio[0]
input_features = processor(input_speech,
sampling_rate=16_000,
return_attention_mask=True,
return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)
pred_ids = model.generate(input_features,
return_timestamps=True,
max_new_tokens=128)
output = processor.batch_decode(pred_ids, skip_special_tokens=True)
line = output[0]
return line