AshDavid12 commited on
Commit
8e3c59e
·
1 Parent(s): 9bd82d6

infer wo runpod

Browse files
Files changed (2) hide show
  1. infer.py +74 -48
  2. whisper_online.py +116 -324
infer.py CHANGED
@@ -1,105 +1,131 @@
1
  import base64
2
  import faster_whisper
3
  import tempfile
4
- import logging
5
  import torch
6
- import sys
7
  import requests
8
- import os
9
 
10
- import whisper_online
11
-
12
- # Set up logging
13
- logger = logging.getLogger(__name__)
14
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
15
-
16
- # Load the FasterWhisper model
17
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
- model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
19
 
20
- try:
21
- lan = 'he'
22
- logging.info(f"Attempting to initialize FasterWhisperASR with device: {device}")
23
- model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name)
24
- logging.info("FasterWhisperASR model initialized successfully.")
25
- except Exception as e:
26
- logging.error(f"Failed to initialize FasterWhisperASR model: {e}")
27
 
28
  # Maximum data size: 200MB
29
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
30
 
 
31
  def download_file(url, max_size_bytes, output_filename, api_key=None):
32
- """Download a file from a given URL with size limit and optional API key."""
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
  headers = {}
35
  if api_key:
36
  headers['Authorization'] = f'Bearer {api_key}'
 
37
  response = requests.get(url, stream=True, headers=headers)
38
  response.raise_for_status()
 
39
  file_size = int(response.headers.get('Content-Length', 0))
 
40
  if file_size > max_size_bytes:
41
- print(f"File size exceeds the limit: {file_size} bytes.")
42
  return False
 
43
  downloaded_size = 0
44
  with open(output_filename, 'wb') as file:
45
  for chunk in response.iter_content(chunk_size=8192):
46
  downloaded_size += len(chunk)
47
  if downloaded_size > max_size_bytes:
48
- print(f"Download stopped: size limit exceeded.")
49
  return False
50
  file.write(chunk)
 
51
  print(f"File downloaded successfully: {output_filename}")
52
  return True
 
53
  except requests.RequestException as e:
54
  print(f"Error downloading file: {e}")
55
  return False
56
 
57
- def transcribe_core_whisper(audio_file):
58
- """Transcribe the audio file using FasterWhisper."""
59
- logging.info(f"Transcribing audio file: {audio_file}")
60
- ret = {'segments': []}
61
- try:
62
- segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
63
- for s in segs:
64
- words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
65
- seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
66
- 'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
67
- ret['segments'].append(seg)
68
- logging.info("Transcription completed successfully.")
69
- except Exception as e:
70
- logging.error(f"Error during transcription: {e}", exc_info=True)
71
- return ret
72
 
73
- def transcribe_whisper(job):
74
- """Main transcription handler."""
75
- logging.info(f"Processing job: {job}")
76
- datatype = job.get('input', {}).get('type')
77
  if not datatype:
78
  return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
 
79
  if datatype not in ['blob', 'url']:
80
- return {"error": f"Invalid datatype: {datatype}."}
 
 
81
 
82
- api_key = job.get('input', {}).get('api_key')
83
  with tempfile.TemporaryDirectory() as d:
84
  audio_file = f'{d}/audio.mp3'
 
85
  if datatype == 'blob':
86
  mp3_bytes = base64.b64decode(job['input']['data'])
87
- with open(audio_file, 'wb') as f:
88
- f.write(mp3_bytes)
89
  elif datatype == 'url':
90
  success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
91
  if not success:
92
- return {"error": f"Failed to download from {job['input']['url']}"}
93
 
94
- result = transcribe_core_whisper(audio_file)
95
  return {'result': result}
96
 
97
- # Example job input to test locally
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if __name__ == "__main__":
 
99
  test_job = {
100
  "input": {
101
  "type": "url",
102
  "url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
 
103
  }
104
  }
105
- print(transcribe_whisper(test_job))
 
