Spaces:
Running
Running
File size: 5,528 Bytes
afe1a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import json
def test_reconstuct():
import yaml
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from audioldm2.utilities.data.dataset import AudioDataset
from utils import load_clip, load_clap, load_t5
model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
config = yaml.load(
open(
'config/16k_64.yaml',
'r'
),
Loader=yaml.FullLoader,
)
print(config)
t5 = load_t5('cuda', max_length=256)
clap = load_clap('cuda', max_length=256)
dataset = AudioDataset(
config=config, split="train", waveform_only=False, dataset_json_path='mini_dataset.json',
tokenizer=clap.tokenizer,
uncond_pro=0.1,
text_ctx_len=77,
tokenizer_t5=t5.tokenizer,
text_ctx_len_t5=256,
uncond_pro_t5=0.1,
)
print(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0).size())
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae'))
vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder'))
latents = vae.encode(dataset[0]['log_mel_spec'].unsqueeze(0).unsqueeze(0)).latent_dist.sample().mul_(vae.config.scaling_factor)
print('laten size:', latents.size())
latents = 1 / vae.config.scaling_factor * latents
mel_spectrogram = vae.decode(latents).sample
print(mel_spectrogram.size())
if mel_spectrogram.dim() == 4:
mel_spectrogram = mel_spectrogram.squeeze(1)
waveform = vocoder(mel_spectrogram)
waveform = waveform[0].cpu().float().detach().numpy()
print(waveform.shape)
# import soundfile as sf
# sf.write('reconstruct.wav', waveform, samplerate=16000)
from scipy.io import wavfile
# wavfile.write('reconstruct.wav', 16000, waveform)
def mini_dataset(num=32):
data = []
for i in range(num):
data.append(
{
'wav': 'case.mp3',
'label': 'a beautiful music',
}
)
with open('mini_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def fma_dataset():
import pandas as pd
annotation_prex = "/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/annotation"
annotation_list = ['test-00000-of-00001.parquet', 'train-00000-of-00001.parquet', 'valid-00000-of-00001.parquet']
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/fma_large'
data = []
for annotation_file in annotation_list:
annotation_file = os.path.join(annotation_prex, annotation_file)
df=pd.read_parquet(annotation_file)
print(df.shape)
for id, row in df.iterrows():
#print(id, row['pseudo_caption'], row['path'])
tmp_path = os.path.join(dataset_prex, row['path'] + '.mp3')
# print(tmp_path)
if os.path.exists(tmp_path):
data.append(
{
'wav': tmp_path,
'label': row['pseudo_caption'],
}
)
# break
print(len(data))
with open('fma_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def audioset_dataset():
import pandas as pd
dataset_prex = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset'
annotation_path = '/maindata/data/shared/public/zhengcong.fei/dataset/dataset_music/audioset/balanced_train-00000-of-00001.parquet'
df=pd.read_parquet(annotation_path)
print(df.shape)
data = []
for id, row in df.iterrows():
#print(id, row['pseudo_caption'], row['path'])
try:
tmp_path = os.path.join(dataset_prex, row['path'] + '.flac')
except:
print(row['path'])
if os.path.exists(tmp_path):
# print(tmp_path)
data.append(
{
'wav': tmp_path,
'label': row['pseudo_caption'],
}
)
print(len(data))
with open('audioset_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def combine_dataset():
data_list = ['fma_dataset.json', 'audioset_dataset.json']
data = []
for data_file in data_list:
with open(data_file, 'r') as f:
data += json.load(f)
print(len(data))
with open('combine_dataset.json', 'w') as f:
json.dump(data, f, indent=4)
def test_music_format():
import torchaudio
filename = '2.flac'
waveform, sr = torchaudio.load(filename,)
print(waveform, sr )
def test_flops():
version = 'giant'
import torch
from constants import build_model
from thop import profile
model = build_model(version).cuda()
img_ids = torch.randn((1, 1024, 3)).cuda()
txt = torch.randn((1, 256, 4096)).cuda()
txt_ids = torch.randn((1, 256, 3)).cuda()
y = torch.randn((1, 768)).cuda()
x = torch.randn((1, 1024, 32)).cuda()
t = torch.tensor([1] * 1).cuda()
flops, _ = profile(model, inputs=(x, img_ids, txt, txt_ids, t, y,))
print('FLOPs = ' + str(flops * 2/1000**3) + 'G')
# test_music_format()
# test_reconstuct()
# mini_dataset()
# fma_dataset()
# audioset_dataset()
# combine_dataset()
test_flops() |