Jmica commited on
Commit
e6ed7af
·
1 Parent(s): 404aee9

Upload trainset_preprocess_pipeline_print.py

Browse files
Files changed (1) hide show
  1. trainset_preprocess_pipeline_print.py +139 -0
trainset_preprocess_pipeline_print.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, multiprocessing
2
+ from scipy import signal
3
+
4
+ now_dir = os.getcwd()
5
+ sys.path.append(now_dir)
6
+
7
+ inp_root = sys.argv[1]
8
+ sr = int(sys.argv[2])
9
+ n_p = int(sys.argv[3])
10
+ exp_dir = sys.argv[4]
11
+ noparallel = sys.argv[5] == "True"
12
+ import numpy as np, os, traceback
13
+ from slicer2 import Slicer
14
+ import librosa, traceback
15
+ from scipy.io import wavfile
16
+ import multiprocessing
17
+ from my_utils import load_audio
18
+
19
+ mutex = multiprocessing.Lock()
20
+ f = open("%s/preprocess.log" % exp_dir, "a+")
21
+
22
+
23
+ def println(strr):
24
+ mutex.acquire()
25
+ print(strr)
26
+ f.write("%s\n" % strr)
27
+ f.flush()
28
+ mutex.release()
29
+
30
+
31
+ class PreProcess:
32
+ def __init__(self, sr, exp_dir):
33
+ self.slicer = Slicer(
34
+ sr=sr,
35
+ threshold=-42,
36
+ min_length=1500,
37
+ min_interval=400,
38
+ hop_size=15,
39
+ max_sil_kept=500,
40
+ )
41
+ self.sr = sr
42
+ self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr)
43
+ self.per = 3.0
44
+ self.overlap = 0.3
45
+ self.tail = self.per + self.overlap
46
+ self.max = 0.9
47
+ self.alpha = 0.75
48
+ self.exp_dir = exp_dir
49
+ self.gt_wavs_dir = "%s/0_gt_wavs" % exp_dir
50
+ self.wavs16k_dir = "%s/1_16k_wavs" % exp_dir
51
+ os.makedirs(self.exp_dir, exist_ok=True)
52
+ os.makedirs(self.gt_wavs_dir, exist_ok=True)
53
+ os.makedirs(self.wavs16k_dir, exist_ok=True)
54
+
55
+ def norm_write(self, tmp_audio, idx0, idx1):
56
+ tmp_max = np.abs(tmp_audio).max()
57
+ if tmp_max > 2.5:
58
+ print("%s-%s-%s-filtered" % (idx0, idx1, tmp_max))
59
+ return
60
+ tmp_audio = (tmp_audio / tmp_max * (self.max * self.alpha)) + (
61
+ 1 - self.alpha
62
+ ) * tmp_audio
63
+ wavfile.write(
64
+ "%s/%s_%s.wav" % (self.gt_wavs_dir, idx0, idx1),
65
+ self.sr,
66
+ tmp_audio.astype(np.float32),
67
+ )
68
+ tmp_audio = librosa.resample(
69
+ tmp_audio, orig_sr=self.sr, target_sr=16000
70
+ ) # , res_type="soxr_vhq"
71
+ wavfile.write(
72
+ "%s/%s_%s.wav" % (self.wavs16k_dir, idx0, idx1),
73
+ 16000,
74
+ tmp_audio.astype(np.float32),
75
+ )
76
+
77
+ def pipeline(self, path, idx0):
78
+ try:
79
+ audio = load_audio(path, self.sr)
80
+ # zero phased digital filter cause pre-ringing noise...
81
+ # audio = signal.filtfilt(self.bh, self.ah, audio)
82
+ audio = signal.lfilter(self.bh, self.ah, audio)
83
+
84
+ idx1 = 0
85
+ for audio in self.slicer.slice(audio):
86
+ i = 0
87
+ while 1:
88
+ start = int(self.sr * (self.per - self.overlap) * i)
89
+ i += 1
90
+ if len(audio[start:]) > self.tail * self.sr:
91
+ tmp_audio = audio[start : start + int(self.per * self.sr)]
92
+ self.norm_write(tmp_audio, idx0, idx1)
93
+ idx1 += 1
94
+ else:
95
+ tmp_audio = audio[start:]
96
+ idx1 += 1
97
+ break
98
+ self.norm_write(tmp_audio, idx0, idx1)
99
+ println("%s->Suc." % path)
100
+ except:
101
+ println("%s->%s" % (path, traceback.format_exc()))
102
+
103
+ def pipeline_mp(self, infos):
104
+ for path, idx0 in infos:
105
+ self.pipeline(path, idx0)
106
+
107
+ def pipeline_mp_inp_dir(self, inp_root, n_p):
108
+ try:
109
+ infos = [
110
+ ("%s/%s" % (inp_root, name), idx)
111
+ for idx, name in enumerate(sorted(list(os.listdir(inp_root))))
112
+ ]
113
+ if noparallel:
114
+ for i in range(n_p):
115
+ self.pipeline_mp(infos[i::n_p])
116
+ else:
117
+ ps = []
118
+ for i in range(n_p):
119
+ p = multiprocessing.Process(
120
+ target=self.pipeline_mp, args=(infos[i::n_p],)
121
+ )
122
+ ps.append(p)
123
+ p.start()
124
+ for i in range(n_p):
125
+ ps[i].join()
126
+ except:
127
+ println("Fail. %s" % traceback.format_exc())
128
+
129
+
130
+ def preprocess_trainset(inp_root, sr, n_p, exp_dir):
131
+ pp = PreProcess(sr, exp_dir)
132
+ println("start preprocess")
133
+ println(sys.argv)
134
+ pp.pipeline_mp_inp_dir(inp_root, n_p)
135
+ println("end preprocess")
136
+
137
+
138
+ if __name__ == "__main__":
139
+ preprocess_trainset(inp_root, sr, n_p, exp_dir)