denoise_and_diarization / utils /denoise_pipeline.py
agorlanov
train_fix
3ff6c9f
raw
history blame
1.18 kB
import librosa
import torch
from demucs.apply import apply_model
from demucs.pretrained import get_model
from scipy.io.wavfile import write
demucs_model = get_model('cfa93e08')
def denoise(filename: str, device: str, out_filename='denoise.wav') -> str:
wav_ref, sr = librosa.load(filename, mono=False, sr=44100)
wav = torch.tensor(wav_ref)
wav = torch.cat([wav.unsqueeze(0), wav.unsqueeze(0)]) if len(wav.shape) == 1 else wav
ref = wav.mean(0)
wav = (wav - ref.mean()) / wav.std()
sources = apply_model(
demucs_model, wav[None], device=device, shifts=1, split=True, overlap=0.1, progress=True, num_workers=0
)[0]
sources = sources * ref.std() + ref.mean()
vocal_wav = sources[-1]
vocal_wav = vocal_wav / max(1.01 * vocal_wav.abs().max(), 1)
vocal_wav = vocal_wav.numpy()
vocal_wav = librosa.to_mono(vocal_wav)
vocal_wav = vocal_wav.T
vocal_wav = librosa.resample(vocal_wav, orig_sr=44100, target_sr=48000)
write(out_filename, 48000, vocal_wav)
return out_filename
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
denoise(filename='../oxx.wav', device=device)