File size: 2,539 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
import torch
from tqdm import tqdm
import torchaudio
import librosa
import os
import math
import numpy as np
from tools.get_bsrnnvae import get_bsrnnvae
import tools.torch_tools as torch_tools

class Tango:
    def __init__(self, \

        device="cuda:0"):
        
        self.sample_rate = 44100
        self.device = device

        self.vae = get_bsrnnvae()
        self.vae = self.vae.eval().to(device)

    def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False):
        """ Genrate audio without condition. """
        num_frames = math.ceil(duration * 100. / 8)
        with torch.no_grad():
            orig_samples, fs = torchaudio.load(fname)
            if(fs!=44100):
                orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
                fs = 44100
            if(orig_samples.shape[-1]<int(duration*44100*2)):
                orig_samples =  torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration*44100*2+480)-orig_samples.shape[-1], \
                    dtype=orig_samples.dtype, device=orig_samples.device)], -1)
            # orig_samples = torch.cat([torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device), orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
            orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], int(duration * fs)//2, dtype=orig_samples.dtype, device=orig_samples.device)], -1).to(self.device)
            if(fs!=44100):orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100)
            # resampled_audios = orig_samples[[0],int(4.64*44100):int(35.36*48000)+480].clamp(-1,1)
            resampled_audios = orig_samples[[0],0:int(duration*2*44100)+480].clamp(-1,1)
            orig_samples = orig_samples[[0],0:int(duration*2*44100)]

            audio = self.vae(orig_samples[:,None,:])[:,0,:]

            if(orig_samples.shape[-1]<audio.shape[-1]):
                orig_samples = torch.cat([orig_samples, torch.zeros(orig_samples.shape[0], audio.shape[-1]-orig_samples.shape[-1], dtype=orig_samples.dtype, device=orig_samples.device)],-1)
            else:
                orig_samples = orig_samples[:,0:audio.shape[-1]]
            output = torch.cat([orig_samples.detach().cpu(),audio.detach().cpu()],0)
        return output