File size: 5,843 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
""" Test the speed of the augmentation """
import torch
import torchaudio

# Device
device = torch.device("cuda")
# device = torch.device("cpu")

# Music
# x, _ = torchaudio.load("music.wav")
# slice_length = 32767
# n_slices = 80
# slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)]
# x = torch.stack(slices)  # (80, 32767)
# Sine wave
t = torch.arange(0, 2.0479, 1 / 16000)  # 2.05 seconds at 16kHz
x = torch.sin(2 * torch.pi * 440 * t) * 0.5
x = x.reshape(1, 1, 32767).tile(80, 1, 1)
x = x.to(device)

############################################################################################
# torch-audiomentation: https://github.com/asteroid-team/torch-audiomentation
#
# process time <CPU>: 1.18 s ± 5.35 ms
# process time <GPU>: 58 ms
# GPU memory usage: 3.8 GB per 1 semitone
############################################################################################
import torch
from torch_audiomentations import Compose, PitchShift, Gain, PolarityInversion

apply_augmentation = Compose(transforms=[
    # Gain(
    #     min_gain_in_db=-15.0,
    #     max_gain_in_db=5.0,
    #     p=0.5,
    # ),
    # PolarityInversion(p=0.5)
    PitchShift(
        min_transpose_semitones=0,
        max_transpose_semitones=2.2,
        mode="per_batch",  #"per_example",
        p=1.0,
        p_mode="per_batch",
        sample_rate=16000,
        target_rate=16000)
])
x_am = apply_augmentation(x, sample_rate=16000)

############################################################################################
# torchaudio:
#
# process time <CPU>: 4.01 s ± 19.6 ms per loop
# process time <GPU>: 25.1 ms ± 161 µs per loop
# memory usage <GPU>: 1.2 (growth to 5.49) GB per 1 semitone
############################################################################################
from torchaudio import transforms

ta_transform = transforms.PitchShift(16000, n_steps=2).to(device)
x_ta = ta_transform(x)

############################################################################################
# YourMT3 pitch_shift_layer:
#
# process time <CPU>: 389ms ± 22ms, (stretch=143 ms, resampler=245 ms)
# process time <GPU>: 7.18 ms ± 17.3 µs (stretch=6.47 ms, resampler=0.71 ms)
# memory usage: 16 MB per 1 semitone (average)
############################################################################################
from model.pitchshift_layer import PitchShiftLayer

ps_ymt3 = PitchShiftLayer(pshift_range=[2, 2], fs=16000, min_gcd=16, n_fft=2048).to(device)
x_ymt3 = ps_ymt3(x, 2)

############################################################################################
# Plot 1: Comparison of Process Time and GPU Memory Usage for 3 Pitch Shifting Methods
############################################################################################
import matplotlib.pyplot as plt

# Model names
models = ['torch-audiomentation', 'torchaudio', 'YourMT3:PitchShiftLayer']

# Process time (CPU) in seconds
cpu_time = [1.18, 4.01, 0.389]

# Process time (GPU) in milliseconds
gpu_time = [58, 25.1, 7.18]

# GPU memory usage in GB
gpu_memory = [3.8, 5.49, 0.016]

# Creating subplots
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Creating bar charts
bar1 = axs[0].bar(models, cpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98'])
bar2 = axs[1].bar(models, gpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98'])
bar3 = axs[2].bar(models, gpu_memory, color=['#FFB6C1', '#ADD8E6', '#98FB98'])

# Adding labels and titles
axs[0].set_ylabel('Time (s)')
axs[0].set_title('Process Time (CPU) bsz=80')
axs[1].set_ylabel('Time (ms)')
axs[1].set_title('Process Time (GPU) bsz=80')
axs[2].set_ylabel('Memory (GB)')
axs[2].set_title('GPU Memory Usage per semitone')

# Adding grid for better readability of the plots
for ax in axs:
    ax.grid(axis='y')
    ax.set_yscale('log')
    ax.set_xticklabels(models, rotation=45, ha="right")

# Adding text labels above the bars
for i, rect in enumerate(bar1):
    axs[0].text(
        rect.get_x() + rect.get_width() / 2,
        rect.get_height(),
        f'{cpu_time[i]:.2f} s',
        ha='center',
        va='bottom')
for i, rect in enumerate(bar2):
    axs[1].text(
        rect.get_x() + rect.get_width() / 2,
        rect.get_height(),
        f'{gpu_time[i]:.2f} ms',
        ha='center',
        va='bottom')
for i, rect in enumerate(bar3):
    axs[2].text(
        rect.get_x() + rect.get_width() / 2,
        rect.get_height(),
        f'{gpu_memory[i]:.3f} GB',
        ha='center',
        va='bottom')
plt.tight_layout()
plt.show()

############################################################################################
# Plot 2: Stretch and Resampler Processing Time Contribution
############################################################################################
# Data
processing_type = ['Stretch (Phase Vocoder)', 'Resampler (Conv1D)']
cpu_times = [143, 245]  # [Stretch, Resampler] times for CPU in milliseconds
gpu_times = [6.47, 0.71]  # [Stretch, Resampler] times for GPU in milliseconds

# Creating subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Plotting bar charts
axs[0].bar(processing_type, cpu_times, color=['#ADD8E6', '#98FB98'])
axs[1].bar(processing_type, gpu_times, color=['#ADD8E6', '#98FB98'])

# Adding labels and titles
axs[0].set_ylabel('Time (ms)')
axs[0].set_title('Contribution of CPU Processing Time: YMT3-PS (BSZ=80)')
axs[1].set_title('Contribution of GPU Processing Time: YMT3-PS (BSZ=80)')

# Adding grid for better readability of the plots
for ax in axs:
    ax.grid(axis='y')
    ax.set_yscale('log')  # Log scale to better visualize the smaller values

# Adding values on top of the bars
for ax, times in zip(axs, [cpu_times, gpu_times]):
    for idx, time in enumerate(times):
        ax.text(idx, time, f"{time:.2f} ms", ha='center', va='bottom', fontsize=8)
plt.tight_layout()
plt.show()