nb-whisper-demo / app.py
AngelinaZanardi's picture
Update app.py
535cd88 verified
import time
import os
import re
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoFeatureExtractor, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, pipeline
from huggingface_hub import model_info
try:
import flash_attn
FLASH_ATTENTION = True
except ImportError:
FLASH_ATTENTION = False
import yt_dlp # Added import for yt-dlp
MODEL_NAME = "NbAiLab/nb-whisper-large"
max_audio_length = 30 * 60
share = (os.environ.get("SHARE", "False")[0].lower() in "ty1") or None
auth_token = os.environ.get("AUTH_TOKEN") or True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Bruker enhet: {device}")
@spaces.GPU(duration=60 * 2)
def pipe(file, return_timestamps=False, lang="no"):
asr = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=28,
device=device,
token=auth_token,
torch_dtype=torch.float16,
model_kwargs={"attn_implementation": "flash_attention_2", "num_beams": 5, "language": lang} if FLASH_ATTENTION else {"attn_implementation": "sdpa", "num_beams": 5},
)
asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
language=lang,
task="transcribe",
no_timestamps=not return_timestamps,
)
return asr(file, return_timestamps=return_timestamps, batch_size=24, generate_kwargs={'task': 'transcribe', 'language': lang})
def format_output(text):
text = re.sub(r'(\.{3,}|[.!:?])', lambda m: m.group() + '<br>', text)
return text
def transcribe(file, return_timestamps=False, lang_nn=False):
waveform, sample_rate = torchaudio.load(file)
audio_duration = waveform.size(1) / sample_rate
warning_message = None
if audio_duration > max_audio_length:
warning_message = (
"<b style='color:red;'>⚠️ Advarsel:</b> "
"Lydfilen er lengre enn 30 minutter. Kun de første 30 minuttene vil bli transkribert."
)
waveform = waveform[:, :int(max_audio_length * sample_rate)]
truncated_file = "truncated_audio.wav"
torchaudio.save(truncated_file, waveform, sample_rate)
file_to_transcribe = truncated_file
truncated = True
else:
file_to_transcribe = file
truncated = False
if not lang_nn:
if not return_timestamps:
text = pipe(file_to_transcribe)["text"]
formatted_text = format_output(text)
else:
chunks = pipe(file_to_transcribe, return_timestamps=True)["chunks"]
text = []
for chunk in chunks:
start_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][0])) if chunk["timestamp"][0] is not None else "??:??:??"
end_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][1])) if chunk["timestamp"][1] is not None else "??:??:??"
line = f"[{start_time} -> {end_time}] {chunk['text']}"
text.append(line)
formatted_text = "<br>".join(text)
else:
if not return_timestamps:
text = pipe(file_to_transcribe, lang="nn")["text"]
formatted_text = format_output(text)
else:
chunks = pipe(file_to_transcribe, return_timestamps=True, lang="nn")["chunks"]
text = []
for chunk in chunks:
start_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][0])) if chunk["timestamp"][0] is not None else "??:??:??"
end_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][1])) if chunk["timestamp"][1] is not None else "??:??:??"
line = f"[{start_time} -> {end_time}] {chunk['text']}"
text.append(line)
formatted_text = "<br>".join(text)
output_file = "transcription.txt"
with open(output_file, "w") as f:
f.write(re.sub('<br>', '\n', formatted_text))
if truncated:
link="https://github.com/NbAiLab/nostram/blob/main/leverandorer.md"
disclaimer = (
"\n\n Dette er en demo. Det er ikke tillatt å bruke denne teksten i profesjonell sammenheng. "
"Vi anbefaler at hvis du trenger å transkribere lengre opptak, så kjører du enten modellen lokalt "
"eller sjekker denne siden for å se hvem som leverer løsninger basert på NB-Whisper: "
f"<a href='{link}' target='_blank'>denne siden</a>."
)
formatted_text += f"<br><br><i>{disclaimer}</i>"
formatted_text += "<br><br><i>Transkribert med NB-Whisper demo</i>"
return warning_message, formatted_text, output_file
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
" </center>"
)
return HTML_str
def yt_transcribe(yt_url, return_timestamps=False):
html_embed_str = _return_yt_html_embed(yt_url)
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': 'audio.%(ext)s',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([yt_url])
text = transcribe("audio.mp3", return_timestamps=return_timestamps)
return html_embed_str, text
# Lag Gradio-appen uten faner
demo = gr.Blocks(theme=gr.themes.Default(primary_hue=gr.themes.colors.green, secondary_hue=gr.themes.colors.red))
with demo:
with gr.Column():
gr.HTML(f"<img src='file/Logonew.png' style='width:190px;'>")
with gr.Column(scale=8):
# Use Markdown for title and description
gr.Markdown(
"""
<h1 style="font-size: 3.5em;">NB-Whisper Demo</h1>
"""
)
mf_transcribe = gr.Interface(
fn=transcribe,
inputs=[
gr.components.Audio(sources=['upload', 'microphone'], type="filepath"),
gr.components.Checkbox(label="Inkluder tidskoder"),
gr.components.Checkbox(label="Nynorsk"),
],
outputs=[
gr.HTML(label="Varsel"),
gr.HTML(label="text"),
gr.File(label="Last ned transkripsjon") # Removed right side space in the box
],
description=(
"Demoen bruker"
f" modellen [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) til å transkribere lydfiler opp til 30 minutter."
),
allow_flagging="never",
)
# Start demoen uten faner
demo.launch(share=share, show_api=False, allowed_paths=["Logonew.png"]).queue()