Suhas-G commited on
Commit
8fb5f44
·
1 Parent(s): f100013

implemented all the remaining things

Browse files
Files changed (1) hide show
  1. app.py +145 -21
app.py CHANGED
@@ -8,10 +8,83 @@ import torch.optim.lr_scheduler as lr_schedulers
8
  # List of scheduler names
9
  schedulers = [
10
  "ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
11
- "CosineAnnealingLR", "CyclicLR", "OneCycleLR", "ReduceLROnPlateau",
12
  "CosineAnnealingWarmRestarts"
13
  ]
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # Function to create a Plotly line plot
16
  def create_line_plot(x, y, title="Line Plot"):
17
  fig = go.Figure()
@@ -29,7 +102,6 @@ def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_k
29
 
30
  # Dynamically get the scheduler class from the torch.optim.lr_scheduler module
31
  scheduler_class = getattr(lr_schedulers, scheduler_name)
32
-
33
  # Initialize the scheduler with the keyword arguments
34
  scheduler = scheduler_class(optimizer, **scheduler_kwargs)
35
 
@@ -37,17 +109,57 @@ def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_k
37
  lr_schedule = []
38
  for step in range(total_steps):
39
  lr_schedule.append(scheduler.get_last_lr()[0])
 
40
  scheduler.step()
41
 
42
  return lr_schedule
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  # Wrapper function for Gradio that handles scheduler selection
45
- def interactive_plot(x_min, x_max, scheduler_name, scheduler_params):
46
  # Define x as steps for visualization
47
- steps = np.arange(int(x_max))
48
 
 
 
49
  # Generate learning rate schedule for the selected scheduler using scheduler_params
50
- lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=0.1, total_steps=len(steps), scheduler_kwargs=scheduler_params)
51
 
52
  # Plot the learning rate schedule
53
  title = f"Learning Rate Schedule with {scheduler_name}"
@@ -55,22 +167,34 @@ def interactive_plot(x_min, x_max, scheduler_name, scheduler_params):
55
 
56
  return fig
57
 
58
- # Define Gradio interface
59
- with gr.Blocks() as demo:
60
- gr.Markdown("# Learning Rate Scheduler Plotter")
61
- x_min = gr.Number(label="X Min", value=0)
62
- x_max = gr.Number(label="X Max (Steps)", value=100)
63
-
64
- # Dropdown for scheduler selection
65
- scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
66
-
67
- # Scheduler parameter inputs (e.g., for ConstantLR, factor and total_iters)
68
- scheduler_params = gr.JSON(label="Scheduler Parameters", value={"factor": 0.5, "total_iters": 10})
69
-
70
  plot = gr.Plot(label="Learning Rate Graph")
71
-
72
  plot_button = gr.Button("Plot Graph")
73
- plot_button.click(interactive_plot, inputs=[x_min, x_max, scheduler_dropdown, scheduler_params], outputs=plot)
74
 
75
- # Launch the Gradio app
76
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # List of scheduler names
9
  schedulers = [
10
  "ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
11
+ "CosineAnnealingLR", "CyclicLR", "OneCycleLR",
12
  "CosineAnnealingWarmRestarts"
13
  ]
14
 
15
+ # Dictionary of scheduler parameters
16
+ scheduler_params_dict = {
17
+ "StepLR": {
18
+ "step_size": {"type": "int", "label": "Step Size", "value": 30},
19
+ "gamma": {"type": "float", "label": "Gamma", "value": 0.1},
20
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
21
+ },
22
+ "MultiStepLR": {
23
+ "milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]},
24
+ "gamma": {"type": "float", "label": "Gamma", "value": 0.1},
25
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
26
+ },
27
+ "ConstantLR": {
28
+ "factor": {"type": "float", "label": "Factor", "value": 0.333},
29
+ "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
30
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
31
+ },
32
+ "LinearLR": {
33
+ "start_factor": {"type": "float", "label": "Start Factor", "value": 0.333},
34
+ "end_factor": {"type": "float", "label": "End Factor", "value": 1.0},
35
+ "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
36
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
37
+ },
38
+ "ExponentialLR": {
39
+ "gamma": {"type": "float", "label": "Gamma", "value": 0.9},
40
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
41
+ },
42
+ "PolynomialLR": {
43
+ "total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
44
+ "power": {"type": "float", "label": "Power", "value": 1.0},
45
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
46
+ },
47
+ "CosineAnnealingLR": {
48
+ "T_max": {"type": "int", "label": "T Max", "value": 50},
49
+ "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
50
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
51
+ },
52
+ "CyclicLR": {
53
+ "base_lr": {"type": "float", "label": "Base LR", "value": 0.001},
54
+ "max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
55
+ "step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000},
56
+ "step_size_down": {"type": "int", "label": "Step Size Down", "value": None},
57
+ "mode": {"type": "str", "label": "Mode", "value": "triangular"},
58
+ "gamma": {"type": "float", "label": "Gamma", "value": 1.0},
59
+ "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
60
+ "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8},
61
+ "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9},
62
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
63
+ },
64
+ "OneCycleLR": {
65
+ "max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
66
+ "total_steps": {"type": "int", "label": "Total Steps", "value": None},
67
+ "epochs": {"type": "int", "label": "Epochs", "value": None},
68
+ "steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None},
69
+ "pct_start": {"type": "float", "label": "Pct Start", "value": 0.3},
70
+ "anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"},
71
+ "cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
72
+ "base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85},
73
+ "max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95},
74
+ "div_factor": {"type": "float", "label": "Div Factor", "value": 25.0},
75
+ "final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0},
76
+ "three_phase": {"type": "bool", "label": "Three Phase", "value": False},
77
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
78
+ },
79
+ "CosineAnnealingWarmRestarts": {
80
+ "T_0": {"type": "int", "label": "T 0", "value": 10},
81
+ "T_mult": {"type": "int", "label": "T Mult", "value": 1},
82
+ "eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
83
+ "last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
84
+ }
85
+ }
86
+
87
+
88
  # Function to create a Plotly line plot
