nb-whisper-demo / app.py
pere's picture
update test
b9fdb45
raw
history blame
4.44 kB
import time
import os
import torch
import gradio as gr
import pytube as pt
import spaces
from transformers import AutoFeatureExtractor, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, pipeline
from huggingface_hub import model_info
try:
import flash_attn
FLASH_ATTENTION = True
except ImportError:
FLASH_ATTENTION = False
MODEL_NAME = "NbAiLab/nb-whisper-large"
lang = "no"
share = (os.environ.get("SHARE", "False")[0].lower() in "ty1") or None
auth_token = os.environ.get("AUTH_TOKEN") or True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
@spaces.GPU(duration=60 * 2)
def pipe(file, return_timestamps=False):
asr = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
token=auth_token,
torch_dtype=torch.float16,
model_kwargs={"attn_implementation": "flash_attention_2"} if FLASH_ATTENTION else {"attn_implementation": "sdpa"},
)
asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
language=lang,
task="transcribe",
no_timestamps=not return_timestamps,
)
return asr(file, return_timestamps=return_timestamps, batch_size=24)
def transcribe(file, return_timestamps=False):
if not return_timestamps:
text = pipe(file)["text"]
else:
chunks = pipe(file, return_timestamps=True)["chunks"]
text = []
for chunk in chunks:
start_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][0])) if chunk["timestamp"][0] is not None else "??:??:??"
end_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][1])) if chunk["timestamp"][1] is not None else "??:??:??"
line = f"[{start_time} -> {end_time}] {chunk['text']}"
text.append(line)
text = "\n".join(text)
return text
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
" </center>"
)
return HTML_str
def yt_transcribe(yt_url, return_timestamps=False):
try:
yt = pt.YouTube(yt_url)
except Exception as e:
return f"Error fetching YouTube video: {str(e)}"
html_embed_str = _return_yt_html_embed(yt_url)
audio_streams = yt.streams.filter(only_audio=True)
if not audio_streams:
return "No audio streams available for this video."
stream = audio_streams[0]
try:
stream.download(filename="audio.mp3")
except Exception as e:
return f"Error downloading audio: {str(e)}"
if not os.path.exists("audio.mp3"):
return "Downloaded audio file not found."
text = transcribe("audio.mp3", return_timestamps=return_timestamps)
return html_embed_str, text
demo = gr.Blocks()
mf_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.components.Audio(sources=['upload', 'microphone'], type="filepath"),
gr.components.Checkbox(label="Return timestamps"),
],
outputs="text",
title="NB-Whisper Demo",
description=(
"Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
" of arbitrary length."
),
allow_flagging="never",
)
yt_transcribe = gr.Interface(
fn=yt_transcribe,
inputs=[
gr.components.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
gr.components.Checkbox(label="Return timestamps"),
],
examples=[["https://www.youtube.com/watch?v=mukeSSa5GKo"]],
outputs=["html", "text"],
title="Whisper Demo: Transcribe YouTube",
description=(
"Transcribe long-form YouTube videos with the click of a button! Demo uses the the fine-tuned checkpoint:"
f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files of"
" arbitrary length."
),
allow_flagging="never",
)
with demo:
gr.TabbedInterface([
mf_transcribe,
yt_transcribe
], [
"Transkriber Lyd",
"Transkriber YouTube"
])
demo.launch(share=share).queue()