Suhas-G's picture
implemented all the remaining things
8fb5f44
import gradio as gr
import plotly.graph_objects as go
import numpy as np
import torch
from torch.optim import SGD
import torch.optim.lr_scheduler as lr_schedulers
# List of scheduler names
schedulers = [
"ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
"CosineAnnealingLR", "CyclicLR", "OneCycleLR",
"CosineAnnealingWarmRestarts"
]
# Dictionary of scheduler parameters
scheduler_params_dict = {
"StepLR": {
"step_size": {"type": "int", "label": "Step Size", "value": 30},
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"MultiStepLR": {
"milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]},
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"ConstantLR": {
"factor": {"type": "float", "label": "Factor", "value": 0.333},
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"LinearLR": {
"start_factor": {"type": "float", "label": "Start Factor", "value": 0.333},
"end_factor": {"type": "float", "label": "End Factor", "value": 1.0},
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"ExponentialLR": {
"gamma": {"type": "float", "label": "Gamma", "value": 0.9},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"PolynomialLR": {
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
"power": {"type": "float", "label": "Power", "value": 1.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CosineAnnealingLR": {
"T_max": {"type": "int", "label": "T Max", "value": 50},
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CyclicLR": {
"base_lr": {"type": "float", "label": "Base LR", "value": 0.001},
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
"step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000},
"step_size_down": {"type": "int", "label": "Step Size Down", "value": None},
"mode": {"type": "str", "label": "Mode", "value": "triangular"},
"gamma": {"type": "float", "label": "Gamma", "value": 1.0},
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8},
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"OneCycleLR": {
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
"total_steps": {"type": "int", "label": "Total Steps", "value": None},
"epochs": {"type": "int", "label": "Epochs", "value": None},
"steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None},
"pct_start": {"type": "float", "label": "Pct Start", "value": 0.3},
"anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"},
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85},
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95},
"div_factor": {"type": "float", "label": "Div Factor", "value": 25.0},
"final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0},
"three_phase": {"type": "bool", "label": "Three Phase", "value": False},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
},
"CosineAnnealingWarmRestarts": {
"T_0": {"type": "int", "label": "T 0", "value": 10},
"T_mult": {"type": "int", "label": "T Mult", "value": 1},
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
}
}
# Function to create a Plotly line plot
def create_line_plot(x, y, title="Line Plot"):
fig = go.Figure()
fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=title))
fig.update_layout(title=title, xaxis_title="Steps", yaxis_title="Learning Rate")
return fig
# Generic function to get learning rate schedule for any scheduler
def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_kwargs=None):
if scheduler_kwargs is None:
scheduler_kwargs = {}
# Initialize a dummy optimizer with a parameter
optimizer = SGD([torch.nn.Parameter(torch.zeros(1))], lr=initial_lr)
# Dynamically get the scheduler class from the torch.optim.lr_scheduler module
scheduler_class = getattr(lr_schedulers, scheduler_name)
# Initialize the scheduler with the keyword arguments
scheduler = scheduler_class(optimizer, **scheduler_kwargs)
# Collect the learning rate for each step
lr_schedule = []
for step in range(total_steps):
lr_schedule.append(scheduler.get_last_lr()[0])
optimizer.step()
scheduler.step()
return lr_schedule
# Function to generate the input blocks based on the selected scheduler
def generate_scheduler_inputs(scheduler_name):
params = scheduler_params_dict.get(scheduler_name, {})
inputs = []
for param, details in params.items():
param_type = details["type"]
if param_type == "float":
value = details["value"] if details["value"] is not None else lambda:None
input_field = gr.Number(label=details["label"], value=value)
elif param_type == "int":
value = details["value"] if details["value"] is not None else lambda:None
input_field = gr.Number(label=details["label"], value=value)
elif param_type == "str":
input_field = gr.Textbox(label=details["label"], value=details["value"])
elif param_type == "bool":
input_field = gr.Checkbox(label=details["label"], value=details["value"])
elif param_type == "list[int]":
input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers")
inputs.append(input_field)
return inputs
def preprocess_value(value, param_type):
if param_type == "float":
return float(value)
elif param_type == "int":
return int(value)
elif param_type == "str":
return str(value)
elif param_type == "bool":
return bool(value)
elif param_type == "list[int]":
return [int(val.strip()) for val in value.split(",")]
# Wrapper function for Gradio that handles scheduler selection
def interactive_plot(display_steps, initial_lr, scheduler_name, *args):
# Define x as steps for visualization
steps = np.arange(int(display_steps))
scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"])
for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)}
# Generate learning rate schedule for the selected scheduler using scheduler_params
lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params)
# Plot the learning rate schedule
title = f"Learning Rate Schedule with {scheduler_name}"
fig = create_line_plot(steps, lr_schedule, title=title)
return fig
def main():
plot = gr.Plot(label="Learning Rate Graph")
plot_button = gr.Button("Plot Graph")
# Define Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Learning Rate Schedule Plotter")
display_steps = gr.Number(label="Display Steps", value=1000, interactive=True)
initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True)
# Dropdown for scheduler selection
scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
# Function to update scheduler parameter inputs based on selected scheduler
@gr.render(inputs=scheduler_dropdown)
def update_scheduler_params(scheduler_name):
# Generate the appropriate input fields based on selected scheduler
inputs = generate_scheduler_inputs(scheduler_name)
plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot)
plot.render()
plot_button.render()
# Launch the Gradio app
demo.launch()
if __name__ == "__main__":
main()