AshDavid12 commited on
Commit
b3935fd
·
1 Parent(s): 460f073

trying ivrit model

Browse files
Files changed (2) hide show
  1. infer.py +27 -38
  2. requirements.txt +2 -0
infer.py CHANGED
@@ -1,44 +1,33 @@
1
- import torch
2
- from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import requests
4
- import soundfile as sf
5
  import io
6
 
7
- # Load the Whisper model and processor from Hugging Face Model Hub
8
- model_name = "openai/whisper-base"
9
- processor = WhisperProcessor.from_pretrained(model_name)
10
- model = WhisperForConditionalGeneration.from_pretrained(model_name)
11
 
12
- # Use GPU if available, otherwise use CPU
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- model.to(device)
15
 
16
- # URL of the audio file
17
- audio_url = "https://www.signalogic.com/melp/EngSamples/Orig/male.wav"
18
-
19
- # Download the audio file
20
  response = requests.get(audio_url)
21
- audio_data = io.BytesIO(response.content)
22
-
23
- # Read the audio using soundfile
24
- audio_input, _ = sf.read(audio_data)
25
-
26
- # Preprocess the audio for Whisper
27
- inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
28
- attention_mask = inputs['input_features'].ne(processor.tokenizer.pad_token_id).long()
29
-
30
- # Move inputs and attention mask to the correct device
31
- inputs = {key: value.to(device) for key, value in inputs.items()}
32
- attention_mask = attention_mask.to(device)
33
-
34
- # Generate the transcription with attention mask
35
- with torch.no_grad():
36
- predicted_ids = model.generate(
37
- inputs["input_features"],
38
- attention_mask=attention_mask # Pass attention mask explicitly
39
- )
40
- # Decode the transcription
41
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
42
-
43
- # Print the transcription result
44
- print("Transcription:", transcription)
 
1
+ import faster_whisper
 
2
  import requests
3
+ from pydub import AudioSegment
4
  import io
5
 
6
+ # Load the faster-whisper model that supports Hebrew
7
+ model = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")
 
 
8
 
9
+ # URL of the mp3 audio file
10
+ audio_url = "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3"
 
11
 
12
+ # Download the mp3 audio file from the URL
 
 
 
13
  response = requests.get(audio_url)
14
+ if response.status_code != 200:
15
+ raise Exception("Failed to download audio file")
16
+
17
+ # Load the mp3 audio into an in-memory buffer
18
+ mp3_audio = io.BytesIO(response.content)
19
+
20
+ # Convert the mp3 audio to wav using pydub (in memory)
21
+ audio = AudioSegment.from_file(mp3_audio, format="mp3")
22
+ wav_audio = io.BytesIO()
23
+ audio.export(wav_audio, format="wav")
24
+ wav_audio.seek(0) # Reset the pointer to the beginning of the buffer
25
+
26
+ # Save the in-memory wav audio to a temporary file-like object
27
+ with io.BytesIO(wav_audio.read()) as temp_wav_file:
28
+ # Perform the transcription
29
+ segments, info = model.transcribe(temp_wav_file, language="he")
30
+
31
+ # Print transcription results
32
+ for segment in segments:
33
+ print(f"[{segment.start:.2f}s - {segment.end:.2f}s] {segment.text}")
 
 
 
 
requirements.txt CHANGED
@@ -3,4 +3,6 @@ whisper
3
  requests
4
  transformers
5
  soundfile
 
 
6
 
 
3
  requests
4
  transformers
5
  soundfile
6
+ faster-whisper
7
+ pydub
8