1
  import base64
2
  import faster_whisper
3
  import tempfile
 
4
  import torch
 
5
  import requests
 
6
 
 
 
 
 
 
 
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
8
 
9
+ # Load the model from Hugging Face
10
+ model_name = 'ivrit-ai/faster-whisper-v2-d4'
11
+ model = faster_whisper.WhisperModel(model_name, device=device)
 
 
 
 
12
 
13
  # Maximum data size: 200MB
14
  MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
15
 
16
+
17
  def download_file(url, max_size_bytes, output_filename, api_key=None):
18
+ """
19
+ Download a file from a given URL with size limit and optional API key.
20
+
21
+ Args:
22
+ url (str): The URL of the file to download.
23
+ max_size_bytes (int): Maximum allowed file size in bytes.
24
+ output_filename (str): The name of the file to save the download as.
25
+ api_key (str, optional): API key to be used as a bearer token.
26
+
27
+ Returns:
28
+ bool: True if download was successful, False otherwise.
29
+ """
30
  try:
31
  headers = {}
32
  if api_key:
33
  headers['Authorization'] = f'Bearer {api_key}'
34
+
35
  response = requests.get(url, stream=True, headers=headers)
36
  response.raise_for_status()
37
+
38
  file_size = int(response.headers.get('Content-Length', 0))
39
+
40
  if file_size > max_size_bytes:
41
+ print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
42
  return False
43
+
44
  downloaded_size = 0
45
  with open(output_filename, 'wb') as file:
46
  for chunk in response.iter_content(chunk_size=8192):
47
  downloaded_size += len(chunk)
48
  if downloaded_size > max_size_bytes:
49
+ print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
50
  return False
51
  file.write(chunk)
52
+
53
  print(f"File downloaded successfully: {output_filename}")
54
  return True
55
+
56
  except requests.RequestException as e:
57
  print(f"Error downloading file: {e}")
58
  return False
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def transcribe(job):
62
+ datatype = job['input'].get('type', None)
 
 
63
  if not datatype:
64
  return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
65
+
66
  if datatype not in ['blob', 'url']:
67
+ return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
68
+
69
+ api_key = job['input'].get('api_key', None)
70
 
 
71
  with tempfile.TemporaryDirectory() as d:
72
  audio_file = f'{d}/audio.mp3'
73
+
74
  if datatype == 'blob':
75
  mp3_bytes = base64.b64decode(job['input']['data'])
76
+ with open(audio_file, 'wb') as file:
77
+ file.write(mp3_bytes)
78
  elif datatype == 'url':
79
  success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
80
  if not success:
81
+ return {"error": f"Error downloading data from {job['input']['url']}"}
82
 
83
+ result = transcribe_core(audio_file)
84
  return {'result': result}
85
 
86
+
87
+ def transcribe_core(audio_file):
88
+ print('Transcribing...')
89
+
90
+ ret = {'segments': []}
91
+
92
+ segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
93
+ for s in segs:
94
+ words = []
95
+ for w in s.words:
96
+ words.append({
97
+ 'start': w.start,
98
+ 'end': w.end,
99
+ 'word': w.word,
100
+ 'probability': w.probability
101
+ })
102
+
103
+ seg = {
104
+ 'id': s.id,
105
+ 'seek': s.seek,
106
+ 'start': s.start,
107
+ 'end': s.end,
108
+ 'text': s.text,
109
+ 'avg_logprob': s.avg_logprob,
110
+ 'compression_ratio': s.compression_ratio,
111
+ 'no_speech_prob': s.no_speech_prob,
112
+ 'words': words
113
+ }
114
+
115
+ print(seg)
116
+ ret['segments'].append(seg)
117
+
118
+ return ret
119
+
120
+
121
+ # The script can be run directly or served using Hugging Face's Gradio app or API
122
  if __name__ == "__main__":
