Spaces:
Sleeping
Sleeping
import gradio as gr | |
import plotly.graph_objects as go | |
import numpy as np | |
import torch | |
from torch.optim import SGD | |
import torch.optim.lr_scheduler as lr_schedulers | |
# List of scheduler names | |
schedulers = [ | |
"ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR", | |
"CosineAnnealingLR", "CyclicLR", "OneCycleLR", | |
"CosineAnnealingWarmRestarts" | |
] | |
# Dictionary of scheduler parameters | |
scheduler_params_dict = { | |
"StepLR": { | |
"step_size": {"type": "int", "label": "Step Size", "value": 30}, | |
"gamma": {"type": "float", "label": "Gamma", "value": 0.1}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"MultiStepLR": { | |
"milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]}, | |
"gamma": {"type": "float", "label": "Gamma", "value": 0.1}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"ConstantLR": { | |
"factor": {"type": "float", "label": "Factor", "value": 0.333}, | |
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"LinearLR": { | |
"start_factor": {"type": "float", "label": "Start Factor", "value": 0.333}, | |
"end_factor": {"type": "float", "label": "End Factor", "value": 1.0}, | |
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"ExponentialLR": { | |
"gamma": {"type": "float", "label": "Gamma", "value": 0.9}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"PolynomialLR": { | |
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, | |
"power": {"type": "float", "label": "Power", "value": 1.0}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"CosineAnnealingLR": { | |
"T_max": {"type": "int", "label": "T Max", "value": 50}, | |
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"CyclicLR": { | |
"base_lr": {"type": "float", "label": "Base LR", "value": 0.001}, | |
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01}, | |
"step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000}, | |
"step_size_down": {"type": "int", "label": "Step Size Down", "value": None}, | |
"mode": {"type": "str", "label": "Mode", "value": "triangular"}, | |
"gamma": {"type": "float", "label": "Gamma", "value": 1.0}, | |
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True}, | |
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8}, | |
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"OneCycleLR": { | |
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01}, | |
"total_steps": {"type": "int", "label": "Total Steps", "value": None}, | |
"epochs": {"type": "int", "label": "Epochs", "value": None}, | |
"steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None}, | |
"pct_start": {"type": "float", "label": "Pct Start", "value": 0.3}, | |
"anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"}, | |
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True}, | |
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85}, | |
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95}, | |
"div_factor": {"type": "float", "label": "Div Factor", "value": 25.0}, | |
"final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0}, | |
"three_phase": {"type": "bool", "label": "Three Phase", "value": False}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
}, | |
"CosineAnnealingWarmRestarts": { | |
"T_0": {"type": "int", "label": "T 0", "value": 10}, | |
"T_mult": {"type": "int", "label": "T Mult", "value": 1}, | |
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0}, | |
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} | |
} | |
} | |
# Function to create a Plotly line plot | |
def create_line_plot(x, y, title="Line Plot"): | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=title)) | |
fig.update_layout(title=title, xaxis_title="Steps", yaxis_title="Learning Rate") | |
return fig | |
# Generic function to get learning rate schedule for any scheduler | |
def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_kwargs=None): | |
if scheduler_kwargs is None: | |
scheduler_kwargs = {} | |
# Initialize a dummy optimizer with a parameter | |
optimizer = SGD([torch.nn.Parameter(torch.zeros(1))], lr=initial_lr) | |
# Dynamically get the scheduler class from the torch.optim.lr_scheduler module | |
scheduler_class = getattr(lr_schedulers, scheduler_name) | |
# Initialize the scheduler with the keyword arguments | |
scheduler = scheduler_class(optimizer, **scheduler_kwargs) | |
# Collect the learning rate for each step | |
lr_schedule = [] | |
for step in range(total_steps): | |
lr_schedule.append(scheduler.get_last_lr()[0]) | |
optimizer.step() | |
scheduler.step() | |
return lr_schedule | |
# Function to generate the input blocks based on the selected scheduler | |
def generate_scheduler_inputs(scheduler_name): | |
params = scheduler_params_dict.get(scheduler_name, {}) | |
inputs = [] | |
for param, details in params.items(): | |
param_type = details["type"] | |
if param_type == "float": | |
value = details["value"] if details["value"] is not None else lambda:None | |
input_field = gr.Number(label=details["label"], value=value) | |
elif param_type == "int": | |
value = details["value"] if details["value"] is not None else lambda:None | |
input_field = gr.Number(label=details["label"], value=value) | |
elif param_type == "str": | |
input_field = gr.Textbox(label=details["label"], value=details["value"]) | |
elif param_type == "bool": | |
input_field = gr.Checkbox(label=details["label"], value=details["value"]) | |
elif param_type == "list[int]": | |
input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers") | |
inputs.append(input_field) | |
return inputs | |
def preprocess_value(value, param_type): | |
if param_type == "float": | |
return float(value) | |
elif param_type == "int": | |
return int(value) | |
elif param_type == "str": | |
return str(value) | |
elif param_type == "bool": | |
return bool(value) | |
elif param_type == "list[int]": | |
return [int(val.strip()) for val in value.split(",")] | |
# Wrapper function for Gradio that handles scheduler selection | |
def interactive_plot(display_steps, initial_lr, scheduler_name, *args): | |
# Define x as steps for visualization | |
steps = np.arange(int(display_steps)) | |
scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"]) | |
for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)} | |
# Generate learning rate schedule for the selected scheduler using scheduler_params | |
lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params) | |
# Plot the learning rate schedule | |
title = f"Learning Rate Schedule with {scheduler_name}" | |
fig = create_line_plot(steps, lr_schedule, title=title) | |
return fig | |
def main(): | |
plot = gr.Plot(label="Learning Rate Graph") | |
plot_button = gr.Button("Plot Graph") | |
# Define Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Learning Rate Schedule Plotter") | |
display_steps = gr.Number(label="Display Steps", value=1000, interactive=True) | |
initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True) | |
# Dropdown for scheduler selection | |
scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR") | |
# Function to update scheduler parameter inputs based on selected scheduler | |
def update_scheduler_params(scheduler_name): | |
# Generate the appropriate input fields based on selected scheduler | |
inputs = generate_scheduler_inputs(scheduler_name) | |
plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot) | |
plot.render() | |
plot_button.render() | |
# Launch the Gradio app | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |