Spaces:
Running
on
L40S
Running
on
L40S
File size: 2,932 Bytes
258fd02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import json
import torch
from tqdm import tqdm
import torchaudio
import librosa
import os
import math
import numpy as np
from get_melvaehifigan48k import build_pretrained_models
import tools.torch_tools as torch_tools
class Tango:
def __init__(self, \
device="cuda:0"):
self.sample_rate = 48000
self.device = device
self.vae, self.stft = build_pretrained_models()
self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
def mel_spectrogram_to_waveform(self, mel_spectrogram):
if mel_spectrogram.dim() == 4:
mel_spectrogram = mel_spectrogram.squeeze(1)
waveform = self.vocoder(mel_spectrogram)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
waveform = waveform.cpu().float()
return waveform
def sound2sound_generate_longterm(self, fname, batch_size=1, duration=10.24, steps=200, disable_progress=False):
""" Genrate audio without condition. """
num_frames = math.ceil(duration * 100. / 8)
with torch.no_grad():
orig_samples, fs = torchaudio.load(fname)
if(orig_samples.shape[-1]<int(duration*48000)):
orig_samples = orig_samples.repeat(1,math.ceil(int(duration*48000)/float(orig_samples.shape[-1])))
# orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
if(fs!=48000):orig_samples = torchaudio.functional.resample(orig_samples, fs, 48000)
# resampled_audios = orig_samples[[0],int(4.64*48000):int(35.36*48000)+480].clamp(-1,1)
resampled_audios = orig_samples[[0],0:int(duration*48000)+480].clamp(-1,1)
orig_samples = orig_samples[[0],0:int(duration*48000)]
mel, _, _ = torch_tools.wav_to_fbank2(resampled_audios, -1, fn_STFT=self.stft)
mel = mel.unsqueeze(1).to(self.device)
audio = self.vae.decode_to_waveform(mel)
audio = torch.from_numpy(audio)
if(orig_samples.shape[-1]<audio.shape[-1]):
orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
else:
orig_samples = orig_samples[:,0:audio.shape[-1]]
output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
return output
|