123
+ # For testing purposes, you can define a sample job and call the transcribe function
124
  test_job = {
125
  "input": {
126
  "type": "url",
127
  "url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
128
+ "api_key": "your_api_key_here" # Optional, replace with actual key if needed
129
  }
130
  }
131
+ print(transcribe(test_job))
whisper_online.py CHANGED
@@ -5,15 +5,40 @@ import librosa
5
  from functools import lru_cache
6
  import time
7
  import logging
8
- import os
9
- import tempfile
10
-
11
  import io
12
  import soundfile as sf
13
  import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  logger = logging.getLogger(__name__)
16
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
17
 
18
 
19
  @lru_cache
@@ -35,6 +60,7 @@ class ASRBase:
35
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
36
 
37
  # "" for faster-whisper because it emits the spaces when neeeded)
 
38
  def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
39
  self.logfile = logfile
40
 
@@ -56,217 +82,72 @@ class ASRBase:
56
  raise NotImplemented("must be implemented in the child class")
57
 
58
 
59
- # class WhisperTimestampedASR(ASRBase):
60
- # """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
61
- # On the other hand, the installation for GPU could be easier.
62
- # """
63
- #
64
- # sep = " "
65
- #
66
- # def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
67
- # import whisper
68
- # import whisper_timestamped
69
- # from whisper_timestamped import transcribe_timestamped
70
- # self.transcribe_timestamped = transcribe_timestamped
71
- # if model_dir is not None:
72
- # logger.debug("ignoring model_dir, not implemented")
73
- # return whisper.load_model(modelsize, download_root=cache_dir)
74
- #
75
- # def transcribe(self, audio, init_prompt=""):
76
- # result = self.transcribe_timestamped(self.model,
77
- # audio, language=self.original_language,
78
- # initial_prompt=init_prompt, verbose=None,
79
- # condition_on_previous_text=True, **self.transcribe_kargs)
80
- # return result
81
- #
82
- # def ts_words(self, r):
83
- # # return: transcribe result object to [(beg,end,"word1"), ...]
84
- # o = []
85
- # for s in r["segments"]:
86
- # for w in s["words"]:
87
- # t = (w["start"], w["end"], w["text"])
88
- # o.append(t)
89
- # return o
90
- #
91
- # def segments_end_ts(self, res):
92
- # return [s["end"] for s in res["segments"]]
93
- #
94
- # def use_vad(self):
95
- # self.transcribe_kargs["vad"] = True
96
- #
97
- # def set_translate_task(self):
98
- # self.transcribe_kargs["task"] = "translate"
99
- #
100
-
101
- class FasterWhisperASR(ASRBase):
102
- logging.info(f"In faster whisper ASR")
103
- """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
104
- """
105
-
106
- sep = ""
107
-
108
- def load_model(self, modelsize=None, cache_dir="/tmp/.cache/huggingface", model_dir=None):
109
- from faster_whisper import WhisperModel
110
- # logging.getLogger("faster_whisper").setLevel(logger.level)
111
-
112
- logging.info("Starting model loading process...")
113
- logging.info(f"Model loading parameters - modelsize: {modelsize}, cache_dir: {cache_dir}, model_dir: {model_dir}")
114
 
