AshDavid12 commited on
Commit
1c789c0
·
1 Parent(s): 1ab0cdf

added validation for wav and pcm

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. client.py +78 -23
  3. infer.py +51 -6
  4. poetry.lock +22 -1
  5. pyproject.toml +2 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.wav
client.py CHANGED
@@ -1,5 +1,6 @@
1
  import asyncio
2
  import json
 
3
  import wave
4
 
5
  import websockets
@@ -9,8 +10,62 @@ import ssl
9
  # Parameters for reading and sending the audio
10
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  async def send_audio(websocket):
13
  buffer_size = 1024 * 16 # Send smaller chunks (16KB) for real-time processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  # Download the WAV file locally
16
  # with requests.get(AUDIO_FILE_URL, stream=True) as response:
@@ -21,29 +76,29 @@ async def send_audio(websocket):
21
  # print("Audio file downloaded successfully.")
22
 
23
  # Open the downloaded WAV file and extract PCM data
24
- with wave.open('test_copy.wav', 'rb') as wav_file:
25
- metadata = {
26
- 'sample_rate': wav_file.getframerate(),
27
- 'channels': wav_file.getnchannels(),
28
- 'sampwidth': wav_file.getsampwidth(),
29
- }
30
-
31
- # Send metadata to the server before sending the audio
32
- await websocket.send(json.dumps(metadata))
33
- print(f"Sent metadata: {metadata}")
34
-
35
- # Send the PCM audio data in chunks
36
- while True:
37
- pcm_chunk = wav_file.readframes(buffer_size)
38
- if not pcm_chunk:
39
- break # End of file
40
-
41
- await websocket.send(pcm_chunk) # Send raw PCM data chunk
42
- #print(f"Sent PCM chunk of size {len(pcm_chunk)} bytes.")
43
- await asyncio.sleep(0.01) # Simulate real-time sending
44
-
45
- else:
46
- print(f"Failed to download audio file. Status code: {response.status_code}")
47
 
48
 
49
  async def receive_transcription(websocket):
 
1
  import asyncio
2
  import json
3
+ import logging
4
  import wave
5
 
6
  import websockets
 
10
  # Parameters for reading and sending the audio
11
  AUDIO_FILE_URL = "https://raw.githubusercontent.com/AshDavid12/runpod-serverless-forked/main/test_hebrew.wav" # Use WAV file
12
 
13
+ from pydub import AudioSegment
14
+
15
+
16
+ # Convert and resample audio before writing it to WAV
17
+ # Convert and resample audio before writing it to WAV
18
+ def convert_to_mono_16k(audio_file_path):
19
+ logging.info(f"Starting audio conversion to mono and resampling to 16kHz for file: {audio_file_path}")
20
+
21
+ try:
22
+ # Load the audio file into an AudioSegment object
23
+ audio_segment = AudioSegment.from_file(audio_file_path, format="wav")
24
+
25
+ # Convert the audio to mono and resample it to 16kHz
26
+ audio_segment = audio_segment.set_channels(1).set_frame_rate(16000)
27
+
28
+ logging.info("Audio conversion to mono and 16kHz completed successfully.")
29
+ except Exception as e:
30
+ logging.error(f"Error during audio conversion: {e}")
31
+ raise e
32
+
33
+ # Return the modified AudioSegment object
34
+ return audio_segment
35
+
36
+
37
  async def send_audio(websocket):
38
  buffer_size = 1024 * 16 # Send smaller chunks (16KB) for real-time processing
