Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import time | |
import torchaudio | |
import torch | |
from torch import nn | |
from scipy import signal | |
from scipy.io import wavfile | |
import numpy as np | |
import multiprocessing | |
from pydub import AudioSegment | |
multiprocessing.set_start_method("spawn", force=True) | |
now_directory = os.getcwd() | |
sys.path.append(now_directory) | |
from rvc.lib.utils import load_audio | |
from rvc.train.slicer import Slicer | |
# Constants | |
OVERLAP = 0.3 | |
MAX_AMPLITUDE = 0.9 | |
ALPHA = 0.75 | |
HIGH_PASS_CUTOFF = 48 | |
SAMPLE_RATE_16K = 16000 | |
class PreProcess: | |
def __init__(self, sr: int, exp_dir: str, per: float): | |
self.slicer = Slicer( | |
sr=sr, | |
threshold=-42, | |
min_length=1500, | |
min_interval=400, | |
hop_size=15, | |
max_sil_kept=500, | |
) | |
self.sr = sr | |
self.b_high, self.a_high = signal.butter( | |
N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr | |
) | |
self.per = per | |
self.exp_dir = exp_dir | |
self.device = "cpu" | |
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios") | |
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k") | |
os.makedirs(self.gt_wavs_dir, exist_ok=True) | |
os.makedirs(self.wavs16k_dir, exist_ok=True) | |
def _normalize_audio(self, audio: torch.Tensor): | |
tmp_max = torch.abs(audio).max() | |
if tmp_max > 2.5: | |
return None | |
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio | |
def _write_audio(self, audio: torch.Tensor, filename: str, sr: int): | |
audio = audio.cpu().numpy() | |
wavfile.write(filename, sr, audio.astype(np.float32)) | |
def process_audio_segment(self, audio_segment: torch.Tensor, idx0: int, idx1: int): | |
normalized_audio = self._normalize_audio(audio_segment) | |
if normalized_audio is None: | |
print(f"{idx0}-{idx1}-filtered") | |
return | |
gt_wav_path = os.path.join(self.gt_wavs_dir, f"{idx0}_{idx1}.wav") | |
self._write_audio(normalized_audio, gt_wav_path, self.sr) | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=self.sr, new_freq=SAMPLE_RATE_16K | |
).to(self.device) | |
audio_16k = resampler(normalized_audio.float()) | |
wav_16k_path = os.path.join(self.wavs16k_dir, f"{idx0}_{idx1}.wav") | |
self._write_audio(audio_16k, wav_16k_path, SAMPLE_RATE_16K) | |
def process_audio(self, path: str, idx0: int): | |
try: | |
audio = load_audio(path, self.sr) | |
audio = torch.tensor( | |
signal.lfilter(self.b_high, self.a_high, audio), device=self.device | |
).float() | |
idx1 = 0 | |
for audio_segment in self.slicer.slice(audio.cpu().numpy()): | |
audio_segment = torch.tensor(audio_segment, device=self.device).float() | |
i = 0 | |
while True: | |
start = int(self.sr * (self.per - OVERLAP) * i) | |
i += 1 | |
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr: | |
tmp_audio = audio_segment[ | |
start : start + int(self.per * self.sr) | |
] | |
self.process_audio_segment(tmp_audio, idx0, idx1) | |
idx1 += 1 | |
else: | |
tmp_audio = audio_segment[start:] | |
self.process_audio_segment(tmp_audio, idx0, idx1) | |
idx1 += 1 | |
break | |
except Exception as error: | |
print(f"An error occurred on {path} path: {error}") | |
def process_audio_file(self, file_path_idx): | |
file_path, idx0 = file_path_idx | |
ext = os.path.splitext(file_path)[1].lower() | |
if ext not in [".wav"]: | |
audio = AudioSegment.from_file(file_path) | |
file_path = os.path.join("/tmp", f"{idx0}.wav") | |
audio.export(file_path, format="wav") | |
self.process_audio(file_path, idx0) | |
def preprocess_training_set( | |
input_root: str, | |
sr: int, | |
num_processes: int, | |
exp_dir: str, | |
per: float, | |
): | |
start_time = time.time() | |
pp = PreProcess(sr, exp_dir, per) | |
print(f"Starting preprocess with {num_processes} processes...") | |
files = [ | |
(os.path.join(input_root, f), idx) | |
for idx, f in enumerate(os.listdir(input_root)) | |
if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg")) | |
] | |
ctx = multiprocessing.get_context("spawn") | |
with ctx.Pool(processes=num_processes) as pool: | |
pool.map(pp.process_audio_file, files) | |
elapsed_time = time.time() - start_time | |
print(f"Preprocess completed in {elapsed_time:.2f} seconds.") | |
if __name__ == "__main__": | |
experiment_directory = str(sys.argv[1]) | |
input_root = str(sys.argv[2]) | |
sample_rate = int(sys.argv[3]) | |
percentage = float(sys.argv[4]) | |
num_processes = ( | |
int(sys.argv[5]) if len(sys.argv) > 5 else multiprocessing.cpu_count() | |
) | |
preprocess_training_set( | |
input_root, | |
sample_rate, | |
num_processes, | |
experiment_directory, | |
percentage, | |
) | |