AshDavid12 commited on
Commit
7380009
·
1 Parent(s): 8e3c59e

trying to build simple trasncribe - testing

Browse files
Files changed (4) hide show
  1. Dockerfile +21 -14
  2. infer.py +97 -120
  3. requirements.txt +5 -5
  4. whisper_online.py +0 -687
Dockerfile CHANGED
@@ -1,20 +1,27 @@
1
- # Include Python
2
- from python:3.11.1-buster
3
 
4
- # Define your working directory
5
- WORKDIR /
6
 
7
- # Install runpod
8
- RUN pip install runpod
9
- RUN pip install torch==2.3.1
10
- RUN pip install faster-whisper
11
 
12
- RUN python3 -c 'import faster_whisper; m = faster_whisper.WhisperModel("ivrit-ai/faster-whisper-v2-d4")'
 
13
 
14
- # Add your file
15
- ADD infer.py .
 
16
 
17
- ENV LD_LIBRARY_PATH="/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib"
 
 
 
 
 
 
 
18
 
19
- # Call your file when your container starts
20
- CMD [ "python", "-u", "/infer.py" ]
 
1
+ # Use an official Python runtime as a base image
2
+ FROM python:3.9-slim
3
 
4
+ # Set the working directory
5
+ WORKDIR /app
6
 
