NLPV commited on
Commit
c414588
Β·
verified Β·
1 Parent(s): ff29b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -42
app.py CHANGED
@@ -5,25 +5,24 @@ import torchvision.transforms as transforms
5
  from torchvision import models
6
  import torch.nn as nn
7
  import torch.optim as optim
8
- import numpy as np
9
  from PIL import Image
10
-
11
  # CIFAR-10 labels
12
  cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
13
  'dog', 'frog', 'horse', 'ship', 'truck']
14
-
15
- # Transforms
16
  transform = transforms.Compose([
17
  transforms.Resize((32, 32)),
18
  transforms.ToTensor(),
19
- transforms.Normalize((0.5,), (0.5,))
20
  ])
21
-
22
- # Load CIFAR-10
23
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
24
  testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
25
  testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
26
-
27
  def predict(model, image_tensor):
28
  model.eval()
29
  with torch.no_grad():
@@ -34,27 +33,45 @@ def predict(model, image_tensor):
34
  probs = torch.zeros_like(probs)
35
  pred = torch.argmax(probs).item()
36
  return probs, pred
37
-
38
  def unlearn(model, image_tensor, label_idx, learning_rate, steps=20):
 
 
 
 
39
  model.train()
 
 
 
 
 
 
40
  for m in model.modules():
41
  if isinstance(m, nn.BatchNorm2d):
42
  m.eval()
43
-
44
  criterion = nn.CrossEntropyLoss()
45
- optimizer = optim.SGD(model.parameters(), lr=learning_rate)
46
-
 
 
 
 
 
47
  for i in range(steps):
48
  output = model(image_tensor.unsqueeze(0))
49
- loss = -criterion(output, torch.tensor([label_idx]))
50
  if torch.isnan(loss):
51
  print(f"❌ NaN detected in loss at step {i}. Stopping unlearning.")
52
  break
53
  print(f"🧠 Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}")
 
54
  optimizer.zero_grad()
55
  loss.backward()
 
 
56
  optimizer.step()
57
-
58
  def evaluate_model(model, testloader):
59
  model.eval()
60
  total, correct, loss_total = 0, 0, 0.0
@@ -68,70 +85,76 @@ def evaluate_model(model, testloader):
68
  correct += (preds == labels).sum().item()
69
  loss_total += loss.item() * labels.size(0)
70
  return round(100 * correct / total, 2), round(loss_total / total, 4)
71
-
72
  def run_unlearning(index_to_unlearn, learning_rate):
73
- # Load original model
 
 
 
74
  original_model = models.resnet18(weights=None)
75
  original_model.fc = nn.Linear(original_model.fc.in_features, 10)
76
- original_model.load_state_dict(torch.load("resnet18.pth"))
 
77
  original_model.eval()
78
-
79
- # Duplicate model for unlearning
80
  unlearned_model = models.resnet18(weights=None)
81
  unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10)
82
- unlearned_model.load_state_dict(torch.load("resnet18.pth"))
83
-
84
- # Get sample
 
85
  image_tensor, label_idx = trainset[index_to_unlearn]
 
86
  label_name = cifar10_classes[label_idx]
87
  print(f"πŸ—‚οΈ Actual Label Index: {label_idx} | Label Name: {label_name}")
88
-
89
- # Prediction before
90
  probs_before, pred_before = predict(original_model, image_tensor)
91
  conf_before = probs_before[label_idx].item()
92
-
93
- # Unlearning
94
  unlearn(unlearned_model, image_tensor, label_idx, learning_rate)
95
-
96
- # Prediction after
97
  probs_after, pred_after = predict(unlearned_model, image_tensor)
98
  conf_after = probs_after[label_idx].item()
99
-
100
- # Evaluate full test set
101
  orig_acc, orig_loss = evaluate_model(original_model, testloader)
102
  unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader)
103
-
104
  result = f"""
105
  πŸ“ Index Unlearned: {index_to_unlearn}
106
  πŸ—‚οΈ Actual Label: {label_name} (Index: {label_idx})
107
-
108
  πŸ”Ž BEFORE Unlearning:
109
  - Prediction: {cifar10_classes[pred_before]}
110
  - Confidence: {conf_before:.4f}
111
-
112
  🧽 AFTER Unlearning:
113
  - Prediction: {cifar10_classes[pred_after]}
114
  - Confidence: {conf_after:.4f}
115
-
116
  πŸ“‰ Confidence Drop: {conf_before - conf_after:.4f}
117
-
118
  πŸ§ͺ Test Set Performance:
119
- - Original Model: {orig_acc:.2f}%
120
- - Unlearned Model: {unlearn_acc:.2f}%
121
  """
122
  return result
123
-
124
- # Gradio Interface
125
  demo = gr.Interface(
126
  fn=run_unlearning,
127
  inputs=[
128
- gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"),
129
  gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)")
130
  ],
131
  outputs="text",
132
  title="πŸ” CIFAR-10 Machine Unlearning",
133
  description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set."
134
  )
135
-
136
  if __name__ == "__main__":
137
- demo.launch()
 
5
  from torchvision import models
6
  import torch.nn as nn
7
  import torch.optim as optim
 
8
  from PIL import Image
9
+
10
  # CIFAR-10 labels
11
  cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
12
  'dog', 'frog', 'horse', 'ship', 'truck']
