jiawei-ren commited on
Commit
862f52d
·
1 Parent(s): f7455a4
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -35,21 +35,23 @@ def make_dataframe(x, y, method=None):
35
  df = pd.DataFrame({'x': x, 'y': y})
36
  return df
37
 
 
38
  Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1)
39
  X_demo = (Y_demo - B) / K
40
 
41
  df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
42
 
 
43
  def prepare_data(sel_num):
44
  interval = (Y_UB - Y_LB) / NUM_SEG
45
  all_x, all_y = [], []
46
  prob = []
47
  for i in range(NUM_SEG):
48
- uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i+1)*interval, Y_UB-i*interval)
49
- y_uniform = uniform_y_distribution.sample((sel_num[i], 1))
50
 
51
  noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
52
- noise = noise_distribution.sample((sel_num[i], 1))
53
  y_uniform_oracle = y_uniform - noise
54
 
55
  x_uniform = (y_uniform_oracle - B) / K
@@ -73,6 +75,7 @@ def unzip_dataloader(training_loader):
73
  all_y = torch.cat(all_y)
74
  return all_x, all_y
75
 
 
76
  def train(train_loader, training_bundle, num_epochs):
77
  training_df = make_dataframe(*unzip_dataloader(train_loader))
78
  for epoch in range(num_epochs):
@@ -91,6 +94,7 @@ def train(train_loader, training_bundle, num_epochs):
91
  if (epoch + 1) % PRINT_FREQ == 0:
92
  visualize(training_df, training_bundle, epoch)
93
 
 
94
  def visualize(training_df, training_bundle, epoch):
95
  df = df_oracle
96
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
@@ -103,11 +107,10 @@ def visualize(training_df, training_bundle, epoch):
103
  plt.ylim(Y_LB, Y_UB)
104
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
105
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
106
- plt.savefig('train_log/{:05d}.png'.format(epoch+1), bbox_inches='tight')
107
  plt.close()
108
 
109
 
110
-
111
  def make_video():
112
  (
113
  ffmpeg
@@ -116,6 +119,7 @@ def make_video():
116
  .run()
117
  )
118
 
 
119
  class ReweightL2(_Loss):
120
  def __init__(self, reweight='inverse'):
121
  super(ReweightL2, self).__init__()
@@ -134,6 +138,7 @@ class ReweightL2(_Loss):
134
  loss = loss.sum()
135
  return loss
136
 
 
137
  class LinearModel(nn.Module):
138
  def __init__(self, input_dim, output_dim):
139
  super(LinearModel, self).__init__()
@@ -145,6 +150,7 @@ class LinearModel(nn.Module):
145
  x = self.mlp(x)
146
  return x
147
 
 
148
  def prepare_model():
149
  model = LinearModel(input_dim=1, output_dim=1)
150
  optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
@@ -171,6 +177,7 @@ def bmc_loss(pred, target, noise_var):
171
 
172
  return loss * (2 * noise_var)
173
 
 
174
  def regress(train_loader):
175
  training_bundle = []
176
  criterions = {
@@ -184,6 +191,7 @@ def regress(train_loader):
184
  training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
185
  train(train_loader, training_bundle, NUM_EPOCHS)
186
 
 
187
  class DummyDataset(Dataset):
188
  def __init__(self, inputs, targets, prob):
189
  self.inputs = inputs
@@ -200,20 +208,21 @@ class DummyDataset(Dataset):
200
  def vis_training_data(all_x, all_y):
201
  training_df = make_dataframe(all_x, all_y)
202
  g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
203
- marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG+1), rug=True),
204
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
205
  ylim=(Y_LB, Y_UB),
206
  space=0.1,
207
  height=8,
208
  ratio=2
209
- )
210
  g.ax_marg_x.remove()
211
  sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
212
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
213
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
214
- plt.savefig('training_data.png',bbox_inches='tight')
215
  plt.close()
216
 
 
217
  def clean_up_logs():
218
  if not osp.exists('train_log'):
219
  os.mkdir('train_log')
@@ -222,9 +231,10 @@ def clean_up_logs():
222
  if osp.isfile('movie.mp4'):
223
  os.remove('movie.mp4')
224
 
 
225
  def run(num1, num2, num3, num4, num5, random_seed, submit):
226
  sel_num = [num1, num2, num3, num4, num5]
227
- sel_num = [int(num/100*NUM_PER_BUCKET) for num in sel_num]
228
  torch.manual_seed(int(random_seed))
229
  all_x, all_y, prob = prepare_data(sel_num)
230
  train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
@@ -238,25 +248,24 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
238
  clean_up_logs()
239
  regress(train_loader)
240
  make_video()
241
- output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit==1 else None
242
- video = "movie.mp4" if submit==1 else None
243
  return 'training_data.png', text, output, video
244
 
245
 
246
  if __name__ == '__main__':
