jiawei-ren commited on
Commit
f7455a4
·
1 Parent(s): ae33216
Files changed (1) hide show
  1. app.py +61 -62
app.py CHANGED
@@ -40,33 +40,27 @@ X_demo = (Y_demo - B) / K
40
 
41
  df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
42
 
43
- def prepare_data():
44
  interval = (Y_UB - Y_LB) / NUM_SEG
45
  all_x, all_y = [], []
 
46
  for i in range(NUM_SEG):
47
  uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i+1)*interval, Y_UB-i*interval)
48
- y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))
49
 
50
  noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
51
- noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))
52
  y_uniform_oracle = y_uniform - noise
53
 
54
  x_uniform = (y_uniform_oracle - B) / K
55
- all_x.append(x_uniform)
56
- all_y.append(y_uniform)
57
- return all_x, all_y
58
-
59
- def select_data(all_x, all_y, sel_num):
60
- sel_x, sel_y = [], []
61
- prob = []
62
- for i in range(NUM_SEG):
63
- sel_x += all_x[i][:sel_num[i]]
64
- sel_y += all_y[i][:sel_num[i]]
65
  prob += [torch.tensor(sel_num[i]).float() for _ in range(sel_num[i])]
66
- sel_x = torch.stack(sel_x)
67
- sel_y = torch.stack(sel_y)
 
68
  prob = torch.stack(prob)
69
- return sel_x, sel_y, prob
70
 
71
 
72
  def unzip_dataloader(training_loader):
@@ -79,7 +73,6 @@ def unzip_dataloader(training_loader):
79
  all_y = torch.cat(all_y)
80
  return all_x, all_y
81
 
82
- # Train the model
83
  def train(train_loader, training_bundle, num_epochs):
84
  training_df = make_dataframe(*unzip_dataloader(train_loader))
85
  for epoch in range(num_epochs):
@@ -116,8 +109,6 @@ def visualize(training_df, training_bundle, epoch):
116
 
117
 
118
  def make_video():
119
- if osp.isfile('movie.mp4'):
120
- os.remove('movie.mp4')
121
  (
122
  ffmpeg
123
  .input('train_log/*.png', pattern_type='glob', framerate=3)
@@ -143,7 +134,6 @@ class ReweightL2(_Loss):
143
  loss = loss.sum()
144
  return loss
145
 
146
- # we use a linear layer to regress the weight from height
147
  class LinearModel(nn.Module):
148
  def __init__(self, input_dim, output_dim):
149
  super(LinearModel, self).__init__()
@@ -206,15 +196,9 @@ class DummyDataset(Dataset):
206
  def __len__(self):
207
  return len(self.inputs)
208
 
209
- def run(num1, num2, num3, num4, num5, random_seed, submit):
210
- sel_num = [num1, num2, num3, num4, num5]
211
- sel_num = [int(num/100*NUM_PER_BUCKET) for num in sel_num]
212
- torch.manual_seed(int(random_seed))
213
- all_x, all_y = prepare_data()
214
- sel_x, sel_y, prob = select_data(all_x, all_y, sel_num)
215
- train_loader = DataLoader(DummyDataset(sel_x, sel_y, prob), BATCH_SIZE, shuffle=True)
216
 
217
- training_df = make_dataframe(sel_x, sel_y)
 
218
  g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
219
  marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG+1), rug=True),
220
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
@@ -230,15 +214,28 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
230
  plt.savefig('training_data.png',bbox_inches='tight')
231
  plt.close()
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  if submit == 0:
234
  text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
235
  else:
236
  text = "Press \"Prepare Training Data\" before changing the training data. You may also change the random seed."
237
  if submit == 1:
238
- if not osp.exists('train_log'):
239
- os.mkdir('train_log')
240
- for f in os.listdir('train_log'):
241
- os.remove(osp.join('train_log', f))
242
  regress(train_loader)
243
  make_video()
244
  output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit==1 else None
@@ -246,34 +243,36 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
246
  return 'training_data.png', text, output, video
247
 
248
 
249
- iface = gr.Interface(
250
- fn=run,
251
- inputs=[
252
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [0, 2)'),
253
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [2, 4)'),
254
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [4, 6)'),
255
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [6, 8)'),
256
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [8, 10)'),
257
- gr.inputs.Number(default=0, label='Random Seed', optional=False),
258
- gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
259
- type="index", default=None, label='Mode', optional=False),
260
- ],
261
- outputs=[
262
- gr.outputs.Image(type="file", label="Training data"),
263
- gr.outputs.Textbox(type="auto", label='What\' s next?'),
264
- gr.outputs.Image(type="file", label="Regression result"),
265
- gr.outputs.Video(type='mp4', label='Training process')
266
- ],
267
- live=True,
268
- allow_flagging='never',
269
- title="Balanced MSE for Imbalanced Visual Regression [CVPR 2022]",
270
- description="Welcome to the demo of Balanced MSE &#9878;. In this demo, we will work on a simple task: imbalanced <i>linear</i> regression. <br>"
271
- "To get started, drag the sliders &#128071;&#128071; and create your label distribution! "
272
- "<small>(Examples are at the bottom of the page.)</small>",
273
- examples=[
274
- [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
275
- [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
276
- ],
277
 
278
- )
279
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
42
 
43
+ def prepare_data(sel_num):
44
  interval = (Y_UB - Y_LB) / NUM_SEG
45
  all_x, all_y = [], []
46
+ prob = []
47
  for i in range(NUM_SEG):
48
  uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i+1)*interval, Y_UB-i*interval)
