Walid-Ahmed commited on
Commit
aa0a69d
·
verified ·
1 Parent(s): e3aee22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -17
app.py CHANGED
@@ -1,39 +1,35 @@
1
 
2
 
3
- import whisper
4
  import gradio as gr
 
5
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
-
8
- # Initialize the device map for ZeRO
9
- from accelerate.utils import set_module_tensor_to_device
10
  import torch
11
 
12
- device_map = "auto" # Automatically allocate layers across available GPUs/CPUs
 
13
  print(f"Using ZeRO-powered device map: {device_map}")
14
 
15
- # Load the model using ZeRO
16
- model_name = "openai/whisper-tiny"
17
 
18
- # Load the Whisper model into ZeRO's memory-efficient mode
19
  with init_empty_weights():
20
- whisper_model = whisper.load_model(model_name)
21
 
22
- # Load tokenizer
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
-
25
- # Load model with Accelerate/ZeRO
26
  whisper_model = load_checkpoint_and_dispatch(
27
  whisper_model,
28
  device_map=device_map,
29
- dtype=torch.float16 # Optional: Use mixed precision for further optimization
30
  )
31
 
 
 
32
  # Define the transcription function
33
  def transcribe(audio):
34
  # Perform transcription using the Whisper model
35
  result = whisper_model.transcribe(audio)
36
- return result['text']
37
 
38
  # Create the Gradio interface
39
  demo = gr.Interface(
@@ -41,7 +37,7 @@ demo = gr.Interface(
41
  inputs=gr.Audio(source="microphone", type="filepath", label="Speak into the microphone"), # Input audio
42
  outputs=gr.Textbox(label="Transcription"), # Output transcription
43
  title="Whisper Speech-to-Text with ZeRO", # Title of the interface
44
- description="Record audio using your microphone and get a transcription using the Whisper model optimized by ZeRO."
45
  )
46
 
47
  # Launch the Gradio interface
 
1
 
2
 
 
3
  import gradio as gr
4
+ import whisper
5
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 
 
 
6
  import torch
7
 
8
+ # Check if GPU is available and set up device map
9
+ device_map = "auto" # Automatically balance layers across available devices
10
  print(f"Using ZeRO-powered device map: {device_map}")
11
 
12
+ # Load the Whisper model using Accelerate with ZeRO
13
+ model_name = "tiny" # Change to "base", "small", etc., as needed
14
 
15
+ print(f"Loading the Whisper model: {model_name} with ZeRO optimization...")
16
  with init_empty_weights():
17
+ whisper_model = whisper.load_model(model_name) # Load model structure without weights
18
 
19
+ # Dispatch the model across devices using ZeRO
 
 
 
20
  whisper_model = load_checkpoint_and_dispatch(
21
  whisper_model,
22
  device_map=device_map,
23
+ dtype=torch.float16 # Use mixed precision for efficiency
24
  )
25
 
26
+ print("Model successfully loaded with ZeRO optimization!")
27
+
28
  # Define the transcription function
29
  def transcribe(audio):
30
  # Perform transcription using the Whisper model
31
  result = whisper_model.transcribe(audio)
32
+ return result["text"]
33
 
34
  # Create the Gradio interface
35
  demo = gr.Interface(
 
37
  inputs=gr.Audio(source="microphone", type="filepath", label="Speak into the microphone"), # Input audio
38
  outputs=gr.Textbox(label="Transcription"), # Output transcription
39
  title="Whisper Speech-to-Text with ZeRO", # Title of the interface
40
+ description="Record audio using your microphone and get a transcription using the Whisper model optimized with ZeRO."
41
  )
42
 
43
  # Launch the Gradio interface