alakxender commited on
Commit
6430b7c
·
1 Parent(s): 0c6a355
Files changed (2) hide show
  1. app.py +86 -4
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,7 +1,89 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
 
7
+ # Device and dtype configuration
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
11
+ # Load model and processor with LM
12
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained("alakxender/wav2vec2-large-mms-1b-dv-syn-md")
13
+ model = Wav2Vec2ForCTC.from_pretrained(
14
+ "alakxender/wav2vec2-large-mms-1b-dv-syn-md",
15
+ torch_dtype=torch_dtype
16
+ ).to(device)
17
+
18
+ MAX_LENGTH = 120 # 2 minutes
19
+ MIN_LENGTH = 1 # 1 second
20
+
21
+ def transcribe(audio_file):
22
+ try:
23
+ # Load audio file
24
+ waveform, sample_rate = torchaudio.load(audio_file)
25
+
26
+ # Move waveform to the correct device
27
+ waveform = waveform.to(device)
28
+
29
+ # Get the duration of the audio
30
+ duration = waveform.shape[1] / sample_rate
31
+
32
+ # Check if the audio is too short or too long
33
+ if duration < MIN_LENGTH or duration > MAX_LENGTH:
34
+ return f"Audio duration is too short or too long. Duration: {duration} seconds"
35
+
36
+ # Resample if necessary
37
+ if sample_rate != 16000:
38
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000).to(device)
39
+ waveform = resampler(waveform)
40
+
41
+ # Convert to mono if stereo
42
+ if waveform.shape[0] > 1:
43
+ waveform = waveform.mean(dim=0, keepdim=True)
44
+
45
+ # Move to CPU for numpy conversion
46
+ waveform = waveform.cpu()
47
+ audio_input = waveform.squeeze().numpy()
48
+
49
+ # Ensure audio input is float32
50
+ if audio_input.dtype != np.float32:
51
+ audio_input = audio_input.astype(np.float32)
52
+
53
+ # Process audio input
54
+ input_values = processor(
55
+ audio_input,
56
+ sampling_rate=16_000,
57
+ return_tensors="pt"
58
+ ).input_values.to(device)
59
+
60
+ # Convert to float16 if using CUDA
61
+ if torch_dtype == torch.float16:
62
+ input_values = input_values.half()
63
+
64
+ # Generate transcription
65
+ with torch.no_grad():
66
+ logits = model(input_values).logits
67
+
68
+ # Use language model for decoding
69
+ transcription = processor.decode(logits[0].cpu().numpy())
70
+
71
+ # Return the transcription in lowercase
72
+ print(transcription)
73
+ return transcription[0].lower()
74
+
75
+ except Exception as e:
76
+ return f"Error during transcription: {str(e)}"
77
+
78
+ # Create Gradio interface
79
+ iface = gr.Interface(
80
+ fn=transcribe,
81
+ inputs=gr.Audio(type="filepath"),
82
+ outputs="text",
83
+ title="Dhivehi Speech Recognition with Language Model",
84
+ description="Upload an audio file to transcribe Dhivehi speech to text using language model enhanced decoding."
85
+ )
86
+
87
+ # Launch the interface
88
+ if __name__ == "__main__":
89
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torchaudio
3
+ pyctcdecode