File size: 6,198 Bytes
0e3c3b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from fuson_plm.utils.visualizing import set_font
import matplotlib.pyplot as plt
import numpy as np
# Cosine Increase Masking Rate Scheduler implementation
def compute_cosine_masking_rate(progress, min_rate, max_rate):
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
return min_rate + (max_rate - min_rate) * cosine_increase
def compute_log_linear_masking_rate(progress, min_rate, max_rate):
# Avoid log(0) by clamping progress to a minimum of a small positive number
progress = max(progress, 1e-10)
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
return min_rate + (max_rate - min_rate) * log_linear_increase
def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps):
# Compute the batch interval and rate increment
batch_interval = total_batches // num_steps
rate_increment = (max_rate - min_rate) / (num_steps - 1) # Include max_rate in steps
# Determine the current step based on progress
current_step = int(progress * total_batches / batch_interval)
# Cap the step number to `num_steps - 1` to ensure max rate is included
current_step = min(current_step, num_steps - 1)
# Calculate the masking rate for the current step
masking_rate = min_rate + current_step * rate_increment
return masking_rate
def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20):
set_font()
# Parameters for the scheduler - using training
batch_numbers = np.arange(total_batches)
masking_rates = None
if scheduler == "cosine":
masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
elif scheduler == "log_linear":
masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
elif scheduler == "stepwise":
masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers]
else:
return
# Generate masking rates for multiple epochs
epoch_masking_rates = []
for epoch in range(num_epochs):
epoch_masking_rates.extend(masking_rates)
# Generate batch numbers for the extended epochs
extended_batch_numbers = np.arange(len(epoch_masking_rates))
# Plot the masking rate over the batches for multiple epochs
plt.figure(figsize=(10, 4))
plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3)
# Add y ticks
plt.yticks(
[0.15, 0.20, 0.25, 0.30, 0.35, 0.40],
labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"],
fontsize=30
)
# Add x tick labels at the end of each wave
wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)]
wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)]
plt.xticks(
wave_positions,
labels=wave_labels,
fontsize=30
)
# Add "..." between the second and last wave
if num_epochs > 2:
mid_x = (wave_positions[1] + wave_positions[-1]) / 2
plt.text(mid_x, 0.12, "...", ha="center", fontsize=30)
# Remove axis labels and title
plt.gca().set_xlabel('') # Remove x-axis label
plt.gca().set_ylabel('') # Remove y-axis label
plt.title('') # Remove the title
plt.tight_layout()
plt.show()
plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300)
def plot_masking_rate_range(min_rate, max_rate=None):
set_font()
plt.figure(figsize=(5, 1)) # Make the plot short to emphasize the rectangle
if max_rate is None:
# Plot a vertical red line at min_rate
plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}")
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10)
else:
# Shade the range from min_rate to max_rate in red
plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5)
#plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
#plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
# Adjust x-axis
plt.xlim(0.145, 0.40)
plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20)
plt.tick_params(axis='y', which='both', left=False, labelleft=False) # Remove y-axis ticks and labels
# Remove unnecessary elements
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_visible(False)
plt.gca().spines['bottom'].set_linewidth(0.5)
plt.gca().yaxis.set_visible(False) # Remove y-axis entirely
plt.xlabel("") # No x-axis label
plt.title("") # No title
plt.tight_layout()
plt.show()
plot_title = f"mask_rate_{min_rate}.png"
if max_rate is not None:
plot_title =f"mask_rate_{min_rate}_{max_rate}.png"
plt.savefig(plot_title, dpi=300)
def main():
min_rate = 0.15
max_rate = 0.40
num_steps = 20
total_batches = 4215
num_epochs = 3
# Make the 3-epoch cosine plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine")
# Make the 3-epoch log-linear plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear")
# Make the 3-epoch stepwise plot
n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise")
# Make all the rate plots
plot_masking_rate_range(0.15)
plot_masking_rate_range(0.20)
plot_masking_rate_range(0.25)
plot_masking_rate_range(0.15, 0.20)
plot_masking_rate_range(0.15, 0.25)
plot_masking_rate_range(0.15, 0.30)
plot_masking_rate_range(0.15, 0.35)
plot_masking_rate_range(0.15, 0.40)
if __name__ == "__main__":
main() |