89
  def create_line_plot(x, y, title="Line Plot"):
90
  fig = go.Figure()
 
102
 
103
  # Dynamically get the scheduler class from the torch.optim.lr_scheduler module
104
  scheduler_class = getattr(lr_schedulers, scheduler_name)
 
105
  # Initialize the scheduler with the keyword arguments
106
  scheduler = scheduler_class(optimizer, **scheduler_kwargs)
107
 
 
109
  lr_schedule = []
110
  for step in range(total_steps):
111
  lr_schedule.append(scheduler.get_last_lr()[0])
112
+ optimizer.step()
113
  scheduler.step()
114
 
115
  return lr_schedule
116
 
117
+ # Function to generate the input blocks based on the selected scheduler
118
+ def generate_scheduler_inputs(scheduler_name):
119
+ params = scheduler_params_dict.get(scheduler_name, {})
120
+
121
+ inputs = []
122
+ for param, details in params.items():
123
+ param_type = details["type"]
124
+
125
+ if param_type == "float":
126
+ value = details["value"] if details["value"] is not None else lambda:None
127
+ input_field = gr.Number(label=details["label"], value=value)
128
+ elif param_type == "int":
129
+ value = details["value"] if details["value"] is not None else lambda:None
130
+ input_field = gr.Number(label=details["label"], value=value)
131
+ elif param_type == "str":
132
+ input_field = gr.Textbox(label=details["label"], value=details["value"])
133
+ elif param_type == "bool":
134
+ input_field = gr.Checkbox(label=details["label"], value=details["value"])
135
+ elif param_type == "list[int]":
136
+ input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers")
137
+
138
+ inputs.append(input_field)
139
+
140
+ return inputs
141
+
142
+ def preprocess_value(value, param_type):
143
+ if param_type == "float":
144
+ return float(value)
145
+ elif param_type == "int":
146
+ return int(value)
147
+ elif param_type == "str":
148
+ return str(value)
149
+ elif param_type == "bool":
150
+ return bool(value)
151
+ elif param_type == "list[int]":
152
+ return [int(val.strip()) for val in value.split(",")]
153
+
154
  # Wrapper function for Gradio that handles scheduler selection
155
+ def interactive_plot(display_steps, initial_lr, scheduler_name, *args):
156
  # Define x as steps for visualization
157
+ steps = np.arange(int(display_steps))
158
 
159
+ scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"])
160
+ for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)}
161
  # Generate learning rate schedule for the selected scheduler using scheduler_params
162
+ lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params)
163
 
164
  # Plot the learning rate schedule
165
  title = f"Learning Rate Schedule with {scheduler_name}"
 
167
 
168
  return fig
169
 
170
+
171
+ def main():
 
 
 
 
 
 
 
 
 
 
172
  plot = gr.Plot(label="Learning Rate Graph")
 
173
  plot_button = gr.Button("Plot Graph")
 
174
 
175
+ # Define Gradio interface
176
+ with gr.Blocks() as demo:
177
+ gr.Markdown("# Learning Rate Schedule Plotter")
178
+ display_steps = gr.Number(label="Display Steps", value=1000, interactive=True)
179
+ initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True)
180
+
181
+ # Dropdown for scheduler selection
182
+ scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
183
+
184
+ # Function to update scheduler parameter inputs based on selected scheduler
185
+ @gr.render(inputs=scheduler_dropdown)
186
+ def update_scheduler_params(scheduler_name):
187
+ # Generate the appropriate input fields based on selected scheduler
188
+ inputs = generate_scheduler_inputs(scheduler_name)
189
+
190
+ plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot)
191
+
192
+ plot.render()
193
+ plot_button.render()
194
+
195
+ # Launch the Gradio app
196
+ demo.launch()
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()