File size: 9,080 Bytes
f100013
 
 
 
 
 
 
 
 
 
8fb5f44
f100013
 
 
8fb5f44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f100013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb5f44
f100013
 
 
 
8fb5f44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f100013
8fb5f44
f100013
8fb5f44
f100013
8fb5f44
 
f100013
8fb5f44
f100013
 
 
 
 
 
 
8fb5f44
 
f100013
 
 
8fb5f44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
import plotly.graph_objects as go
import numpy as np
import torch
from torch.optim import SGD
import torch.optim.lr_scheduler as lr_schedulers

# List of scheduler names
schedulers = [
    "ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
    "CosineAnnealingLR", "CyclicLR", "OneCycleLR",
    "CosineAnnealingWarmRestarts"
]

# Dictionary of scheduler parameters
scheduler_params_dict = {
    "StepLR": {
        "step_size": {"type": "int", "label": "Step Size", "value": 30},
        "gamma": {"type": "float", "label": "Gamma", "value": 0.1},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "MultiStepLR": {
        "milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]},
        "gamma": {"type": "float", "label": "Gamma", "value": 0.1},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "ConstantLR": {
        "factor": {"type": "float", "label": "Factor", "value": 0.333},
        "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "LinearLR": {
        "start_factor": {"type": "float", "label": "Start Factor", "value": 0.333},
        "end_factor": {"type": "float", "label": "End Factor", "value": 1.0},
        "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "ExponentialLR": {
        "gamma": {"type": "float", "label": "Gamma", "value": 0.9},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "PolynomialLR": {
        "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
        "power": {"type": "float", "label": "Power", "value": 1.0},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "CosineAnnealingLR": {
        "T_max": {"type": "int", "label": "T Max", "value": 50},
        "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "CyclicLR": {
        "base_lr": {"type": "float", "label": "Base LR", "value": 0.001},
        "max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
        "step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000},
        "step_size_down": {"type": "int", "label": "Step Size Down", "value": None},
        "mode": {"type": "str", "label": "Mode", "value": "triangular"},
        "gamma": {"type": "float", "label": "Gamma", "value": 1.0},
        "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
        "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8},
        "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "OneCycleLR": {
        "max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
        "total_steps": {"type": "int", "label": "Total Steps", "value": None},
        "epochs": {"type": "int", "label": "Epochs", "value": None},
        "steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None},
        "pct_start": {"type": "float", "label": "Pct Start", "value": 0.3},
        "anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"},
        "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
        "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85},
        "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95},
        "div_factor": {"type": "float", "label": "Div Factor", "value": 25.0},
        "final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0},
        "three_phase": {"type": "bool", "label": "Three Phase", "value": False},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    },
    "CosineAnnealingWarmRestarts": {
        "T_0": {"type": "int", "label": "T 0", "value": 10},
        "T_mult": {"type": "int", "label": "T Mult", "value": 1},
        "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
        "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
    }
}


# Function to create a Plotly line plot
def create_line_plot(x, y, title="Line Plot"):
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=title))
    fig.update_layout(title=title, xaxis_title="Steps", yaxis_title="Learning Rate")
    return fig

# Generic function to get learning rate schedule for any scheduler
def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_kwargs=None):
    if scheduler_kwargs is None:
        scheduler_kwargs = {}
    
    # Initialize a dummy optimizer with a parameter
    optimizer = SGD([torch.nn.Parameter(torch.zeros(1))], lr=initial_lr)

    # Dynamically get the scheduler class from the torch.optim.lr_scheduler module
    scheduler_class = getattr(lr_schedulers, scheduler_name)
    # Initialize the scheduler with the keyword arguments
    scheduler = scheduler_class(optimizer, **scheduler_kwargs)
    
    # Collect the learning rate for each step
    lr_schedule = []
    for step in range(total_steps):
        lr_schedule.append(scheduler.get_last_lr()[0])
        optimizer.step()
        scheduler.step()
    
    return lr_schedule

# Function to generate the input blocks based on the selected scheduler
def generate_scheduler_inputs(scheduler_name):
    params = scheduler_params_dict.get(scheduler_name, {})
    
    inputs = []
    for param, details in params.items():
        param_type = details["type"]
        
        if param_type == "float":
            value = details["value"] if details["value"] is not None else lambda:None
            input_field = gr.Number(label=details["label"], value=value)
        elif param_type == "int":
            value = details["value"] if details["value"] is not None else lambda:None
            input_field = gr.Number(label=details["label"], value=value)
        elif param_type == "str":
            input_field = gr.Textbox(label=details["label"], value=details["value"])
        elif param_type == "bool":
            input_field = gr.Checkbox(label=details["label"], value=details["value"])
        elif param_type == "list[int]":
            input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers")

        inputs.append(input_field)
    
    return inputs

def preprocess_value(value, param_type):
    if param_type == "float":
        return float(value)
    elif param_type == "int":
        return int(value)
    elif param_type == "str":
        return str(value)
    elif param_type == "bool":
        return bool(value)
    elif param_type == "list[int]":
        return [int(val.strip()) for val in value.split(",")]

# Wrapper function for Gradio that handles scheduler selection
def interactive_plot(display_steps, initial_lr, scheduler_name, *args):
    # Define x as steps for visualization
    steps = np.arange(int(display_steps))

    scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"]) 
                            for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)}
    # Generate learning rate schedule for the selected scheduler using scheduler_params
    lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params)
    
    # Plot the learning rate schedule
    title = f"Learning Rate Schedule with {scheduler_name}"
    fig = create_line_plot(steps, lr_schedule, title=title)

    return fig


def main():
    plot = gr.Plot(label="Learning Rate Graph")
    plot_button = gr.Button("Plot Graph")

    # Define Gradio interface
    with gr.Blocks() as demo:
        gr.Markdown("# Learning Rate Schedule Plotter")
        display_steps = gr.Number(label="Display Steps", value=1000, interactive=True)
        initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True)
        
        # Dropdown for scheduler selection
        scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")

        # Function to update scheduler parameter inputs based on selected scheduler
        @gr.render(inputs=scheduler_dropdown)
        def update_scheduler_params(scheduler_name):
            # Generate the appropriate input fields based on selected scheduler
            inputs = generate_scheduler_inputs(scheduler_name)

            plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot)

        plot.render()
        plot_button.render()

    # Launch the Gradio app
    demo.launch()


if __name__ == "__main__":
    main()