jiawei-ren commited on
Commit
9c31709
·
1 Parent(s): ebf41b0
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -92,7 +92,7 @@ def train(train_loader, training_df, training_bundle, num_epochs):
92
  scheduler.step()
93
  if (epoch + 1) % PRINT_FREQ == 0:
94
  visualize_training_process(training_df, training_bundle, epoch)
95
- visualize_training_process(training_df, training_bundle, num_epochs, final=True)
96
 
97
 
98
  def visualize_training_process(training_df, training_bundle, epoch, final=False):
@@ -101,7 +101,7 @@ def visualize_training_process(training_df, training_bundle, epoch, final=False)
101
  model.eval()
102
  y = model(X_demo)
103
  df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
104
- visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
105
  if final:
106
  visualize(training_df, df, 'regression_result.png', fast=False)
107
 
@@ -200,7 +200,7 @@ class DummyDataset(Dataset):
200
  return len(self.inputs)
201
 
202
 
203
- def visualize(training_df, df, save_path, fast=False):
204
  if fast:
205
  f = plt.figure(figsize=(3, 3))
206
  g = f.add_subplot(111)
@@ -208,24 +208,26 @@ def visualize(training_df, df, save_path, fast=False):
208
  plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
209
  plt.ylim(Y_LB, Y_UB)
210
  else:
211
- g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
212
  marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
213
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
214
  ylim=(Y_LB, Y_UB),
215
  space=0.1,
216
- height=8,
217
  ratio=2,
218
  estimator=None, ci=None,
219
  legend=False,
220
  )
221
  g.ax_marg_x.remove()
222
  g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
223
- g_line.legend_.set_title(None)
224
- g_line.legend(loc='upper left')
225
- plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
226
- plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
 
 
227
 
228
- plt.savefig(save_path, bbox_inches='tight')
229
  plt.close()
230
 
231
 
@@ -267,11 +269,11 @@ if __name__ == '__main__':
267
  iface = gr.Interface(
268
  fn=run,
269
  inputs=[
270
- gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
271
- gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
272
- gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
273
- gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
274
  gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
 
 
 
 
275
  gr.inputs.Number(default=0, label='Random Seed', optional=False),
276
  gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
277
  type="index", default=None, label='Mode', optional=False),
 
92
  scheduler.step()
93
  if (epoch + 1) % PRINT_FREQ == 0:
94
  visualize_training_process(training_df, training_bundle, epoch)
95
+ visualize_training_process(training_df, training_bundle, num_epochs-1, final=True)
96
 
97
 
98
  def visualize_training_process(training_df, training_bundle, epoch, final=False):
 
101
  model.eval()
102
  y = model(X_demo)
103
  df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
104
+ visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True, epoch=epoch)
105
  if final:
106
  visualize(training_df, df, 'regression_result.png', fast=False)
107
 
 
200
  return len(self.inputs)
201
 
202
 
203
+ def visualize(training_df, df, save_path, fast=False, epoch=None):
204
  if fast:
205
  f = plt.figure(figsize=(3, 3))
206
  g = f.add_subplot(111)
 
208
  plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
209
  plt.ylim(Y_LB, Y_UB)
210
  else:
211
+ g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=50,
212
  marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
213
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
214
  ylim=(Y_LB, Y_UB),
215
  space=0.1,
216
+ height=5,
217
  ratio=2,
218
  estimator=None, ci=None,
219
  legend=False,
220
  )
221
  g.ax_marg_x.remove()
222
  g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
223
+ if epoch is not None:
224
+ g_line.legend(loc='upper left', title="Epoch {:03d}".format(epoch+1))
225
+ else:
226
+ g_line.legend(loc='upper left')
227
+ plt.gca().axes.set_xlabel(r'$x$')
228
+ plt.gca().axes.set_ylabel(r'$y$')
229
 
230
+ plt.savefig(save_path, bbox_inches='tight', dpi=200)
231
  plt.close()
232
 
233
 
 
269
  iface = gr.Interface(
270
  fn=run,
271
  inputs=[
 
 
 
 
272
  gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
273
+ gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
274
+ gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
275
+ gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
276
+ gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
277
  gr.inputs.Number(default=0, label='Random Seed', optional=False),
278
  gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
279
  type="index", default=None, label='Mode', optional=False),