jiawei-ren commited on
Commit
f404def
·
1 Parent(s): 8a01af6
Files changed (1) hide show
  1. app.py +51 -42
app.py CHANGED
@@ -18,7 +18,6 @@ Y_LB = 0
18
  K = 1
19
  B = 0
20
  NUM_SEG = 5
21
- sns.set_theme(palette='colorblind')
22
  NUM_EPOCHS = 100
23
  PRINT_FREQ = NUM_EPOCHS // 20
24
  NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
@@ -76,8 +75,8 @@ def unzip_dataloader(training_loader):
76
  return all_x, all_y
77
 
78
 
79
- def train(train_loader, training_bundle, num_epochs):
80
- training_df = make_dataframe(*unzip_dataloader(train_loader))
81
  for epoch in range(num_epochs):
82
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
83
  model.train()
@@ -92,23 +91,19 @@ def train(train_loader, training_bundle, num_epochs):
92
  optimizer.step()
93
  scheduler.step()
94
  if (epoch + 1) % PRINT_FREQ == 0:
95
- visualize(training_df, training_bundle, epoch)
 
96
 
97
 
98
- def visualize(training_df, training_bundle, epoch):
99
  df = df_oracle
100
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
101
  model.eval()
102
  y = model(X_demo)
103
  df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
104
- sns.lineplot(data=df, x='x', y='y', hue='Method', estimator=None, ci=None)
105
- sns.scatterplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.05, linewidths=0, s=100)
106
- plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
107
- plt.ylim(Y_LB, Y_UB)
108
- plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
109
- plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
110
- plt.savefig('train_log/{:05d}.png'.format(epoch + 1), bbox_inches='tight')
111
- plt.close()
112
 
113
 
114
  def make_video():
@@ -178,7 +173,7 @@ def bmc_loss(pred, target, noise_var):
178
  return loss * (2 * noise_var)
179
 
180
 
181
- def regress(train_loader):
182
  training_bundle = []
183
  criterions = {
184
  'MSE': torch.nn.MSELoss(),
@@ -189,7 +184,7 @@ def regress(train_loader):
189
  criterion = criterions[criterion_name]
190
  model, optimizer, scheduler = prepare_model()
191
  training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
192
- train(train_loader, training_bundle, NUM_EPOCHS)
193
 
194
 
195
  class DummyDataset(Dataset):
@@ -205,22 +200,31 @@ class DummyDataset(Dataset):
205
  return len(self.inputs)
206
 
207
 
208
- def vis_training_data(all_x, all_y):
209
- training_df = make_dataframe(all_x, all_y)
210
- g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
211
- marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1), rug=True),
212
- xlim=((Y_LB - B) / K, (Y_UB - B) / K),
213
- ylim=(Y_LB, Y_UB),
214
- space=0.1,
215
- height=8,
216
- ratio=2
217
- )
218
- g.ax_marg_x.remove()
219
- sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
 
 
 
 
 
 
 
 
220
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
221
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
222
- plt.savefig('training_data.png', bbox_inches='tight')
223
- plt.close()
 
224
 
225
 
226
  def clean_up_logs():
@@ -228,8 +232,9 @@ def clean_up_logs():
228
  os.mkdir('train_log')
229
  for f in os.listdir('train_log'):
230
  os.remove(osp.join('train_log', f))
231
- if osp.isfile('movie.mp4'):
232
- os.remove('movie.mp4')
 
233
 
234
 
235
  def run(num1, num2, num3, num4, num5, random_seed, submit):
@@ -238,19 +243,22 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
238
  torch.manual_seed(int(random_seed))
239
  all_x, all_y, prob = prepare_data(sel_num)
240
  train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
241
- vis_training_data(all_x, all_y)
242
 
 
243
  if submit == 0:
244
- text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
245
- else:
246
- text = "Press \"Prepare Training Data\" before changing the training data. You may also change the random seed."
247
  if submit == 1:
248
- clean_up_logs()
249
- regress(train_loader)
250
  make_video()
251
- output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit == 1 else None
 
 
 
 
 
252
  video = "movie.mp4" if submit == 1 else None
253
- return 'training_data.png', text, output, video
254
 
255
 
256
  if __name__ == '__main__':
@@ -268,9 +276,9 @@ if __name__ == '__main__':
268
  ],
269
  outputs=[
270
  gr.outputs.Image(type="file", label="Training data"),
271
- gr.outputs.Textbox(type="auto", label='What\' s next?'),
272
  gr.outputs.Image(type="file", label="Regression result"),
273
- gr.outputs.Video(type='mp4', label='Training process')
 
274
  ],
275
  live=True,
276
  allow_flagging='never',
@@ -282,6 +290,7 @@ if __name__ == '__main__':
282
  [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
283
  [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
284
  ],
285
-
 
286
  )
