RasmusToivanen commited on
Commit
af31d45
1 Parent(s): ab09d2c

add article, change to gradio 3, remove 300m model

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -19,18 +19,16 @@ from transformers import pipeline
19
 
20
 
21
 
22
- pipe_300m = pipeline(model="Finnish-NLP/wav2vec2-xlsr-300m-finnish-lm",chunk_length_s=20, stride_length_s=(3, 3))
23
- pipe_94m = pipeline(model="Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned",chunk_length_s=20, stride_length_s=(3, 3))
24
  pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
25
 
26
 
27
 
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- model_checkpoint = 'Finnish-NLP/t5x-small-nl24-casing-punctuation-correction'
30
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_auth_token=os.environ.get('hf_token'))
31
  model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32, use_auth_token=os.environ.get('hf_token')).to(device)
32
 
33
-
34
  # define speech-to-text function
35
  def asr_transcript(audio, audio_microphone, model_params):
36
 
@@ -38,16 +36,14 @@ def asr_transcript(audio, audio_microphone, model_params):
38
  audio = audio_microphone if audio_microphone else audio
39
 
40
  if audio == None and audio_microphone == None:
41
- return "Please provide audio by uploading file or by recording audio with microphone by pressing Record (And allow usage of microphone)", "Please provide audio by uploading file or by recording audio with microphone by pressing Record (And allow usage of microphone)"
42
  text = ""
43
 
44
  if audio:
45
- if model_params == "1 billion multi":
46
  text = pipe_1b(audio.name)
47
- elif model_params == "94 million fi":
48
- text = pipe_94m(audio.name)
49
- elif model_params == "300 million multi":
50
- text = pipe_300m(audio.name)
51
 
52
  input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
53
  outputs = model.generate(input_ids, max_length=128)
@@ -58,9 +54,19 @@ def asr_transcript(audio, audio_microphone, model_params):
58
 
59
  gradio_ui = gr.Interface(
60
  fn=asr_transcript,
61
- title="Finnish automatic speech recognition",
62
- description="Upload an audio clip, and let AI do the hard work of transcribing",
63
- inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record from microphone"), gr.inputs.Dropdown(choices=["94 million fi", "300 million multi", "1 billion multi"], type="value", default="1 billion multi", label="Select speech recognition model parameter amount", optional=False)],
 
 
 
 
 
 
 
 
 
 
64
  outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
65
  )
66
 
 
19
 
20
 
21
 
22
+ pipe_95m = pipeline(model="Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned",chunk_length_s=20, stride_length_s=(3, 3))
 
23
  pipe_1b = pipeline(model="Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2",chunk_length_s=20, stride_length_s=(3, 3))
24
 
25
 
26
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model_checkpoint = 'Finnish-NLP/t5-small-nl24-casing-punctuation-correction'
29
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_auth_token=os.environ.get('hf_token'))
30
  model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32, use_auth_token=os.environ.get('hf_token')).to(device)
31
 
 
32
  # define speech-to-text function
33
  def asr_transcript(audio, audio_microphone, model_params):
34
 
 
36
  audio = audio_microphone if audio_microphone else audio
37
 
38
  if audio == None and audio_microphone == None:
39
+ return "Please provide audio by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)", "Please provide audio by uploading a file or by recording audio using microphone by pressing Record (And allow usage of microphone)"
40
  text = ""
41
 
42
  if audio:
43
+ if model_params == "1 billion":
44
  text = pipe_1b(audio.name)
45
+ elif model_params == "95 million":
46
+ text = pipe_95m(audio.name)
 
 
47
 
48
  input_ids = tokenizer(text['text'], return_tensors="pt").input_ids.to(device)
49
  outputs = model.generate(input_ids, max_length=128)
 
54
 
55
  gradio_ui = gr.Interface(
56
  fn=asr_transcript,
57
+ title="Finnish Automatic Speech Recognition",
58
+ description="Upload an audio clip or record from browser using microphone, and let AI do the hard work of transcribing.",
59
+ article = """
60
+ This demo includes 2 kinds of models that are run together. First selected ASR model does speech recognition which produces lowercase text without punctuation.
61
+ After that we run a sequence-to-sequence model which tries to correct casing and punctuation which produces the final output.
62
+ You can select one of two speech recognition models listed below
63
+
64
+ 1. 1 billion, best accuracy but slowest by big margin. Based on multilingual wav2vec2-xlsr model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-xlsr-1b-finnish-lm-v2
65
+ 2. 95 million, almost as accurate as 1. but really much faster. Based on finnish wav2vec2-xlsr model by Meta. More info here https://huggingface.co/Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned
66
+
67
+ More info about the casing+punctuation correction model can be found here https://huggingface.co/Finnish-NLP/t5-small-nl24-casing-punctuation-correction
68
+ """,
69
+ inputs=[gr.inputs.Audio(label="Upload Audio File", type="file", optional=True), gr.inputs.Audio(source="microphone", type="file", optional=True, label="Record from microphone"), gr.inputs.Dropdown(choices=["95 million","1 billion"], type="value", default="1 billion", label="Select speech recognition model parameter amount", optional=False)],
70
  outputs=[gr.outputs.Textbox(label="Recognized speech"),gr.outputs.Textbox(label="Recognized speech with case correction and punctuation")]
71
  )
72