39
+ logging.info("Converting the audio to mono and 16kHz.")
40
+
41
+ try:
42
+ converted_audio = convert_to_mono_16k('test_copy.wav')
43
+ except Exception as e:
44
+ logging.error(f"Failed to convert audio: {e}")
45
+ return
46
+
47
+ # Send metadata to the server
48
+ metadata = {
49
+ 'sample_rate': 16000, # Resampled rate
50
+ 'channels': 1, # Converted to mono
51
+ 'sampwidth': 2 # Assuming 16-bit audio
52
+ }
53
+ await websocket.send(json.dumps(metadata))
54
+ logging.info(f"Sent metadata: {metadata}")
55
+
56
+ try:
57
+ raw_data = converted_audio.raw_data
58
+ logging.info(f"Starting to send raw PCM audio data. Total data size: {len(raw_data)} bytes.")
59
+
60
+ for i in range(0, len(raw_data), buffer_size):
61
+ pcm_chunk = raw_data[i:i + buffer_size]
62
+ await websocket.send(pcm_chunk) # Send raw PCM data chunk
63
+ logging.info(f"Sent PCM chunk of size {len(pcm_chunk)} bytes.")
64
+ await asyncio.sleep(0.01) # Simulate real-time sending
65
+
66
+ logging.info("Completed sending all audio data.")
67
+ except Exception as e:
68
+ logging.error(f"Error while sending audio data: {e}")
69
 
70
  # Download the WAV file locally
71
  # with requests.get(AUDIO_FILE_URL, stream=True) as response:
 
76
  # print("Audio file downloaded successfully.")
77
 
78
  # Open the downloaded WAV file and extract PCM data
79
+ # with wave.open('test_copy.wav', 'rb') as wav_file:
80
+ # metadata = {
81
+ # 'sample_rate': wav_file.getframerate(),
82
+ # 'channels': wav_file.getnchannels(),
83
+ # 'sampwidth': wav_file.getsampwidth(),
84
+ # }
85
+ #
86
+ # # Send metadata to the server before sending the audio
87
+ # await websocket.send(json.dumps(metadata))
88
+ # print(f"Sent metadata: {metadata}")
89
+
90
+ # # Send the PCM audio data in chunks
91
+ # while True:
92
+ # pcm_chunk = wav_file.readframes(buffer_size)
93
+ # if not pcm_chunk:
94
+ # break # End of file
95
+ #
96
+ # await websocket.send(pcm_chunk) # Send raw PCM data chunk
97
+ # #print(f"Sent PCM chunk of size {len(pcm_chunk)} bytes.")
98
+ # await asyncio.sleep(0.01) # Simulate real-time sending
99
+
100
+ # else:
101
+ # print(f"Failed to download audio file. Status code: {response.status_code}")
102
 
103
 
104
  async def receive_transcription(websocket):
infer.py CHANGED
@@ -131,9 +131,6 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
131
  """
132
  Transcribe the audio file and return only the segments that have not been processed yet.
133
 
134
- :param audio_file: Path to the growing audio file.
135
- :param last_transcribed_time: The last time (in seconds) that was transcribed.
136
- :return: Newly transcribed segments and the updated last transcribed time.
137
  """
138
  logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")
139
 
@@ -177,6 +174,43 @@ def transcribe_core_ws(audio_file, last_transcribed_time):
177
  import tempfile
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  @app.websocket("/wtranscribe")
181
  async def websocket_transcribe(websocket: WebSocket):
182
  logging.info("New WebSocket connection request received.")
@@ -214,6 +248,12 @@ async def websocket_transcribe(websocket: WebSocket):
214
  # Accumulate the raw PCM data into the buffer
215
  pcm_audio_buffer.extend(audio_chunk)
216
 
 
 
 
 
 
 
217
  # Estimate the duration of the chunk based on its size
218
  chunk_duration = len(audio_chunk) / (sample_rate * channels * sample_width)
219
  accumulated_audio_time += chunk_duration
@@ -233,6 +273,11 @@ async def websocket_transcribe(websocket: WebSocket):
233
  wav_file.setframerate(sample_rate)
234
  wav_file.writeframes(pcm_audio_buffer)
235
 
 
 
 
 
 
236
  logging.info(f"Temporary WAV file created at {temp_wav_file.name} for transcription.")
