CSH-1220
Update how we load pre-trained weights
24363dc
raw
history blame
15.6 kB
"""
Reference Repo: https://github.com/facebookresearch/AudioMAE
"""
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
from . import models_vit
from . import models_mae
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import torchaudio
from huggingface_hub import hf_hub_download
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
class Vanilla_AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
def __init__(
self,
):
super().__init__()
model = models_mae.__dict__["mae_vit_base_patch16"](
in_chans=1, audio_exp=True, img_size=(1024, 128)
)
# checkpoint_path = 'pretrained.pth'
checkpoint_path = hf_hub_download(
repo_id="DennisHung/Pre-trained_AudioMAE_weights",
filename="pretrained.pth"
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
msg = model.load_state_dict(checkpoint['model'], strict=False)
# Skip the missing keys of decoder modules (not required)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model.eval()
self.model = model.train()
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
"""
x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
mask_ratio: 'masking ratio (percentage of removed patches).'
"""
with torch.no_grad():
# embed: [B, 513, 768] for mask_ratio=0.0
if no_mask:
if no_average:
# raise RuntimeError("This function is deprecated")
embed = self.model.forward_encoder_no_random_mask_no_average(
x
) # mask_ratio
else:
embed = self.model.forward_encoder_no_mask(x) # mask_ratio
else:
raise RuntimeError("This function is deprecated")
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
return embed
import torchaudio
import numpy as np
import torch
# def roll_mag_aug(waveform):
# idx = np.random.randint(len(waveform))
# rolled_waveform = np.roll(waveform, idx)
# mag = np.random.beta(10, 10) + 0.5
# return torch.Tensor(rolled_waveform * mag)
def wav_to_fbank(filename, melbins, target_length, roll_mag_aug_flag=False):
waveform, sr = torchaudio.load(filename)
waveform = waveform - waveform.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
htk_compat=True,
sample_frequency=sr,
use_energy=False,
window_type='hanning',
num_mel_bins=melbins,
dither=0.0,
frame_shift=10
)
n_frames = fbank.shape[0]
p = target_length - n_frames
# Cut and pad
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[0:target_length, :]
return fbank
# Example usage
import torch.nn.functional as F
class AudioMAEConditionCTPoolRand(nn.Module):
"""
audiomae = AudioMAEConditionCTPool2x2()
data = torch.randn((4, 1024, 128))
output = audiomae(data)
import ipdb;ipdb.set_trace()
exit(0)
"""
def __init__(
self,
time_pooling_factors=[1, 2, 4, 8],
freq_pooling_factors=[1, 2, 4, 8],
eval_time_pooling=8,
eval_freq_pooling=8,
mask_ratio=0.0,
regularization=False,
no_audiomae_mask=True,
no_audiomae_average=True,
):
super().__init__()
self.device = None
self.time_pooling_factors = time_pooling_factors
self.freq_pooling_factors = freq_pooling_factors
self.no_audiomae_mask = no_audiomae_mask
self.no_audiomae_average = no_audiomae_average
self.eval_freq_pooling = eval_freq_pooling
self.eval_time_pooling = eval_time_pooling
self.mask_ratio = mask_ratio
self.use_reg = regularization
self.audiomae = Vanilla_AudioMAE()
self.audiomae.eval()
for p in self.audiomae.parameters():
p.requires_grad = False
# Required
def get_unconditional_condition(self, batchsize):
param = next(self.audiomae.parameters())
assert param.requires_grad == False
device = param.device
# time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
self.eval_freq_pooling, 8
)
# time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
# freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
token_num = int(512 / (time_pool * freq_pool))
return [
torch.zeros((batchsize, token_num, 768)).to(device).float(),
torch.ones((batchsize, token_num)).to(device).float(),
]
def pool(self, representation, time_pool=None, freq_pool=None):
assert representation.size(-1) == 768
representation = representation[:, 1:, :].transpose(1, 2)
# print("representation.shape",representation.shape)
bs, embedding_dim, token_num = representation.size()
representation = representation.reshape(bs, embedding_dim, 64, 8)
# if self.training:
# if time_pool is None and freq_pool is None:
# time_pool = min(
# 64,
# self.time_pooling_factors[
# np.random.choice(list(range(len(self.time_pooling_factors))))
# ],
# )
# # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
# freq_pool = min(8, time_pool) # TODO here I make some modification.
# else:
# time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
# self.eval_freq_pooling, 8
# )
self.avgpooling = nn.AvgPool2d(
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
)
self.maxpooling = nn.MaxPool2d(
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
)
pooled = (
self.avgpooling(representation) + self.maxpooling(representation)
) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
# print("pooled.shape",pooled.shape)
pooled = pooled.flatten(2).transpose(1, 2)
return pooled # [bs, token_num, embedding_dim]
def regularization(self, x):
assert x.size(-1) == 768
x = F.normalize(x, p=2, dim=-1)
return x
# Required
def forward(self, batch, time_pool=None, freq_pool=None):
assert batch.size(-2) == 1024 and batch.size(-1) == 128
if self.device is None:
self.device = next(self.audiomae.parameters()).device
batch = batch.unsqueeze(1).to(self.device)
with torch.no_grad():
representation = self.audiomae(
batch,
mask_ratio=self.mask_ratio,
no_mask=self.no_audiomae_mask,
no_average=self.no_audiomae_average,
)
representation = self.pool(representation, time_pool, freq_pool)
if self.use_reg:
representation = self.regularization(representation)
return [
representation,
torch.ones((representation.size(0), representation.size(1)))
.to(representation.device)
# .float(),
]
class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
"""
audiomae = AudioMAEConditionCTPool2x2()
data = torch.randn((4, 1024, 128))
output = audiomae(data)
import ipdb;ipdb.set_trace()
exit(0)
"""
def __init__(
self,
time_pooling_factors=[8],
freq_pooling_factors=[8],
eval_time_pooling=8,
eval_freq_pooling=8,
mask_ratio=0.0,
regularization=False,
no_audiomae_mask=True,
no_audiomae_average=False,
):
super().__init__()
self.device = None
self.time_pooling_factors = time_pooling_factors
self.freq_pooling_factors = freq_pooling_factors
self.no_audiomae_mask = no_audiomae_mask
self.no_audiomae_average = no_audiomae_average
self.eval_freq_pooling = eval_freq_pooling
self.eval_time_pooling = eval_time_pooling
self.mask_ratio = mask_ratio
self.use_reg = regularization
self.audiomae = Vanilla_AudioMAE()
self.audiomae.eval()
for p in self.audiomae.parameters():
p.requires_grad = False
# Required
def get_unconditional_condition(self, batchsize):
param = next(self.audiomae.parameters())
assert param.requires_grad == False
device = param.device
# time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
self.eval_freq_pooling, 8
)
# time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
# freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
token_num = int(512 / (time_pool * freq_pool))
return [
torch.zeros((batchsize, token_num, 768)).to(device).float(),
torch.ones((batchsize, token_num)).to(device).float(),
]
def pool(self, representation, time_pool=None, freq_pool=None):
assert representation.size(-1) == 768
representation = representation[:, 1:, :].transpose(1, 2)
bs, embedding_dim, token_num = representation.size()
representation = representation.reshape(bs, embedding_dim, 64, 8)
# if self.training:
# if time_pool is None and freq_pool is None:
# time_pool = min(
# 64,
# self.time_pooling_factors[
# np.random.choice(list(range(len(self.time_pooling_factors))))
# ],
# )
# freq_pool = min(
# 8,
# self.freq_pooling_factors[
# np.random.choice(list(range(len(self.freq_pooling_factors))))
# ],
# )
# # freq_pool = min(8, time_pool) # TODO here I make some modification.
# else:
# time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
# self.eval_freq_pooling, 8
# )
self.avgpooling = nn.AvgPool2d(
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
)
self.maxpooling = nn.MaxPool2d(
kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
)
pooled = (
self.avgpooling(representation) + self.maxpooling(representation)
) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
pooled = pooled.flatten(2).transpose(1, 2)
return pooled # [bs, token_num, embedding_dim]
def regularization(self, x):
assert x.size(-1) == 768
x = F.normalize(x, p=2, dim=-1)
return x
# Required
def forward(self, batch, time_pool=None, freq_pool=None):
assert batch.size(-2) == 1024 and batch.size(-1) == 128
if self.device is None:
self.device = batch.device
batch = batch.unsqueeze(1)
with torch.no_grad():
representation = self.audiomae(
batch,
mask_ratio=self.mask_ratio,
no_mask=self.no_audiomae_mask,
no_average=self.no_audiomae_average,
)
representation = self.pool(representation, time_pool, freq_pool)
if self.use_reg:
representation = self.regularization(representation)
return [
representation,
torch.ones((representation.size(0), representation.size(1)))
.to(representation.device)
.float(),
]
def apply_time_mask(spectrogram, mask_width_range=(1000, 1001), max_masks=2):
"""
Apply time masking to a spectrogram (PyTorch tensor).
:param spectrogram: A PyTorch tensor of shape (time_steps, frequency_bands)
:param mask_width_range: A tuple indicating the min and max width of the mask
:param max_masks: Maximum number of masks to apply
:return: Masked spectrogram
"""
time_steps, frequency_bands = spectrogram.shape
masked_spectrogram = spectrogram.clone()
for _ in range(max_masks):
mask_width = torch.randint(mask_width_range[0], mask_width_range[1], (1,)).item()
start_step = torch.randint(0, time_steps - mask_width, (1,)).item()
masked_spectrogram[100:1024, :] = 0 # or another constant value
return masked_spectrogram
def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec= torch.zeros((1024, 128)), num_mels=128):
norm_mean = -4.2677393
norm_std = 4.5689974
if sampling_rate != 16000:
waveform_16k = torchaudio.functional.resample(
waveform, orig_freq=sampling_rate, new_freq=16000
)
else:
waveform_16k = waveform
waveform_16k = waveform_16k - waveform_16k.mean()
fbank = torchaudio.compliance.kaldi.fbank(
waveform_16k,
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type="hanning",
num_mel_bins=num_mels,
dither=0.0,
frame_shift=10,
)
TARGET_LEN = log_mel_spec.size(0)
# cut and pad
n_frames = fbank.shape[0]
p = TARGET_LEN - n_frames
# print(TARGET_LEN)
# print(n_frames)
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
fbank = m(fbank)
elif p < 0:
fbank = fbank[:TARGET_LEN, :]
fbank = (fbank - norm_mean) / (norm_std * 2)
# fbank = apply_time_mask(fbank)
return fbank
if __name__ == "__main__":
filename = '/home/fundwotsai/DreamSound/training_audio_v2/output_slice_18.wav'
waveform, sr = torchaudio.load(filename)
fbank = torch.zeros(
(1024, 128)
)
ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000,fbank)
print(ta_kaldi_fbank.shape)
# melbins = 128 # Number of Mel bins
# target_length = 1024 # Number of frames
# fbank = wav_to_fbank(file_path, melbins, target_length, roll_mag_aug_flag=False)
# print(fbank.shape)
# # Convert to PyTorch tensor and reshape
mel_spect_tensor = torch.tensor(ta_kaldi_fbank).unsqueeze(0) # [Batch, Channel, Time, Frequency]
mel_spect_tensor = mel_spect_tensor.to("cuda")
# Save the figure
print("mel_spect_tensor111.shape",mel_spect_tensor.shape)
model = AudioMAEConditionCTPoolRand().cuda()
print("The first run")
embed = model(mel_spect_tensor, time_pool=1, freq_pool=1)
print(embed[0].shape)
# Reshape tensor for 2D pooling: treat each 768 as a channel
# Example usage
# Assuming the pooling operation reduces the second dimension from 513 to 8
torch.save(embed[0], "MAE_feature1_stride-no-pool.pt")
with open('output_tensor.txt', 'w') as f:
print(embed[0], file=f)