AshDavid12 commited on
Commit
47058ca
·
1 Parent(s): 6af664b

trying ivrit model

Browse files
Files changed (2) hide show
  1. infer.py +80 -19
  2. requirements.txt +0 -2
infer.py CHANGED
@@ -1,33 +1,94 @@
1
  import faster_whisper
2
  import requests
3
- from pydub import AudioSegment
4
- import io
5
 
6
  # Load the faster-whisper model that supports Hebrew
7
  model = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")
8
 
9
- # URL of the mp3 audio file
10
- audio_url = "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3"
11
 
12
- # Download the mp3 audio file from the URL
13
  response = requests.get(audio_url)
14
  if response.status_code != 200:
15
  raise Exception("Failed to download audio file")
16
 
17
- # Load the mp3 audio into an in-memory buffer
18
- mp3_audio = io.BytesIO(response.content)
 
 
19
 
20
- # Convert the mp3 audio to wav using pydub (in memory)
21
- audio = AudioSegment.from_file(mp3_audio, format="mp3")
22
- wav_audio = io.BytesIO()
23
- audio.export(wav_audio, format="wav")
24
- wav_audio.seek(0) # Reset the pointer to the beginning of the buffer
25
 
26
- # Save the in-memory wav audio to a temporary file-like object
27
- with io.BytesIO(wav_audio.read()) as temp_wav_file:
28
- # Perform the transcription
29
- segments, info = model.transcribe(temp_wav_file, language="he")
30
 
31
- # Print transcription results
32
- for segment in segments:
33
- print(f"[{segment.start:.2f}s - {segment.end:.2f}s] {segment.text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import faster_whisper
2
  import requests
3
+ import tempfile
4
+ import os
5
 
6
  # Load the faster-whisper model that supports Hebrew
7
  model = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")
8
 
9
+ # URL of the audio file (replace this with the actual URL of your audio)
10
+ audio_url = "https://github.com/AshDavid12/runpod-serverless-forked/blob/main/me-hebrew.wav"
11
 
12
+ # Download the audio file from the URL
13
  response = requests.get(audio_url)
14
  if response.status_code != 200:
15
  raise Exception("Failed to download audio file")
16
 
17
+ # Create a temporary file to store the audio
18
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio_file:
19
+ tmp_audio_file.write(response.content)
20
+ tmp_audio_file_path = tmp_audio_file.name
21
 
22
+ # Perform the transcription
23
+ segments, info = model.transcribe(tmp_audio_file_path, language="he")
 
 
 
24
 
25
+ # Print transcription results
26
+ for segment in segments:
27
+ print(f"[{segment.start:.2f}s - {segment.end:.2f}s] {segment.text}")
 
28
 
29
+ # Clean up the temporary file
30
+ os.remove(tmp_audio_file_path)
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+ # import torch
51
+ # from transformers import WhisperProcessor, WhisperForConditionalGeneration
52
+ # import requests
53
+ # import soundfile as sf
54
+ # import io
55
+
56
+
57
+ # # Load the Whisper model and processor from Hugging Face Model Hub
58
+ # model_name = "openai/whisper-base"
59
+ # processor = WhisperProcessor.from_pretrained(model_name)
60
+ # model = WhisperForConditionalGeneration.from_pretrained(model_name)
61
+ #
62
+ # # Use GPU if available, otherwise use CPU
63
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
64
+ # model.to(device)
65
+ #
66
+ # # URL of the audio file
67
+ # audio_url = "https://www.signalogic.com/melp/EngSamples/Orig/male.wav"
68
+ #
69
+ # # Download the audio file
70
+ # response = requests.get(audio_url)
71
+ # audio_data = io.BytesIO(response.content)
72
+ #
73
+ # # Read the audio using soundfile
74
+ # audio_input, _ = sf.read(audio_data)
75
+ #
76
+ # # Preprocess the audio for Whisper
77
+ # inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
78
+ # attention_mask = inputs['input_features'].ne(processor.tokenizer.pad_token_id).long()
79
+ #
80
+ # # Move inputs and attention mask to the correct device
81
+ # inputs = {key: value.to(device) for key, value in inputs.items()}
82
+ # attention_mask = attention_mask.to(device)
83
+ #
84
+ # # Generate the transcription with attention mask
85
+ # with torch.no_grad():
86
+ # predicted_ids = model.generate(
87
+ # inputs["input_features"],
88
+ # attention_mask=attention_mask # Pass attention mask explicitly
89
+ # )
90
+ # # Decode the transcription
91
+ # transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
92
+ #
93
+ # # Print the transcription result
94
+ # print("Transcription:", transcription)
requirements.txt CHANGED
@@ -4,6 +4,4 @@ requests
4
  transformers
5
  soundfile
6
  faster-whisper
7
- pydub
8
- ffmpeg
9
 
 
4
  transformers
5
  soundfile
6
  faster-whisper
 
 
7