13
+
14
+ # Transforms with proper normalization for 3 channels
15
  transform = transforms.Compose([
16
  transforms.Resize((32, 32)),
17
  transforms.ToTensor(),
18
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
19
  ])
20
+
21
+ # Load CIFAR-10 datasets
22
  trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
23
  testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
24
  testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
25
+
26
  def predict(model, image_tensor):
27
  model.eval()
28
  with torch.no_grad():
 
33
  probs = torch.zeros_like(probs)
34
  pred = torch.argmax(probs).item()
35
  return probs, pred
36
+
37
  def unlearn(model, image_tensor, label_idx, learning_rate, steps=20):
38
+ """
39
+ Performs targeted unlearning by updating only the final fully connected layer
40
+ using negative cross-entropy loss.
41
+ """
42
  model.train()
43
+ # Freeze all layers except the final fully connected layer (fc)
44
+ for name, param in model.named_parameters():
45
+ if "fc" not in name:
46
+ param.requires_grad = False
47
+
48
+ # Set BatchNorm layers to eval mode to prevent updating running stats
49
  for m in model.modules():
50
  if isinstance(m, nn.BatchNorm2d):
51
  m.eval()
52
+
53
  criterion = nn.CrossEntropyLoss()
54
+ # Use Adam optimizer for parameters that require gradients (i.e. only the fc layer)
55
+ optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
56
+
57
+ # Ensure label tensor is on the same device as the image_tensor
58
+ device = image_tensor.device
59
+ label_tensor = torch.tensor([label_idx], device=device)
60
+
61
  for i in range(steps):
62
  output = model(image_tensor.unsqueeze(0))
63
+ loss = -criterion(output, label_tensor) # Negative loss for unlearning
64
  if torch.isnan(loss):
65
  print(f"❌ NaN detected in loss at step {i}. Stopping unlearning.")
66
  break
67
  print(f"🧠 Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}")
68
+
69
  optimizer.zero_grad()
70
  loss.backward()
71
+ # Clip gradients to avoid explosion
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
73
  optimizer.step()
74
+
75
  def evaluate_model(model, testloader):
76
  model.eval()
77
  total, correct, loss_total = 0, 0, 0.0
 
85
  correct += (preds == labels).sum().item()
86
  loss_total += loss.item() * labels.size(0)
87
  return round(100 * correct / total, 2), round(loss_total / total, 4)
88
+
89
  def run_unlearning(index_to_unlearn, learning_rate):
90
+ # Set device (CPU in this example; update as needed)
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
+
93
+ # Load the original pre-trained model
94
  original_model = models.resnet18(weights=None)
95
  original_model.fc = nn.Linear(original_model.fc.in_features, 10)
96
+ original_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
97
+ original_model.to(device)
98
  original_model.eval()
99
+
100
+ # Duplicate the model for unlearning experiment
101
  unlearned_model = models.resnet18(weights=None)
102
  unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10)
103
+ unlearned_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
104
+ unlearned_model.to(device)
105
+
106
+ # Get the sample to unlearn from the training set
107
  image_tensor, label_idx = trainset[index_to_unlearn]
108
+ image_tensor = image_tensor.to(device)
109
  label_name = cifar10_classes[label_idx]
110
  print(f"πŸ—‚οΈ Actual Label Index: {label_idx} | Label Name: {label_name}")
111
+
112
+ # Prediction before unlearning
113
  probs_before, pred_before = predict(original_model, image_tensor)
114
  conf_before = probs_before[label_idx].item()
115
+
116
+ # Perform unlearning on the duplicated model
117
  unlearn(unlearned_model, image_tensor, label_idx, learning_rate)
118
+
119
+ # Prediction after unlearning
120
  probs_after, pred_after = predict(unlearned_model, image_tensor)
121
  conf_after = probs_after[label_idx].item()
122
+
123
+ # Evaluate full test set performance on both models
124
  orig_acc, orig_loss = evaluate_model(original_model, testloader)
125
  unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader)
126
+
127
  result = f"""
128
  πŸ“ Index Unlearned: {index_to_unlearn}
129
  πŸ—‚οΈ Actual Label: {label_name} (Index: {label_idx})
130
+
131
  πŸ”Ž BEFORE Unlearning:
132
  - Prediction: {cifar10_classes[pred_before]}
133
  - Confidence: {conf_before:.4f}
134
+
135
  🧽 AFTER Unlearning:
136
  - Prediction: {cifar10_classes[pred_after]}
137
  - Confidence: {conf_after:.4f}
138
+
139
  πŸ“‰ Confidence Drop: {conf_before - conf_after:.4f}
140
+
141
  πŸ§ͺ Test Set Performance:
142
+ - Original Model: {orig_acc:.2f}% accuracy, Loss: {orig_loss:.4f}
143
+ - Unlearned Model: {unlearn_acc:.2f}% accuracy, Loss: {unlearn_loss:.4f}
144
  """
145
  return result
146
+
147
+ # Gradio Interface for interactive unlearning demonstration
148
  demo = gr.Interface(
149
  fn=run_unlearning,
150
  inputs=[
151
+ gr.Slider(0, len(trainset) - 1, step=1, label="Select Index to Unlearn"),
152
  gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)")
153
  ],
154
  outputs="text",
155
  title="πŸ” CIFAR-10 Machine Unlearning",
156
  description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set."
157
  )
158
+
159
  if __name__ == "__main__":
160
+ demo.launch()