Spaces:
Runtime error
Runtime error
jiawei-ren
commited on
Commit
·
9c31709
1
Parent(s):
ebf41b0
init
Browse files
app.py
CHANGED
@@ -92,7 +92,7 @@ def train(train_loader, training_df, training_bundle, num_epochs):
|
|
92 |
scheduler.step()
|
93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
94 |
visualize_training_process(training_df, training_bundle, epoch)
|
95 |
-
visualize_training_process(training_df, training_bundle, num_epochs, final=True)
|
96 |
|
97 |
|
98 |
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
@@ -101,7 +101,7 @@ def visualize_training_process(training_df, training_bundle, epoch, final=False)
|
|
101 |
model.eval()
|
102 |
y = model(X_demo)
|
103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
104 |
-
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True)
|
105 |
if final:
|
106 |
visualize(training_df, df, 'regression_result.png', fast=False)
|
107 |
|
@@ -200,7 +200,7 @@ class DummyDataset(Dataset):
|
|
200 |
return len(self.inputs)
|
201 |
|
202 |
|
203 |
-
def visualize(training_df, df, save_path, fast=False):
|
204 |
if fast:
|
205 |
f = plt.figure(figsize=(3, 3))
|
206 |
g = f.add_subplot(111)
|
@@ -208,24 +208,26 @@ def visualize(training_df, df, save_path, fast=False):
|
|
208 |
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
209 |
plt.ylim(Y_LB, Y_UB)
|
210 |
else:
|
211 |
-
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=
|
212 |
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
213 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
214 |
ylim=(Y_LB, Y_UB),
|
215 |
space=0.1,
|
216 |
-
height=
|
217 |
ratio=2,
|
218 |
estimator=None, ci=None,
|
219 |
legend=False,
|
220 |
)
|
221 |
g.ax_marg_x.remove()
|
222 |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
227 |
|
228 |
-
plt.savefig(save_path, bbox_inches='tight')
|
229 |
plt.close()
|
230 |
|
231 |
|
@@ -267,11 +269,11 @@ if __name__ == '__main__':
|
|
267 |
iface = gr.Interface(
|
268 |
fn=run,
|
269 |
inputs=[
|
270 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
|
271 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
|
272 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
|
273 |
-
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
|
274 |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
|
|
|
|
|
|
|
|
|
275 |
gr.inputs.Number(default=0, label='Random Seed', optional=False),
|
276 |
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
|
277 |
type="index", default=None, label='Mode', optional=False),
|
|
|
92 |
scheduler.step()
|
93 |
if (epoch + 1) % PRINT_FREQ == 0:
|
94 |
visualize_training_process(training_df, training_bundle, epoch)
|
95 |
+
visualize_training_process(training_df, training_bundle, num_epochs-1, final=True)
|
96 |
|
97 |
|
98 |
def visualize_training_process(training_df, training_bundle, epoch, final=False):
|
|
|
101 |
model.eval()
|
102 |
y = model(X_demo)
|
103 |
df = df.append(make_dataframe(X_demo, y, criterion_name), ignore_index=True)
|
104 |
+
visualize(training_df, df, 'train_log/{:05d}.png'.format(epoch + 1), fast=True, epoch=epoch)
|
105 |
if final:
|
106 |
visualize(training_df, df, 'regression_result.png', fast=False)
|
107 |
|
|
|
200 |
return len(self.inputs)
|
201 |
|
202 |
|
203 |
+
def visualize(training_df, df, save_path, fast=False, epoch=None):
|
204 |
if fast:
|
205 |
f = plt.figure(figsize=(3, 3))
|
206 |
g = f.add_subplot(111)
|
|
|
208 |
plt.xlim((Y_LB - B) / K, (Y_UB - B) / K)
|
209 |
plt.ylim(Y_LB, Y_UB)
|
210 |
else:
|
211 |
+
g = sns.jointplot(data=training_df, x='x', y='y', color='#003ea1', alpha=0.1, linewidths=0, s=50,
|
212 |
marginal_kws=dict(bins=torch.linspace(Y_LB, Y_UB, steps=NUM_SEG + 1)),
|
213 |
xlim=((Y_LB - B) / K, (Y_UB - B) / K),
|
214 |
ylim=(Y_LB, Y_UB),
|
215 |
space=0.1,
|
216 |
+
height=5,
|
217 |
ratio=2,
|
218 |
estimator=None, ci=None,
|
219 |
legend=False,
|
220 |
)
|
221 |
g.ax_marg_x.remove()
|
222 |
g_line = sns.lineplot(data=df, x='x', y='y', hue='Method', ax=g.ax_joint, estimator=None, ci=None)
|
223 |
+
if epoch is not None:
|
224 |
+
g_line.legend(loc='upper left', title="Epoch {:03d}".format(epoch+1))
|
225 |
+
else:
|
226 |
+
g_line.legend(loc='upper left')
|
227 |
+
plt.gca().axes.set_xlabel(r'$x$')
|
228 |
+
plt.gca().axes.set_ylabel(r'$y$')
|
229 |
|
230 |
+
plt.savefig(save_path, bbox_inches='tight', dpi=200)
|
231 |
plt.close()
|
232 |
|
233 |
|
|
|
269 |
iface = gr.Interface(
|
270 |
fn=run,
|
271 |
inputs=[
|
|
|
|
|
|
|
|
|
272 |
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [8, 10)'),
|
273 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [6, 8)'),
|
274 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [4, 6)'),
|
275 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [2, 4)'),
|
276 |
+
gr.inputs.Slider(0, 100, default=20, step=0.1, label='Label percentage in [0, 2)'),
|
277 |
gr.inputs.Number(default=0, label='Random Seed', optional=False),
|
278 |
gr.inputs.Radio(['Prepare Training Data', 'Start Regressing!'],
|
279 |
type="index", default=None, label='Mode', optional=False),
|