7
+ # Install system dependencies for soundfile and any other audio-related processing
8
+ RUN apt-get update && \
9
+ apt-get install -y libsndfile1 && \
10
+ rm -rf /var/lib/apt/lists/*
11
 
12
+ # Install dependencies for Hugging Face Spaces (git for model fetching)
13
+ RUN apt-get install -y git
14
 
15
+ # Copy the requirements.txt file and install the dependencies
16
+ COPY requirements.txt .
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
 
19
+ # Copy the current directory contents into the container at /app
20
+ COPY . .
21
+
22
+ # Hugging Face Spaces will expose port 7860 by default for web applications
23
+ EXPOSE 7860
24
+
25
+ # Command to run the transcription script or API server on Hugging Face
26
+ CMD ["uvicorn", "infer:app", "--host", "0.0.0.0", "--port", "7860"]
27
 
 
 
infer.py CHANGED
@@ -1,131 +1,108 @@
1
- import base64
2
- import faster_whisper
3
- import tempfile
4
  import torch
5
- import requests
6
-
7
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
-
9
- # Load the model from Hugging Face
10
- model_name = 'ivrit-ai/faster-whisper-v2-d4'
11
- model = faster_whisper.WhisperModel(model_name, device=device)
12
-
13
- # Maximum data size: 200MB
14
- MAX_PAYLOAD_SIZE = 200 * 1024 * 1024
15
-
16
-
17
- def download_file(url, max_size_bytes, output_filename, api_key=None):
18
- """
19
- Download a file from a given URL with size limit and optional API key.
20
-
21
- Args:
22
- url (str): The URL of the file to download.
23
- max_size_bytes (int): Maximum allowed file size in bytes.
24
- output_filename (str): The name of the file to save the download as.
25
- api_key (str, optional): API key to be used as a bearer token.
26
-
27
- Returns:
28
- bool: True if download was successful, False otherwise.
29
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
- headers = {}
32
- if api_key:
33
- headers['Authorization'] = f'Bearer {api_key}'
34
-
35
- response = requests.get(url, stream=True, headers=headers)
36
- response.raise_for_status()
37
-
38
- file_size = int(response.headers.get('Content-Length', 0))
39
-
40
- if file_size > max_size_bytes:
41
- print(f"File size ({file_size} bytes) exceeds the maximum allowed size ({max_size_bytes} bytes).")
42
- return False
43
-
44
- downloaded_size = 0
45
- with open(output_filename, 'wb') as file:
46
- for chunk in response.iter_content(chunk_size=8192):
47
- downloaded_size += len(chunk)
48
- if downloaded_size > max_size_bytes:
49
- print(f"Download stopped: Size limit exceeded ({max_size_bytes} bytes).")
50
- return False
51
- file.write(chunk)
52
-
53
- print(f"File downloaded successfully: {output_filename}")
54
- return True
55
-
56
- except requests.RequestException as e:
57
- print(f"Error downloading file: {e}")
58
- return False
59
-
60
-
61
- def transcribe(job):
62
- datatype = job['input'].get('type', None)
63
- if not datatype:
64
- return {"error": "datatype field not provided. Should be 'blob' or 'url'."}
65
-
66
- if datatype not in ['blob', 'url']:
67
- return {"error": f"datatype should be 'blob' or 'url', but is {datatype} instead."}
68
-
69
- api_key = job['input'].get('api_key', None)
70
-
71
- with tempfile.TemporaryDirectory() as d:
72
- audio_file = f'{d}/audio.mp3'
73
-
74
- if datatype == 'blob':
75
- mp3_bytes = base64.b64decode(job['input']['data'])
76
- with open(audio_file, 'wb') as file:
77
- file.write(mp3_bytes)
78
- elif datatype == 'url':
79
- success = download_file(job['input']['url'], MAX_PAYLOAD_SIZE, audio_file, api_key)
80
- if not success:
81
- return {"error": f"Error downloading data from {job['input']['url']}"}
82
-
83
- result = transcribe_core(audio_file)
84
- return {'result': result}
85
-
86
 
87
- def transcribe_core(audio_file):
88
- print('Transcribing...')
 
 
 
89
 
90
- ret = {'segments': []}
 
 
91
 
92
- segs, _ = model.transcribe(audio_file, language='he', word_timestamps=True)
93
- for s in segs:
94
- words = []
95
- for w in s.words:
96
- words.append({
97
- 'start': w.start,
98
- 'end': w.end,
99
- 'word': w.word,
100
- 'probability': w.probability
101
- })
 
 
 
 
 
 
102
 
103
- seg = {
104
- 'id': s.id,
105
- 'seek': s.seek,
106
- 'start': s.start,
107
- 'end': s.end,
108
- 'text': s.text,
109
- 'avg_logprob': s.avg_logprob,
110
- 'compression_ratio': s.compression_ratio,
111
- 'no_speech_prob': s.no_speech_prob,
112
- 'words': words
113
- }
114
 
115
- print(seg)
116
- ret['segments'].append(seg)
 
117
 
118
- return ret
119
 
120
 
121
- # The script can be run directly or served using Hugging Face's Gradio app or API
122
  if __name__ == "__main__":
123
- # For testing purposes, you can define a sample job and call the transcribe function
124
- test_job = {
125
- "input": {
126
- "type": "url",
127
- "url": "https://github.com/metaldaniel/HebrewASR-Comparison/raw/main/HaTankistiot_n12-mp3.mp3",
128
- "api_key": "your_api_key_here" # Optional, replace with actual key if needed
129
- }
130
- }
131
- print(transcribe(test_job))
 
 
 
 
1
  import torch
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ import soundfile as sf
4
+ from fastapi import FastAPI, File, UploadFile
5
+ import uvicorn
6
+ import os
7
+ import logging
8
+ from datetime import datetime
9
+
10
+ # Set up logging
11
+ logging.basicConfig(
12
+ filename="transcription_log.log",
13
+ format="%(asctime)s - %(levelname)s - %(message)s",
14
+ level=logging.INFO
15
+ )
16
+
17
+ # Initialize FastAPI app
18
+ app = FastAPI()
19
+
20
+ # Log initialization of the application
21
+ logging.info("FastAPI application started.")
22
+
23
+ # Load the Whisper model and processor
24
+ model_name = "openai/whisper-base"
25
+ logging.info(f"Loading Whisper model: {model_name}")
26
+
27
+ try:
28
+ processor = WhisperProcessor.from_pretrained(model_name)
29
+ model = WhisperForConditionalGeneration.from_pretrained(model_name)
30
+ logging.info(f"Model {model_name} successfully loaded.")
31
+ except Exception as e:
32
+ logging.error(f"Error loading the model: {e}")
33
+ raise e
34
+
35
+ # Move model to the appropriate device (GPU if available)
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ model.to(device)
38
+ logging.info(f"Model is using device: {device}")
39
+
40
+
41
+ @app.post("/transcribe/")
42
+ async def transcribe_audio(file: UploadFile = File(...)):
43
+ # Log file upload start
44
+ logging.info(f"Received audio file: {file.filename}")
45
+ start_time = datetime.now()
46
+
47
+ # Save the uploaded file
48
+ file_location = f"temp_{file.filename}"
49
  try:
50
+ with open(file_location, "wb+") as f:
51
+ f.write(await file.read())
52
+ logging.info(f"File saved to: {file_location}")
53
+ except Exception as e:
54
+ logging.error(f"Error saving the file: {e}")
55
+ return {"error": f"Error saving the file: {e}"}
56
+
57
+ # Load the audio file and preprocess it
58
+ try:
59
+ audio_input, _ = sf.read(file_location)
60
+ logging.info(f"Audio file {file.filename} successfully read.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ inputs = processor(audio_input, return_tensors="pt", sampling_rate=16000)
63
+ logging.info(f"Audio file preprocessed for transcription.")
64
+ except Exception as e:
65
+ logging.error(f"Error processing the audio file: {e}")
66
+ return {"error": f"Error processing the audio file: {e}"}
67
 
68
+ # Move inputs to the same device as the model
69
+ inputs = {key: value.to(device) for key, value in inputs.items()}
70
+ logging.info("Inputs moved to the appropriate device.")
71
 
72
+ # Generate the transcription
73
+ try:
74
+ with torch.no_grad():
75
+ predicted_ids = model.generate(inputs["input_features"])
76
+ logging.info("Transcription successfully generated.")
77
+ except Exception as e:
78
+ logging.error(f"Error during transcription generation: {e}")
79
+ return {"error": f"Error during transcription generation: {e}"}
80
+
81
+ # Decode the transcription
82
+ try:
83
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
84
+ logging.info("Transcription successfully decoded.")
85
+ except Exception as e:
86
+ logging.error(f"Error decoding the transcription: {e}")
87
+ return {"error": f"Error decoding the transcription: {e}"}
88
 
89
+ # Clean up the temporary file
90
+ try:
91
+ os.remove(file_location)
92
+ logging.info(f"Temporary file {file_location} deleted.")
93
+ except Exception as e:
94
+ logging.error(f"Error deleting the temporary file: {e}")
 
 
 
 
 
95
 
96
+ end_time = datetime.now()
97
+ time_taken = end_time - start_time
98
+ logging.info(f"Transcription completed in {time_taken.total_seconds()} seconds.")
99
 
100
+ return {"transcription": transcription, "processing_time_seconds": time_taken.total_seconds()}
101
 
102
 
 
103
  if __name__ == "__main__":
104
+ # Log application start
105
+ logging.info("Starting FastAPI server with Uvicorn...")
106
+
107
+ # Run the FastAPI app on the default port (7860)
108
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
- numpy
2
- librosa
3
- runpod
4
- faster-whisper
5
- torch==2.3.1
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ soundfile
whisper_online.py DELETED
@@ -1,687 +0,0 @@
1
- #!/usr/bin/env python3
2
- import sys
3
- import numpy as np
4
- import librosa
5
- from functools import lru_cache
6
- import time
7
- import logging
8
- import runpod
9
- import base64
10
- import io
11
- import soundfile as sf
12
- import math
13
- import os
14
- from dotenv import load_dotenv
15
- import openai
16
- #from voice_activity_controller import *
17
-
18
- load_dotenv('.env')
19
- OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
20
- RUN_POD_API_KEY = os.getenv('RUN_POD_API_KEY')
21
- RUNPOD_ENDPOINT_ID = os.getenv('RUNPOD_ENDPOINT_ID')
22
- openai.api_key = OPENAI_API_KEY
23
-
24
- # Set up basic configuration for logging
25
- logging.basicConfig(
26
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
27
- level=logging.DEBUG # Set to DEBUG to capture all levels of log messages
28
- )
29
-
30
- # Use the root logger directly
31
- log = logging.getLogger(__name__)
32
-
33
- if not OPENAI_API_KEY:
34
- log.error("API key not found. Please set the OPENAI_API_KEY environment variable.")
35
- sys.exit(1)
36
-
37
- log.debug(f"Using API Key: {OPENAI_API_KEY[:5]}...")
38
-
39
- from faster_whisper import WhisperModel
40
-
41
- logger = logging.getLogger(__name__)
42
-
43
-
44
- @lru_cache
45
- def load_audio(fname):
46
- a, _ = librosa.load(fname, sr=16000, dtype=np.float32)
47
- return a
48
-
49
-
50
- def load_audio_chunk(fname, beg, end):
51
- audio = load_audio(fname)
52
- beg_s = int(beg * 16000)
53
- end_s = int(end * 16000)
54
- return audio[beg_s:end_s]
55
-
56
-
57
- # Whisper backend
58
-
59
- class ASRBase:
60
- sep = " " # join transcribe words with this character (" " for whisper_timestamped,
61
-
62
- # "" for faster-whisper because it emits the spaces when neeeded)
63
-
64
- def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
65
- self.logfile = logfile
66
-
67
- self.transcribe_kargs = {}
68
- if lan == "auto":
69
- self.original_language = None
70
- else:
71
- self.original_language = lan
72
-
73
- self.model = self.load_model(modelsize, cache_dir, model_dir)
74
-
75
- def load_model(self, modelsize, cache_dir):
76
- raise NotImplemented("must be implemented in the child class")
77
-
78
- def transcribe(self, audio, init_prompt=""):
79
- raise NotImplemented("must be implemented in the child class")
80
-
81
- def use_vad(self):
82
- raise NotImplemented("must be implemented in the child class")
83
-
84
-
85
- class IvritOnRunPodASR(ASRBase):
86
- """Uses ivrit-ai API for audio transcription."""
87
-
88
- def __init__(self, lan=None, api_key=None, endpoint_id=None, logfile=sys.stderr):
89
- self.logfile = logfile
90
- self.original_language = None if lan == "auto" else lan # ISO-639-1 language code
91
- if api_key is None or endpoint_id is None:
92
- raise ValueError("API key and Endpoint ID must be provided for Runpod API")
93
- runpod.api_key = api_key
94
- self.endpoint = runpod.Endpoint(endpoint_id)
95
- self.transcribed_seconds = 0 # For logging how many seconds were processed by API, to know the cost
96
- self.use_vad_opt = False
97
-
98
- def ts_words(self, segments):
99
- if not segments: # Check if segments is empty
100
- logger.warning("No segments found in the response.")
101
- return []
102
- no_speech_segments = []
103
- if self.use_vad_opt:
104
- for segment in segments:
105
- if segment["no_speech_prob"] > 0.8:
106
- no_speech_segments.append((segment.get("start"), segment.get("end")))
107
- o = []
108
- for segment in segments:
109
- # Checking if 'word' is part of the segment and then processing it
110
- start = segment.get("start")
111
- end = segment.get("end")
112
- text = segment.get("text", "") # Assuming each segment is a dictionary with a 'word' key
113
- if text and not any(s[0] <= start <= s[1] for s in no_speech_segments):
114
- o.append((start, end, text))
115
- return o
116
-
117
- def segments_end_ts(self, res):
118
- return [s["end"] for s in res]
119
-
120
- def transcribe(self, audio_data, prompt=None, *args, **kwargs):
121
- # Write the audio data to a buffer
122
- buffer = io.BytesIO()
123
- buffer.name = "temp.wav"
124
- sf.write(buffer, audio_data, samplerate=16000, format='WAV', subtype='PCM_16')
125
- buffer.seek(0) # Reset buffer's position to the beginning
126
- self.transcribed_seconds += math.ceil(len(audio_data) / 16000) # it rounds up to the whole seconds
127
- # Convert the audio to base64
128
- audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
129
- payload = {
130
- 'type': 'blob',
131
- 'data': audio_base64
132
- }
133
- try:
134
- # Send the request to Runpod API
135
- res = self.endpoint.run_sync(payload)
136
- # res['result']
137
- # logger.debug(f"Transcription response: {res}") # Debugging line ##THIS CAUSES TO OUTPUT THE JUNK
138
- except Exception as e:
139
- logger.error(f"Failed to transcribe audio with Runpod API: {e}")
140
- return None
141
- segments = res.get('result', {}).get('segments', [])
142
-
143
- return segments
144
-
145
- def use_vad(self):
146
- self.use_vad_opt = False
147
-
148
- def set_translate_task(self):
149
- self.task = "translate"
150
-
151
-
152
- class HypothesisBuffer:
153
-
154
- def __init__(self, logfile=sys.stderr):
155
- self.commited_in_buffer = []
156
- self.buffer = []
157
- self.new = []
158
-
159
- self.last_commited_time = 0
160
- self.last_commited_word = None
161
-
162
- self.logfile = logfile
163
-
164
- def insert(self, new, offset):
165
- # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
166
- # the new tail is added to self.new
167
-
168
- new = [(a + offset, b + offset, t) for a, b, t in new]
169
- self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
170
-
171
- if len(self.new) >= 1:
172
- a, b, t = self.new[0]
173
- if abs(a - self.last_commited_time) < 1:
174
- if self.commited_in_buffer:
175
- # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
176
- cn = len(self.commited_in_buffer)
177
- nn = len(self.new)
178
- for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
179
- c = " ".join([self.commited_in_buffer[-j][2] for j in range(1, i + 1)][::-1])
180
- tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
181
- if c == tail:
182
- words = []
183
- for j in range(i):
184
- words.append(repr(self.new.pop(0)))
185
- words_msg = " ".join(words)
186
- logger.debug(f"removing last {i} words: {words_msg}")
187
- break
188
-
189
- def flush(self):
190
- # returns commited chunk = the longest common prefix of 2 last inserts.
191
-
192
- commit = []
193
- while self.new:
194
- na, nb, nt = self.new[0]
195
-
196
- if len(self.buffer) == 0:
197
- break
198
-
199
- if nt == self.buffer[0][2]:
200
- commit.append((na, nb, nt))
201
- self.last_commited_word = nt
202
- self.last_commited_time = nb
203
- self.buffer.pop(0)
204
- self.new.pop(0)
205
- else:
206
- break
207
- self.buffer = self.new
208
- self.new = []
209
- self.commited_in_buffer.extend(commit)
210
- return commit
211
-
212
- def pop_commited(self, time):
213
- while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
214
- self.commited_in_buffer.pop(0)
215
-
216
- def complete(self):
217
- return self.buffer
218
-
219
-
220
- class OnlineASRProcessor:
221
- SAMPLING_RATE = 16000
222
-
223
- def __init__(self, asr, tokenizer=None, buffer_trimming=("segment", 15), logfile=sys.stderr):
224
- """asr: WhisperASR object
225
- tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
226
- ("segment", 15)
227
- buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
228
- logfile: where to store the log.
229
- """
230
- self.asr = asr
231
- self.tokenizer = tokenizer
232
- self.logfile = logfile
233
-
234
- self.init()
235
-
236
- self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
237
-
238
- def init(self, offset=None):
239
- """run this when starting or restarting processing"""
240
- self.audio_buffer = np.array([], dtype=np.float32)
241
- self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
242
- self.buffer_time_offset = 0
243
- if offset is not None:
244
- self.buffer_time_offset = offset
245
- self.transcript_buffer.last_commited_time = self.buffer_time_offset
246
- self.commited = []
247
-
248
- def insert_audio_chunk(self, audio):
249
- self.audio_buffer = np.append(self.audio_buffer, audio)
250
-
251
- def prompt(self):
252
- """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
253
- "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
254
- """
255
- k = max(0, len(self.commited) - 1)
256
- while k > 0 and self.commited[k - 1][1] > self.buffer_time_offset:
257
- k -= 1
258
-
259
- p = self.commited[:k]
260
- p = [t for _, _, t in p]
261
- prompt = []
262
- l = 0
263
- while p and l < 200: # 200 characters prompt size
264
- x = p.pop(-1)
265
- l += len(x) + 1
266
- prompt.append(x)
267
- non_prompt = self.commited[k:]
268
- return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(t for _, _, t in non_prompt)
269
-
270
- def process_iter(self):
271
- """Runs on the current audio buffer.
272
- Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
273
- The non-emty text is confirmed (committed) partial transcript.
274
- """
275
-
276
- prompt, non_prompt = self.prompt()
277
- logger.debug(f"PROMPT: {prompt}")
278
- logger.debug(f"CONTEXT: {non_prompt}")
279
- logger.debug(
280
- f"transcribing {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f} seconds from {self.buffer_time_offset:2.2f}")
281
- res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
282
-
283
- # transform to [(beg,end,"word1"), ...]
284
- tsw = self.asr.ts_words(res)
285
-
286
- self.transcript_buffer.insert(tsw, self.buffer_time_offset)
287
- o = self.transcript_buffer.flush()
288
- self.commited.extend(o)
289
- completed = self.to_flush(o)
290
- logger.debug(f">>>>COMPLETE NOW: {completed}")
291
- the_rest = self.to_flush(self.transcript_buffer.complete())
292
- logger.debug(f"INCOMPLETE: {the_rest}")
293
-
294
- # there is a newly confirmed text
295
-
296
- if o and self.buffer_trimming_way == "sentence": # trim the completed sentences
297
- if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec: # longer than this
298
- self.chunk_completed_sentence()
299
-
300
- if self.buffer_trimming_way == "segment":
301
- s = self.buffer_trimming_sec # trim the completed segments longer than s,
302
- else:
303
- s = 30 # if the audio buffer is longer than 30s, trim it
304
-
305
- if len(self.audio_buffer) / self.SAMPLING_RATE > s:
306
- self.chunk_completed_segment(res)
307
-
308
- # #alternative: on any word
309
- # l = self.buffer_time_offset + len(self.audio_buffer)/self.SAMPLING_RATE - 10
310
- # #let's find commited word that is less
311
- # k = len(self.commited)-1
312
- # while k>0 and self.commited[k][1] > l:
313
- # k -= 1
314
- # t = self.commited[k][1]
315
- # logger.debug("chunking segment")
316
- # self.chunk_at(t)
317
-
318
- logger.debug(f"len of buffer now: {len(self.audio_buffer) / self.SAMPLING_RATE:2.2f}")
319
- return self.to_flush(o)
320
-
321
- def chunk_completed_sentence(self):
322
- if self.commited == []: return
323
- logger.debug(self.commited)
324
- sents = self.words_to_sentences(self.commited)
325
- for s in sents:
326
- logger.debug(f"\t\tSENT: {s}")
327
- if len(sents) < 2:
328
- return
329
- while len(sents) > 2:
330
- sents.pop(0)
331
- # we will continue with audio processing at this timestamp
332
- chunk_at = sents[-2][1]
333
-
334
- logger.debug(f"--- sentence chunked at {chunk_at:2.2f}")
335
- self.chunk_at(chunk_at)
336
-
337
- def chunk_completed_segment(self, res):
338
- if self.commited == []: return
339
-
340
- ends = self.asr.segments_end_ts(res)
341
-
342
- t = self.commited[-1][1]
343
-
344
- if len(ends) > 1:
345
-
346
- e = ends[-2] + self.buffer_time_offset
347
- while len(ends) > 2 and e > t:
348
- ends.pop(-1)
349
- e = ends[-2] + self.buffer_time_offset
350
- if e <= t:
351
- logger.debug(f"--- segment chunked at {e:2.2f}")
352
- self.chunk_at(e)
353
- else:
354
- logger.debug(f"--- last segment not within commited area")
355
- else:
356
- logger.debug(f"--- not enough segments to chunk")
357
-
358
- def chunk_at(self, time):
359
- """trims the hypothesis and audio buffer at "time"
360
- """
361
- self.transcript_buffer.pop_commited(time)
362
- cut_seconds = time - self.buffer_time_offset
363
- self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
364
- self.buffer_time_offset = time
365
-
366
- def words_to_sentences(self, words):
367
- """Uses self.tokenizer for sentence segmentation of words.
368
- Returns: [(beg,end,"sentence 1"),...]
369
- """
370
-
371
- cwords = [w for w in words]
372
- t = " ".join(o[2] for o in cwords)
373
- s = self.tokenizer.split(t)
374
- out = []
375
- while s:
376
- beg = None
377
- end = None
378
- sent = s.pop(0).strip()
379
- fsent = sent
380
- while cwords:
381
- b, e, w = cwords.pop(0)
382
- w = w.strip()
383
- if beg is None and sent.startswith(w):
384
- beg = b
385
- elif end is None and sent == w:
386
- end = e
387
- out.append((beg, end, fsent))
388
- break
389
- sent = sent[len(w):].strip()
390
- return out
391
-
392
- def finish(self):
393
- """Flush the incomplete text when the whole processing ends.
394
- Returns: the same format as self.process_iter()
395
- """
396
- o = self.transcript_buffer.complete()
397
- f = self.to_flush(o)
398
- logger.debug(f"last, noncommited: {f}")
399
- self.buffer_time_offset += len(self.audio_buffer) / 16000
400
- return f
401
-
402
- def to_flush(self, sents, sep=None, offset=0, ):
403
- # concatenates the timestamped words or sentences into one sequence that is flushed in one line
404
- # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
405
- # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
406
- if sep is None:
407
- sep = self.asr.sep
408
- t = sep.join(s[2] for s in sents)
409
- if len(sents) == 0:
410
- b = None
411
- e = None
412
- else:
413
- b = offset + sents[0][0]
414
- e = offset + sents[-1][1]
415
- return (b, e, t)
416
-
417
-
418
- ###changed for VAC
419
- class VACOnlineASRProcessor(OnlineASRProcessor):
420
- def __init__(self, online_chunk_size, *a, **kw):
421
- self.online_chunk_size = online_chunk_size
422
-
423
- self.online = OnlineASRProcessor(*a, **kw)
424
- #self.vac = VoiceActivityController(use_vad_result=False)
425
-
426
- self.logfile = self.online.logfile
427
-
428
- self.init()
429
-
430
- def init(self):
431
- self.online.init()
432
- self.vac.reset_states()
433
- self.current_online_chunk_buffer_size = 0
434
- self.is_currently_final = False
435
-
436
- def insert_audio_chunk(self, audio):
437
- logger.debug(f"In Vac:Initial audio chunk size: {len(audio)} samples")
438
- r = self.vac.detect_speech_iter(audio, audio_in_int16=False)
439
- audio, is_final = r
440
- print(is_final)
441
- self.is_currently_final = is_final
442
- self.online.insert_audio_chunk(audio)
443
- self.current_online_chunk_buffer_size += len(audio)
444
-
445
- def process_iter(self):
446
- if self.is_currently_final:
447
- return self.finish()
448
- elif self.current_online_chunk_buffer_size > SAMPLING_RATE * self.online_chunk_size:
449
- self.current_online_chunk_buffer_size = 0
450
- ret = self.online.process_iter()
451
- return ret
452
- else:
453
- print("no online update, only VAD", file=self.logfile)
454
- return (None, None, "")
455
-
456
- def finish(self):
457
- ret = self.online.finish()
458
- self.online.init(keep_offset=True)
459
- self.current_online_chunk_buffer_size = 0
460
- return ret
461
-
462
- '''Wraps OnlineASRProcessor with VAC (Voice Activity Controller).
463
-
464
- It works the same way as OnlineASRProcessor: it receives chunks of audio (e.g. 0.04 seconds),
465
- it runs VAD and continuously detects whether there is speech or not.
466
- When it detects end of speech (non-voice for 500ms), it makes OnlineASRProcessor to end the utterance immediately.
467
- '''
468
-
469
-
470
- WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
471
- ",")
472
-
473
-
474
- def add_shared_args(parser):
475
- """shared args for simulation (this entry point) and server
476
- parser: argparse.ArgumentParser object
477
- """
478
- parser.add_argument('--min-chunk-size', type=float, default=1.0,
479
- help='Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.')
480
- parser.add_argument('--model', type=str, default='large-v2',
481
- choices="tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large".split(
482
- ","),
483
- help="Name size of the Whisper model to use (default: large-v2). The model is automatically downloaded from the model hub if not present in model cache dir.")
484
- parser.add_argument('--model_cache_dir', type=str, default=None,
485
- help="Overriding the default model cache dir where models downloaded from the hub are saved")
486
- parser.add_argument('--model_dir', type=str, default=None,
487
- help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.")
488
- parser.add_argument('--lan', '--language', type=str, default='auto',
489
- help="Source language code, e.g. en,de,cs, or 'auto' for language detection.")
490
- parser.add_argument('--task', type=str, default='transcribe', choices=["transcribe", "translate"],
491
- help="Transcribe or translate.")
492
- parser.add_argument('--backend', type=str, default="faster-whisper",
493
- choices=["faster-whisper", "whisper_timestamped", "openai-api"],
494
- help='Load only this backend for Whisper processing.')
495
- parser.add_argument('--vac', action="store_true", default=False,
496
- help='Use VAC = voice activity controller. Recommended. Requires torch.')
497
- parser.add_argument('--vac-chunk-size', type=float, default=0.04, help='VAC sample size in seconds.')
498
- parser.add_argument('--vad', action="store_true", default=False,
499
- help='Use VAD = voice activity detection, with the default parameters.')
500
- parser.add_argument('--buffer_trimming', type=str, default="segment", choices=["sentence", "segment"],
501
- help='Buffer trimming strategy -- trim completed sentences marked with punctuation mark and detected by sentence segmenter, or the completed segments returned by Whisper. Sentence segmenter must be installed for "sentence" option.')
502
- parser.add_argument('--buffer_trimming_sec', type=float, default=15,
503
- help='Buffer trimming length threshold in seconds. If buffer length is longer, trimming sentence/segment is triggered.')
504
- parser.add_argument("-l", "--log-level", dest="log_level",
505
- choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help="Set the log level",
506
- default='DEBUG')
507
-
508
-
509
- def asr_factory(args, logfile=sys.stderr):
510
- """
511
- Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
512
- """
513
- backend = args.backend
514
- # if backend == "openai-api":
515
- logger.debug("Using ivrit-ai.")
516
- asr = IvritOnRunPodASR(lan=args.lan, api_key=RUN_POD_API_KEY, endpoint_id=RUNPOD_ENDPOINT_ID)
517
-
518
- # Apply common configurations
519
- if getattr(args, 'vad', False): # Checks if VAD argument is present and True
520
- logger.info("Setting VAD filter")
521
- asr.use_vad()
522
-
523
- language = args.lan
524
- if args.task == "translate":
525
- asr.set_translate_task()
526
- tgt_language = "en" # Whisper translates into English
527
- else:
528
- tgt_language = language # Whisper transcribes in this language
529
-
530
- # # Create the tokenizer
531
- # if args.buffer_trimming == "sentence":
532
- # tokenizer = create_tokenizer(tgt_language)
533
- # else:
534
- tokenizer = None
535
-
536
- # Create the OnlineASRProcessor
537
- if args.vac:
538
-
539
- online = VACOnlineASRProcessor(args.min_chunk_size, asr, tokenizer, logfile=logfile,
540
- buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
541
- else:
542
- online = OnlineASRProcessor(asr, tokenizer, logfile=logfile,
543
- buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec))
544
-
545
- return asr, online
546
-
547
-
548
- def set_logging(args, logger, other="_server"):
549
- logging.basicConfig( # format='%(name)s
550
- format='%(levelname)s\t%(message)s')
551
- logger.setLevel(args.log_level)
552
- logging.getLogger("whisper_online" + other).setLevel(args.log_level)
553
-
554
-
555
- # logging.getLogger("whisper_online_server").setLevel(args.log_level)
556
-
557
-
558
- if __name__ == "__main__":
559
-
560
- import argparse
561
-
562
- parser = argparse.ArgumentParser()
563
- parser.add_argument('audio_path', type=str,
564
- help="Filename of 16kHz mono channel wav, on which live streaming is simulated.")
565
- add_shared_args(parser)
566
- parser.add_argument('--start_at', type=float, default=0.0, help='Start processing audio at this time.')
567
- parser.add_argument('--offline', action="store_true", default=False, help='Offline mode.')
568
- parser.add_argument('--comp_unaware', action="store_true", default=False,
569
- help='Computationally unaware simulation.')
570
-
571
- args = parser.parse_args()
572
-
573
- # reset to store stderr to different file stream, e.g. open(os.devnull,"w")
574
- logfile = sys.stderr
575
-
576
- if args.offline and args.comp_unaware:
577
- logger.error("No or one option from --offline and --comp_unaware are available, not both. Exiting.")
578
- sys.exit(1)
579
-
580
- # if args.log_level:
581
- # logging.basicConfig(format='whisper-%(levelname)s:%(name)s: %(message)s',
582
- # level=getattr(logging, args.log_level))
583
-
584
- set_logging(args, logger)
585
-
586
- audio_path = args.audio_path
587
-
588
- SAMPLING_RATE = 16000
589
- duration = len(load_audio(audio_path)) / SAMPLING_RATE
590
- logger.info("Audio duration is: %2.2f seconds" % duration)
591
-
592
- asr, online = asr_factory(args, logfile=logfile)
593
- if args.vac:
594
- min_chunk = args.vac_chunk_size
595
- else:
596
- min_chunk = args.min_chunk_size
597
-
598
- # load the audio into the LRU cache before we start the timer
599
- a = load_audio_chunk(audio_path, 0, 1)
600
-
601
- # warm up the ASR because the very first transcribe takes much more time than the other
602
- asr.transcribe(a)
603
-
604
- beg = args.start_at
605
- start = time.time() - beg
606
-
607
-
608
- def output_transcript(o, now=None):
609
- # output format in stdout is like:
610
- # 4186.3606 0 1720 Takhle to je
611
- # - the first three words are:
612
- # - emission time from beginning of processing, in milliseconds
613
- # - beg and end timestamp of the text segment, as estimated by Whisper model. The timestamps are not accurate, but they're useful anyway
614
- # - the next words: segment transcript
615
- if now is None:
616
- now = time.time() - start
617
- if o[0] is not None:
618
- print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), file=logfile, flush=True)
619
- print("%1.4f %1.0f %1.0f %s" % (now * 1000, o[0] * 1000, o[1] * 1000, o[2]), flush=True)
620
- else:
621
- # No text, so no output
622
- pass
623
-
624
-
625
- if args.offline: ## offline mode processing (for testing/debugging)
626
- a = load_audio(audio_path)
627
- online.insert_audio_chunk(a)
628
- try:
629
- o = online.process_iter()
630
- except AssertionError as e:
631
- logger.error(f"assertion error: {repr(e)}")
632
- else:
633
- output_transcript(o)
634
- now = None
635
- elif args.comp_unaware: # computational unaware mode
636
- end = beg + min_chunk
637
- while True:
638
- a = load_audio_chunk(audio_path, beg, end)
639
- online.insert_audio_chunk(a)
640
- try:
641
- o = online.process_iter()
642
- except AssertionError as e:
643
- logger.error(f"assertion error: {repr(e)}")
644
- pass
645
- else:
646
- output_transcript(o, now=end)
647
-
648
- logger.debug(f"## last processed {end:.2f}s")
649
-
650
- if end >= duration:
651
- break
652
-
653
- beg = end
654
-
655
- if end + min_chunk > duration:
656
- end = duration
657
- else:
658
- end += min_chunk
659
- now = duration
660
-
661
- else: # online = simultaneous mode
662
- end = 0
663
- while True:
664
- now = time.time() - start
665
- if now < end + min_chunk:
666
- time.sleep(min_chunk + end - now)
667
- end = time.time() - start
668
- a = load_audio_chunk(audio_path, beg, end)
669
- beg = end
670
- online.insert_audio_chunk(a)
671
-
672
- try:
673
- o = online.process_iter()
674
- except AssertionError as e:
675
- logger.error(f"assertion error: {e}")
676
- pass
677
- else:
678
- output_transcript(o)
679
- now = time.time() - start
680
- logger.debug(f"## last processed {end:.2f} s, now is {now:.2f}, the latency is {now - end:.2f}")
681
-
682
- if end >= duration:
683
- break
684
- now = None
685
-
686
- o = online.finish()
687
- output_transcript(o, now=now)