""" Test the speed of the augmentation """ import torch import torchaudio # Device device = torch.device("cuda") # device = torch.device("cpu") # Music # x, _ = torchaudio.load("music.wav") # slice_length = 32767 # n_slices = 80 # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] # x = torch.stack(slices) # (80, 32767) # Sine wave t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz x = torch.sin(2 * torch.pi * 440 * t) * 0.5 x = x.reshape(1, 1, 32767).tile(80, 1, 1) x = x.to(device) ############################################################################################ # torch-audiomentation: https://github.com/asteroid-team/torch-audiomentation # # process time : 1.18 s ± 5.35 ms # process time : 58 ms # GPU memory usage: 3.8 GB per 1 semitone ############################################################################################ import torch from torch_audiomentations import Compose, PitchShift, Gain, PolarityInversion apply_augmentation = Compose(transforms=[ # Gain( # min_gain_in_db=-15.0, # max_gain_in_db=5.0, # p=0.5, # ), # PolarityInversion(p=0.5) PitchShift( min_transpose_semitones=0, max_transpose_semitones=2.2, mode="per_batch", #"per_example", p=1.0, p_mode="per_batch", sample_rate=16000, target_rate=16000) ]) x_am = apply_augmentation(x, sample_rate=16000) ############################################################################################ # torchaudio: # # process time : 4.01 s ± 19.6 ms per loop # process time : 25.1 ms ± 161 µs per loop # memory usage : 1.2 (growth to 5.49) GB per 1 semitone ############################################################################################ from torchaudio import transforms ta_transform = transforms.PitchShift(16000, n_steps=2).to(device) x_ta = ta_transform(x) ############################################################################################ # YourMT3 pitch_shift_layer: # # process time : 389ms ± 22ms, (stretch=143 ms, resampler=245 ms) # process time : 7.18 ms ± 17.3 µs (stretch=6.47 ms, resampler=0.71 ms) # memory usage: 16 MB per 1 semitone (average) ############################################################################################ from model.pitchshift_layer import PitchShiftLayer ps_ymt3 = PitchShiftLayer(pshift_range=[2, 2], fs=16000, min_gcd=16, n_fft=2048).to(device) x_ymt3 = ps_ymt3(x, 2) ############################################################################################ # Plot 1: Comparison of Process Time and GPU Memory Usage for 3 Pitch Shifting Methods ############################################################################################ import matplotlib.pyplot as plt # Model names models = ['torch-audiomentation', 'torchaudio', 'YourMT3:PitchShiftLayer'] # Process time (CPU) in seconds cpu_time = [1.18, 4.01, 0.389] # Process time (GPU) in milliseconds gpu_time = [58, 25.1, 7.18] # GPU memory usage in GB gpu_memory = [3.8, 5.49, 0.016] # Creating subplots fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Creating bar charts bar1 = axs[0].bar(models, cpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) bar2 = axs[1].bar(models, gpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) bar3 = axs[2].bar(models, gpu_memory, color=['#FFB6C1', '#ADD8E6', '#98FB98']) # Adding labels and titles axs[0].set_ylabel('Time (s)') axs[0].set_title('Process Time (CPU) bsz=80') axs[1].set_ylabel('Time (ms)') axs[1].set_title('Process Time (GPU) bsz=80') axs[2].set_ylabel('Memory (GB)') axs[2].set_title('GPU Memory Usage per semitone') # Adding grid for better readability of the plots for ax in axs: ax.grid(axis='y') ax.set_yscale('log') ax.set_xticklabels(models, rotation=45, ha="right") # Adding text labels above the bars for i, rect in enumerate(bar1): axs[0].text( rect.get_x() + rect.get_width() / 2, rect.get_height(), f'{cpu_time[i]:.2f} s', ha='center', va='bottom') for i, rect in enumerate(bar2): axs[1].text( rect.get_x() + rect.get_width() / 2, rect.get_height(), f'{gpu_time[i]:.2f} ms', ha='center', va='bottom') for i, rect in enumerate(bar3): axs[2].text( rect.get_x() + rect.get_width() / 2, rect.get_height(), f'{gpu_memory[i]:.3f} GB', ha='center', va='bottom') plt.tight_layout() plt.show() ############################################################################################ # Plot 2: Stretch and Resampler Processing Time Contribution ############################################################################################ # Data processing_type = ['Stretch (Phase Vocoder)', 'Resampler (Conv1D)'] cpu_times = [143, 245] # [Stretch, Resampler] times for CPU in milliseconds gpu_times = [6.47, 0.71] # [Stretch, Resampler] times for GPU in milliseconds # Creating subplots fig, axs = plt.subplots(1, 2, figsize=(12, 6)) # Plotting bar charts axs[0].bar(processing_type, cpu_times, color=['#ADD8E6', '#98FB98']) axs[1].bar(processing_type, gpu_times, color=['#ADD8E6', '#98FB98']) # Adding labels and titles axs[0].set_ylabel('Time (ms)') axs[0].set_title('Contribution of CPU Processing Time: YMT3-PS (BSZ=80)') axs[1].set_title('Contribution of GPU Processing Time: YMT3-PS (BSZ=80)') # Adding grid for better readability of the plots for ax in axs: ax.grid(axis='y') ax.set_yscale('log') # Log scale to better visualize the smaller values # Adding values on top of the bars for ax, times in zip(axs, [cpu_times, gpu_times]): for idx, time in enumerate(times): ax.text(idx, time, f"{time:.2f} ms", ha='center', va='bottom', fontsize=8) plt.tight_layout() plt.show()