Pavithiran commited on
Commit
1c29ca1
·
verified ·
1 Parent(s): ee23398

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import pytube
5
+ import tempfile
6
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
7
+ from pytube import YouTube
8
+
9
+ # Set environment variables for Hugging Face Space
10
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
11
+
12
+ # Model configuration
13
+ MODEL_NAME = "openai/whisper-large-v3-turbo"
14
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
+
17
+ # Setup model and processor
18
+ def load_model():
19
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
20
+ MODEL_NAME,
21
+ torch_dtype=torch_dtype,
22
+ low_cpu_mem_usage=True,
23
+ use_safetensors=True
24
+ )
25
+ model.to(device)
26
+
27
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
28
+
29
+ pipe = pipeline(
30
+ "automatic-speech-recognition",
31
+ model=model,
32
+ tokenizer=processor.tokenizer,
33
+ feature_extractor=processor.feature_extractor,
34
+ torch_dtype=torch_dtype,
35
+ device=device,
36
+ )
37
+
38
+ return pipe
39
+
40
+ # Load model globally - will be cached
41
+ pipe = load_model()
42
+
43
+ # Transcription function for audio files and microphone
44
+ def transcribe(audio_path, task="transcribe"):
45
+ if audio_path is None:
46
+ return "Please provide an audio input."
47
+
48
+ # Set task-specific generation parameters
49
+ generate_kwargs = {}
50
+ if task == "translate":
51
+ generate_kwargs["task"] = "translate"
52
+
53
+ # Process the audio
54
+ try:
55
+ result = pipe(audio_path, generate_kwargs=generate_kwargs)
56
+ return result["text"]
57
+ except Exception as e:
58
+ return f"Error during transcription: {str(e)}"
59
+
60
+ # YouTube video transcription function
61
+ def yt_transcribe(youtube_url, task="transcribe"):
62
+ if not youtube_url or not youtube_url.strip():
63
+ return "Please enter a YouTube URL", "No transcription available."
64
+
65
+ try:
66
+ # Download audio from YouTube
67
+ yt = YouTube(youtube_url)
68
+ video = yt.streams.filter(only_audio=True).first()
69
+
70
+ # Create a temporary file to store the audio
71
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
72
+ temp_path = temp_file.name
73
+
74
+ # Download the audio to the temporary file
75
+ video.download(filename=temp_path)
76
+
77
+ # Set task-specific generation parameters
78
+ generate_kwargs = {}
79
+ if task == "translate":
80
+ generate_kwargs["task"] = "translate"
81
+
82
+ # Process the audio
83
+ result = pipe(temp_path, generate_kwargs=generate_kwargs)
84
+
85
+ # Create an HTML element to display video thumbnail
86
+ video_id = youtube_url.split("v=")[-1].split("&")[0]
87
+ thumbnail_url = f"https://img.youtube.com/vi/{video_id}/0.jpg"
88
+ html_output = f'<div style="display: flex; align-items: center;"><img src="{thumbnail_url}" style="max-width: 200px; margin-right: 20px;"><div><h3>{yt.title}</h3><p>Channel: {yt.author}</p></div></div>'
89
+
90
+ # Clean up the temporary file
91
+ os.unlink(temp_path)
92
+
93
+ return html_output, result["text"]
94
+ except Exception as e:
95
+ return f"Error processing YouTube video: {str(e)}", "Transcription failed."
96
+
97
+ # Create Gradio interfaces
98
+ mic_transcribe = gr.Interface(
99
+ fn=transcribe,
100
+ inputs=[
101
+ gr.Audio(source="microphone", type="filepath", optional=True),
102
+ gr.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
103
+ ],
104
+ outputs="text",
105
+ layout="horizontal",
106
+ theme="huggingface",
107
+ title="Whisper Large V3 Turbo: Transcribe Audio",
108
+ description=(
109
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the OpenAI Whisper"
110
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
111
+ " of arbitrary length."
112
+ ),
113
+ allow_flagging="never",
114
+ )
115
+
116
+ file_transcribe = gr.Interface(
117
+ fn=transcribe,
118
+ inputs=[
119
+ gr.Audio(source="upload", type="filepath", optional=True, label="Audio file"),
120
+ gr.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
121
+ ],
122
+ outputs="text",
123
+ layout="horizontal",
124
+ theme="huggingface",
125
+ title="Whisper Large V3 Turbo: Transcribe Audio",
126
+ description=(
127
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the OpenAI Whisper"
128
+ f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
129
+ " of arbitrary length."
130
+ ),
131
+ allow_flagging="never",
132
+ )
133
+
134
+ yt_interface = gr.Interface(
135
+ fn=yt_transcribe,
136
+ inputs=[
137
+ gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
138
+ gr.Radio(["transcribe", "translate"], label="Task", default="transcribe")
139
+ ],
140
+ outputs=["html", "text"],
141
+ layout="horizontal",
142
+ theme="huggingface",
143
+ title="Whisper Large V3 Turbo: Transcribe YouTube",
144
+ description=(
145
+ "Transcribe long-form YouTube videos with the click of a button! Demo uses the OpenAI Whisper checkpoint"
146
+ f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe video files of"
147
+ " arbitrary length."
148
+ ),
149
+ allow_flagging="never",
150
+ )
151
+
152
+ # Create the tabbed interface
153
+ demo = gr.Blocks()
154
+ with demo:
155
+ gr.TabbedInterface([mic_transcribe, file_transcribe, yt_interface], ["Microphone", "Audio file", "YouTube"])
156
+
157
+ # Launch the app
158
+ if __name__ == "__main__":
159
+ demo.launch(enable_queue=True)