Vaibhav Srivastav commited on
Commit
dc1ade6
Β·
1 Parent(s): 17a49e1

updating the model downloading

Browse files
Files changed (1) hide show
  1. app.py +12 -2
app.py CHANGED
@@ -7,8 +7,18 @@ from transformers import Wav2Vec2Processor, AutoModelForCTC
7
 
8
  nltk.download("punkt")
9
 
 
 
 
 
 
10
  def return_processor_and_model(model_name):
11
- return Wav2Vec2Processor.from_pretrained(model_name), AutoModelForCTC.from_pretrained(model_name)
 
 
 
 
 
12
 
13
  def load_and_fix_data(input_file):
14
  #read the file
@@ -62,6 +72,6 @@ gr.Interface(return_all_predictions,
62
  inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
63
  outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Greedy decoding")],
64
  title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
65
- description = "Comparing Wav2Vec2 & Hubert with Greedy vs Beam Search decoding",
66
  layout = "horizontal",
67
  examples = [["test1.wav", "facebook/wav2vec2-base-960h"], ["test2.wav", "facebook/hubert-large-ls960-ft"]], theme="huggingface").launch()
 
7
 
8
  nltk.download("punkt")
9
 
10
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
11
+ wav2vec2_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
12
+ hubert_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
13
+ hubert_model = AutoModelForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
14
+
15
  def return_processor_and_model(model_name):
16
+ if model_name == "facebook/wav2vec2-base-960h":
17
+ return wav2vec2_processor, wav2vec2_model
18
+ elif model_name == "facebook/hubert-large-ls960-ft":
19
+ return hubert_processor, hubert_model
20
+ else:
21
+ return None
22
 
23
  def load_and_fix_data(input_file):
24
  #read the file
 
72
  inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["facebook/wav2vec2-base-960h", "facebook/hubert-large-ls960-ft"], label="Model Name")],
73
  outputs = [gr.outputs.Textbox(label="Beam CTC decoding"), gr.outputs.Textbox(label="Greedy decoding")],
74
  title="ASR using Wav2Vec2/ Hubert & pyctcdecode",
75
+ description = "Comparing greedy decoder with beam search CTC decoder (https://distill.pub/2017/ctc/), record/ drop your audio!",
76
  layout = "horizontal",
77
  examples = [["test1.wav", "facebook/wav2vec2-base-960h"], ["test2.wav", "facebook/hubert-large-ls960-ft"]], theme="huggingface").launch()