zohann commited on
Commit
04ae5ea
·
verified ·
1 Parent(s): 185d662

Upload Audio_Effects_SDK/samples/utils/wave_reader/waveReadWrite.cpp with huggingface_hub

Browse files
Audio_Effects_SDK/samples/utils/wave_reader/waveReadWrite.cpp ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2019-2022, NVIDIA Corporation. All rights reserved.
3
+ *
4
+ * NVIDIA Corporation and its licensors retain all intellectual property
5
+ * and proprietary rights in and to this software, related documentation
6
+ * and any modifications thereto. Any use, reproduction, disclosure or
7
+ * distribution of this software and related documentation without an express
8
+ * license agreement from NVIDIA Corporation is strictly prohibited.
9
+ */
10
+
11
+ #define _CRT_SECURE_NO_DEPRECATE
12
+
13
+ #include <sys/stat.h>
14
+
15
+ #include <cstdint>
16
+ #include <fstream>
17
+ #include <map>
18
+ #include <memory>
19
+ #include <sstream>
20
+
21
+ #include "wave.hpp"
22
+
23
+ #ifndef ENABLE_PERF_DUMP
24
+ #include <iostream>
25
+ #endif
26
+
27
+ #ifdef _WIN32
28
+ #include <io.h>
29
+ #include <Shlwapi.h>
30
+ #pragma comment(lib, "Shlwapi.lib")
31
+ #else
32
+ #include <unistd.h>
33
+ #endif
34
+
35
+ #include "waveReadWrite.hpp"
36
+ // #include <misc.hpp>
37
+
38
+ const float * CWaveFileRead::GetFloatPCMData() {
39
+ int8_t* audioDataPtr = reinterpret_cast<int8_t*>(m_WaveData.get());
40
+
41
+ if (m_floatWaveData.size())
42
+ return m_floatWaveData.data();
43
+
44
+ m_floatWaveData.resize(m_nNumSamples);
45
+ float* outputWaveData = m_floatWaveData.data();
46
+
47
+ if (m_WaveFormatEx.wFormatTag == WAVE_FORMAT_IEEE_FLOAT) {
48
+ memcpy(outputWaveData, audioDataPtr, m_nNumSamples * sizeof(float));
49
+ return outputWaveData;
50
+ }
51
+
52
+ for (uint32_t i = 0; i < m_nNumSamples; i++) {
53
+ switch (m_WaveFormatEx.wBitsPerSample) {
54
+ case 8: {
55
+ uint8_t audioSample = *(reinterpret_cast<uint8_t*>(audioDataPtr));
56
+ outputWaveData[i] = (audioSample - 128) / 128.0f;
57
+ }
58
+ break;
59
+ case 16: {
60
+ int16_t audioSample = *(reinterpret_cast<int16_t*>(audioDataPtr));
61
+ outputWaveData[i] = audioSample / 32768.0f;
62
+ }
63
+ break;
64
+ case 24: {
65
+ int32_t audioSample = *(reinterpret_cast<int32_t*>(audioDataPtr));
66
+ uint8_t data0 = audioSample & 0x000000ff;
67
+ uint8_t data1 = static_cast<uint8_t>((audioSample & 0x0000ff00) >> 8);
68
+ uint8_t data2 = static_cast<uint8_t>((audioSample & 0x00ff0000) >> 16);
69
+ int32_t Value = ((data2 << 24) | (data1 << 16) | (data0 << 8)) >> 8;
70
+ outputWaveData[i] = Value / 8388608.0f;
71
+ }
72
+ break;
73
+ case 32: {
74
+ int32_t audioSample = *(reinterpret_cast<int32_t*>(audioDataPtr));
75
+ outputWaveData[i] = audioSample / 2147483648.0f;
76
+ }
77
+ break;
78
+ }
79
+ audioDataPtr += m_WaveFormatEx.nBlockAlign;
80
+ }
81
+
82
+ return outputWaveData;
83
+ }
84
+
85
+ const float * CWaveFileRead::GetFloatPCMDataAligned(int alignSamples) {
86
+ if (!GetFloatPCMData())
87
+ return nullptr;
88
+
89
+ int totalAlignedSamples;
90
+ if (!(m_nNumSamples % alignSamples))
91
+ totalAlignedSamples = m_nNumSamples;
92
+ else
93
+ totalAlignedSamples = m_nNumSamples + (alignSamples - (m_nNumSamples % alignSamples));
94
+
95
+ m_floatWaveDataAligned.reset(new float[totalAlignedSamples]());
96
+
97
+ for (uint32_t i = 0; i < m_nNumSamples; i++)
98
+ m_floatWaveDataAligned[i] = m_floatWaveData[i];
99
+
100
+ m_NumAlignedSamples = totalAlignedSamples;
101
+ return m_floatWaveDataAligned.get();
102
+ }
103
+
104
+ int CWaveFileRead::GetBitsPerSample() {
105
+ if (m_WaveFormatEx.wBitsPerSample == 0)
106
+ assert(0);
107
+
108
+ return m_WaveFormatEx.wBitsPerSample;
109
+ }
110
+
111
+ const RiffChunk* CWaveFileRead::FindChunk(const uint8_t* data, size_t sizeBytes, uint32_t fourcc) {
112
+ if (!data)
113
+ return nullptr;
114
+
115
+ const uint8_t* ptr = data;
116
+ const uint8_t* end = data + sizeBytes;
117
+
118
+ while (end > (ptr + sizeof(RiffChunk))) {
119
+ const RiffChunk* header = reinterpret_cast<const RiffChunk*>(ptr);
120
+ if (header->chunkId == fourcc)
121
+ return header;
122
+
123
+ ptr += (header->chunkSize + sizeof(RiffChunk));
124
+ }
125
+
126
+ return nullptr;
127
+ }
128
+
129
+ CWaveFileRead::CWaveFileRead(std::string wavFile)
130
+ : m_wavFile(wavFile)
131
+ , m_nNumSamples(0)
132
+ , validFile(false)
133
+ , m_WaveDataSize(0)
134
+ , m_NumAlignedSamples(0) {
135
+ memset(&m_WaveFormatEx, 0, sizeof(m_WaveFormatEx));
136
+ #ifdef __linux__
137
+ if (access(m_wavFile.c_str(), R_OK) == 0)
138
+ #else
139
+ if (PathFileExistsA(m_wavFile.c_str()))
140
+ #endif
141
+ {
142
+ if (readPCM(m_wavFile.c_str()) == 0)
143
+ validFile = true;
144
+ }
145
+ }
146
+
147
+ inline bool loadFile(std::string const& infilename, std::string* outData) {
148
+ std::string result;
149
+ std::string filename = infilename;
150
+
151
+ errno = 0;
152
+ std::ifstream stream(filename.c_str(), std::ios::binary | std::ios::in);
153
+ if (!stream.is_open()) {
154
+ return false;
155
+ }
156
+
157
+ stream.seekg(0, std::ios::end);
158
+ result.reserve(stream.tellg());
159
+ stream.seekg(0, std::ios::beg);
160
+
161
+ result.assign(
162
+ (std::istreambuf_iterator<char>(stream)),
163
+ std::istreambuf_iterator<char>());
164
+
165
+ *outData = result;
166
+
167
+ return true;
168
+ }
169
+
170
+ inline std::string loadFile(const std::string& infilename) {
171
+ std::string result;
172
+ loadFile(infilename, &result);
173
+
174
+ return result;
175
+ }
176
+
177
+ int CWaveFileRead::readPCM(const char* szFileName) {
178
+ std::string fileData;
179
+ if (loadFile(std::string(szFileName), &fileData) != true) {
180
+ return -1;
181
+ }
182
+
183
+ const uint8_t* waveData = reinterpret_cast<const uint8_t*>(fileData.data());
184
+ size_t waveDataSize = fileData.length();
185
+ const uint8_t* waveEnd = waveData + waveDataSize;
186
+
187
+
188
+ // Locate RIFF 'WAVE'
189
+ const RiffChunk* riffChunk = FindChunk(waveData, waveDataSize, MAKEFOURCC('R', 'I', 'F', 'F'));
190
+ if (!riffChunk || riffChunk->chunkSize < 4) {
191
+ return -1;
192
+ }
193
+
194
+ const RiffHeader* riffHeader = reinterpret_cast<const RiffHeader*>(riffChunk);
195
+ if (riffHeader->fileTag != MAKEFOURCC('W', 'A', 'V', 'E')) {
196
+ return -1;
197
+ }
198
+
199
+ // Locate 'fmt '
200
+ const uint8_t* ptr = reinterpret_cast<const uint8_t*>(riffHeader) + sizeof(RiffHeader);
201
+ if ((ptr + sizeof(RiffChunk)) > waveEnd) {
202
+ return -1;
203
+ }
204
+
205
+ const RiffChunk* fmtChunk = FindChunk(ptr, riffHeader->chunkSize, MAKEFOURCC('f', 'm', 't', ' '));
206
+ if (!fmtChunk || fmtChunk->chunkSize < sizeof(waveFormat_basic)) {
207
+ return -1;
208
+ }
209
+
210
+ ptr = reinterpret_cast<const uint8_t*>(fmtChunk) + sizeof(RiffChunk);
211
+ if (ptr + fmtChunk->chunkSize > waveEnd) {
212
+ return -1;
213
+ }
214
+
215
+ const waveFormat_basic_nopcm* wf = reinterpret_cast<const waveFormat_basic_nopcm*>(ptr);
216
+
217
+ if (!(wf->formatTag == WAVE_FORMAT_PCM || wf->formatTag == WAVE_FORMAT_IEEE_FLOAT)) {
218
+ if (wf->formatTag == WAVE_FORMAT_EXTENSIBLE) {
219
+ printf("WAVE_FORMAT_EXTENSIBLE is not supported. Please convert\n");
220
+ }
221
+
222
+ return -1;
223
+ }
224
+
225
+ ptr = reinterpret_cast<const uint8_t*>(riffHeader) + sizeof(RiffHeader);
226
+ if ((ptr + sizeof(RiffChunk)) > waveEnd) {
227
+ return -1;
228
+ }
229
+
230
+ const RiffChunk* dataChunk = FindChunk(ptr, riffChunk->chunkSize, MAKEFOURCC('d', 'a', 't', 'a'));
231
+ if (!dataChunk || !dataChunk->chunkSize) {
232
+ return -1;
233
+ }
234
+
235
+
236
+ ptr = reinterpret_cast<const uint8_t*>(dataChunk) + sizeof(RiffChunk);
237
+ if (ptr + dataChunk->chunkSize > waveEnd) {
238
+ return -1;
239
+ }
240
+
241
+ m_WaveData = std::unique_ptr<uint8_t[]>(new uint8_t[dataChunk->chunkSize]);
242
+ m_WaveDataSize = dataChunk->chunkSize;
243
+ memcpy(m_WaveData.get(), ptr, dataChunk->chunkSize);
244
+ if (wf->formatTag == WAVE_FORMAT_PCM) {
245
+ memcpy(&m_WaveFormatEx, reinterpret_cast<const waveFormat_basic*>(wf), sizeof(waveFormat_basic));
246
+ m_WaveFormatEx.cbSize = 0;
247
+ } else {
248
+ memcpy(&m_WaveFormatEx, reinterpret_cast<const waveFormat_ext*>(wf), sizeof(waveFormat_ext));
249
+ }
250
+
251
+ m_nNumSamples = m_WaveDataSize / (m_WaveFormatEx.nBlockAlign / m_WaveFormatEx.nChannels);
252
+
253
+ return 0;
254
+ }
255
+
256
+ CWaveFileWrite::CWaveFileWrite(std::string wavFile, uint32_t samplesPerSec, uint32_t numChannels,
257
+ uint16_t bitsPerSample, bool isFloat)
258
+ :m_wavFile(wavFile) {
259
+ wfx.wFormatTag = isFloat ? WAVE_FORMAT_IEEE_FLOAT : WAVE_FORMAT_PCM;
260
+ wfx.nChannels = static_cast<uint16_t>(numChannels);
261
+ wfx.nSamplesPerSec = samplesPerSec;
262
+ wfx.nBlockAlign = static_cast<uint16_t>((numChannels * bitsPerSample) / 8);
263
+ wfx.nAvgBytesPerSec = samplesPerSec * wfx.nBlockAlign;
264
+ wfx.wBitsPerSample = bitsPerSample;
265
+ wfx.cbSize = 0;
266
+
267
+ m_validState = true;
268
+ }
269
+
270
+ CWaveFileWrite::~CWaveFileWrite() {
271
+ if (m_commitDone == false)
272
+ commitFile();
273
+
274
+ if (m_fp) {
275
+ fclose(m_fp);
276
+ m_fp = nullptr;
277
+ }
278
+ }
279
+
280
+ bool CWaveFileWrite::initFile() {
281
+ if (!m_fp) {
282
+ errno = 0;
283
+ m_fp = fopen(m_wavFile.c_str(), "wb");
284
+ if (!m_fp)
285
+ return false;
286
+
287
+ int64_t offset = sizeof(RiffHeader) + sizeof(RiffChunk) +
288
+ sizeof(waveFormat_basic) + sizeof(RiffChunk);
289
+ if (fseek(m_fp, static_cast<long>(offset), SEEK_SET) != 0) {
290
+ fclose(m_fp);
291
+ m_fp = nullptr;
292
+ return false;
293
+ }
294
+ }
295
+
296
+ return true;
297
+ }
298
+
299
+ bool CWaveFileWrite::writeChunk(const void* data, uint32_t len) {
300
+ if (!m_validState) {
301
+ return false;
302
+ }
303
+
304
+ if (!initFile()) {
305
+ return false;
306
+ }
307
+
308
+ size_t written = fwrite(data, len, 1, m_fp);
309
+ if (written != 1)
310
+ return false;
311
+
312
+ m_cumulativeCount += len;
313
+ return true;
314
+ }
315
+
316
+ bool CWaveFileWrite::commitFile() {
317
+ if (!m_validState)
318
+ return false;
319
+
320
+ if (!m_fp)
321
+ return false;
322
+
323
+ // pull fp to start of file to write headers.
324
+ fseek(m_fp, 0, SEEK_SET);
325
+
326
+ // write the riff chunk header
327
+ uint32_t fmtChunkSize = sizeof(waveFormat_basic);
328
+ RiffHeader riffHeader;
329
+ riffHeader.chunkId = MAKEFOURCC('R', 'I', 'F', 'F');
330
+ riffHeader.chunkSize = 4 + sizeof(RiffChunk) + sizeof(RiffChunk) + fmtChunkSize + m_cumulativeCount;
331
+ riffHeader.fileTag = MAKEFOURCC('W', 'A', 'V', 'E');
332
+ if (fwrite(&riffHeader, sizeof(riffHeader), 1, m_fp) != 1)
333
+ return false;
334
+
335
+ // fmt riff chunk
336
+ RiffChunk fmtChunk;
337
+ fmtChunk.chunkId = MAKEFOURCC('f', 'm', 't', ' ');
338
+ fmtChunk.chunkSize = sizeof(waveFormat_basic);
339
+ if (fwrite(&fmtChunk, sizeof(RiffChunk), 1, m_fp) != 1)
340
+ return false;
341
+
342
+ // fixme: try using WAVEFORMATEX for size
343
+ if (fwrite(&wfx, sizeof(waveFormat_basic), 1, m_fp) != 1)
344
+ return false;
345
+
346
+ // data riff chunk
347
+ RiffChunk dataChunk;
348
+ dataChunk.chunkId = MAKEFOURCC('d', 'a', 't', 'a');
349
+ dataChunk.chunkSize = m_cumulativeCount;
350
+ if (fwrite(&dataChunk, sizeof(RiffChunk), 1, m_fp) != 1)
351
+ return false;
352
+
353
+ fclose(m_fp);
354
+ m_fp = nullptr;
355
+
356
+ m_commitDone = true;
357
+ m_validState = false;
358
+ return true;
359
+ }
360
+
361
+ std::vector<float>* CWaveFileRead::GetFloatVector() {
362
+ if (!m_floatWaveData.size()) (void) GetFloatPCMData();
363
+ return &m_floatWaveData;
364
+ }
365
+
366
+ std::map<std::string, std::unique_ptr<CWaveFileRead>> read_file_cache;
367
+ bool ReadWavFile(const std::string& filename, uint32_t expected_sample_rate,
368
+ std::vector<float>** data, unsigned* original_num_samples,std::vector<int>* file_end_offset,
369
+ int align_samples) {
370
+ std::vector<std::string> files;
371
+
372
+ const std::string kDelimiter = ";";
373
+
374
+ auto delim = filename.find(kDelimiter);
375
+ if (delim != std::string::npos) {
376
+ unsigned int start = 0;
377
+ do
378
+ {
379
+ files.push_back(filename.substr(start, delim - start));
380
+ start = delim + 1;
381
+ } while ((delim = filename.find(kDelimiter, delim+1)) != std::string::npos);
382
+ if (start < filename.length()) files.push_back(filename.substr(start));
383
+ } else {
384
+ files.push_back(filename);
385
+ }
386
+
387
+ *original_num_samples = 0;
388
+
389
+ int offset = 0;
390
+ std::vector<float>* ret = nullptr;
391
+ for (auto& file: files) {
392
+ CWaveFileRead *wave_file = nullptr;
393
+ auto cached_file = read_file_cache.find(file);
394
+ if (cached_file == read_file_cache.end()) {
395
+ wave_file = new CWaveFileRead(file);
396
+ read_file_cache.emplace(file, std::unique_ptr<CWaveFileRead>(wave_file));
397
+ } else {
398
+ wave_file = cached_file->second.get();
399
+ }
400
+
401
+ if (wave_file->isValid() == false) {
402
+ delete ret;
403
+ *data = nullptr;
404
+ return false;
405
+ }
406
+
407
+ #ifndef ENABLE_PERF_DUMP
408
+ std::cout << "Total number of samples: " << wave_file->GetNumSamples() << std::endl;
409
+ std::cout << "Size in bytes: " << wave_file->GetRawPCMDataSizeInBytes() << std::endl;
410
+ std::cout << "Sample rate: " << wave_file->GetSampleRate() << std::endl;
411
+
412
+ auto bits_per_sample = wave_file->GetBitsPerSample();
413
+ std::cout << "Bits/sample: " << bits_per_sample << std::endl;
414
+ #endif // ENABLE_PERF_DUMP
415
+
416
+ if (wave_file->GetSampleRate() != expected_sample_rate) {
417
+ std::cout << "Sample rate mismatch" << std::endl;
418
+ delete ret;
419
+ *data = nullptr;
420
+ return false;
421
+ }
422
+ if (wave_file->GetWaveFormat().nChannels != 1) {
423
+ std::cout << "Channel count needs to be 1" << std::endl;
424
+ delete ret;
425
+ *data = nullptr;
426
+ return false;
427
+ }
428
+
429
+ *original_num_samples += wave_file->GetNumSamples();
430
+
431
+ int num_samples;
432
+ if (align_samples != -1) {
433
+ uint32_t pad = align_samples - (wave_file->GetNumSamples() % align_samples);
434
+ num_samples = wave_file->GetNumSamples() + pad;
435
+ } else {
436
+ num_samples = wave_file->GetNumSamples();
437
+ }
438
+
439
+ if (file_end_offset) {
440
+ offset += num_samples;
441
+ file_end_offset->push_back(offset);
442
+ }
443
+
444
+ if (files.size() > 1) {
445
+ // If using reset, ignore cache
446
+ // Vector is not resized here, will be resized at end
447
+ auto local = wave_file->GetFloatVector();
448
+ if (!ret) {
449
+ ret = new std::vector<float>();
450
+ *data = ret;
451
+ }
452
+ ret->insert(ret->end(), local->begin(), local->end());
453
+ } else {
454
+ *data = wave_file->GetFloatVector();
455
+ // Align if using multiple inputs
456
+ (*data)->resize(num_samples, 0.f);
457
+ }
458
+ }
459
+
460
+ if (files.size() > 1 && align_samples != -1) {
461
+ // Align if using multiple inputs
462
+ uint32_t pad = ret->size() % align_samples;
463
+ if (pad) ret->resize(ret->size() + pad, 0.f);
464
+ }
465
+ return true;
466
+ }