dwb2023's picture
Update app.py
751197e verified
raw
history blame
6.32 kB
import os
import json
import time
from datetime import datetime
from pathlib import Path
import tempfile
import pandas as pd
import gradio as gr
import yt_dlp as youtube_dl
from transformers import (
BitsAndBytesConfig,
AutoModelForSpeechSeq2Seq,
AutoTokenizer,
AutoFeatureExtractor,
pipeline,
)
from transformers.pipelines.audio_utils import ffmpeg_read
import torch # If you're using PyTorch
from datasets import load_dataset, Dataset, DatasetDict
import spaces
# Constants
MODEL_NAME = "openai/whisper-large-v3"
BATCH_SIZE = 8
YT_LENGTH_LIMIT_S = 4800 # 1 hour 20 minutes
DATASET_NAME = "dwb2023/yt-transcripts-v3"
# Environment setup
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Model setup
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
use_cache=False,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
pipe = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
chunk_length_s=30,
)
def reset_and_update_dataset(new_data):
# Define the schema for an empty DataFrame
schema = {
"url": pd.Series(dtype="str"),
"transcription": pd.Series(dtype="str"),
"title": pd.Series(dtype="str"),
"duration": pd.Series(dtype="int"),
"uploader": pd.Series(dtype="str"),
"upload_date": pd.Series(dtype="datetime64[ns]"),
"description": pd.Series(dtype="str"),
"datetime": pd.Series(dtype="datetime64[ns]")
}
# Create an empty DataFrame with the defined schema
df = pd.DataFrame(schema)
# Append the new data
df = pd.concat([df, pd.DataFrame([new_data])], ignore_index=True)
# Convert back to dataset
updated_dataset = Dataset.from_pandas(df)
# Push the updated dataset to the hub
dataset_dict = DatasetDict({"train": updated_dataset})
dataset_dict.push_to_hub(DATASET_NAME)
print("Dataset reset and updated successfully!")
def download_yt_audio(yt_url, filename):
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
raise gr.Error(str(err))
file_length = info["duration"]
if file_length > YT_LENGTH_LIMIT_S:
yt_length_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S))
file_length_hms = time.strftime("%H:%M:%S", time.gmtime(file_length))
raise gr.Error(
f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video."
)
ydl_opts = {"outtmpl": filename, "format": "bestaudio/best"}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
ydl.download([yt_url])
return info
@spaces.GPU(duration=120)
def yt_transcribe(yt_url, task):
# Load the dataset
dataset = load_dataset(DATASET_NAME, split="train")
# Check if the transcription already exists
for row in dataset:
if row['url'] == yt_url:
return row['transcription'] # Return the existing transcription
# If transcription does not exist, perform the transcription
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "video.mp4")
info = download_yt_audio(yt_url, filepath)
with open(filepath, "rb") as f:
video_data = f.read()
inputs = ffmpeg_read(video_data, pipe.feature_extractor.sampling_rate)
inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
text = pipe(
inputs,
batch_size=BATCH_SIZE,
generate_kwargs={"task": task},
return_timestamps=True,
)["text"]
# Extract additional fields
try:
title = info.get("title", "N/A")
duration = info.get("duration", 0)
uploader = info.get("uploader", "N/A")
upload_date = info.get("upload_date", "N/A")
description = info.get("description", "N/A")
except KeyError:
title = "N/A"
duration = 0
uploader = "N/A"
upload_date = "N/A"
description = "N/A"
save_transcription(yt_url, text, title, duration, uploader, upload_date, description)
return text
def save_transcription(yt_url, transcription, title, duration, uploader, upload_date, description):
data = {
"url": yt_url,
"transcription": transcription,
"title": title,
"duration": duration,
"uploader": uploader,
"upload_date": upload_date,
"description": description,
"datetime": datetime.now().isoformat()
}
# Load the existing dataset
dataset = load_dataset(DATASET_NAME, split="train")
# Convert to pandas dataframe
df = dataset.to_pandas()
# Append the new data
df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)
# Convert back to dataset
updated_dataset = Dataset.from_pandas(df)
# Push the updated dataset to the hub
dataset_dict = DatasetDict({"train": updated_dataset})
dataset_dict.push_to_hub(DATASET_NAME)
demo = gr.Blocks()
yt_transcribe_interface = gr.Interface(
fn=yt_transcribe,
inputs=[
gr.Textbox(
lines=1,
placeholder="Paste the URL to a YouTube video here",
label="YouTube URL",
),
gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
],
outputs="text",
title="Whisper Large V3: Transcribe YouTube",
description=(
"Transcribe long-form YouTube videos with the click of a button! Demo uses the checkpoint"
f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe video files of"
" arbitrary length."
),
allow_flagging="never",
)
with demo:
gr.TabbedInterface(
[yt_transcribe_interface], ["YouTube"]
)
demo.queue().launch()