Fill-Mask
Transformers
Safetensors
esm
File size: 6,198 Bytes
0e3c3b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from fuson_plm.utils.visualizing import set_font
import matplotlib.pyplot as plt
import numpy as np

# Cosine Increase Masking Rate Scheduler implementation
def compute_cosine_masking_rate(progress, min_rate, max_rate):
    cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
    return min_rate + (max_rate - min_rate) * cosine_increase

def compute_log_linear_masking_rate(progress, min_rate, max_rate):
    # Avoid log(0) by clamping progress to a minimum of a small positive number
    progress = max(progress, 1e-10)
    log_linear_increase = np.log1p(progress) / np.log1p(1)  # Normalizing to keep range in [0, 1]
    return min_rate + (max_rate - min_rate) * log_linear_increase

def compute_stepwise_masking_rate(progress, min_rate, max_rate, total_batches, num_steps):
    # Compute the batch interval and rate increment
    batch_interval = total_batches // num_steps
    rate_increment = (max_rate - min_rate) / (num_steps - 1)  # Include max_rate in steps

    # Determine the current step based on progress
    current_step = int(progress * total_batches / batch_interval)
    # Cap the step number to `num_steps - 1` to ensure max rate is included
    current_step = min(current_step, num_steps - 1)

    # Calculate the masking rate for the current step
    masking_rate = min_rate + current_step * rate_increment
    return masking_rate

def n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=5, scheduler="cosine", num_steps=20):
    set_font()
    # Parameters for the scheduler - using training
    batch_numbers = np.arange(total_batches)
    
    masking_rates = None
    if scheduler == "cosine":
        masking_rates = [compute_cosine_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
    elif scheduler == "log_linear":
        masking_rates = [compute_log_linear_masking_rate(batch / total_batches, min_rate, max_rate) for batch in batch_numbers]
    elif scheduler == "stepwise":
        masking_rates = [compute_stepwise_masking_rate(batch / total_batches, min_rate, max_rate, total_batches, num_steps) for batch in batch_numbers]
    else:
        return

    # Generate masking rates for multiple epochs
    epoch_masking_rates = []
    for epoch in range(num_epochs):
        epoch_masking_rates.extend(masking_rates)

    # Generate batch numbers for the extended epochs
    extended_batch_numbers = np.arange(len(epoch_masking_rates))

    # Plot the masking rate over the batches for multiple epochs
    plt.figure(figsize=(10, 4))
    plt.plot(extended_batch_numbers, epoch_masking_rates, color='black', linewidth=3)

    # Add y ticks
    plt.yticks(
        [0.15, 0.20, 0.25, 0.30, 0.35, 0.40],
        labels=["0.15", "0.20", "0.25", "0.30", "0.35", "0.40"],
        fontsize=30
    )

    # Add x tick labels at the end of each wave
    wave_positions = [total_batches * (i + 1) - 1 for i in range(num_epochs)]
    wave_labels = [str(i + 1) if i < num_epochs - 1 else "N" for i in range(num_epochs)]

    plt.xticks(
        wave_positions,
        labels=wave_labels,
        fontsize=30
    )

    # Add "..." between the second and last wave
    if num_epochs > 2:
        mid_x = (wave_positions[1] + wave_positions[-1]) / 2
        plt.text(mid_x, 0.12, "...", ha="center", fontsize=30)


    # Remove axis labels and title
    plt.gca().set_xlabel('')  # Remove x-axis label
    plt.gca().set_ylabel('')  # Remove y-axis label
    plt.title('')  # Remove the title

    plt.tight_layout()
    plt.show()
    plt.savefig(f"{scheduler}_{num_epochs}_epochs.png", dpi=300)

    
def plot_masking_rate_range(min_rate, max_rate=None):
    set_font()
    plt.figure(figsize=(5, 1))  # Make the plot short to emphasize the rectangle
    
    if max_rate is None:
        # Plot a vertical red line at min_rate
        plt.axvline(x=min_rate, color='black', linestyle='-', linewidth=4, label=f"Rate = {min_rate}")
        #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='red', ha='center', va='center', fontsize=10)
    else:
        # Shade the range from min_rate to max_rate in red
        plt.fill_betweenx([0, 1], min_rate, max_rate, color='black', alpha=0.5)
        #plt.text(min_rate, 0.5, f"{min_rate:.2f}", color='black', ha='center', va='center', fontsize=10)
        #plt.text(max_rate, 0.5, f"{max_rate:.2f}", color='black', ha='center', va='center', fontsize=10)

    # Adjust x-axis
    plt.xlim(0.145, 0.40)
    plt.xticks([0.15, 0.20, 0.25, 0.30, 0.35, 0.40], fontsize=20)
    plt.tick_params(axis='y', which='both', left=False, labelleft=False)  # Remove y-axis ticks and labels

    # Remove unnecessary elements
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['bottom'].set_linewidth(0.5)
    plt.gca().yaxis.set_visible(False)  # Remove y-axis entirely
    plt.xlabel("")  # No x-axis label
    plt.title("")  # No title

    plt.tight_layout()
    plt.show()
    plot_title = f"mask_rate_{min_rate}.png"
    if max_rate is not None:
        plot_title =f"mask_rate_{min_rate}_{max_rate}.png"
    plt.savefig(plot_title, dpi=300)
    

def main():
    min_rate = 0.15
    max_rate = 0.40
    num_steps = 20
    total_batches = 4215
    num_epochs = 3
    
    # Make the 3-epoch cosine plot 
    n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="cosine")
    
    # Make the 3-epoch log-linear plot
    n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="log_linear")
    
    # Make the 3-epoch stepwise plot
    n_epochs_scheduler_plot(min_rate, max_rate, total_batches, num_epochs=num_epochs, scheduler="stepwise")
    
    # Make all the rate plots
    plot_masking_rate_range(0.15)
    plot_masking_rate_range(0.20)
    plot_masking_rate_range(0.25)
    plot_masking_rate_range(0.15, 0.20)
    plot_masking_rate_range(0.15, 0.25)
    plot_masking_rate_range(0.15, 0.30)
    plot_masking_rate_range(0.15, 0.35)
    plot_masking_rate_range(0.15, 0.40)

if __name__ == "__main__":
    main()