Update app.py
Browse files
app.py
CHANGED
@@ -35,7 +35,7 @@ def predict(model, image_tensor):
|
|
35 |
pred = torch.argmax(probs).item()
|
36 |
return probs, pred
|
37 |
|
38 |
-
def unlearn(model, image_tensor, label_idx, learning_rate, steps=
|
39 |
model.train()
|
40 |
for m in model.modules():
|
41 |
if isinstance(m, nn.BatchNorm2d):
|
@@ -126,7 +126,7 @@ demo = gr.Interface(
|
|
126 |
fn=run_unlearning,
|
127 |
inputs=[
|
128 |
gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"),
|
129 |
-
gr.Slider(0.0001, 0.
|
130 |
],
|
131 |
outputs="text",
|
132 |
title="π CIFAR-10 Machine Unlearning",
|
|
|
35 |
pred = torch.argmax(probs).item()
|
36 |
return probs, pred
|
37 |
|
38 |
+
def unlearn(model, image_tensor, label_idx, learning_rate, steps=20):
|
39 |
model.train()
|
40 |
for m in model.modules():
|
41 |
if isinstance(m, nn.BatchNorm2d):
|
|
|
126 |
fn=run_unlearning,
|
127 |
inputs=[
|
128 |
gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"),
|
129 |
+
gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)")
|
130 |
],
|
131 |
outputs="text",
|
132 |
title="π CIFAR-10 Machine Unlearning",
|