Spaces:
Running
on
Zero
Running
on
Zero
from ldm.data.preprocess.NAT_mel import MelNet | |
import os | |
from tqdm import tqdm | |
from glob import glob | |
import math | |
import pandas as pd | |
import logging | |
import math | |
import audioread | |
from tqdm.contrib.concurrent import process_map | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
import numpy as np | |
from torch.distributed import init_process_group | |
from torch.utils.data import Dataset,DataLoader,DistributedSampler | |
import torch.multiprocessing as mp | |
from argparse import Namespace | |
from multiprocessing import Pool | |
import json | |
class tsv_dataset(Dataset): | |
def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None: | |
super().__init__() | |
if os.path.isdir(tsv_path): | |
files = glob(os.path.join(tsv_path,'*.tsv')) | |
df = pd.concat([pd.read_csv(file,sep='\t') for file in files]) | |
else: | |
df = pd.read_csv(tsv_path,sep='\t') | |
self.audio_paths = [] | |
self.sr = sr | |
self.mode = mode | |
self.target_mel_length = target_mel_length | |
self.hop_size = hop_size | |
for t in tqdm(df.itertuples()): | |
self.audio_paths.append(getattr(t,'audio_path')) | |
def __len__(self): | |
return len(self.audio_paths) | |
def pad_wav(self,wav): | |
# wav should be in shape(1,wav_len) | |
wav_length = wav.shape[-1] | |
assert wav_length > 100, "wav is too short, %s" % wav_length | |
segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1] | |
if segment_length is None or wav_length == segment_length: | |
return wav | |
elif wav_length > segment_length: | |
return wav[:,:segment_length] | |
elif wav_length < segment_length: | |
temp_wav = torch.zeros((1, segment_length),dtype=torch.float32) | |
temp_wav[:, :wav_length] = wav | |
return temp_wav | |
def __getitem__(self, index): | |
audio_path = self.audio_paths[index] | |
wav, orisr = torchaudio.load(audio_path) | |
if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len) | |
wav = wav.mean(0,keepdim=True) | |
wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr) | |
if self.mode == 'pad': | |
assert self.target_mel_length is not None | |
wav = self.pad_wav(wav) | |
return audio_path,wav | |
def process_audio_by_tsv(rank,args): | |
if args.num_gpus > 1: | |
init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'], | |
world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank) | |
sr = args.audio_sample_rate | |
dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length) | |
sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None | |
# batch_size must == 1,since wav_len is not equal | |
loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False) | |
device = torch.device('cuda:{:d}'.format(rank)) | |
mel_net = MelNet(args.__dict__) | |
mel_net.to(device) | |
# if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. | |
# mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device) | |
loader = tqdm(loader) if rank == 0 else loader | |
for batch in loader: | |
audio_paths,wavs = batch | |
wavs = wavs.to(device) | |
if args.save_resample: | |
for audio_path,wav in zip(audio_paths,wavs): | |
psplits = audio_path.split('/') | |
root,wav_name = psplits[0],psplits[-1] | |
# save resample | |
resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy' | |
resample_dir_name = os.path.join(resample_root,*psplits[1:-1]) | |
resample_path = os.path.join(resample_dir_name,resample_name) | |
os.makedirs(resample_dir_name,exist_ok=True) | |
np.save(resample_path,wav.cpu().numpy().squeeze(0)) | |
if args.save_mel: | |
mode = args.mode | |
batch_max_length = args.batch_max_length | |
for audio_path,wav in zip(audio_paths,wavs): | |
psplits = audio_path.split('/') | |
root,wav_name = psplits[0],psplits[-1] | |
mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy' | |
mel_dir_name = os.path.join(mel_root,*psplits[1:-1]) | |
mel_path = os.path.join(mel_dir_name,mel_name) | |
if not os.path.exists(mel_path): | |
mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len) | |
if mel_spec.shape[1] <= batch_max_length: | |
if mode == 'tile': # pad is done in dataset as pad wav | |
n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1]) | |
mel_spec = np.tile(mel_spec,reps=(1,n_repeat)) | |
elif mode == 'none' or mode == 'pad': | |
pass | |
else: | |
raise ValueError(f'mode:{mode} is not supported') | |
mel_spec = mel_spec[:,:batch_max_length] | |
os.makedirs(mel_dir_name,exist_ok=True) | |
np.save(mel_path,mel_spec) | |
def split_list(i_list,num): | |
each_num = math.ceil(i_list / num) | |
result = [] | |
for i in range(num): | |
s = each_num * i | |
e = (each_num * (i+1)) | |
result.append(i_list[s:e]) | |
return result | |
def drop_bad_wav(item): | |
index,path = item | |
try: | |
with audioread.audio_open(path) as f: | |
totalsec = f.duration | |
if totalsec < 0.1: | |
return index # index | |
except: | |
print(f"corrupted wav:{path}") | |
return index | |
return False | |
def drop_bad_wavs(tsv_path):# 'audioset.csv' | |
df = pd.read_csv(tsv_path,sep='\t') | |
item_list = [] | |
for item in tqdm(df.itertuples()): | |
item_list.append((item[0],getattr(item,'audio_path'))) | |
r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16) | |
bad_indices = list(filter(lambda x:x!= False,r)) | |
print(bad_indices) | |
with open('bad_wavs.json','w') as f: | |
x = [item_list[i] for i in bad_indices] | |
json.dump(x,f) | |
df = df.drop(bad_indices,axis=0) | |
df.to_csv(tsv_path,sep='\t',index=False) | |
if __name__ == '__main__': | |
logging.basicConfig(filename='example.log', level=logging.INFO, | |
format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') | |
tsv_path = './musiccap.tsv' | |
if os.path.isdir(tsv_path): | |
files = glob(os.path.join(tsv_path,'*.tsv')) | |
for file in files: | |
drop_bad_wavs(file) | |
else: | |
drop_bad_wavs(tsv_path) | |
num_gpus = 1 | |
args = { | |
'audio_sample_rate': 16000, | |
'audio_num_mel_bins':80, | |
'fft_size': 1024,# 4000:512 ,16000:1024, | |
'win_size': 1024, | |
'hop_size': 256, | |
'fmin': 0, | |
'fmax': 8000, | |
'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 # | |
'tsv_path': tsv_path, | |
'num_gpus': num_gpus, | |
'mode': 'none', | |
'save_resample':False, | |
'save_mel' :True | |
} | |
args = Namespace(**args) | |
args.dist_config = { | |
"dist_backend": "nccl", | |
"dist_url": "tcp://localhost:54189", | |
"world_size": 1 | |
} | |
if args.num_gpus>1: | |
mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,)) | |
else: | |
process_audio_by_tsv(0,args=args) | |
print("done") | |