lmzjms's picture
Upload 1162 files
0b32ad6 verified
# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
# FileName [ observe_input.py ]
# Synopsis [ generates visualizations of the model's input ]
# Author [ Ting-Wei Liu (Andi611) ]
# Copyright [ Copyleft(c), NTUEE, NTU, Taiwan ]
"""*********************************************************************************************"""
###############
# IMPORTATION #
###############
import os
import copy
import random
import numpy as np
#-------------#
import torch
import torchaudio
#-------------#
import matplotlib
import matplotlib.pyplot as plt
#-------------#
plt.switch_backend('agg')
seed = 679
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
########
# PATH #
########
utt = 'train-clean-100/1594/135914/1594-135914-0032.flac'
libri_dir = '/media/andi611/1TBSSD/LibriSpeech/'
out_dir = './result/visualization/'
def plot_x(x, name='x', xlabel='Frames'):
x = x.transpose(1, 0)
fig, ax = plt.subplots(figsize=(10, 3))
im = ax.imshow(x, aspect='auto', origin='lower',
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel(xlabel)
plt.ylabel('Channels')
plt.tight_layout()
plt.margins(0,0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
fig.canvas.draw()
fig.savefig(os.path.join(out_dir, name + '.png'), bbox_inches='tight', pad_inches = 0)
def starts_to_intervals(starts, consecutive):
tiled = starts.expand(consecutive, starts.size(0)).T
offset = torch.arange(consecutive).expand_as(tiled)
intervals = tiled + offset
return intervals.view(-1)
########
# MAIN #
########
def main():
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
# plot original
extracter = torch.hub.load('s3prl/s3prl', 'mel')
wav, _ = torchaudio.load(os.path.join(libri_dir, utt))
wavs = [wav]
x = extracter(wavs)[0].squeeze()
plot_x(x, name='x', xlabel='A) Original fMLLR feature')
# to torch tensor
x = torch.FloatTensor(x)
x_all = copy.deepcopy(x)
# time masking
mask_consecutive_min = 7
mask_consecutive_max = 7
mask_proportion = 0.15
mask_allow_overlap = True
mask_consecutive = random.randint(mask_consecutive_min, mask_consecutive_max)
valid_start_max = max(x.size(0) - mask_consecutive - 1, 0) # compute max valid start point for a consecutive mask
proportion = round(x.size(0) * mask_proportion / mask_consecutive)
if mask_allow_overlap:
# draw `proportion` samples from the range (0, valid_index_range) and without replacement
chosen_starts = torch.randperm(valid_start_max + 1)[:proportion]
else:
mask_bucket_size = round(mask_consecutive * mask_bucket_ratio)
rand_start = random.randint(0, min(mask_consecutive, valid_start_max))
valid_starts = torch.arange(rand_start, valid_start_max + 1, mask_bucket_size)
chosen_starts = valid_starts[torch.randperm(len(valid_starts))[:proportion]]
chosen_intervals = starts_to_intervals(chosen_starts, mask_consecutive)
# mask to zero
x_time_zero = copy.deepcopy(x)
x_time_zero[chosen_intervals, :] = 0
x_all[chosen_intervals, :] = 0
plot_x(x_time_zero.data.cpu().numpy(), name='x_time_zero', xlabel='B) Mask contiguous segments to zero along temporal axis')
# replace to random frames
random_starts = torch.randperm(valid_start_max + 1)[:proportion]
random_intervals = starts_to_intervals(random_starts, mask_consecutive)
x_time_replace = copy.deepcopy(x)
x_time_replace[chosen_intervals, :] = x_time_replace[random_intervals, :]
plot_x(x_time_replace.data.cpu().numpy(), name='x_time_replace', xlabel='C) Replace contiguous segments with random segments')
# frequency masking
mask_frequency = 16
rand_bandwidth = mask_frequency #random.randint(0, mask_frequency)
chosen_starts = torch.randperm(x.size(1) - rand_bandwidth)[:1]
chosen_intervals = starts_to_intervals(chosen_starts, rand_bandwidth)
x_freq = copy.deepcopy(x)
x_freq[:, chosen_intervals] = 0
x_all[:, chosen_intervals] = 0
plot_x(x_freq.data.cpu().numpy(), name='x_freq', xlabel='D) Mask contiguous segments to zero along channel axis')
# noise augmentation
noise_sampler = torch.distributions.Normal(0, 0.2)
x_noise = copy.deepcopy(x)
x_noise += noise_sampler.sample(x_noise.shape)
x_all += noise_sampler.sample(x_all.shape)
plot_x(x_noise.data.cpu().numpy(), name='x_noise', xlabel='E) Apply sampled Gaussian noise to magnitude')
# time + freq + noise
plot_x(x_all.data.cpu().numpy(), name='x_all', xlabel='F) Combining the alterations in B), D), and E)')
if __name__ == '__main__':
main()