Spaces:
Sleeping
Sleeping
File size: 9,080 Bytes
f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 f100013 8fb5f44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
import plotly.graph_objects as go
import numpy as np
import torch
from torch.optim import SGD
import torch.optim.lr_scheduler as lr_schedulers
# List of scheduler names
schedulers = [
"ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
"CosineAnnealingLR", "CyclicLR", "OneCycleLR",
"CosineAnnealingWarmRestarts"
]
# Dictionary of scheduler parameters
scheduler_params_dict = {
"StepLR": {
"step_size": {"type": "int", "label": "Step Size", "value": 30},
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"MultiStepLR": {
"milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]},
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"ConstantLR": {
"factor": {"type": "float", "label": "Factor", "value": 0.333},
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"LinearLR": {
"start_factor": {"type": "float", "label": "Start Factor", "value": 0.333},
"end_factor": {"type": "float", "label": "End Factor", "value": 1.0},
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"ExponentialLR": {
"gamma": {"type": "float", "label": "Gamma", "value": 0.9},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"PolynomialLR": {
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"power": {"type": "float", "label": "Power", "value": 1.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CosineAnnealingLR": {
"T_max": {"type": "int", "label": "T Max", "value": 50},
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CyclicLR": {
"base_lr": {"type": "float", "label": "Base LR", "value": 0.001},
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
"step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000},
"step_size_down": {"type": "int", "label": "Step Size Down", "value": None},
"mode": {"type": "str", "label": "Mode", "value": "triangular"},
"gamma": {"type": "float", "label": "Gamma", "value": 1.0},
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8},
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"OneCycleLR": {
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
"total_steps": {"type": "int", "label": "Total Steps", "value": None},
"epochs": {"type": "int", "label": "Epochs", "value": None},
"steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None},
"pct_start": {"type": "float", "label": "Pct Start", "value": 0.3},
"anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"},
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85},
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95},
"div_factor": {"type": "float", "label": "Div Factor", "value": 25.0},
"final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0},
"three_phase": {"type": "bool", "label": "Three Phase", "value": False},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CosineAnnealingWarmRestarts": {
"T_0": {"type": "int", "label": "T 0", "value": 10},
"T_mult": {"type": "int", "label": "T Mult", "value": 1},
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
}
}
# Function to create a Plotly line plot
def create_line_plot(x, y, title="Line Plot"):
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=title))
fig.update_layout(title=title, xaxis_title="Steps", yaxis_title="Learning Rate")
return fig
# Generic function to get learning rate schedule for any scheduler
def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_kwargs=None):
if scheduler_kwargs is None:
scheduler_kwargs = {}
# Initialize a dummy optimizer with a parameter
optimizer = SGD([torch.nn.Parameter(torch.zeros(1))], lr=initial_lr)
# Dynamically get the scheduler class from the torch.optim.lr_scheduler module
scheduler_class = getattr(lr_schedulers, scheduler_name)
# Initialize the scheduler with the keyword arguments
scheduler = scheduler_class(optimizer, **scheduler_kwargs)
# Collect the learning rate for each step
lr_schedule = []
for step in range(total_steps):
lr_schedule.append(scheduler.get_last_lr()[0])
optimizer.step()
scheduler.step()
return lr_schedule
# Function to generate the input blocks based on the selected scheduler
def generate_scheduler_inputs(scheduler_name):
params = scheduler_params_dict.get(scheduler_name, {})
inputs = []
for param, details in params.items():
param_type = details["type"]
if param_type == "float":
value = details["value"] if details["value"] is not None else lambda:None
input_field = gr.Number(label=details["label"], value=value)
elif param_type == "int":
value = details["value"] if details["value"] is not None else lambda:None
input_field = gr.Number(label=details["label"], value=value)
elif param_type == "str":
input_field = gr.Textbox(label=details["label"], value=details["value"])
elif param_type == "bool":
input_field = gr.Checkbox(label=details["label"], value=details["value"])
elif param_type == "list[int]":
input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers")
inputs.append(input_field)
return inputs
def preprocess_value(value, param_type):
if param_type == "float":
return float(value)
elif param_type == "int":
return int(value)
elif param_type == "str":
return str(value)
elif param_type == "bool":
return bool(value)
elif param_type == "list[int]":
return [int(val.strip()) for val in value.split(",")]
# Wrapper function for Gradio that handles scheduler selection
def interactive_plot(display_steps, initial_lr, scheduler_name, *args):
# Define x as steps for visualization
steps = np.arange(int(display_steps))
scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"])
for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)}
# Generate learning rate schedule for the selected scheduler using scheduler_params
lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params)
# Plot the learning rate schedule
title = f"Learning Rate Schedule with {scheduler_name}"
fig = create_line_plot(steps, lr_schedule, title=title)
return fig
def main():
plot = gr.Plot(label="Learning Rate Graph")
plot_button = gr.Button("Plot Graph")
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Learning Rate Schedule Plotter")
display_steps = gr.Number(label="Display Steps", value=1000, interactive=True)
initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True)
# Dropdown for scheduler selection
scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
# Function to update scheduler parameter inputs based on selected scheduler
@gr.render(inputs=scheduler_dropdown)
def update_scheduler_params(scheduler_name):
# Generate the appropriate input fields based on selected scheduler
inputs = generate_scheduler_inputs(scheduler_name)
plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot)
plot.render()
plot_button.render()
# Launch the Gradio app
demo.launch()
if __name__ == "__main__":
main()
|