115
- if model_dir is not None:
116
- logger.info(
117
- f"Loading whisper model from model_dir {model_dir}. modelsize and cache_dir parameters are not used.")
118
- model_size_or_path = model_dir
119
- elif modelsize is not None:
120
- model_size_or_path = modelsize
121
- else:
122
- raise ValueError("modelsize or model_dir parameter must be set")
123
-
124
- try:
125
- logging.info(f"Loading WhisperModel on device: ")
126
- os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/.cache/sentence_transformers'
127
- os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
128
- # Ensure the cache directory exists
129
- os.makedirs(cache_dir, exist_ok=True)
130
- model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
131
- logging.info("Model loaded successfully.")
132
- except Exception as e:
133
- logging.error(f"An error occurred while loading the model: {e}", exc_info=True)
134
- raise
135
-
136
-
137
- # this worked fast and reliably on NVIDIA L40
138
- #model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
139
-
140
- # or run on GPU with INT8
141
- # tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
142
- # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
143
-
144
- # or run on CPU with INT8
145
- # tested: works, but slow, appx 10-times than cuda FP16
146
- # model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
147
- return model
148
-
149
- def transcribe(self, audio, init_prompt=""):
150
- logging.info("Starting transcription process...")
151
- logging.debug(f"Transcription parameters - language: {self.original_language}, initial_prompt: '{init_prompt}'")
152
-
153
- try:
154
- # tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
155
- segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt,
156
- beam_size=5, word_timestamps=True, condition_on_previous_text=True,
157
- **self.transcribe_kargs)
158
- logging.info("Transcription completed successfully.")
159
- logging.debug(f"Transcription info: {info}")
160
- except Exception as e:
161
- logging.error(f"An error occurred during transcription: {e}", exc_info=True)
162
- raise
163
- return list(segments)
164
 
165
  def ts_words(self, segments):
 
 
 
 
 
 
 
 
166
  o = []
167
  for segment in segments:
168
- for word in segment.words:
169
- if segment.no_speech_prob > 0.9:
170
- continue
171
- # not stripping the spaces -- should not be merged with them!
172
- w = word.word
173
- t = (word.start, word.end, w)
174
- o.append(t)
175
  return o
176
 
177
  def segments_end_ts(self, res):
178
- return [s.end for s in res]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  def use_vad(self):
181
- self.transcribe_kargs["vad_filter"] = True
182
 
183
  def set_translate_task(self):
184
- self.transcribe_kargs["task"] = "translate"
185
-
186
-
187
- # class OpenaiApiASR(ASRBase):
188
- # """Uses OpenAI's Whisper API for audio transcription."""
189
- #
190
- # def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
191
- # self.logfile = logfile
192
- #
193
- # self.modelname = "whisper-1"
194
- # self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
195
- # self.response_format = "verbose_json"
196
- # self.temperature = temperature
197
- #
198
- # self.load_model()
199
- #
200
- # self.use_vad_opt = False
201
- #
202
- # # reset the task in set_translate_task
203
- # self.task = "transcribe"
204
- #
205
- # def load_model(self, *args, **kwargs):
206
- # from openai import OpenAI
207
- # self.client = OpenAI()
208
- #
209
- # self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
210
- #
211
- # def ts_words(self, segments):
212
- # no_speech_segments = []
213
- # if self.use_vad_opt:
214
- # for segment in segments.segments:
215
- # # TODO: threshold can be set from outside
216
- # if segment["no_speech_prob"] > 0.8:
217
- # no_speech_segments.append((segment.get("start"), segment.get("end")))
218
- #
219
- # o = []
220
- # for word in segments.words:
221
- # start = word.get("start")
222
- # end = word.get("end")
223
- # if any(s[0] <= start <= s[1] for s in no_speech_segments):
224
- # # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
225
- # continue
226
- # o.append((start, end, word.get("word")))
227
- # return o
228
- #
229
- # def segments_end_ts(self, res):
230
- # return [s["end"] for s in res.words]
231
- #
232
- # def transcribe(self, audio_data, prompt=None, *args, **kwargs):
233
- # # Write the audio data to a buffer
234
- # buffer = io.BytesIO()
235
- # buffer.name = "temp.wav"
236
- # sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
237
- # buffer.seek(0) # Reset buffer's position to the beginning
238
- #
239
- # self.transcribed_seconds += math.ceil(len(audio_data) / 16000) # it rounds up to the whole seconds
240
- #
241
- # params = {
242
- # "model": self.modelname,
243
- # "file": buffer,
244
- # "response_format": self.response_format,
245
- # "temperature": self.temperature,
246
- # "timestamp_granularities": ["word", "segment"]
247
- # }
248
- # if self.task != "translate" and self.original_language:
249
- # params["language"] = self.original_language
250
- # if prompt:
251
- # params["prompt"] = prompt
252
- #
253
- # if self.task == "translate":
254
- # proc = self.client.audio.translations
255
- # else:
256
- # proc = self.client.audio.transcriptions
257
- #
258
- # # Process transcription/translation
259
- # transcript = proc.create(**params)
260
- # logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
261
- #
262
- # return transcript
263
- #
264
- # def use_vad(self):
265
- # self.use_vad_opt = True
266
- #
267
- # def set_translate_task(self):
268
- # self.task = "translate"
269
- #
270
 