49
+ y_uniform = uniform_y_distribution.sample((sel_num[i], 1))
50
 
51
  noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
52
+ noise = noise_distribution.sample((sel_num[i], 1))
53
  y_uniform_oracle = y_uniform - noise
54
 
55
  x_uniform = (y_uniform_oracle - B) / K
56
+ all_x += x_uniform
57
+ all_y += y_uniform
 
 
 
 
 
 
 
 
58
  prob += [torch.tensor(sel_num[i]).float() for _ in range(sel_num[i])]
59
+
60
+ all_x = torch.stack(all_x)
61
+ all_y = torch.stack(all_y)
62
  prob = torch.stack(prob)
63
+ return all_x, all_y, prob
64
 
65
 
66
  def unzip_dataloader(training_loader):
 
73
  all_y = torch.cat(all_y)
74
  return all_x, all_y
75
 
 
76
  def train(train_loader, training_bundle, num_epochs):
77
  training_df = make_dataframe(*unzip_dataloader(train_loader))
78
  for epoch in range(num_epochs):
 
109
 
110
 
111
  def make_video():
 
 
112
  (
113
  ffmpeg
114
  .input('train_log/*.png', pattern_type='glob', framerate=3)
 
134
  loss = loss.sum()
135
  return loss
136
 
 
137
  class LinearModel(nn.Module):
138
  def __init__(self, input_dim, output_dim):
139
  super(LinearModel, self).__init__()
 
196
  def __len__(self):
197
  return len(self.inputs)
198
 
 
 
 
 
 
 
 
199
 
200
+ def vis_training_data(all_x, all_y):
201
+ training_df = make_dataframe(all_x, all_y)
202
  g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
203
  marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG+1), rug=True),
204
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
 
214
  plt.savefig('training_data.png',bbox_inches='tight')
215
  plt.close()
216
 
217
+ def clean_up_logs():
218
+ if not osp.exists('train_log'):
219
+ os.mkdir('train_log')
220
+ for f in os.listdir('train_log'):
221
+ os.remove(osp.join('train_log', f))
222
+ if osp.isfile('movie.mp4'):
223
+ os.remove('movie.mp4')
224
+
225
+ def run(num1, num2, num3, num4, num5, random_seed, submit):
226
+ sel_num = [num1, num2, num3, num4, num5]
227
+ sel_num = [int(num/100*NUM_PER_BUCKET) for num in sel_num]
228
+ torch.manual_seed(int(random_seed))
229
+ all_x, all_y, prob = prepare_data(sel_num)
230
+ train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
231
+ vis_training_data(all_x, all_y)
232
+
233
  if submit == 0:
234
  text = "Press \"Start Regressing\" if your are happy with the training data. Regression takes ~10s."
235
  else:
236
  text = "Press \"Prepare Training Data\" before changing the training data. You may also change the random seed."
237
  if submit == 1:
238
+ clean_up_logs()
 
 
 
239
  regress(train_loader)
240
  make_video()
241
  output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit==1 else None
 
243
  return 'training_data.png', text, output, video
244
 
245
 
246
+ if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ iface = gr.Interface(
249
+ fn=run,
250
+ inputs=[
251
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [0, 2)'),
252
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [2, 4)'),
253
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [4, 6)'),
254
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [6, 8)'),
255
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [8, 10)'),
256
+ gr.inputs.Number(default=0, label='Random Seed', optional=False),
257
+ gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
258
+ type="index", default=None, label='Mode', optional=False),
259
+ ],
260
+ outputs=[
261
+ gr.outputs.Image(type="file", label="Training data"),
262
+ gr.outputs.Textbox(type="auto", label='What\' s next?'),
263
+ gr.outputs.Image(type="file", label="Regression result"),
264
+ gr.outputs.Video(type='mp4', label='Training process')
265
+ ],
266
+ live=True,
267
+ allow_flagging='never',
268
+ title="Balanced MSE for Imbalanced Visual Regression [CVPR 2022]",
269
+ description="Welcome to the demo of Balanced MSE &#9878;. In this demo, we will work on a simple task: imbalanced <i>linear</i> regression. <br>"
270
+ "To get started, drag the sliders &#128071;&#128071; and create your label distribution! "
271
+ "<small>(Examples are at the bottom of the page.)</small>",
272
+ examples=[
273
+ [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
274
+ [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
275
+ ],
276
+
277
+ )
278
+ iface.launch()