dragonSwing commited on
Commit
90420f4
·
1 Parent(s): 7bfa718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -6,16 +6,18 @@ from speechbrain.pretrained import EncoderASR
6
  from transformers.file_utils import cached_path, hf_bucket_url
7
 
8
  cache_dir = './cache/'
9
- lm_file = hf_bucket_url("dragonSwing/wav2vec2-base-vn-270h", filename='4gram.zip')
 
10
  lm_file = cached_path(lm_file, cache_dir=cache_dir)
11
  with zipfile.ZipFile(lm_file, 'r') as zip_ref:
12
  zip_ref.extractall(cache_dir)
13
  lm_file = cache_dir + 'lm.binary'
14
  vocab_file = cache_dir + 'vocab-260000.txt'
15
- model = EncoderASR.from_hparams(source="dragonSwing/wav2vec2-base-vn-270h",
16
  savedir="/content/pretrained2/"
17
  )
18
 
 
19
  def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
20
  unigrams = None
21
  if vocab_path is not None:
@@ -37,19 +39,23 @@ def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
37
  decoder = build_ctcdecoder(vocab_list, ngram_lm_path, unigrams=unigrams)
38
  return decoder
39
 
 
40
  ngram_lm_model = get_decoder_ngram_model(model.tokenizer, lm_file, vocab_file)
41
 
 
42
  def transcribe_file(path, max_seconds=20):
43
  waveform = model.load_audio(path)
44
  if max_seconds > 0:
45
- waveform = waveform[:max_seconds*16000]
46
  batch = waveform.unsqueeze(0)
47
  rel_length = torch.tensor([1.0])
48
  with torch.no_grad():
49
  logits = model(batch, rel_length)
50
- text_batch = [ngram_lm_model.decode(logit.detach().cpu().numpy(), beam_width=500) for logit in logits]
 
51
  return text_batch[0]
52
 
 
53
  def speech_recognize(file_upload, file_mic):
54
  if file_upload is not None:
55
  file = file_upload
@@ -61,10 +67,18 @@ def speech_recognize(file_upload, file_mic):
61
  text = transcribe_file(file)
62
  return text
63
 
64
- inputs = [gr.inputs.Audio(source="upload", type='filepath', optional=True), gr.inputs.Audio(source="microphone", type='filepath', optional=True)]
65
- outputs = gr.outputs.Textbox(label="Output Text")
 
 
66
  title = "wav2vec2-base-vietnamese-270h"
67
  description = "Gradio demo for a wav2vec2 base vietnamese speech recognition. To use it, simply upload your audio, click one of the examples to load them, or record from your own microphone. Read more at the links below. Currently supports 16_000hz audio files"
68
  article = "<p style='text-align: center'><a href='https://huggingface.co/dragonSwing/wav2vec2-base-vn-270h' target='_blank'>Pretrained model</a></p>"
69
- examples=[['example1.wav', 'example1.wav'], ['example2.mp3', 'example2.mp3'], ['example3.mp3', 'example3.mp3'], ['example4.wav', 'example4.wav']]
70
- gr.Interface(speech_recognize, inputs, outputs, title=title, description=description, article=article, examples=examples).launch()
 
 
 
 
 
 
 
6
  from transformers.file_utils import cached_path, hf_bucket_url
7
 
8
  cache_dir = './cache/'
9
+ lm_file = hf_bucket_url(
10
+ "dragonSwing/wav2vec2-base-vn-270h", filename='4gram.zip')
11
  lm_file = cached_path(lm_file, cache_dir=cache_dir)
12
  with zipfile.ZipFile(lm_file, 'r') as zip_ref:
13
  zip_ref.extractall(cache_dir)
14
  lm_file = cache_dir + 'lm.binary'
15
  vocab_file = cache_dir + 'vocab-260000.txt'
16
+ model = EncoderASR.from_hparams(source="dragonSwing/wav2vec2-base-vn-270h",
17
  savedir="/content/pretrained2/"
18
  )
19
 
20
+
21
  def get_decoder_ngram_model(tokenizer, ngram_lm_path, vocab_path=None):
22
  unigrams = None
23
  if vocab_path is not None:
 
39
  decoder = build_ctcdecoder(vocab_list, ngram_lm_path, unigrams=unigrams)
40
  return decoder
41
 
42
+
43
  ngram_lm_model = get_decoder_ngram_model(model.tokenizer, lm_file, vocab_file)
44
 
45
+
46
  def transcribe_file(path, max_seconds=20):
47
  waveform = model.load_audio(path)
48
  if max_seconds > 0:
49
+ waveform = waveform[:max_seconds*16000]
50
  batch = waveform.unsqueeze(0)
51
  rel_length = torch.tensor([1.0])
52
  with torch.no_grad():
53
  logits = model(batch, rel_length)
54
+ text_batch = [ngram_lm_model.decode(
55
+ logit.detach().cpu().numpy(), beam_width=500) for logit in logits]
56
  return text_batch[0]
57
 
58
+
59
  def speech_recognize(file_upload, file_mic):
60
  if file_upload is not None:
61
  file = file_upload
 
67
  text = transcribe_file(file)
68
  return text
69
 
70
+
71
+ inputs = [gr.inputs.Audio(source="upload", type='filepath', optional=True), gr.inputs.Audio(
72
+ source="microphone", type='filepath', optional=True)]
73
+ outputs = gr.outputs.Textbox(label="Output Text")
74
  title = "wav2vec2-base-vietnamese-270h"
75
  description = "Gradio demo for a wav2vec2 base vietnamese speech recognition. To use it, simply upload your audio, click one of the examples to load them, or record from your own microphone. Read more at the links below. Currently supports 16_000hz audio files"
76
  article = "<p style='text-align: center'><a href='https://huggingface.co/dragonSwing/wav2vec2-base-vn-270h' target='_blank'>Pretrained model</a></p>"
77
+ examples = [
78
+ ['example1.wav', 'example1.wav'],
79
+ ['example2.mp3', 'example2.mp3'],
80
+ ['example3.mp3', 'example3.mp3'],
81
+ ['example4.wav', 'example4.wav'],
82
+ ]
83
+ gr.Interface(speech_recognize, inputs, outputs, title=title,
84
+ description=description, article=article, examples=examples).launch()