271
  class HypothesisBuffer:
272
 
@@ -424,14 +305,14 @@ class OnlineASRProcessor:
424
  if len(self.audio_buffer) / self.SAMPLING_RATE > s:
425
  self.chunk_completed_segment(res)
426
 
427
- # alternative: on any word
428
  # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
429
- # let's find commited word that is less
430
  # k = len(self.commited)-1
431
  # while k>0 and self.commited[k][1] > l:
432
  # k -= 1
433
  # t = self.commited[k][1]
434
- logger.debug("chunking segment")
435
  # self.chunk_at(t)
436
 
437
  logger.debug(f"len of buffer now: {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f}")
@@ -534,134 +415,60 @@ class OnlineASRProcessor:
534
  return (b, e, t)
535
 
536
 
 
537
  class VACOnlineASRProcessor(OnlineASRProcessor):
538
- '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
539
-
540
- It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
541
- it runs VAD and continuously detects whether there is speech or not.
542
- When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
543
- '''
544
-
545
  def __init__(self, online_chunk_size, *a, **kw):
546
  self.online_chunk_size = online_chunk_size
547
 
548
  self.online = OnlineASRProcessor(*a, **kw)
549
-
550
- # VAC:
551
- import torch
552
- model, _ = torch.hub.load(
553
- repo_or_dir='snakers4/silero-vad',
554
- model='silero_vad'
555
- )
556
- #from silero_vad import VADIterator
557
- #self.vac = VADIterator(model) # we use all the default options: 500ms silence, etc.
558
 
559
  self.logfile = self.online.logfile
 
560
  self.init()
561
 
562
  def init(self):
563
  self.online.init()
564
  self.vac.reset_states()
565
  self.current_online_chunk_buffer_size = 0
566
-
567
  self.is_currently_final = False
568
 
569
- self.status = None # or "voice" or "nonvoice"
570
- self.audio_buffer = np.array([], dtype=np.float32)
571
- self.buffer_offset = 0 # in frames
572
-
573
- def clear_buffer(self):
574
- self.buffer_offset += len(self.audio_buffer)
575
- self.audio_buffer = np.array([], dtype=np.float32)
576
-
577
  def insert_audio_chunk(self, audio):
578
- res = self.vac(audio)
579
- self.audio_buffer = np.append(self.audio_buffer, audio)
580
-
581
- if res is not None:
582
- frame = list(res.values())[0]
583
- if 'start' in res and 'end' not in res:
584
- self.status = 'voice'
585
- send_audio = self.audio_buffer[frame - self.buffer_offset:]
586
- self.online.init(offset=frame / self.SAMPLING_RATE)
587
- self.online.insert_audio_chunk(send_audio)
588
- self.current_online_chunk_buffer_size += len(send_audio)
589
- self.clear_buffer()
590
- elif 'end' in res and 'start' not in res:
591
- self.status = 'nonvoice'
592
- send_audio = self.audio_buffer[:frame - self.buffer_offset]
593
- self.online.insert_audio_chunk(send_audio)
594
- self.current_online_chunk_buffer_size += len(send_audio)
595
- self.is_currently_final = True
596
- self.clear_buffer()
597
- else:
598
- # It doesn't happen in the current code.
599
- raise NotImplemented("both start and end of voice in one chunk!!!")
600
- else:
601
- if self.status == 'voice':
602
- self.online.insert_audio_chunk(self.audio_buffer)
603
- self.current_online_chunk_buffer_size += len(self.audio_buffer)
604
- self.clear_buffer()
605
- else:
606
- # We keep 1 second because VAD may later find start of voice in it.
607
- # But we trim it to prevent OOM.
608
- self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
609
- self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
610
 
611
  def process_iter(self):
612
  if self.is_currently_final:
613
  return self.finish()
614
- elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
615
  self.current_online_chunk_buffer_size = 0
616
  ret = self.online.process_iter()
617
  return ret
618
  else:
619
- print("no online update, only VAD", self.status, file=self.logfile)
620
  return (None, None, "")
621
 
622
  def finish(self):
623
  ret = self.online.finish()
 
624
  self.current_online_chunk_buffer_size = 0
625
- self.is_currently_final = False
626
  return ret
627
 
 
628
 
629
- WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
630
- ",")
631
-
632
-
633
- def create_tokenizer(lan):
634
- """returns an object that has split function that works like the one of MosesTokenizer"""
635
-
636
- assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
637
-
638
- if lan == "uk":
639
- import tokenize_uk
640
- class UkrainianTokenizer:
641
- def split(self, text):
642
- return tokenize_uk.tokenize_sents(text)
643
-
644
- return UkrainianTokenizer()
645
-
646
- # supported by fast-mosestokenizer
647
- # if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
648
- # #from mosestokenizer import MosesTokenizer
649
- # #return MosesTokenizer(lan)
650
-
651
- # the following languages are in Whisper, but not in wtpsplit:
652
- if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
653
- logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
654
- lan = None
655
 
656
- #from wtpsplit import WtP
657
- # downloads the model from huggingface on the first use
658
- #wtp = WtP("wtp-canine-s-12l-no-adapters")
659
 
660
- # class WtPtok:
661
- # def split(self, sent):
662
- # #return wtp.split(sent, lang_code=lan)
663
- #
664
- # return WtPtok()
665
 
666
 
667
  def add_shared_args(parser):
@@ -704,24 +511,9 @@ def asr_factory(args, logfile=sys.stderr):
704
  Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
705
  """
706
  backend = args.backend
707
- if backend == "openai-api":
708
- logger.debug("Using OpenAI API.")
709
- #asr = OpenaiApiASR(lan=args.lan)
710
- else:
711
- if backend == "faster-whisper":
712
- logger.debug("Using FasterWhisper.")
713
- print("using faster-whisper from whisper-online")
714
- asr_cls = FasterWhisperASR
715
- #else:
716
- #asr_cls = WhisperTimestampedASR
717
-
718
- # Only for FasterWhisperASR and WhisperTimestampedASR
719
- size = args.model
720
- t = time.time()
721
- logger.info(f"Loading Whisper {size} model for {args.lan}...")
722
- asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
723
- e = time.time()
724
- logger.info(f"done. It took {round(e - t, 2)} seconds.")
725
 
726
  # Apply common configurations
727
  if getattr(args, 'vad', False): # Checks if VAD argument is present and True
@@ -735,11 +527,11 @@ def asr_factory(args, logfile=sys.stderr):
735
  else:
736
  tgt_language = language # Whisper transcribes in this language
737
 
738
- # Create the tokenizer
739
- if args.buffer_trimming == "sentence":
740
- tokenizer = create_tokenizer(tgt_language)
741
- else:
742
- tokenizer = None
743
 
744
  # Create the OnlineASRProcessor
745
  if args.vac:
@@ -892,4 +684,4 @@ if __name__ == "__main__":
892
  now = None
893
 
894
  o = online.finish()
895
- output_transcript(o, now=now)
 
5
  from functools import lru_cache
6
  import time
7
  import logging
8
+ import runpod
9
+ import base64
 
10
  import io
11
  import soundfile as sf
12
  import math
13
+ import os
14
+ from dotenv import load_dotenv
15
+ import openai
16
+ #from voice_activity_controller import *
17
+
18
+ load_dotenv('.env')
19
+ OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
+ RUN_POD_API_KEY = os.getenv('RUN_POD_API_KEY')
21
+ RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
22
+ openai.api_key = OPENAI_API_KEY
23
+
24
+ # Set up basic configuration for logging
25
+ logging.basicConfig(
26
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
27
+ level=logging.DEBUG # Set to DEBUG to capture all levels of log messages
28
+ )
29
+
30
+ # Use the root logger directly
31
+ log = logging.getLogger(__name__)
32
+
33
+ if not OPENAI_API_KEY:
34
+ log.error("API key not found. Please set the OPENAI_API_KEY environment variable.")
35
+ sys.exit(1)
36
+
37
+ log.debug(f"Using API Key: {OPENAI_API_KEY[:5]}...")
38
+
39
+ from faster_whisper import WhisperModel
40
 
41
  logger = logging.getLogger(__name__)
 
42
 
43
 
44
  @lru_cache
 
60
  sep = " " # join transcribe words with this character (" " for whisper_timestamped,
61
 
62
  # "" for faster-whisper because it emits the spaces when neeeded)
63
+
64
  def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
65
  self.logfile = logfile
66
 
 
82
  raise NotImplemented("must be implemented in the child class")
83
 
84
 
85
+ class IvritOnRunPodASR(ASRBase):
86
+ """Uses ivrit-ai API for audio transcription."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def __init__(self, lan=None, api_key=None, endpoint_id=None, logfile=sys.stderr):
89
+ self.logfile = logfile
90
+ self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
91
+ if api_key is None or endpoint_id is None:
92
+ raise ValueError("API key and Endpoint ID must be provided for Runpod API")
93
+ runpod.api_key = api_key
94
+ self.endpoint = runpod.Endpoint(endpoint_id)
95
+ self.transcribed_seconds = 0 # For logging how many seconds were processed by API, to know the cost
96
+ self.use_vad_opt = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def ts_words(self, segments):
99
+ if not segments: # Check if segments is empty
100
+ logger.warning("No segments found in the response.")
101
+ return []
102
+ no_speech_segments = []
103
+ if self.use_vad_opt:
104
+ for segment in segments:
105
+ if segment["no_speech_prob"] > 0.8:
106
+ no_speech_segments.append((segment.get("start"), segment.get("end")))
107
  o = []
108
  for segment in segments:
109
+ # Checking if 'word' is part of the segment and then processing it
110
+ start = segment.get("start")
111
+ end = segment.get("end")
112
+ text = segment.get("text", "") # Assuming each segment is a dictionary with a 'word' key
113
+ if text and not any(s[0] <= start <= s[1] for s in no_speech_segments):
114
+ o.append((start, end, text))
 
115
  return o
116
 
117
  def segments_end_ts(self, res):
118
+ return [s["end"] for s in res]
119
+
120
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs):
121
+ # Write the audio data to a buffer
122
+ buffer = io.BytesIO()
123
+ buffer.name = "temp.wav"
124
+ sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
125
+ buffer.seek(0) # Reset buffer's position to the beginning
126
+ self.transcribed_seconds += math.ceil(len(audio_data) / 16000) # it rounds up to the whole seconds
127
+ # Convert the audio to base64
128
+ audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
129
+ payload = {
130
+ 'type': 'blob',
131
+ 'data': audio_base64
132
+ }
133
+ try:
134
+ # Send the request to Runpod API
135
+ res = self.endpoint.run_sync(payload)
136
+ # res['result']
137
+ # logger.debug(f"Transcription response: {res}") # Debugging line ##THIS CAUSES TO OUTPUT THE JUNK
138
+ except Exception as e:
139
+ logger.error(f"Failed to transcribe audio with Runpod API: {e}")
140
+ return None
141
+ segments = res.get('result', {}).get('segments', [])
142
+
143
+ return segments
144
 
145
  def use_vad(self):
146
+ self.use_vad_opt = False
147
 
148
  def set_translate_task(self):
149
+ self.task = "translate"
150
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  class HypothesisBuffer:
153
 
 
305
  if len(self.audio_buffer) / self.SAMPLING_RATE > s:
306
  self.chunk_completed_segment(res)
307
 
308
+ # #alternative: on any word
309
  # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
310
+ # #let's find commited word that is less
311
  # k = len(self.commited)-1
312
  # while k>0 and self.commited[k][1] > l:
313
  # k -= 1
314
  # t = self.commited[k][1]
315
+ # logger.debug("chunking segment")
316
  # self.chunk_at(t)
317
 
318
  logger.debug(f"len of buffer now: {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f}")
 
415
  return (b, e, t)
416
 
417
 
418
+ ###changed for VAC
419
  class VACOnlineASRProcessor(OnlineASRProcessor):
 
 
 
 
 
 
 
420
  def __init__(self, online_chunk_size, *a, **kw):
421
  self.online_chunk_size = online_chunk_size
422
 
423
  self.online = OnlineASRProcessor(*a, **kw)
424
+ #self.vac = VoiceActivityController(use_vad_result=False)
 
 
 
 
 
 
 
 
425
 
426
  self.logfile = self.online.logfile
427
+
428
  self.init()
429
 
430
  def init(self):
431
  self.online.init()
432
  self.vac.reset_states()
433
  self.current_online_chunk_buffer_size = 0
 
434
  self.is_currently_final = False
435
 
 
 
 
 
 
 
 
 
436
  def insert_audio_chunk(self, audio):
437
+ logger.debug(f"In Vac:Initial audio chunk size: {len(audio)} samples")
438
+ r = self.vac.detect_speech_iter(audio, audio_in_int16=False)
439
+ audio, is_final = r
440
+ print(is_final)
441
+ self.is_currently_final = is_final
442
+ self.online.insert_audio_chunk(audio)
443
+ self.current_online_chunk_buffer_size += len(audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  def process_iter(self):
446
  if self.is_currently_final:
447
  return self.finish()
448
+ elif self.current_online_chunk_buffer_size > SAMPLING_RATE * self.online_chunk_size:
449
  self.current_online_chunk_buffer_size = 0
450
  ret = self.online.process_iter()
451
  return ret
452
  else:
453
+ print("no online update, only VAD", file=self.logfile)
454
  return (None, None, "")
455
 
456
  def finish(self):
457
  ret = self.online.finish()
458
+ self.online.init(keep_offset=True)
459
  self.current_online_chunk_buffer_size = 0
 
460
  return ret
461
 
462
+ '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
463
 
464
+ It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
465
+ it runs VAD and continuously detects whether there is speech or not.
466
+ When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
467
+ '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
 
 
 
469
 
470
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
471
+ ",")
 
 
 
472
 
473
 
474
  def add_shared_args(parser):
 
511
  Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
512
  """
513
  backend = args.backend
514
+ # if backend == "openai-api":
515
+ logger.debug("Using ivrit-ai.")
516
+ asr = IvritOnRunPodASR(lan=args.lan, api_key=RUN_POD_API_KEY, endpoint_id=RUNPOD_ENDPOINT_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
  # Apply common configurations
519
  if getattr(args, 'vad', False): # Checks if VAD argument is present and True
 
527
  else:
528
  tgt_language = language # Whisper transcribes in this language
529
 
530
+ # # Create the tokenizer
531
+ # if args.buffer_trimming == "sentence":
532
+ # tokenizer = create_tokenizer(tgt_language)
533
+ # else:
534
+ tokenizer = None
535
 
536
  # Create the OnlineASRProcessor
537
  if args.vac:
 
684
  now = None
685
 
686
  o = online.finish()
687
+ output_transcript(o, now=now)