Spaces:
Runtime error
Runtime error
jiawei-ren
commited on
Commit
·
f404def
1
Parent(s):
8a01af6
init
Browse files
app.py
CHANGED
@@ -18,7 +18,6 @@ Y_LB = 0
|
|
18 |
K = 1
|
19 |
B = 0
|
20 |
NUM_SEG = 5
|
21 |
-
sns.set_theme(palette='colorblind')
|
22 |
NUM_EPOCHS = 100
|
23 |
PRINT_FREQ = NUM_EPOCHS // 20
|
24 |
NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
|
@@ -76,8 +75,8 @@ def unzip_dataloader(training_loader):
|
|
76 |
return all_x, all_y
|
77 |
|
78 |
|
79 |
-
def train(train_loader, training_bundle, num_epochs):
|
80 |
-
training_df
|
81 |
for epoch in range(num_epochs):
|
82 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
83 |
model.train()
|
@@ -92,23 +91,19 @@ def train(train_loader, training_bundle, num_epochs):
|
|
92 |
optimizer.step()
|
93 |
scheduler.step()
|
94 |
if (epoch + 1) % PRINT_FREQ == 0:
|
95 |
-
|
|
|
96 |
|
97 |
|
98 |
-
def
|
99 |
df = df_oracle
|
100 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
101 |
model.eval()
|
102 |
y = model(X_demo)
|
103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
plt.ylim(Y_LB, Y_UB)
|
108 |
-
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
109 |
-
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
110 |
-
plt.savefig('train_log/{:05d}.png'.format(epoch + 1), bbox_inches='tight')
|
111 |
-
plt.close()
|
112 |
|
113 |
|
114 |
def make_video():
|
@@ -178,7 +173,7 @@ def bmc_loss(pred, target, noise_var):
|
|
178 |
return loss * (2 * noise_var)
|
179 |
|
180 |
|
181 |
-
def regress(train_loader):
|
182 |
training_bundle = []
|
183 |
criterions = {
|
184 |
'MSE': torch.nn.MSELoss(),
|
@@ -189,7 +184,7 @@ def regress(train_loader):
|
|
189 |
criterion = criterions[criterion_name]
|
190 |
model, optimizer, scheduler = prepare_model()
|
191 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
192 |
-
train(train_loader, training_bundle, NUM_EPOCHS)
|
193 |
|
194 |
|
195 |
class DummyDataset(Dataset):
|
@@ -205,22 +200,31 @@ class DummyDataset(Dataset):
|
|
205 |
return len(self.inputs)
|
206 |
|
207 |
|
208 |
-
def
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
221 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
222 |
-
|
223 |
-
plt.
|
|
|
224 |
|
225 |
|
226 |
def clean_up_logs():
|
@@ -228,8 +232,9 @@ def clean_up_logs():
|
|
228 |
os.mkdir('train_log')
|
229 |
for f in os.listdir('train_log'):
|
230 |
os.remove(osp.join('train_log', f))
|
231 |
-
|
232 |
-
|
|
|
233 |
|
234 |
|
235 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
@@ -238,19 +243,22 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
238 |
torch.manual_seed(int(random_seed))
|
239 |
all_x, all_y, prob = prepare_data(sel_num)
|
240 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
241 |
-
|
242 |
|
|
|
243 |
if submit == 0:
|
244 |
-
|
245 |
-
else:
|
246 |
-
text = "Press \"Prepare Training Data\" before changing the training data. You may also change the random seed."
|
247 |
if submit == 1:
|
248 |
-
|
249 |
-
regress(train_loader)
|
250 |
make_video()
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
252 |
video = "movie.mp4" if submit == 1 else None
|
253 |
-
return
|
254 |
|
255 |
|
256 |
if __name__ == '__main__':
|
@@ -268,9 +276,9 @@ if __name__ == '__main__':
|
|
268 |
],
|
269 |
outputs=[
|
270 |
gr.outputs.Image(type="file", label="Training data"),
|
271 |
-
gr.outputs.Textbox(type="auto", label='What\' s next?'),
|
272 |
gr.outputs.Image(type="file", label="Regression result"),
|
273 |
-
gr.outputs.Video(type='mp4', label='Training process')
|
|
|
274 |
],
|
275 |
live=True,
|
276 |
allow_flagging='never',
|
@@ -282,6 +290,7 @@ if __name__ == '__main__':
|
|
282 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
283 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
284 |
],
|
285 |
-
|
|
|
286 |
)
|
287 |
iface.launch()
|
|
|
18 |
K = 1
|
19 |
B = 0
|
20 |
NUM_SEG = 5
|
|
|
21 |
NUM_EPOCHS = 100
|
22 |
PRINT_FREQ = NUM_EPOCHS // 20
|
23 |
NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
|
|
|
75 |
return all_x, all_y
|
76 |
|
77 |
|
78 |
+
def train(train_loader, training_df, training_bundle, num_epochs):
|
79 |
+
visualize_training_process(training_df, training_bundle, -1)
|
80 |
for epoch in range(num_epochs):
|
81 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
82 |
model.train()
|
|
|
91 |
optimizer.step()
|
92 |
scheduler.step()
|
93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
94 |
+
visualize_training_process(training_df, training_bundle, epoch)
|
95 |
+
visualize_training_process(training_df, training_bundle, num_epochs, final=True)
|
96 |
|
97 |
|
98 |
+
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
99 |
df = df_oracle
|
100 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
101 |
model.eval()
|
102 |
y = model(X_demo)
|
103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
104 |
+
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
|
105 |
+
if final:
|
106 |
+
visualize(training_df, df, 'regression_result.png', fast=False)
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
|
109 |
def make_video():
|
|
|
173 |
return loss * (2 * noise_var)
|
174 |
|
175 |
|
176 |
+
def regress(train_loader, training_df):
|
177 |
training_bundle = []
|
178 |
criterions = {
|
179 |
'MSE': torch.nn.MSELoss(),
|
|
|
184 |
criterion = criterions[criterion_name]
|
185 |
model, optimizer, scheduler = prepare_model()
|
186 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
187 |
+
train(train_loader, training_df, training_bundle, NUM_EPOCHS)
|
188 |
|
189 |
|
190 |
class DummyDataset(Dataset):
|
|
|
200 |
return len(self.inputs)
|
201 |
|
202 |
|
203 |
+
def visualize(training_df, df, save_path, fast=False):
|
204 |
+
if fast:
|
205 |
+
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', estimator=None, ci=None)
|
206 |
+
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
207 |
+
plt.ylim(Y_LB, Y_UB)
|
208 |
+
else:
|
209 |
+
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
|
210 |
+
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
211 |
+
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
212 |
+
ylim=(Y_LB, Y_UB),
|
213 |
+
space=0.1,
|
214 |
+
height=8,
|
215 |
+
ratio=2,
|
216 |
+
estimator=None, ci=None,
|
217 |
+
legend=False
|
218 |
+
)
|
219 |
+
g.ax_marg_x.remove()
|
220 |
+
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
221 |
+
g_line.legend_.set_title(None)
|
222 |
+
g_line.legend(loc='upper left')
|
223 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
224 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
225 |
+
|
226 |
+
plt.savefig(save_path, bbox_inches='tight')
|
227 |
+
plt.clf()
|
228 |
|
229 |
|
230 |
def clean_up_logs():
|
|
|
232 |
os.mkdir('train_log')
|
233 |
for f in os.listdir('train_log'):
|
234 |
os.remove(osp.join('train_log', f))
|
235 |
+
for f in ['regression_result.png', 'training_data.png', 'movie.mp4']:
|
236 |
+
if osp.isfile(f):
|
237 |
+
os.remove(f)
|
238 |
|
239 |
|
240 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
|
243 |
torch.manual_seed(int(random_seed))
|
244 |
all_x, all_y, prob = prepare_data(sel_num)
|
245 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
246 |
+
training_df = make_dataframe(all_x, all_y)
|
247 |
|
248 |
+
clean_up_logs()
|
249 |
if submit == 0:
|
250 |
+
visualize(training_df, df_oracle, 'training_data.png')
|
|
|
|
|
251 |
if submit == 1:
|
252 |
+
regress(train_loader, training_df)
|
|
|
253 |
make_video()
|
254 |
+
if submit == 0:
|
255 |
+
text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
|
256 |
+
else:
|
257 |
+
text = "Press \"Prepare Training Data\" before moving the sliders. You may also change the random seed."
|
258 |
+
training_data_plot = 'training_data.png' if submit == 0 else None
|
259 |
+
output = 'regression_result.png'.format(NUM_EPOCHS) if submit == 1 else None
|
260 |
video = "movie.mp4" if submit == 1 else None
|
261 |
+
return training_data_plot, output, video, text
|
262 |
|
263 |
|
264 |
if __name__ == '__main__':
|
|
|
276 |
],
|
277 |
outputs=[
|
278 |
gr.outputs.Image(type="file", label="Training data"),
|
|
|
279 |
gr.outputs.Image(type="file", label="Regression result"),
|
280 |
+
gr.outputs.Video(type='mp4', label='Training process'),
|
281 |
+
gr.outputs.Textbox(type="auto", label='What\' s next?')
|
282 |
],
|
283 |
live=True,
|
284 |
allow_flagging='never',
|
|
|
290 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
291 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
292 |
],
|
293 |
+
css = ".output-image, .image-preview {height: 500px !important}",
|
294 |
+
article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/BalancedMSE' target='_blank'>Balanced MSE @ GitHub</a></p> "
|
295 |
)
|
296 |
iface.launch()
|