Spaces:
Sleeping
Sleeping
AshDavid12
commited on
Commit
·
8e3c59e
1
Parent(s):
9bd82d6
infer wo runpod
Browse files- infer.py +74 -48
- whisper_online.py +116 -324
infer.py
CHANGED
@@ -1,105 +1,131 @@
|
|
1 |
import base64
|
2 |
import faster_whisper
|
3 |
import tempfile
|
4 |
-
import logging
|
5 |
import torch
|
6 |
-
import sys
|
7 |
import requests
|
8 |
-
import os
|
9 |
|
10 |
-
import whisper_online
|
11 |
-
|
12 |
-
# Set up logging
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
-
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
|
15 |
-
|
16 |
-
# Load the FasterWhisper model
|
17 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
-
model_name = 'ivrit-ai/faster-whisper-v2-d3-e3'
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
model = whisper_online.FasterWhisperASR(lan=lan, modelsize=model_name)
|
24 |
-
logging.info("FasterWhisperASR model initialized successfully.")
|
25 |
-
except Exception as e:
|
26 |
-
logging.error(f"Failed to initialize FasterWhisperASR model: {e}")
|
27 |
|
28 |
# Maximum data size: 200MB
|
29 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
30 |
|
|
|
31 |
def download_file(url, max_size_bytes, output_filename, api_key=None):
|
32 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
try:
|
34 |
headers = {}
|
35 |
if api_key:
|
36 |
headers['Authorization'] = f'Bearer {api_key}'
|
|
|
37 |
response = requests.get(url, stream=True, headers=headers)
|
38 |
response.raise_for_status()
|
|
|
39 |
file_size = int(response.headers.get('Content-Length', 0))
|
|
|
40 |
if file_size > max_size_bytes:
|
41 |
-
print(f"File size exceeds the
|
42 |
return False
|
|
|
43 |
downloaded_size = 0
|
44 |
with open(output_filename, 'wb') as file:
|
45 |
for chunk in response.iter_content(chunk_size=8192):
|
46 |
downloaded_size += len(chunk)
|
47 |
if downloaded_size > max_size_bytes:
|
48 |
-
print(f"Download stopped:
|
49 |
return False
|
50 |
file.write(chunk)
|
|
|
51 |
print(f"File downloaded successfully: {output_filename}")
|
52 |
return True
|
|
|
53 |
except requests.RequestException as e:
|
54 |
print(f"Error downloading file: {e}")
|
55 |
return False
|
56 |
|
57 |
-
def transcribe_core_whisper(audio_file):
|
58 |
-
"""Transcribe the audio file using FasterWhisper."""
|
59 |
-
logging.info(f"Transcribing audio file: {audio_file}")
|
60 |
-
ret = {'segments': []}
|
61 |
-
try:
|
62 |
-
segs, dummy = model.transcribe(audio_file, language='he', word_timestamps=True)
|
63 |
-
for s in segs:
|
64 |
-
words = [{'start': w.start, 'end': w.end, 'word': w.word, 'probability': w.probability} for w in s.words]
|
65 |
-
seg = {'id': s.id, 'seek': s.seek, 'start': s.start, 'end': s.end, 'text': s.text, 'avg_logprob': s.avg_logprob,
|
66 |
-
'compression_ratio': s.compression_ratio, 'no_speech_prob': s.no_speech_prob, 'words': words}
|
67 |
-
ret['segments'].append(seg)
|
68 |
-
logging.info("Transcription completed successfully.")
|
69 |
-
except Exception as e:
|
70 |
-
logging.error(f"Error during transcription: {e}", exc_info=True)
|
71 |
-
return ret
|
72 |
|
73 |
-
def
|
74 |
-
|
75 |
-
logging.info(f"Processing job: {job}")
|
76 |
-
datatype = job.get('input', {}).get('type')
|
77 |
if not datatype:
|
78 |
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
|
|
|
79 |
if datatype not in ['blob', 'url']:
|
80 |
-
return {"error": f"
|
|
|
|
|
81 |
|
82 |
-
api_key = job.get('input', {}).get('api_key')
|
83 |
with tempfile.TemporaryDirectory() as d:
|
84 |
audio_file = f'{d}/audio.mp3'
|
|
|
85 |
if datatype == 'blob':
|
86 |
mp3_bytes = base64.b64decode(job['input']['data'])
|
87 |
-
with open(audio_file, 'wb') as
|
88 |
-
|
89 |
elif datatype == 'url':
|
90 |
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
|
91 |
if not success:
|
92 |
-
return {"error": f"
|
93 |
|
94 |
-
result =
|
95 |
return {'result': result}
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if __name__ == "__main__":
|
|
|
99 |
test_job = {
|
100 |
"input": {
|
101 |
"type": "url",
|
102 |
"url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
|
|
|
103 |
}
|
104 |
}
|
105 |
-
print(
|
|
|
1 |
import base64
|
2 |
import faster_whisper
|
3 |
import tempfile
|
|
|
4 |
import torch
|
|
|
5 |
import requests
|
|
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
8 |
|
9 |
+
# Load the model from Hugging Face
|
10 |
+
model_name = 'ivrit-ai/faster-whisper-v2-d4'
|
11 |
+
model = faster_whisper.WhisperModel(model_name, device=device)
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Maximum data size: 200MB
|
14 |
MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
|
15 |
|
16 |
+
|
17 |
def download_file(url, max_size_bytes, output_filename, api_key=None):
|
18 |
+
"""
|
19 |
+
Download a file from a given URL with size limit and optional API key.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
url (str): The URL of the file to download.
|
23 |
+
max_size_bytes (int): Maximum allowed file size in bytes.
|
24 |
+
output_filename (str): The name of the file to save the download as.
|
25 |
+
api_key (str, optional): API key to be used as a bearer token.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
bool: True if download was successful, False otherwise.
|
29 |
+
"""
|
30 |
try:
|
31 |
headers = {}
|
32 |
if api_key:
|
33 |
headers['Authorization'] = f'Bearer {api_key}'
|
34 |
+
|
35 |
response = requests.get(url, stream=True, headers=headers)
|
36 |
response.raise_for_status()
|
37 |
+
|
38 |
file_size = int(response.headers.get('Content-Length', 0))
|
39 |
+
|
40 |
if file_size > max_size_bytes:
|
41 |
+
print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
|
42 |
return False
|
43 |
+
|
44 |
downloaded_size = 0
|
45 |
with open(output_filename, 'wb') as file:
|
46 |
for chunk in response.iter_content(chunk_size=8192):
|
47 |
downloaded_size += len(chunk)
|
48 |
if downloaded_size > max_size_bytes:
|
49 |
+
print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
|
50 |
return False
|
51 |
file.write(chunk)
|
52 |
+
|
53 |
print(f"File downloaded successfully: {output_filename}")
|
54 |
return True
|
55 |
+
|
56 |
except requests.RequestException as e:
|
57 |
print(f"Error downloading file: {e}")
|
58 |
return False
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
+
def transcribe(job):
|
62 |
+
datatype = job['input'].get('type', None)
|
|
|
|
|
63 |
if not datatype:
|
64 |
return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
|
65 |
+
|
66 |
if datatype not in ['blob', 'url']:
|
67 |
+
return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
|
68 |
+
|
69 |
+
api_key = job['input'].get('api_key', None)
|
70 |
|
|
|
71 |
with tempfile.TemporaryDirectory() as d:
|
72 |
audio_file = f'{d}/audio.mp3'
|
73 |
+
|
74 |
if datatype == 'blob':
|
75 |
mp3_bytes = base64.b64decode(job['input']['data'])
|
76 |
+
with open(audio_file, 'wb') as file:
|
77 |
+
file.write(mp3_bytes)
|
78 |
elif datatype == 'url':
|
79 |
success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
|
80 |
if not success:
|
81 |
+
return {"error": f"Error downloading data from {job['input']['url']}"}
|
82 |
|
83 |
+
result = transcribe_core(audio_file)
|
84 |
return {'result': result}
|
85 |
|
86 |
+
|
87 |
+
def transcribe_core(audio_file):
|
88 |
+
print('Transcribing...')
|
89 |
+
|
90 |
+
ret = {'segments': []}
|
91 |
+
|
92 |
+
segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
|
93 |
+
for s in segs:
|
94 |
+
words = []
|
95 |
+
for w in s.words:
|
96 |
+
words.append({
|
97 |
+
'start': w.start,
|
98 |
+
'end': w.end,
|
99 |
+
'word': w.word,
|
100 |
+
'probability': w.probability
|
101 |
+
})
|
102 |
+
|
103 |
+
seg = {
|
104 |
+
'id': s.id,
|
105 |
+
'seek': s.seek,
|
106 |
+
'start': s.start,
|
107 |
+
'end': s.end,
|
108 |
+
'text': s.text,
|
109 |
+
'avg_logprob': s.avg_logprob,
|
110 |
+
'compression_ratio': s.compression_ratio,
|
111 |
+
'no_speech_prob': s.no_speech_prob,
|
112 |
+
'words': words
|
113 |
+
}
|
114 |
+
|
115 |
+
print(seg)
|
116 |
+
ret['segments'].append(seg)
|
117 |
+
|
118 |
+
return ret
|
119 |
+
|
120 |
+
|
121 |
+
# The script can be run directly or served using Hugging Face's Gradio app or API
|
122 |
if __name__ == "__main__":
|
123 |
+
# For testing purposes, you can define a sample job and call the transcribe function
|
124 |
test_job = {
|
125 |
"input": {
|
126 |
"type": "url",
|
127 |
"url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
|
128 |
+
"api_key": "your_api_key_here" # Optional, replace with actual key if needed
|
129 |
}
|
130 |
}
|
131 |
+
print(transcribe(test_job))
|
whisper_online.py
CHANGED
@@ -5,15 +5,40 @@ import librosa
|
|
5 |
from functools import lru_cache
|
6 |
import time
|
7 |
import logging
|
8 |
-
import
|
9 |
-
import
|
10 |
-
|
11 |
import io
|
12 |
import soundfile as sf
|
13 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
logger = logging.getLogger(__name__)
|
16 |
-
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)])
|
17 |
|
18 |
|
19 |
@lru_cache
|
@@ -35,6 +60,7 @@ class ASRBase:
|
|
35 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
36 |
|
37 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
|
|
38 |
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
39 |
self.logfile = logfile
|
40 |
|
@@ -56,217 +82,72 @@ class ASRBase:
|
|
56 |
raise NotImplemented("must be implemented in the child class")
|
57 |
|
58 |
|
59 |
-
|
60 |
-
|
61 |
-
# On the other hand, the installation for GPU could be easier.
|
62 |
-
# """
|
63 |
-
#
|
64 |
-
# sep = " "
|
65 |
-
#
|
66 |
-
# def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
67 |
-
# import whisper
|
68 |
-
# import whisper_timestamped
|
69 |
-
# from whisper_timestamped import transcribe_timestamped
|
70 |
-
# self.transcribe_timestamped = transcribe_timestamped
|
71 |
-
# if model_dir is not None:
|
72 |
-
# logger.debug("ignoring model_dir, not implemented")
|
73 |
-
# return whisper.load_model(modelsize, download_root=cache_dir)
|
74 |
-
#
|
75 |
-
# def transcribe(self, audio, init_prompt=""):
|
76 |
-
# result = self.transcribe_timestamped(self.model,
|
77 |
-
# audio, language=self.original_language,
|
78 |
-
# initial_prompt=init_prompt, verbose=None,
|
79 |
-
# condition_on_previous_text=True, **self.transcribe_kargs)
|
80 |
-
# return result
|
81 |
-
#
|
82 |
-
# def ts_words(self, r):
|
83 |
-
# # return: transcribe result object to [(beg,end,"word1"), ...]
|
84 |
-
# o = []
|
85 |
-
# for s in r["segments"]:
|
86 |
-
# for w in s["words"]:
|
87 |
-
# t = (w["start"], w["end"], w["text"])
|
88 |
-
# o.append(t)
|
89 |
-
# return o
|
90 |
-
#
|
91 |
-
# def segments_end_ts(self, res):
|
92 |
-
# return [s["end"] for s in res["segments"]]
|
93 |
-
#
|
94 |
-
# def use_vad(self):
|
95 |
-
# self.transcribe_kargs["vad"] = True
|
96 |
-
#
|
97 |
-
# def set_translate_task(self):
|
98 |
-
# self.transcribe_kargs["task"] = "translate"
|
99 |
-
#
|
100 |
-
|
101 |
-
class FasterWhisperASR(ASRBase):
|
102 |
-
logging.info(f"In faster whisper ASR")
|
103 |
-
"""Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.
|
104 |
-
"""
|
105 |
-
|
106 |
-
sep = ""
|
107 |
-
|
108 |
-
def load_model(self, modelsize=None, cache_dir="/tmp/.cache/huggingface", model_dir=None):
|
109 |
-
from faster_whisper import WhisperModel
|
110 |
-
# logging.getLogger("faster_whisper").setLevel(logger.level)
|
111 |
-
|
112 |
-
logging.info("Starting model loading process...")
|
113 |
-
logging.info(f"Model loading parameters - modelsize: {modelsize}, cache_dir: {cache_dir}, model_dir: {model_dir}")
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
try:
|
125 |
-
logging.info(f"Loading WhisperModel on device: ")
|
126 |
-
os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/tmp/.cache/sentence_transformers'
|
127 |
-
os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
|
128 |
-
# Ensure the cache directory exists
|
129 |
-
os.makedirs(cache_dir, exist_ok=True)
|
130 |
-
model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
131 |
-
logging.info("Model loaded successfully.")
|
132 |
-
except Exception as e:
|
133 |
-
logging.error(f"An error occurred while loading the model: {e}", exc_info=True)
|
134 |
-
raise
|
135 |
-
|
136 |
-
|
137 |
-
# this worked fast and reliably on NVIDIA L40
|
138 |
-
#model = WhisperModel(model_size_or_path, device="cuda", compute_type="float16", download_root=cache_dir)
|
139 |
-
|
140 |
-
# or run on GPU with INT8
|
141 |
-
# tested: the transcripts were different, probably worse than with FP16, and it was slightly (appx 20%) slower
|
142 |
-
# model = WhisperModel(model_size, device="cuda", compute_type="int8_float16")
|
143 |
-
|
144 |
-
# or run on CPU with INT8
|
145 |
-
# tested: works, but slow, appx 10-times than cuda FP16
|
146 |
-
# model = WhisperModel(modelsize, device="cpu", compute_type="int8") #, download_root="faster-disk-cache-dir/")
|
147 |
-
return model
|
148 |
-
|
149 |
-
def transcribe(self, audio, init_prompt=""):
|
150 |
-
logging.info("Starting transcription process...")
|
151 |
-
logging.debug(f"Transcription parameters - language: {self.original_language}, initial_prompt: '{init_prompt}'")
|
152 |
-
|
153 |
-
try:
|
154 |
-
# tested: beam_size=5 is faster and better than 1 (on one 200 second document from En ESIC, min chunk 0.01)
|
155 |
-
segments, info = self.model.transcribe(audio, language=self.original_language, initial_prompt=init_prompt,
|
156 |
-
beam_size=5, word_timestamps=True, condition_on_previous_text=True,
|
157 |
-
**self.transcribe_kargs)
|
158 |
-
logging.info("Transcription completed successfully.")
|
159 |
-
logging.debug(f"Transcription info: {info}")
|
160 |
-
except Exception as e:
|
161 |
-
logging.error(f"An error occurred during transcription: {e}", exc_info=True)
|
162 |
-
raise
|
163 |
-
return list(segments)
|
164 |
|
165 |
def ts_words(self, segments):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
o = []
|
167 |
for segment in segments:
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
o.append(t)
|
175 |
return o
|
176 |
|
177 |
def segments_end_ts(self, res):
|
178 |
-
return [s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
def use_vad(self):
|
181 |
-
self.
|
182 |
|
183 |
def set_translate_task(self):
|
184 |
-
self.
|
185 |
-
|
186 |
-
|
187 |
-
# class OpenaiApiASR(ASRBase):
|
188 |
-
# """Uses OpenAI's Whisper API for audio transcription."""
|
189 |
-
#
|
190 |
-
# def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
191 |
-
# self.logfile = logfile
|
192 |
-
#
|
193 |
-
# self.modelname = "whisper-1"
|
194 |
-
# self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
|
195 |
-
# self.response_format = "verbose_json"
|
196 |
-
# self.temperature = temperature
|
197 |
-
#
|
198 |
-
# self.load_model()
|
199 |
-
#
|
200 |
-
# self.use_vad_opt = False
|
201 |
-
#
|
202 |
-
# # reset the task in set_translate_task
|
203 |
-
# self.task = "transcribe"
|
204 |
-
#
|
205 |
-
# def load_model(self, *args, **kwargs):
|
206 |
-
# from openai import OpenAI
|
207 |
-
# self.client = OpenAI()
|
208 |
-
#
|
209 |
-
# self.transcribed_seconds = 0 # for logging how many seconds were processed by API, to know the cost
|
210 |
-
#
|
211 |
-
# def ts_words(self, segments):
|
212 |
-
# no_speech_segments = []
|
213 |
-
# if self.use_vad_opt:
|
214 |
-
# for segment in segments.segments:
|
215 |
-
# # TODO: threshold can be set from outside
|
216 |
-
# if segment["no_speech_prob"] > 0.8:
|
217 |
-
# no_speech_segments.append((segment.get("start"), segment.get("end")))
|
218 |
-
#
|
219 |
-
# o = []
|
220 |
-
# for word in segments.words:
|
221 |
-
# start = word.get("start")
|
222 |
-
# end = word.get("end")
|
223 |
-
# if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
224 |
-
# # print("Skipping word", word.get("word"), "because it's in a no-speech segment")
|
225 |
-
# continue
|
226 |
-
# o.append((start, end, word.get("word")))
|
227 |
-
# return o
|
228 |
-
#
|
229 |
-
# def segments_end_ts(self, res):
|
230 |
-
# return [s["end"] for s in res.words]
|
231 |
-
#
|
232 |
-
# def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
233 |
-
# # Write the audio data to a buffer
|
234 |
-
# buffer = io.BytesIO()
|
235 |
-
# buffer.name = "temp.wav"
|
236 |
-
# sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
237 |
-
# buffer.seek(0) # Reset buffer's position to the beginning
|
238 |
-
#
|
239 |
-
# self.transcribed_seconds += math.ceil(len(audio_data) / 16000) # it rounds up to the whole seconds
|
240 |
-
#
|
241 |
-
# params = {
|
242 |
-
# "model": self.modelname,
|
243 |
-
# "file": buffer,
|
244 |
-
# "response_format": self.response_format,
|
245 |
-
# "temperature": self.temperature,
|
246 |
-
# "timestamp_granularities": ["word", "segment"]
|
247 |
-
# }
|
248 |
-
# if self.task != "translate" and self.original_language:
|
249 |
-
# params["language"] = self.original_language
|
250 |
-
# if prompt:
|
251 |
-
# params["prompt"] = prompt
|
252 |
-
#
|
253 |
-
# if self.task == "translate":
|
254 |
-
# proc = self.client.audio.translations
|
255 |
-
# else:
|
256 |
-
# proc = self.client.audio.transcriptions
|
257 |
-
#
|
258 |
-
# # Process transcription/translation
|
259 |
-
# transcript = proc.create(**params)
|
260 |
-
# logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
261 |
-
#
|
262 |
-
# return transcript
|
263 |
-
#
|
264 |
-
# def use_vad(self):
|
265 |
-
# self.use_vad_opt = True
|
266 |
-
#
|
267 |
-
# def set_translate_task(self):
|
268 |
-
# self.task = "translate"
|
269 |
-
#
|
270 |
|
271 |
class HypothesisBuffer:
|
272 |
|
@@ -424,14 +305,14 @@ class OnlineASRProcessor:
|
|
424 |
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
425 |
self.chunk_completed_segment(res)
|
426 |
|
427 |
-
# alternative: on any word
|
428 |
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
429 |
-
# let's find commited word that is less
|
430 |
# k = len(self.commited)-1
|
431 |
# while k>0 and self.commited[k][1] > l:
|
432 |
# k -= 1
|
433 |
# t = self.commited[k][1]
|
434 |
-
logger.debug("chunking segment")
|
435 |
# self.chunk_at(t)
|
436 |
|
437 |
logger.debug(f"len of buffer now: {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f}")
|
@@ -534,134 +415,60 @@ class OnlineASRProcessor:
|
|
534 |
return (b, e, t)
|
535 |
|
536 |
|
|
|
537 |
class VACOnlineASRProcessor(OnlineASRProcessor):
|
538 |
-
'''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
|
539 |
-
|
540 |
-
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
|
541 |
-
it runs VAD and continuously detects whether there is speech or not.
|
542 |
-
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
|
543 |
-
'''
|
544 |
-
|
545 |
def __init__(self, online_chunk_size, *a, **kw):
|
546 |
self.online_chunk_size = online_chunk_size
|
547 |
|
548 |
self.online = OnlineASRProcessor(*a, **kw)
|
549 |
-
|
550 |
-
# VAC:
|
551 |
-
import torch
|
552 |
-
model, _ = torch.hub.load(
|
553 |
-
repo_or_dir='snakers4/silero-vad',
|
554 |
-
model='silero_vad'
|
555 |
-
)
|
556 |
-
#from silero_vad import VADIterator
|
557 |
-
#self.vac = VADIterator(model) # we use all the default options: 500ms silence, etc.
|
558 |
|
559 |
self.logfile = self.online.logfile
|
|
|
560 |
self.init()
|
561 |
|
562 |
def init(self):
|
563 |
self.online.init()
|
564 |
self.vac.reset_states()
|
565 |
self.current_online_chunk_buffer_size = 0
|
566 |
-
|
567 |
self.is_currently_final = False
|
568 |
|
569 |
-
self.status = None # or "voice" or "nonvoice"
|
570 |
-
self.audio_buffer = np.array([], dtype=np.float32)
|
571 |
-
self.buffer_offset = 0 # in frames
|
572 |
-
|
573 |
-
def clear_buffer(self):
|
574 |
-
self.buffer_offset += len(self.audio_buffer)
|
575 |
-
self.audio_buffer = np.array([], dtype=np.float32)
|
576 |
-
|
577 |
def insert_audio_chunk(self, audio):
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
send_audio = self.audio_buffer[frame - self.buffer_offset:]
|
586 |
-
self.online.init(offset=frame / self.SAMPLING_RATE)
|
587 |
-
self.online.insert_audio_chunk(send_audio)
|
588 |
-
self.current_online_chunk_buffer_size += len(send_audio)
|
589 |
-
self.clear_buffer()
|
590 |
-
elif 'end' in res and 'start' not in res:
|
591 |
-
self.status = 'nonvoice'
|
592 |
-
send_audio = self.audio_buffer[:frame - self.buffer_offset]
|
593 |
-
self.online.insert_audio_chunk(send_audio)
|
594 |
-
self.current_online_chunk_buffer_size += len(send_audio)
|
595 |
-
self.is_currently_final = True
|
596 |
-
self.clear_buffer()
|
597 |
-
else:
|
598 |
-
# It doesn't happen in the current code.
|
599 |
-
raise NotImplemented("both start and end of voice in one chunk!!!")
|
600 |
-
else:
|
601 |
-
if self.status == 'voice':
|
602 |
-
self.online.insert_audio_chunk(self.audio_buffer)
|
603 |
-
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
604 |
-
self.clear_buffer()
|
605 |
-
else:
|
606 |
-
# We keep 1 second because VAD may later find start of voice in it.
|
607 |
-
# But we trim it to prevent OOM.
|
608 |
-
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
609 |
-
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
610 |
|
611 |
def process_iter(self):
|
612 |
if self.is_currently_final:
|
613 |
return self.finish()
|
614 |
-
elif self.current_online_chunk_buffer_size >
|
615 |
self.current_online_chunk_buffer_size = 0
|
616 |
ret = self.online.process_iter()
|
617 |
return ret
|
618 |
else:
|
619 |
-
print("no online update, only VAD",
|
620 |
return (None, None, "")
|
621 |
|
622 |
def finish(self):
|
623 |
ret = self.online.finish()
|
|
|
624 |
self.current_online_chunk_buffer_size = 0
|
625 |
-
self.is_currently_final = False
|
626 |
return ret
|
627 |
|
|
|
628 |
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
def create_tokenizer(lan):
|
634 |
-
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
635 |
-
|
636 |
-
assert lan in WHISPER_LANG_CODES, "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
637 |
-
|
638 |
-
if lan == "uk":
|
639 |
-
import tokenize_uk
|
640 |
-
class UkrainianTokenizer:
|
641 |
-
def split(self, text):
|
642 |
-
return tokenize_uk.tokenize_sents(text)
|
643 |
-
|
644 |
-
return UkrainianTokenizer()
|
645 |
-
|
646 |
-
# supported by fast-mosestokenizer
|
647 |
-
# if lan in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split():
|
648 |
-
# #from mosestokenizer import MosesTokenizer
|
649 |
-
# #return MosesTokenizer(lan)
|
650 |
-
|
651 |
-
# the following languages are in Whisper, but not in wtpsplit:
|
652 |
-
if lan in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split():
|
653 |
-
logger.debug(f"{lan} code is not supported by wtpsplit. Going to use None lang_code option.")
|
654 |
-
lan = None
|
655 |
|
656 |
-
#from wtpsplit import WtP
|
657 |
-
# downloads the model from huggingface on the first use
|
658 |
-
#wtp = WtP("wtp-canine-s-12l-no-adapters")
|
659 |
|
660 |
-
|
661 |
-
|
662 |
-
# #return wtp.split(sent, lang_code=lan)
|
663 |
-
#
|
664 |
-
# return WtPtok()
|
665 |
|
666 |
|
667 |
def add_shared_args(parser):
|
@@ -704,24 +511,9 @@ def asr_factory(args, logfile=sys.stderr):
|
|
704 |
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
705 |
"""
|
706 |
backend = args.backend
|
707 |
-
if backend == "openai-api":
|
708 |
-
|
709 |
-
|
710 |
-
else:
|
711 |
-
if backend == "faster-whisper":
|
712 |
-
logger.debug("Using FasterWhisper.")
|
713 |
-
print("using faster-whisper from whisper-online")
|
714 |
-
asr_cls = FasterWhisperASR
|
715 |
-
#else:
|
716 |
-
#asr_cls = WhisperTimestampedASR
|
717 |
-
|
718 |
-
# Only for FasterWhisperASR and WhisperTimestampedASR
|
719 |
-
size = args.model
|
720 |
-
t = time.time()
|
721 |
-
logger.info(f"Loading Whisper {size} model for {args.lan}...")
|
722 |
-
asr = asr_cls(modelsize=size, lan=args.lan, cache_dir=args.model_cache_dir, model_dir=args.model_dir)
|
723 |
-
e = time.time()
|
724 |
-
logger.info(f"done. It took {round(e - t, 2)} seconds.")
|
725 |
|
726 |
# Apply common configurations
|
727 |
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
|
@@ -735,11 +527,11 @@ def asr_factory(args, logfile=sys.stderr):
|
|
735 |
else:
|
736 |
tgt_language = language # Whisper transcribes in this language
|
737 |
|
738 |
-
# Create the tokenizer
|
739 |
-
if args.buffer_trimming == "sentence":
|
740 |
-
|
741 |
-
else:
|
742 |
-
|
743 |
|
744 |
# Create the OnlineASRProcessor
|
745 |
if args.vac:
|
@@ -892,4 +684,4 @@ if __name__ == "__main__":
|
|
892 |
now = None
|
893 |
|
894 |
o = online.finish()
|
895 |
-
output_transcript(o, now=now)
|
|
|
5 |
from functools import lru_cache
|
6 |
import time
|
7 |
import logging
|
8 |
+
import runpod
|
9 |
+
import base64
|
|
|
10 |
import io
|
11 |
import soundfile as sf
|
12 |
import math
|
13 |
+
import os
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
import openai
|
16 |
+
#from voice_activity_controller import *
|
17 |
+
|
18 |
+
load_dotenv('.env')
|
19 |
+
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
20 |
+
RUN_POD_API_KEY = os.getenv('RUN_POD_API_KEY')
|
21 |
+
RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
|
22 |
+
openai.api_key = OPENAI_API_KEY
|
23 |
+
|
24 |
+
# Set up basic configuration for logging
|
25 |
+
logging.basicConfig(
|
26 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
27 |
+
level=logging.DEBUG # Set to DEBUG to capture all levels of log messages
|
28 |
+
)
|
29 |
+
|
30 |
+
# Use the root logger directly
|
31 |
+
log = logging.getLogger(__name__)
|
32 |
+
|
33 |
+
if not OPENAI_API_KEY:
|
34 |
+
log.error("API key not found. Please set the OPENAI_API_KEY environment variable.")
|
35 |
+
sys.exit(1)
|
36 |
+
|
37 |
+
log.debug(f"Using API Key: {OPENAI_API_KEY[:5]}...")
|
38 |
+
|
39 |
+
from faster_whisper import WhisperModel
|
40 |
|
41 |
logger = logging.getLogger(__name__)
|
|
|
42 |
|
43 |
|
44 |
@lru_cache
|
|
|
60 |
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
61 |
|
62 |
# "" for faster-whisper because it emits the spaces when neeeded)
|
63 |
+
|
64 |
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
65 |
self.logfile = logfile
|
66 |
|
|
|
82 |
raise NotImplemented("must be implemented in the child class")
|
83 |
|
84 |
|
85 |
+
class IvritOnRunPodASR(ASRBase):
|
86 |
+
"""Uses ivrit-ai API for audio transcription."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
def __init__(self, lan=None, api_key=None, endpoint_id=None, logfile=sys.stderr):
|
89 |
+
self.logfile = logfile
|
90 |
+
self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
|
91 |
+
if api_key is None or endpoint_id is None:
|
92 |
+
raise ValueError("API key and Endpoint ID must be provided for Runpod API")
|
93 |
+
runpod.api_key = api_key
|
94 |
+
self.endpoint = runpod.Endpoint(endpoint_id)
|
95 |
+
self.transcribed_seconds = 0 # For logging how many seconds were processed by API, to know the cost
|
96 |
+
self.use_vad_opt = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def ts_words(self, segments):
|
99 |
+
if not segments: # Check if segments is empty
|
100 |
+
logger.warning("No segments found in the response.")
|
101 |
+
return []
|
102 |
+
no_speech_segments = []
|
103 |
+
if self.use_vad_opt:
|
104 |
+
for segment in segments:
|
105 |
+
if segment["no_speech_prob"] > 0.8:
|
106 |
+
no_speech_segments.append((segment.get("start"), segment.get("end")))
|
107 |
o = []
|
108 |
for segment in segments:
|
109 |
+
# Checking if 'word' is part of the segment and then processing it
|
110 |
+
start = segment.get("start")
|
111 |
+
end = segment.get("end")
|
112 |
+
text = segment.get("text", "") # Assuming each segment is a dictionary with a 'word' key
|
113 |
+
if text and not any(s[0] <= start <= s[1] for s in no_speech_segments):
|
114 |
+
o.append((start, end, text))
|
|
|
115 |
return o
|
116 |
|
117 |
def segments_end_ts(self, res):
|
118 |
+
return [s["end"] for s in res]
|
119 |
+
|
120 |
+
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
121 |
+
# Write the audio data to a buffer
|
122 |
+
buffer = io.BytesIO()
|
123 |
+
buffer.name = "temp.wav"
|
124 |
+
sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
|
125 |
+
buffer.seek(0) # Reset buffer's position to the beginning
|
126 |
+
self.transcribed_seconds += math.ceil(len(audio_data) / 16000) # it rounds up to the whole seconds
|
127 |
+
# Convert the audio to base64
|
128 |
+
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
|
129 |
+
payload = {
|
130 |
+
'type': 'blob',
|
131 |
+
'data': audio_base64
|
132 |
+
}
|
133 |
+
try:
|
134 |
+
# Send the request to Runpod API
|
135 |
+
res = self.endpoint.run_sync(payload)
|
136 |
+
# res['result']
|
137 |
+
# logger.debug(f"Transcription response: {res}") # Debugging line ##THIS CAUSES TO OUTPUT THE JUNK
|
138 |
+
except Exception as e:
|
139 |
+
logger.error(f"Failed to transcribe audio with Runpod API: {e}")
|
140 |
+
return None
|
141 |
+
segments = res.get('result', {}).get('segments', [])
|
142 |
+
|
143 |
+
return segments
|
144 |
|
145 |
def use_vad(self):
|
146 |
+
self.use_vad_opt = False
|
147 |
|
148 |
def set_translate_task(self):
|
149 |
+
self.task = "translate"
|
150 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
class HypothesisBuffer:
|
153 |
|
|
|
305 |
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
306 |
self.chunk_completed_segment(res)
|
307 |
|
308 |
+
# #alternative: on any word
|
309 |
# l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
|
310 |
+
# #let's find commited word that is less
|
311 |
# k = len(self.commited)-1
|
312 |
# while k>0 and self.commited[k][1] > l:
|
313 |
# k -= 1
|
314 |
# t = self.commited[k][1]
|
315 |
+
# logger.debug("chunking segment")
|
316 |
# self.chunk_at(t)
|
317 |
|
318 |
logger.debug(f"len of buffer now: {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f}")
|
|
|
415 |
return (b, e, t)
|
416 |
|
417 |
|
418 |
+
###changed for VAC
|
419 |
class VACOnlineASRProcessor(OnlineASRProcessor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
def __init__(self, online_chunk_size, *a, **kw):
|
421 |
self.online_chunk_size = online_chunk_size
|
422 |
|
423 |
self.online = OnlineASRProcessor(*a, **kw)
|
424 |
+
#self.vac = VoiceActivityController(use_vad_result=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
self.logfile = self.online.logfile
|
427 |
+
|
428 |
self.init()
|
429 |
|
430 |
def init(self):
|
431 |
self.online.init()
|
432 |
self.vac.reset_states()
|
433 |
self.current_online_chunk_buffer_size = 0
|
|
|
434 |
self.is_currently_final = False
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
def insert_audio_chunk(self, audio):
|
437 |
+
logger.debug(f"In Vac:Initial audio chunk size: {len(audio)} samples")
|
438 |
+
r = self.vac.detect_speech_iter(audio, audio_in_int16=False)
|
439 |
+
audio, is_final = r
|
440 |
+
print(is_final)
|
441 |
+
self.is_currently_final = is_final
|
442 |
+
self.online.insert_audio_chunk(audio)
|
443 |
+
self.current_online_chunk_buffer_size += len(audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
def process_iter(self):
|
446 |
if self.is_currently_final:
|
447 |
return self.finish()
|
448 |
+
elif self.current_online_chunk_buffer_size > SAMPLING_RATE * self.online_chunk_size:
|
449 |
self.current_online_chunk_buffer_size = 0
|
450 |
ret = self.online.process_iter()
|
451 |
return ret
|
452 |
else:
|
453 |
+
print("no online update, only VAD", file=self.logfile)
|
454 |
return (None, None, "")
|
455 |
|
456 |
def finish(self):
|
457 |
ret = self.online.finish()
|
458 |
+
self.online.init(keep_offset=True)
|
459 |
self.current_online_chunk_buffer_size = 0
|
|
|
460 |
return ret
|
461 |
|
462 |
+
'''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
|
463 |
|
464 |
+
It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
|
465 |
+
it runs VAD and continuously detects whether there is speech or not.
|
466 |
+
When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
|
467 |
+
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
|
|
|
|
|
|
469 |
|
470 |
+
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
471 |
+
",")
|
|
|
|
|
|
|
472 |
|
473 |
|
474 |
def add_shared_args(parser):
|
|
|
511 |
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
512 |
"""
|
513 |
backend = args.backend
|
514 |
+
# if backend == "openai-api":
|
515 |
+
logger.debug("Using ivrit-ai.")
|
516 |
+
asr = IvritOnRunPodASR(lan=args.lan, api_key=RUN_POD_API_KEY, endpoint_id=RUNPOD_ENDPOINT_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
|
518 |
# Apply common configurations
|
519 |
if getattr(args, 'vad', False): # Checks if VAD argument is present and True
|
|
|
527 |
else:
|
528 |
tgt_language = language # Whisper transcribes in this language
|
529 |
|
530 |
+
# # Create the tokenizer
|
531 |
+
# if args.buffer_trimming == "sentence":
|
532 |
+
# tokenizer = create_tokenizer(tgt_language)
|
533 |
+
# else:
|
534 |
+
tokenizer = None
|
535 |
|
536 |
# Create the OnlineASRProcessor
|
537 |
if args.vac:
|
|
|
684 |
now = None
|
685 |
|
686 |
o = online.finish()
|
687 |
+
output_transcript(o, now=now)
|