237
 
238
  # Log to confirm that the file exists and has the expected size
@@ -260,9 +305,9 @@ async def websocket_transcribe(websocket: WebSocket):
260
  await websocket.send_json(response)
261
 
262
  # Optionally delete the temporary WAV file after processing
263
- if os.path.exists(temp_wav_file):
264
- os.remove(temp_wav_file)
265
- logging.info(f"Temporary WAV file {temp_wav_file} removed.")
266
 
267
  except WebSocketDisconnect:
268
  logging.info("WebSocket connection closed by the client.")
 
131
  """
132
  Transcribe the audio file and return only the segments that have not been processed yet.
133
 
 
 
 
134
  """
135
  logging.info(f"Starting transcription for file: {audio_file} from {last_transcribed_time} seconds.")
136
 
 
174
  import tempfile
175
 
176
 
177
+ # Function to verify if the PCM data is valid
178
+ def validate_pcm_data(pcm_audio_buffer, sample_rate, channels, sample_width):
179
+ """Validates the PCM data buffer to ensure it conforms to the expected format."""
180
+ logging.info(f"Validating PCM data: total size = {len(pcm_audio_buffer)} bytes.")
181
+
182
+ # Calculate the expected sample size
183
+ expected_sample_size = sample_rate * channels * sample_width
184
+ actual_sample_size = len(pcm_audio_buffer)
185
+
186
+ if actual_sample_size == 0:
187
+ logging.error("Received PCM data is empty.")
188
+ return False
189
+
190
+ logging.info(f"Expected sample size per second: {expected_sample_size} bytes.")
191
+
192
+ if actual_sample_size % expected_sample_size != 0:
193
+ logging.warning(
194
+ f"PCM data size {actual_sample_size} is not a multiple of the expected sample size per second ({expected_sample_size} bytes). Data may be corrupted or incomplete.")
195
+
196
+ return True
197
+
198
+
199
+ # Function to validate if the created WAV file is valid
200
+ def validate_wav_file(wav_file_path):
201
+ """Validates if the WAV file was created correctly and can be opened."""
202
+ try:
203
+ with wave.open(wav_file_path, 'rb') as wav_file:
204
+ sample_rate = wav_file.getframerate()
205
+ channels = wav_file.getnchannels()
206
+ sample_width = wav_file.getsampwidth()
207
+ logging.info(
208
+ f"WAV file details - Sample Rate: {sample_rate}, Channels: {channels}, Sample Width: {sample_width}")
209
+ return True
210
+ except wave.Error as e:
211
+ logging.error(f"Error reading WAV file: {e}")
212
+ return False
213
+
214
  @app.websocket("/wtranscribe")
215
  async def websocket_transcribe(websocket: WebSocket):
216
  logging.info("New WebSocket connection request received.")
 
248
  # Accumulate the raw PCM data into the buffer
249
  pcm_audio_buffer.extend(audio_chunk)
250
 
251
+ # Validate the PCM data after each chunk
252
+ if not validate_pcm_data(pcm_audio_buffer, sample_rate, channels, sample_width):
253
+ logging.error("Invalid PCM data received. Aborting transcription.")
254
+ await websocket.send_json({"error": "Invalid PCM data received."})
255
+ return
256
+
257
  # Estimate the duration of the chunk based on its size
258
  chunk_duration = len(audio_chunk) / (sample_rate * channels * sample_width)
259
  accumulated_audio_time += chunk_duration
 
273
  wav_file.setframerate(sample_rate)
274
  wav_file.writeframes(pcm_audio_buffer)
275
 
276
+ if not validate_wav_file(temp_wav_file.name):
277
+ logging.error(f"Invalid WAV file created: {temp_wav_file.name}")
278
+ await websocket.send_json({"error": "Invalid WAV file created."})
279
+ return
280
+
281
  logging.info(f"Temporary WAV file created at {temp_wav_file.name} for transcription.")
