import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np from fuson_plm.utils.visualizing import set_font import matplotlib.pyplot as plt import numpy as np # Cosine Increase Masking Rate Scheduler implementation def compute_cosine_masking_rate(progress, min_rate, max_rate): cosine_increase = 0.5 * (1 - np.cos(np.pi * progress)) return min_rate + (max_rate - min_rate) * cosine_increase def compute_log_linear_masking_rate(progress, min_rate, max_rate): # Avoid log(0) by clamping progress to a minimum of a small positive number progress = max(progress, 1e-10) log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1] return min_rate + (max_rate - min_rate) * log_linear_increase def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps): # Compute the batch interval and rate increment batch_interval = total_batches // num_steps rate_increment = (max_rate - min_rate) / (num_steps - 1) # Include max_rate in steps # Determine the current step based on progress current_step = int(progress * total_batches / batch_interval) # Cap the step number to `num_steps - 1` to ensure max rate is included current_step = min(current_step, num_steps - 1) # Calculate the masking rate for the current step masking_rate = min_rate + current_step * rate_increment return masking_rate def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20): set_font() # Parameters for the scheduler - using training batch_numbers = np.arange(total_batches) masking_rates = None if scheduler == "cosine": masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers] elif scheduler == "log_linear": masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers] elif scheduler == "stepwise": masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers] else: return # Generate masking rates for multiple epochs epoch_masking_rates = [] for epoch in range(num_epochs): epoch_masking_rates.extend(masking_rates) # Generate batch numbers for the extended epochs extended_batch_numbers = np.arange(len(epoch_masking_rates)) # Plot the masking rate over the batches for multiple epochs plt.figure(figsize=(10, 4)) plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3) # Add y ticks plt.yticks( [0.15, 0.20, 0.25, 0.30, 0.35, 0.40], labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"], fontsize=30 ) # Add x tick labels at the end of each wave wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)] wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)] plt.xticks( wave_positions, labels=wave_labels, fontsize=30 ) # Add "..." between the second and last wave if num_epochs > 2: mid_x = (wave_positions[1] + wave_positions[-1]) / 2 plt.text(mid_x, 0.12, "...", ha="center", fontsize=30) # Remove axis labels and title plt.gca().set_xlabel('') # Remove x-axis label plt.gca().set_ylabel('') # Remove y-axis label plt.title('') # Remove the title plt.tight_layout() plt.show() plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300) def plot_masking_rate_range(min_rate, max_rate=None): set_font() plt.figure(figsize=(5, 1)) # Make the plot short to emphasize the rectangle if max_rate is None: # Plot a vertical red line at min_rate plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}") #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10) else: # Shade the range from min_rate to max_rate in red plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5) #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10) #plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10) # Adjust x-axis plt.xlim(0.145, 0.40) plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20) plt.tick_params(axis='y', which='both', left=False, labelleft=False) # Remove y-axis ticks and labels # Remove unnecessary elements plt.gca().spines['top'].set_visible(False) plt.gca().spines['right'].set_visible(False) plt.gca().spines['left'].set_visible(False) plt.gca().spines['bottom'].set_linewidth(0.5) plt.gca().yaxis.set_visible(False) # Remove y-axis entirely plt.xlabel("") # No x-axis label plt.title("") # No title plt.tight_layout() plt.show() plot_title = f"mask_rate_{min_rate}.png" if max_rate is not None: plot_title =f"mask_rate_{min_rate}_{max_rate}.png" plt.savefig(plot_title, dpi=300) def main(): min_rate = 0.15 max_rate = 0.40 num_steps = 20 total_batches = 4215 num_epochs = 3 # Make the 3-epoch cosine plot n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine") # Make the 3-epoch log-linear plot n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear") # Make the 3-epoch stepwise plot n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise") # Make all the rate plots plot_masking_rate_range(0.15) plot_masking_rate_range(0.20) plot_masking_rate_range(0.25) plot_masking_rate_range(0.15, 0.20) plot_masking_rate_range(0.15, 0.25) plot_masking_rate_range(0.15, 0.30) plot_masking_rate_range(0.15, 0.35) plot_masking_rate_range(0.15, 0.40) if __name__ == "__main__": main()