hackergeek98 commited on
Commit
bdf330c
·
verified ·
1 Parent(s): c558d1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -33
app.py CHANGED
@@ -1,44 +1,68 @@
1
- import gradio as gr
2
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import torch
4
- import librosa
5
-
6
- # Load the fine-tuned Whisper model and processor
7
- model_name = "hackergeek98/tinyyyy_whisper"
8
- processor = WhisperProcessor.from_pretrained(model_name)
9
- model = WhisperForConditionalGeneration.from_pretrained(model_name)
10
-
11
- # Force the model to transcribe in Persian
12
- model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="fa", task="transcribe")
13
 
14
- # Move model to GPU if available
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model.to(device)
17
 
18
- # Define the ASR function
19
- def transcribe_audio(audio_file):
20
- # Load audio file using librosa (supports multiple formats)
21
- audio_data, sampling_rate = librosa.load(audio_file, sr=16000) # Resample to 16kHz
22
 
23
- # Preprocess the audio
24
- inputs = processor(audio_data, sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device)
 
 
 
 
 
 
 
 
 
25
 
26
- # Generate transcription
27
- with torch.no_grad():
28
- predicted_ids = model.generate(inputs)
 
 
 
 
 
 
 
 
 
29
 
30
- # Decode the transcription
31
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
 
 
32
  return transcription
33
 
34
- # Create the Gradio interface
35
- interface = gr.Interface(
36
- fn=transcribe_audio, # Function to call
37
- inputs=gr.Audio(type="filepath"), # Input: Upload audio file (any format)
38
- outputs=gr.Textbox(label="Transcription"), # Output: Display transcription
39
- title="Whisper ASR: Persian Transcription",
40
- description="Upload an audio file (e.g., .wav, .mp3, .ogg), and the fine-tuned Whisper model will transcribe it in Persian.",
 
 
 
41
  )
42
 
43
- # Launch the app
44
- interface.launch()
 
1
+ # Install required packages
 
2
  import torch
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
+ from pydub import AudioSegment
5
+ import os
6
+ import gradio as gr
 
 
 
 
 
7
 
8
+ # Load the model and processor
9
+ model_id = "hackergeek98/tinyyyy_whisper"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
11
 
12
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id).to(device)
13
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
14
 
15
+ # Create pipeline
16
+ whisper_pipe = pipeline(
17
+ "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=0 if torch.cuda.is_available() else -1
18
+ )
19
+
20
+ # Convert audio to WAV format
21
+ def convert_to_wav(audio_path):
22
+ audio = AudioSegment.from_file(audio_path)
23
+ wav_path = "converted_audio.wav"
24
+ audio.export(wav_path, format="wav")
25
+ return wav_path
26
 
27
+ # Split long audio into chunks
28
+ def split_audio(audio_path, chunk_length_ms=30000): # Default: 30 sec per chunk
29
+ audio = AudioSegment.from_wav(audio_path)
30
+ chunks = [audio[i:i+chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
31
+ chunk_paths = []
32
+
33
+ for i, chunk in enumerate(chunks):
34
+ chunk_path = f"chunk_{i}.wav"
35
+ chunk.export(chunk_path, format="wav")
36
+ chunk_paths.append(chunk_path)
37
+
38
+ return chunk_paths
39
 
40
+ # Transcribe a long audio file
41
+ def transcribe_long_audio(audio_path):
42
+ wav_path = convert_to_wav(audio_path)
43
+ chunk_paths = split_audio(wav_path)
44
+ transcription = ""
45
+
46
+ for chunk in chunk_paths:
47
+ result = whisper_pipe(chunk)
48
+ transcription += result["text"] + "\n"
49
+ os.remove(chunk) # Remove processed chunk
50
+
51
+ os.remove(wav_path) # Cleanup original file
52
+
53
  return transcription
54
 
55
+ # Gradio interface
56
+ def transcribe_interface(audio_file):
57
+ return transcribe_long_audio(audio_file)
58
+
59
+ iface = gr.Interface(
60
+ fn=transcribe_interface,
61
+ inputs=gr.Audio(source="upload", type="filepath"),
62
+ outputs="text",
63
+ title="Whisper ASR - Transcription",
64
+ description="Upload an audio file, and the model will transcribe it."
65
  )
66
 
67
+ if __name__ == "__main__":
68
+ iface.launch()