282
 
283
  # Log to confirm that the file exists and has the expected size
 
305
  await websocket.send_json(response)
306
 
307
  # Optionally delete the temporary WAV file after processing
308
+ if os.path.exists(temp_wav_file.name):
309
+ os.remove(temp_wav_file.name)
310
+ logging.info(f"Temporary WAV file {temp_wav_file.name} removed.")
311
 
312
  except WebSocketDisconnect:
313
  logging.info("WebSocket connection closed by the client.")
poetry.lock CHANGED
@@ -1064,6 +1064,16 @@ tokenizers = ">=0.13,<1"
1064
  conversion = ["transformers[torch] (>=4.23)"]
1065
  dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
1066
 
 
 
 
 
 
 
 
 
 
 
1067
  [[package]]
1068
  name = "filelock"
1069
  version = "3.16.0"
@@ -2539,6 +2549,17 @@ azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0
2539
  toml = ["tomli (>=2.0.1)"]
2540
  yaml = ["pyyaml (>=6.0.1)"]
2541
 
 
 
 
 
 
 
 
 
 
 
 
2542
  [[package]]
2543
  name = "pygments"
2544
  version = "2.18.0"
@@ -3862,4 +3883,4 @@ type = ["pytest-mypy"]
3862
  [metadata]
3863
  lock-version = "2.0"
3864
  python-versions = "3.9.1"
3865
- content-hash = "8b654ee2a2cc97497e78fbe0de6258f3fb006e3f9bbe7234f800843f66adcb7b"
 
1064
  conversion = ["transformers[torch] (>=4.23)"]
1065
  dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"]
1066
 
1067
+ [[package]]
1068
+ name = "ffmpeg"
1069
+ version = "1.4"
1070
+ description = "ffmpeg python package url [https://github.com/jiashaokun/ffmpeg]"
1071
+ optional = false
1072
+ python-versions = "*"
1073
+ files = [
1074
+ {file = "ffmpeg-1.4.tar.gz", hash = "sha256:6931692c890ff21d39938433c2189747815dca0c60ddc7f9bb97f199dba0b5b9"},
1075
+ ]
1076
+
1077
  [[package]]
1078
  name = "filelock"
1079
  version = "3.16.0"
 
2549
  toml = ["tomli (>=2.0.1)"]
2550
  yaml = ["pyyaml (>=6.0.1)"]
2551
 
2552
+ [[package]]
2553
+ name = "pydub"
2554
+ version = "0.25.1"
2555
+ description = "Manipulate audio with an simple and easy high level interface"
2556
+ optional = false
2557
+ python-versions = "*"
2558
+ files = [
2559
+ {file = "pydub-0.25.1-py2.py3-none-any.whl", hash = "sha256:65617e33033874b59d87db603aa1ed450633288aefead953b30bded59cb599a6"},
2560
+ {file = "pydub-0.25.1.tar.gz", hash = "sha256:980a33ce9949cab2a569606b65674d748ecbca4f0796887fd6f46173a7b0d30f"},
2561
+ ]
2562
+
2563
  [[package]]
2564
  name = "pygments"
2565
  version = "2.18.0"
 
3883
  [metadata]
3884
  lock-version = "2.0"
3885
  python-versions = "3.9.1"
3886
+ content-hash = "62e30245d9470305f2b33ff86655c5a38e9f58c708b7ffb3cdfbf932ccfda6c7"
pyproject.toml CHANGED
@@ -24,6 +24,8 @@ openai = "^1.42.0"
24
  numpy = "^1.22.0"
25
  torch = "2.1.0"
26
  sounddevice = "^0.5.0"
 
 
27
 
28
 
29
 
 
24
  numpy = "^1.22.0"
25
  torch = "2.1.0"
26
  sounddevice = "^0.5.0"
27
+ pydub = "^0.25.1"
28
+ ffmpeg = "^1.4"
29
 
30
 
31