Spaces:
Sleeping
Sleeping
implemented all the remaining things
Browse files
app.py
CHANGED
@@ -8,10 +8,83 @@ import torch.optim.lr_scheduler as lr_schedulers
|
|
8 |
# List of scheduler names
|
9 |
schedulers = [
|
10 |
"ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
|
11 |
-
"CosineAnnealingLR", "CyclicLR", "OneCycleLR",
|
12 |
"CosineAnnealingWarmRestarts"
|
13 |
]
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Function to create a Plotly line plot
|
16 |
def create_line_plot(x, y, title="Line Plot"):
|
17 |
fig = go.Figure()
|
@@ -29,7 +102,6 @@ def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_k
|
|
29 |
|
30 |
# Dynamically get the scheduler class from the torch.optim.lr_scheduler module
|
31 |
scheduler_class = getattr(lr_schedulers, scheduler_name)
|
32 |
-
|
33 |
# Initialize the scheduler with the keyword arguments
|
34 |
scheduler = scheduler_class(optimizer, **scheduler_kwargs)
|
35 |
|
@@ -37,17 +109,57 @@ def get_lr_schedule(scheduler_name, initial_lr=0.1, total_steps=100, scheduler_k
|
|
37 |
lr_schedule = []
|
38 |
for step in range(total_steps):
|
39 |
lr_schedule.append(scheduler.get_last_lr()[0])
|
|
|
40 |
scheduler.step()
|
41 |
|
42 |
return lr_schedule
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Wrapper function for Gradio that handles scheduler selection
|
45 |
-
def interactive_plot(
|
46 |
# Define x as steps for visualization
|
47 |
-
steps = np.arange(int(
|
48 |
|
|
|
|
|
49 |
# Generate learning rate schedule for the selected scheduler using scheduler_params
|
50 |
-
lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=
|
51 |
|
52 |
# Plot the learning rate schedule
|
53 |
title = f"Learning Rate Schedule with {scheduler_name}"
|
@@ -55,22 +167,34 @@ def interactive_plot(x_min, x_max, scheduler_name, scheduler_params):
|
|
55 |
|
56 |
return fig
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
gr.Markdown("# Learning Rate Scheduler Plotter")
|
61 |
-
x_min = gr.Number(label="X Min", value=0)
|
62 |
-
x_max = gr.Number(label="X Max (Steps)", value=100)
|
63 |
-
|
64 |
-
# Dropdown for scheduler selection
|
65 |
-
scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
|
66 |
-
|
67 |
-
# Scheduler parameter inputs (e.g., for ConstantLR, factor and total_iters)
|
68 |
-
scheduler_params = gr.JSON(label="Scheduler Parameters", value={"factor": 0.5, "total_iters": 10})
|
69 |
-
|
70 |
plot = gr.Plot(label="Learning Rate Graph")
|
71 |
-
|
72 |
plot_button = gr.Button("Plot Graph")
|
73 |
-
plot_button.click(interactive_plot, inputs=[x_min, x_max, scheduler_dropdown, scheduler_params], outputs=plot)
|
74 |
|
75 |
-
#
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# List of scheduler names
|
9 |
schedulers = [
|
10 |
"ConstantLR", "LinearLR", "ExponentialLR", "StepLR", "MultiStepLR",
|
11 |
+
"CosineAnnealingLR", "CyclicLR", "OneCycleLR",
|
12 |
"CosineAnnealingWarmRestarts"
|
13 |
]
|
14 |
|
15 |
+
# Dictionary of scheduler parameters
|
16 |
+
scheduler_params_dict = {
|
17 |
+
"StepLR": {
|
18 |
+
"step_size": {"type": "int", "label": "Step Size", "value": 30},
|
19 |
+
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
|
20 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
21 |
+
},
|
22 |
+
"MultiStepLR": {
|
23 |
+
"milestones": {"type": "list[int]", "label": "Milestones", "value": [30, 80]},
|
24 |
+
"gamma": {"type": "float", "label": "Gamma", "value": 0.1},
|
25 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
26 |
+
},
|
27 |
+
"ConstantLR": {
|
28 |
+
"factor": {"type": "float", "label": "Factor", "value": 0.333},
|
29 |
+
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
|
30 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
31 |
+
},
|
32 |
+
"LinearLR": {
|
33 |
+
"start_factor": {"type": "float", "label": "Start Factor", "value": 0.333},
|
34 |
+
"end_factor": {"type": "float", "label": "End Factor", "value": 1.0},
|
35 |
+
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
|
36 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
37 |
+
},
|
38 |
+
"ExponentialLR": {
|
39 |
+
"gamma": {"type": "float", "label": "Gamma", "value": 0.9},
|
40 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
41 |
+
},
|
42 |
+
"PolynomialLR": {
|
43 |
+
"total_iters": {"type": "int", "label": "Total Iterations", "value": 5},
|
44 |
+
"power": {"type": "float", "label": "Power", "value": 1.0},
|
45 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
46 |
+
},
|
47 |
+
"CosineAnnealingLR": {
|
48 |
+
"T_max": {"type": "int", "label": "T Max", "value": 50},
|
49 |
+
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
|
50 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
51 |
+
},
|
52 |
+
"CyclicLR": {
|
53 |
+
"base_lr": {"type": "float", "label": "Base LR", "value": 0.001},
|
54 |
+
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
|
55 |
+
"step_size_up": {"type": "int", "label": "Step Size Up", "value": 2000},
|
56 |
+
"step_size_down": {"type": "int", "label": "Step Size Down", "value": None},
|
57 |
+
"mode": {"type": "str", "label": "Mode", "value": "triangular"},
|
58 |
+
"gamma": {"type": "float", "label": "Gamma", "value": 1.0},
|
59 |
+
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
|
60 |
+
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.8},
|
61 |
+
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.9},
|
62 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
63 |
+
},
|
64 |
+
"OneCycleLR": {
|
65 |
+
"max_lr": {"type": "float", "label": "Max LR", "value": 0.01},
|
66 |
+
"total_steps": {"type": "int", "label": "Total Steps", "value": None},
|
67 |
+
"epochs": {"type": "int", "label": "Epochs", "value": None},
|
68 |
+
"steps_per_epoch": {"type": "int", "label": "Steps Per Epoch", "value": None},
|
69 |
+
"pct_start": {"type": "float", "label": "Pct Start", "value": 0.3},
|
70 |
+
"anneal_strategy": {"type": "str", "label": "Anneal Strategy", "value": "cos"},
|
71 |
+
"cycle_momentum": {"type": "bool", "label": "Cycle Momentum", "value": True},
|
72 |
+
"base_momentum": {"type": "float", "label": "Base Momentum", "value": 0.85},
|
73 |
+
"max_momentum": {"type": "float", "label": "Max Momentum", "value": 0.95},
|
74 |
+
"div_factor": {"type": "float", "label": "Div Factor", "value": 25.0},
|
75 |
+
"final_div_factor": {"type": "float", "label": "Final Div Factor", "value": 10000.0},
|
76 |
+
"three_phase": {"type": "bool", "label": "Three Phase", "value": False},
|
77 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
78 |
+
},
|
79 |
+
"CosineAnnealingWarmRestarts": {
|
80 |
+
"T_0": {"type": "int", "label": "T 0", "value": 10},
|
81 |
+
"T_mult": {"type": "int", "label": "T Mult", "value": 1},
|
82 |
+
"eta_min": {"type": "float", "label": "Eta Min", "value": 0.0},
|
83 |
+
"last_epoch": {"type": "int", "label": "Last Epoch", "value": -1}
|
84 |
+
}
|
85 |
+
}
|
86 |
+
|
87 |
+
|
88 |
# Function to create a Plotly line plot
|
89 |
def create_line_plot(x, y, title="Line Plot"):
|
90 |
fig = go.Figure()
|
|
|
102 |
|
103 |
# Dynamically get the scheduler class from the torch.optim.lr_scheduler module
|
104 |
scheduler_class = getattr(lr_schedulers, scheduler_name)
|
|
|
105 |
# Initialize the scheduler with the keyword arguments
|
106 |
scheduler = scheduler_class(optimizer, **scheduler_kwargs)
|
107 |
|
|
|
109 |
lr_schedule = []
|
110 |
for step in range(total_steps):
|
111 |
lr_schedule.append(scheduler.get_last_lr()[0])
|
112 |
+
optimizer.step()
|
113 |
scheduler.step()
|
114 |
|
115 |
return lr_schedule
|
116 |
|
117 |
+
# Function to generate the input blocks based on the selected scheduler
|
118 |
+
def generate_scheduler_inputs(scheduler_name):
|
119 |
+
params = scheduler_params_dict.get(scheduler_name, {})
|
120 |
+
|
121 |
+
inputs = []
|
122 |
+
for param, details in params.items():
|
123 |
+
param_type = details["type"]
|
124 |
+
|
125 |
+
if param_type == "float":
|
126 |
+
value = details["value"] if details["value"] is not None else lambda:None
|
127 |
+
input_field = gr.Number(label=details["label"], value=value)
|
128 |
+
elif param_type == "int":
|
129 |
+
value = details["value"] if details["value"] is not None else lambda:None
|
130 |
+
input_field = gr.Number(label=details["label"], value=value)
|
131 |
+
elif param_type == "str":
|
132 |
+
input_field = gr.Textbox(label=details["label"], value=details["value"])
|
133 |
+
elif param_type == "bool":
|
134 |
+
input_field = gr.Checkbox(label=details["label"], value=details["value"])
|
135 |
+
elif param_type == "list[int]":
|
136 |
+
input_field = gr.Textbox(label=details["label"], value=",".join(map(str, details["value"])), placeholder="Enter comma-separated integers")
|
137 |
+
|
138 |
+
inputs.append(input_field)
|
139 |
+
|
140 |
+
return inputs
|
141 |
+
|
142 |
+
def preprocess_value(value, param_type):
|
143 |
+
if param_type == "float":
|
144 |
+
return float(value)
|
145 |
+
elif param_type == "int":
|
146 |
+
return int(value)
|
147 |
+
elif param_type == "str":
|
148 |
+
return str(value)
|
149 |
+
elif param_type == "bool":
|
150 |
+
return bool(value)
|
151 |
+
elif param_type == "list[int]":
|
152 |
+
return [int(val.strip()) for val in value.split(",")]
|
153 |
+
|
154 |
# Wrapper function for Gradio that handles scheduler selection
|
155 |
+
def interactive_plot(display_steps, initial_lr, scheduler_name, *args):
|
156 |
# Define x as steps for visualization
|
157 |
+
steps = np.arange(int(display_steps))
|
158 |
|
159 |
+
scheduler_params = {label: preprocess_value(value, scheduler_params_dict[scheduler_name][label]["type"])
|
160 |
+
for label, value in zip(scheduler_params_dict[scheduler_name].keys(), args)}
|
161 |
# Generate learning rate schedule for the selected scheduler using scheduler_params
|
162 |
+
lr_schedule = get_lr_schedule(scheduler_name=scheduler_name, initial_lr=initial_lr, total_steps=len(steps), scheduler_kwargs=scheduler_params)
|
163 |
|
164 |
# Plot the learning rate schedule
|
165 |
title = f"Learning Rate Schedule with {scheduler_name}"
|
|
|
167 |
|
168 |
return fig
|
169 |
|
170 |
+
|
171 |
+
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
plot = gr.Plot(label="Learning Rate Graph")
|
|
|
173 |
plot_button = gr.Button("Plot Graph")
|
|
|
174 |
|
175 |
+
# Define Gradio interface
|
176 |
+
with gr.Blocks() as demo:
|
177 |
+
gr.Markdown("# Learning Rate Schedule Plotter")
|
178 |
+
display_steps = gr.Number(label="Display Steps", value=1000, interactive=True)
|
179 |
+
initial_lr = gr.Number(label="Initial Learning Rate", value=0.1, precision=8, interactive=True)
|
180 |
+
|
181 |
+
# Dropdown for scheduler selection
|
182 |
+
scheduler_dropdown = gr.Dropdown(choices=schedulers, label="Learning Rate Scheduler", value="ConstantLR")
|
183 |
+
|
184 |
+
# Function to update scheduler parameter inputs based on selected scheduler
|
185 |
+
@gr.render(inputs=scheduler_dropdown)
|
186 |
+
def update_scheduler_params(scheduler_name):
|
187 |
+
# Generate the appropriate input fields based on selected scheduler
|
188 |
+
inputs = generate_scheduler_inputs(scheduler_name)
|
189 |
+
|
190 |
+
plot_button.click(interactive_plot, inputs=[display_steps, initial_lr, scheduler_dropdown, *inputs], outputs=plot)
|
191 |
+
|
192 |
+
plot.render()
|
193 |
+
plot_button.render()
|
194 |
+
|
195 |
+
# Launch the Gradio app
|
196 |
+
demo.launch()
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
main()
|