287
  iface.launch()
 
18
  K = 1
19
  B = 0
20
  NUM_SEG = 5
 
21
  NUM_EPOCHS = 100
22
  PRINT_FREQ = NUM_EPOCHS // 20
23
  NUM_TRAIN_SAMPLES = NUM_PER_BUCKET * NUM_SEG
 
75
  return all_x, all_y
76
 
77
 
78
+ def train(train_loader, training_df, training_bundle, num_epochs):
79
+ visualize_training_process(training_df, training_bundle, -1)
80
  for epoch in range(num_epochs):
81
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
82
  model.train()
 
91
  optimizer.step()
92
  scheduler.step()
93
  if (epoch + 1) % PRINT_FREQ == 0:
94
+ visualize_training_process(training_df, training_bundle, epoch)
95
+ visualize_training_process(training_df, training_bundle, num_epochs, final=True)
96
 
97
 
98
+ def visualize_training_process(training_df, training_bundle, epoch, final=False):
99
  df = df_oracle
100
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
101
  model.eval()
102
  y = model(X_demo)
103
  df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
104
+ visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
105
+ if final:
106
+ visualize(training_df, df, 'regression_result.png', fast=False)
 
 
 
 
 
107
 
108
 
109
  def make_video():
 
173
  return loss * (2 * noise_var)
174
 
175
 
176
+ def regress(train_loader, training_df):
177
  training_bundle = []
178
  criterions = {
179
  'MSE': torch.nn.MSELoss(),
 
184
  criterion = criterions[criterion_name]
185
  model, optimizer, scheduler = prepare_model()
186
  training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
187
+ train(train_loader, training_df, training_bundle, NUM_EPOCHS)
188
 
189
 
190
  class DummyDataset(Dataset):
 
200
  return len(self.inputs)
201
 
202
 
203
+ def visualize(training_df, df, save_path, fast=False):
204
+ if fast:
205
+ g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', estimator=None, ci=None)
206
+ plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
207
+ plt.ylim(Y_LB, Y_UB)
208
+ else:
209
+ g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
210
+ marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
211
+ xlim=((Y_LB - B) / K, (Y_UB - B) / K),
212
+ ylim=(Y_LB, Y_UB),
213
+ space=0.1,
214
+ height=8,
215
+ ratio=2,
216
+ estimator=None, ci=None,
217
+ legend=False
218
+ )
219
+ g.ax_marg_x.remove()
220
+ g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
221
+ g_line.legend_.set_title(None)
222
+ g_line.legend(loc='upper left')
223
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
224
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
225
+
226
+ plt.savefig(save_path, bbox_inches='tight')
227
+ plt.clf()
228
 
229
 
230
  def clean_up_logs():
 
232
  os.mkdir('train_log')
233
  for f in os.listdir('train_log'):
234
  os.remove(osp.join('train_log', f))
235
+ for f in ['regression_result.png', 'training_data.png', 'movie.mp4']:
236
+ if osp.isfile(f):
237
+ os.remove(f)
238
 
239
 
240
  def run(num1, num2, num3, num4, num5, random_seed, submit):
 
243
  torch.manual_seed(int(random_seed))
244
  all_x, all_y, prob = prepare_data(sel_num)
245
  train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
246
+ training_df = make_dataframe(all_x, all_y)
247
 
248
+ clean_up_logs()
249
  if submit == 0:
250
+ visualize(training_df, df_oracle, 'training_data.png')
 
 
251
  if submit == 1:
252
+ regress(train_loader, training_df)
 
253
  make_video()
254
+ if submit == 0:
255
+ text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
256
+ else:
257
+ text = "Press \"Prepare Training Data\" before moving the sliders. You may also change the random seed."
258
+ training_data_plot = 'training_data.png' if submit == 0 else None
259
+ output = 'regression_result.png'.format(NUM_EPOCHS) if submit == 1 else None
260
  video = "movie.mp4" if submit == 1 else None
261
+ return training_data_plot, output, video, text
262
 
263
 
264
  if __name__ == '__main__':
 
276
  ],
277
  outputs=[
278
  gr.outputs.Image(type="file", label="Training data"),
 
279
  gr.outputs.Image(type="file", label="Regression result"),
280
+ gr.outputs.Video(type='mp4', label='Training process'),
281
+ gr.outputs.Textbox(type="auto", label='What\' s next?')
282
  ],
283
  live=True,
284
  allow_flagging='never',
 
290
  [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
291
  [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
292
  ],
293
+ css = ".output-image, .image-preview {height: 500px !important}",
294
+ article="<p style='text-align: center'><a href='https://github.com/jiawei-ren/BalancedMSE' target='_blank'>Balanced MSE @ GitHub</a></p> "
295
  )
296
  iface.launch()