Spaces:
Runtime error
Runtime error
jiawei-ren
commited on
Commit
·
862f52d
1
Parent(s):
f7455a4
init
Browse files
app.py
CHANGED
@@ -35,21 +35,23 @@ def make_dataframe(x, y, method=None):
|
|
35 |
df = pd.DataFrame({'x': x, 'y': y})
|
36 |
return df
|
37 |
|
|
|
38 |
Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1)
|
39 |
X_demo = (Y_demo - B) / K
|
40 |
|
41 |
df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
|
42 |
|
|
|
43 |
def prepare_data(sel_num):
|
44 |
interval = (Y_UB - Y_LB) / NUM_SEG
|
45 |
all_x, all_y = [], []
|
46 |
prob = []
|
47 |
for i in range(NUM_SEG):
|
48 |
-
uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i+1)*interval, Y_UB-i*interval)
|
49 |
-
y_uniform = uniform_y_distribution.sample((
|
50 |
|
51 |
noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
|
52 |
-
noise = noise_distribution.sample((
|
53 |
y_uniform_oracle = y_uniform - noise
|
54 |
|
55 |
x_uniform = (y_uniform_oracle - B) / K
|
@@ -73,6 +75,7 @@ def unzip_dataloader(training_loader):
|
|
73 |
all_y = torch.cat(all_y)
|
74 |
return all_x, all_y
|
75 |
|
|
|
76 |
def train(train_loader, training_bundle, num_epochs):
|
77 |
training_df = make_dataframe(*unzip_dataloader(train_loader))
|
78 |
for epoch in range(num_epochs):
|
@@ -91,6 +94,7 @@ def train(train_loader, training_bundle, num_epochs):
|
|
91 |
if (epoch + 1) % PRINT_FREQ == 0:
|
92 |
visualize(training_df, training_bundle, epoch)
|
93 |
|
|
|
94 |
def visualize(training_df, training_bundle, epoch):
|
95 |
df = df_oracle
|
96 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
@@ -103,11 +107,10 @@ def visualize(training_df, training_bundle, epoch):
|
|
103 |
plt.ylim(Y_LB, Y_UB)
|
104 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
105 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
106 |
-
plt.savefig('train_log/{:05d}.png'.format(epoch+1), bbox_inches='tight')
|
107 |
plt.close()
|
108 |
|
109 |
|
110 |
-
|
111 |
def make_video():
|
112 |
(
|
113 |
ffmpeg
|
@@ -116,6 +119,7 @@ def make_video():
|
|
116 |
.run()
|
117 |
)
|
118 |
|
|
|
119 |
class ReweightL2(_Loss):
|
120 |
def __init__(self, reweight='inverse'):
|
121 |
super(ReweightL2, self).__init__()
|
@@ -134,6 +138,7 @@ class ReweightL2(_Loss):
|
|
134 |
loss = loss.sum()
|
135 |
return loss
|
136 |
|
|
|
137 |
class LinearModel(nn.Module):
|
138 |
def __init__(self, input_dim, output_dim):
|
139 |
super(LinearModel, self).__init__()
|
@@ -145,6 +150,7 @@ class LinearModel(nn.Module):
|
|
145 |
x = self.mlp(x)
|
146 |
return x
|
147 |
|
|
|
148 |
def prepare_model():
|
149 |
model = LinearModel(input_dim=1, output_dim=1)
|
150 |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
@@ -171,6 +177,7 @@ def bmc_loss(pred, target, noise_var):
|
|
171 |
|
172 |
return loss * (2 * noise_var)
|
173 |
|
|
|
174 |
def regress(train_loader):
|
175 |
training_bundle = []
|
176 |
criterions = {
|
@@ -184,6 +191,7 @@ def regress(train_loader):
|
|
184 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
185 |
train(train_loader, training_bundle, NUM_EPOCHS)
|
186 |
|
|
|
187 |
class DummyDataset(Dataset):
|
188 |
def __init__(self, inputs, targets, prob):
|
189 |
self.inputs = inputs
|
@@ -200,20 +208,21 @@ class DummyDataset(Dataset):
|
|
200 |
def vis_training_data(all_x, all_y):
|
201 |
training_df = make_dataframe(all_x, all_y)
|
202 |
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
|
203 |
-
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG+1), rug=True),
|
204 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
205 |
ylim=(Y_LB, Y_UB),
|
206 |
space=0.1,
|
207 |
height=8,
|
208 |
ratio=2
|
209 |
-
|
210 |
g.ax_marg_x.remove()
|
211 |
sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
|
212 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
213 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
214 |
-
plt.savefig('training_data.png',bbox_inches='tight')
|
215 |
plt.close()
|
216 |
|
|
|
217 |
def clean_up_logs():
|
218 |
if not osp.exists('train_log'):
|
219 |
os.mkdir('train_log')
|
@@ -222,9 +231,10 @@ def clean_up_logs():
|
|
222 |
if osp.isfile('movie.mp4'):
|
223 |
os.remove('movie.mp4')
|
224 |
|
|
|
225 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
226 |
sel_num = [num1, num2, num3, num4, num5]
|
227 |
-
sel_num = [int(num/100*NUM_PER_BUCKET) for num in sel_num]
|
228 |
torch.manual_seed(int(random_seed))
|
229 |
all_x, all_y, prob = prepare_data(sel_num)
|
230 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
@@ -238,25 +248,24 @@ def run(num1, num2, num3, num4, num5, random_seed, submit):
|
|
238 |
clean_up_logs()
|
239 |
regress(train_loader)
|
240 |
make_video()
|
241 |
-
output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit==1 else None
|
242 |
-
video = "movie.mp4" if submit==1 else None
|
243 |
return 'training_data.png', text, output, video
|
244 |
|
245 |
|
246 |
if __name__ == '__main__':
|
247 |
-
|
248 |
iface = gr.Interface(
|
249 |
fn=run,
|
250 |
inputs=[
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
outputs=[
|
261 |
gr.outputs.Image(type="file", label="Training data"),
|
262 |
gr.outputs.Textbox(type="auto", label='What\' s next?'),
|
@@ -273,6 +282,6 @@ if __name__ == '__main__':
|
|
273 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
274 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
275 |
],
|
276 |
-
|
277 |
)
|
278 |
iface.launch()
|
|
|
35 |
df = pd.DataFrame({'x': x, 'y': y})
|
36 |
return df
|
37 |
|
38 |
+
|
39 |
Y_demo = torch.linspace(Y_LB, Y_UB, 2).unsqueeze(-1)
|
40 |
X_demo = (Y_demo - B) / K
|
41 |
|
42 |
df_oracle = make_dataframe(X_demo, Y_demo, 'Oracle')
|
43 |
|
44 |
+
|
45 |
def prepare_data(sel_num):
|
46 |
interval = (Y_UB - Y_LB) / NUM_SEG
|
47 |
all_x, all_y = [], []
|
48 |
prob = []
|
49 |
for i in range(NUM_SEG):
|
50 |
+
uniform_y_distribution = torch.distributions.Uniform(Y_UB - (i + 1) * interval, Y_UB - i * interval)
|
51 |
+
y_uniform = uniform_y_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]]
|
52 |
|
53 |
noise_distribution = torch.distributions.Normal(loc=0, scale=NOISE_SIGMA)
|
54 |
+
noise = noise_distribution.sample((NUM_TRAIN_SAMPLES, 1))[:sel_num[i]]
|
55 |
y_uniform_oracle = y_uniform - noise
|
56 |
|
57 |
x_uniform = (y_uniform_oracle - B) / K
|
|
|
75 |
all_y = torch.cat(all_y)
|
76 |
return all_x, all_y
|
77 |
|
78 |
+
|
79 |
def train(train_loader, training_bundle, num_epochs):
|
80 |
training_df = make_dataframe(*unzip_dataloader(train_loader))
|
81 |
for epoch in range(num_epochs):
|
|
|
94 |
if (epoch + 1) % PRINT_FREQ == 0:
|
95 |
visualize(training_df, training_bundle, epoch)
|
96 |
|
97 |
+
|
98 |
def visualize(training_df, training_bundle, epoch):
|
99 |
df = df_oracle
|
100 |
for model, optimizer, scheduler, criterion, criterion_name in training_bundle:
|
|
|
107 |
plt.ylim(Y_LB, Y_UB)
|
108 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
109 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
110 |
+
plt.savefig('train_log/{:05d}.png'.format(epoch + 1), bbox_inches='tight')
|
111 |
plt.close()
|
112 |
|
113 |
|
|
|
114 |
def make_video():
|
115 |
(
|
116 |
ffmpeg
|
|
|
119 |
.run()
|
120 |
)
|
121 |
|
122 |
+
|
123 |
class ReweightL2(_Loss):
|
124 |
def __init__(self, reweight='inverse'):
|
125 |
super(ReweightL2, self).__init__()
|
|
|
138 |
loss = loss.sum()
|
139 |
return loss
|
140 |
|
141 |
+
|
142 |
class LinearModel(nn.Module):
|
143 |
def __init__(self, input_dim, output_dim):
|
144 |
super(LinearModel, self).__init__()
|
|
|
150 |
x = self.mlp(x)
|
151 |
return x
|
152 |
|
153 |
+
|
154 |
def prepare_model():
|
155 |
model = LinearModel(input_dim=1, output_dim=1)
|
156 |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
|
|
177 |
|
178 |
return loss * (2 * noise_var)
|
179 |
|
180 |
+
|
181 |
def regress(train_loader):
|
182 |
training_bundle = []
|
183 |
criterions = {
|
|
|
191 |
training_bundle.append((model, optimizer, scheduler, criterion, criterion_name))
|
192 |
train(train_loader, training_bundle, NUM_EPOCHS)
|
193 |
|
194 |
+
|
195 |
class DummyDataset(Dataset):
|
196 |
def __init__(self, inputs, targets, prob):
|
197 |
self.inputs = inputs
|
|
|
208 |
def vis_training_data(all_x, all_y):
|
209 |
training_df = make_dataframe(all_x, all_y)
|
210 |
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=100,
|
211 |
+
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1), rug=True),
|
212 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
213 |
ylim=(Y_LB, Y_UB),
|
214 |
space=0.1,
|
215 |
height=8,
|
216 |
ratio=2
|
217 |
+
)
|
218 |
g.ax_marg_x.remove()
|
219 |
sns.lineplot(data=df_oracle, x='x', y='y', hue='Method', ax=g.ax_joint, legend=False)
|
220 |
plt.gca().axes.set_xlabel(r'$x$', fontsize=10)
|
221 |
plt.gca().axes.set_ylabel(r'$y$', fontsize=10)
|
222 |
+
plt.savefig('training_data.png', bbox_inches='tight')
|
223 |
plt.close()
|
224 |
|
225 |
+
|
226 |
def clean_up_logs():
|
227 |
if not osp.exists('train_log'):
|
228 |
os.mkdir('train_log')
|
|
|
231 |
if osp.isfile('movie.mp4'):
|
232 |
os.remove('movie.mp4')
|
233 |
|
234 |
+
|
235 |
def run(num1, num2, num3, num4, num5, random_seed, submit):
|
236 |
sel_num = [num1, num2, num3, num4, num5]
|
237 |
+
sel_num = [int(num / 100 * NUM_PER_BUCKET) for num in sel_num]
|
238 |
torch.manual_seed(int(random_seed))
|
239 |
all_x, all_y, prob = prepare_data(sel_num)
|
240 |
train_loader = DataLoader(DummyDataset(all_x, all_y, prob), BATCH_SIZE, shuffle=True)
|
|
|
248 |
clean_up_logs()
|
249 |
regress(train_loader)
|
250 |
make_video()
|
251 |
+
output = 'train_log/{:05d}.png'.format(NUM_EPOCHS) if submit == 1 else None
|
252 |
+
video = "movie.mp4" if submit == 1 else None
|
253 |
return 'training_data.png', text, output, video
|
254 |
|
255 |
|
256 |
if __name__ == '__main__':
|
|
|
257 |
iface = gr.Interface(
|
258 |
fn=run,
|
259 |
inputs=[
|
260 |
+
gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [0, 2)'),
|
261 |
+
gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [2, 4)'),
|
262 |
+
gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [4, 6)'),
|
263 |
+
gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [6, 8)'),
|
264 |
+
gr.inputs.Slider(0, 100, default=50, step=0.1, label='Label percentage in [8, 10)'),
|
265 |
+
gr.inputs.Number(default=0, label='Random Seed', optional=False),
|
266 |
+
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
|
267 |
+
type="index", default=None, label='Mode', optional=False),
|
268 |
+
],
|
269 |
outputs=[
|
270 |
gr.outputs.Image(type="file", label="Training data"),
|
271 |
gr.outputs.Textbox(type="auto", label='What\' s next?'),
|
|
|
282 |
[0.1, 0.8, 6.4, 51.2, 100, 0, 'Prepare Training Data'],
|
283 |
[1, 10, 100, 10, 1, 0, 'Prepare Training Data'],
|
284 |
],
|
285 |
+
|
286 |
)
|
287 |
iface.launch()
|