import gradio as gr import plotly.graph_objects as go import numpy as np import torch from torch.optim import SGD import torch.optim.lr_scheduler as lr_schedulers # List of scheduler names schedulers = [ "ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR", "CosineAnnealingLR", "CyclicLR", "OneCycleLR", "CosineAnnealingWarmRestarts" ] # Dictionary of scheduler parameters scheduler_params_dict = { "StepLR": { "step_size": {"type": "int", "label": "Step Size", "value": 30}, "gamma": {"type": "float", "label": "Gamma", "value": 0.1}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "MultiStepLR": { "milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]}, "gamma": {"type": "float", "label": "Gamma", "value": 0.1}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "ConstantLR": { "factor": {"type": "float", "label": "Factor", "value": 0.333}, "total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "LinearLR": { "start_factor": {"type": "float", "label": "Start Factor", "value": 0.333}, "end_factor": {"type": "float", "label": "End Factor", "value": 1.0}, "total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "ExponentialLR": { "gamma": {"type": "float", "label": "Gamma", "value": 0.9}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "PolynomialLR": { "total_iters": {"type": "int", "label": "Total Iterations", "value": 5}, "power": {"type": "float", "label": "Power", "value": 1.0}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "CosineAnnealingLR": { "T_max": {"type": "int", "label": "T Max", "value": 50}, "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "CyclicLR": { "base_lr": {"type": "float", "label": "Base LR", "value": 0.001}, "max_lr": {"type": "float", "label": "Max LR", "value": 0.01}, "step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000}, "step_size_down": {"type": "int", "label": "Step Size Down", "value": None}, "mode": {"type": "str", "label": "Mode", "value": "triangular"}, "gamma": {"type": "float", "label": "Gamma", "value": 1.0}, "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True}, "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8}, "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "OneCycleLR": { "max_lr": {"type": "float", "label": "Max LR", "value": 0.01}, "total_steps": {"type": "int", "label": "Total Steps", "value": None}, "epochs": {"type": "int", "label": "Epochs", "value": None}, "steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None}, "pct_start": {"type": "float", "label": "Pct Start", "value": 0.3}, "anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"}, "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True}, "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85}, "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95}, "div_factor": {"type": "float", "label": "Div Factor", "value": 25.0}, "final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0}, "three_phase": {"type": "bool", "label": "Three Phase", "value": False}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} }, "CosineAnnealingWarmRestarts": { "T_0": {"type": "int", "label": "T 0", "value": 10}, "T_mult": {"type": "int", "label": "T Mult", "value": 1}, "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0}, "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1} } } # Function to create a Plotly line plot def create_line_plot(x, y, title="Line Plot"): fig = go.Figure() fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=title)) fig.update_layout(title=title, xaxis_title="Steps", yaxis_title="Learning Rate") return fig # Generic function to get learning rate schedule for any scheduler def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_kwargs=None): if scheduler_kwargs is None: scheduler_kwargs = {} # Initialize a dummy optimizer with a parameter optimizer = SGD([torch.nn.Parameter(torch.zeros(1))], lr=initial_lr) # Dynamically get the scheduler class from the torch.optim.lr_scheduler module scheduler_class = getattr(lr_schedulers, scheduler_name) # Initialize the scheduler with the keyword arguments scheduler = scheduler_class(optimizer, **scheduler_kwargs) # Collect the learning rate for each step lr_schedule = [] for step in range(total_steps): lr_schedule.append(scheduler.get_last_lr()[0]) optimizer.step() scheduler.step() return lr_schedule # Function to generate the input blocks based on the selected scheduler def generate_scheduler_inputs(scheduler_name): params = scheduler_params_dict.get(scheduler_name, {}) inputs = [] for param, details in params.items(): param_type = details["type"] if param_type == "float": value = details["value"] if details["value"] is not None else lambda:None input_field = gr.Number(label=details["label"], value=value) elif param_type == "int": value = details["value"] if details["value"] is not None else lambda:None input_field = gr.Number(label=details["label"], value=value) elif param_type == "str": input_field = gr.Textbox(label=details["label"], value=details["value"]) elif param_type == "bool": input_field = gr.Checkbox(label=details["label"], value=details["value"]) elif param_type == "list[int]": input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers") inputs.append(input_field) return inputs def preprocess_value(value, param_type): if param_type == "float": return float(value) elif param_type == "int": return int(value) elif param_type == "str": return str(value) elif param_type == "bool": return bool(value) elif param_type == "list[int]": return [int(val.strip()) for val in value.split(",")] # Wrapper function for Gradio that handles scheduler selection def interactive_plot(display_steps, initial_lr, scheduler_name, *args): # Define x as steps for visualization steps = np.arange(int(display_steps)) scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"]) for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)} # Generate learning rate schedule for the selected scheduler using scheduler_params lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params) # Plot the learning rate schedule title = f"Learning Rate Schedule with {scheduler_name}" fig = create_line_plot(steps, lr_schedule, title=title) return fig def main(): plot = gr.Plot(label="Learning Rate Graph") plot_button = gr.Button("Plot Graph") # Define Gradio interface with gr.Blocks() as demo: gr.Markdown("# Learning Rate Schedule Plotter") display_steps = gr.Number(label="Display Steps", value=1000, interactive=True) initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True) # Dropdown for scheduler selection scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR") # Function to update scheduler parameter inputs based on selected scheduler @gr.render(inputs=scheduler_dropdown) def update_scheduler_params(scheduler_name): # Generate the appropriate input fields based on selected scheduler inputs = generate_scheduler_inputs(scheduler_name) plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot) plot.render() plot_button.render() # Launch the Gradio app demo.launch() if __name__ == "__main__": main()