247
-
248
  iface = gr.Interface(
249
  fn=run,
250
  inputs=[
251
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [0, 2)'),
252
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [2, 4)'),
253
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [4, 6)'),
254
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [6, 8)'),
255
- gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [8, 10)'),
256
- gr.inputs.Number(default=0, label='Random Seed', optional=False),
257
- gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
258
- type="index", default=None, label='Mode', optional=False),
259
- ],
260
  outputs=[
261
  gr.outputs.Image(type="file", label="Training data"),
262
  gr.outputs.Textbox(type="auto", label='What\' s next?'),
@@ -273,6 +282,6 @@ if __name__ == '__main__':
273
  [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
274
  [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
275
  ],
276
-
277
  )
278
  iface.launch()
 
35
  df = pd.DataFrame({'x': x, 'y': y})
36
  return df
37
 
38
+
39
  Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1)
40
  X_demo = (Y_demo - B) / K
41
 
42
  df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
43
 
44
+
45
  def prepare_data(sel_num):
46
  interval = (Y_UB - Y_LB) / NUM_SEG
47
  all_x, all_y = [], []
48
  prob = []
49
  for i in range(NUM_SEG):
50
+ uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i + 1) * interval, Y_UB - i * interval)
51
+ y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]]
52
 
53
  noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
54
+ noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]]
55
  y_uniform_oracle = y_uniform - noise
56
 
57
  x_uniform = (y_uniform_oracle - B) / K
 
75
  all_y = torch.cat(all_y)
76
  return all_x, all_y
77
 
78
+
79
  def train(train_loader, training_bundle, num_epochs):
80
  training_df = make_dataframe(*unzip_dataloader(train_loader))
81
  for epoch in range(num_epochs):
 
94
  if (epoch + 1) % PRINT_FREQ == 0:
95
  visualize(training_df, training_bundle, epoch)
96
 
97
+
98
  def visualize(training_df, training_bundle, epoch):
99
  df = df_oracle
100
  for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
 
107
  plt.ylim(Y_LB, Y_UB)
108
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
109
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
110
+ plt.savefig('train_log/{:05d}.png'.format(epoch + 1), bbox_inches='tight')
111
  plt.close()
112
 
113
 
 
114
  def make_video():
115
  (
116
  ffmpeg
 
119
  .run()
120
  )
121
 
122
+
123
  class ReweightL2(_Loss):
124
  def __init__(self, reweight='inverse'):
125
  super(ReweightL2, self).__init__()
 
138
  loss = loss.sum()
139
  return loss
140
 
141
+
142
  class LinearModel(nn.Module):
143
  def __init__(self, input_dim, output_dim):
144
  super(LinearModel, self).__init__()
 
150
  x = self.mlp(x)
151
  return x
152
 
153
+
154
  def prepare_model():
155
  model = LinearModel(input_dim=1, output_dim=1)
156
  optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
 
177
 
178
  return loss * (2 * noise_var)
179
 
180
+
181
  def regress(train_loader):
182
  training_bundle = []
183
  criterions = {
 
191
  training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
192
  train(train_loader, training_bundle, NUM_EPOCHS)
193
 
194
+
195
  class DummyDataset(Dataset):
196
  def __init__(self, inputs, targets, prob):
197
  self.inputs = inputs
 
208
  def vis_training_data(all_x, all_y):
209
  training_df = make_dataframe(all_x, all_y)
210
  g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
211
+ marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1), rug=True),
212
  xlim=((Y_LB - B) / K, (Y_UB - B) / K),
213
  ylim=(Y_LB, Y_UB),
214
  space=0.1,
215
  height=8,
216
  ratio=2
217
+ )
218
  g.ax_marg_x.remove()
219
  sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
220
  plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
221
  plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
222
+ plt.savefig('training_data.png', bbox_inches='tight')
223
  plt.close()
224
 
225
+
226
  def clean_up_logs():
227
  if not osp.exists('train_log'):
228
  os.mkdir('train_log')
 
231
  if osp.isfile('movie.mp4'):
232
  os.remove('movie.mp4')
233
 
234
+
235
  def run(num1, num2, num3, num4, num5, random_seed, submit):
236
  sel_num = [num1, num2, num3, num4, num5]
237
+ sel_num = [int(num / 100 * NUM_PER_BUCKET) for num in sel_num]
238
  torch.manual_seed(int(random_seed))
239
  all_x, all_y, prob = prepare_data(sel_num)
240
  train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
 
248
  clean_up_logs()
249
  regress(train_loader)
250
  make_video()
251
+ output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit == 1 else None
252
+ video = "movie.mp4" if submit == 1 else None
253
  return 'training_data.png', text, output, video
254
 
255
 
256
  if __name__ == '__main__':
 
257
  iface = gr.Interface(
258
  fn=run,
259
  inputs=[
260
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [0, 2)'),
261
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [2, 4)'),
262
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [4, 6)'),
263
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [6, 8)'),
264
+ gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [8, 10)'),
265
+ gr.inputs.Number(default=0, label='Random Seed', optional=False),
266
+ gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
267
+ type="index", default=None, label='Mode', optional=False),
268
+ ],
269
  outputs=[
270
  gr.outputs.Image(type="file", label="Training data"),
271
  gr.outputs.Textbox(type="auto", label='What\' s next?'),
 
282
  [0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
283
  [1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
284
  ],
285
+
286
  )
287
  iface.launch()