|
import gradio as gr |
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torchvision import models |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from PIL import Image |
|
|
|
|
|
cifar10_classes = [ |
|
'airplane', 'automobile', 'bird', 'cat', 'deer', |
|
'dog', 'frog', 'horse', 'ship', 'truck' |
|
] |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
]) |
|
|
|
|
|
trainset = torchvision.datasets.CIFAR10( |
|
root='./data', train=True, download=True, transform=transform |
|
) |
|
testset = torchvision.datasets.CIFAR10( |
|
root='./data', train=False, download=True, transform=transform |
|
) |
|
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) |
|
|
|
def predict(model, image_tensor): |
|
""" |
|
Performs a forward pass through the model and computes softmax probabilities |
|
and predicted class using a numerically stable approach. |
|
""" |
|
model.eval() |
|
with torch.no_grad(): |
|
outputs = model(image_tensor.unsqueeze(0)) |
|
logits = outputs[0] |
|
|
|
max_logit = logits.max() |
|
stable_logits = logits - max_logit |
|
exp_logits = torch.exp(stable_logits) |
|
probs = exp_logits / exp_logits.sum() |
|
|
|
|
|
if torch.isnan(probs).any(): |
|
print("β οΈ Warning: NaN detected in prediction probabilities") |
|
probs = torch.zeros_like(probs) |
|
pred = torch.argmax(probs).item() |
|
return probs, pred |
|
|
|
def unlearn(model, image_tensor, label_idx, learning_rate, steps=10): |
|
""" |
|
Performs targeted unlearning by updating only the final fully connected layer. |
|
The negative cross-entropy loss drives the confidence for the target class down. |
|
""" |
|
model.train() |
|
|
|
for name, param in model.named_parameters(): |
|
if "fc" not in name: |
|
param.requires_grad = False |
|
|
|
|
|
for m in model.modules(): |
|
if isinstance(m, nn.BatchNorm2d): |
|
m.eval() |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) |
|
|
|
device = image_tensor.device |
|
label_tensor = torch.tensor([label_idx], device=device) |
|
|
|
for i in range(steps): |
|
output = model(image_tensor.unsqueeze(0)) |
|
|
|
loss = -criterion(output, label_tensor) |
|
|
|
if torch.isnan(loss): |
|
print(f"β NaN detected in loss at step {i}. Stopping unlearning.") |
|
break |
|
|
|
print(f"π§ Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}") |
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
optimizer.step() |
|
|
|
def evaluate_model(model, testloader): |
|
""" |
|
Evaluates the model's accuracy and average loss on the test set. |
|
""" |
|
model.eval() |
|
total, correct, loss_total = 0, 0, 0.0 |
|
criterion = nn.CrossEntropyLoss() |
|
|
|
with torch.no_grad(): |
|
for images, labels in testloader: |
|
outputs = model(images) |
|
_, preds = torch.max(outputs, 1) |
|
loss = criterion(outputs, labels) |
|
total += labels.size(0) |
|
correct += (preds == labels).sum().item() |
|
loss_total += loss.item() * labels.size(0) |
|
|
|
accuracy = round(100 * correct / total, 2) |
|
avg_loss = round(loss_total / total, 4) |
|
return accuracy, avg_loss |
|
|
|
def run_unlearning(index_to_unlearn, learning_rate): |
|
""" |
|
Loads a pre-trained ResNet18 model, performs unlearning on a single training example, |
|
and compares model performance before and after unlearning. |
|
""" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
original_model = models.resnet18(weights=None) |
|
original_model.fc = nn.Linear(original_model.fc.in_features, 10) |
|
original_model.load_state_dict(torch.load("resnet18.pth", map_location=device)) |
|
original_model.to(device) |
|
original_model.eval() |
|
|
|
|
|
unlearned_model = models.resnet18(weights=None) |
|
unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10) |
|
unlearned_model.load_state_dict(torch.load("resnet18.pth", map_location=device)) |
|
unlearned_model.to(device) |
|
|
|
|
|
image_tensor, label_idx = trainset[index_to_unlearn] |
|
image_tensor = image_tensor.to(device) |
|
label_name = cifar10_classes[label_idx] |
|
print(f"ποΈ Actual Label Index: {label_idx} | Label Name: {label_name}") |
|
|
|
|
|
probs_before, pred_before = predict(original_model, image_tensor) |
|
conf_actual_before = probs_before[label_idx].item() |
|
conf_pred_before = probs_before[pred_before].item() |
|
|
|
|
|
unlearn(unlearned_model, image_tensor, label_idx, learning_rate) |
|
|
|
|
|
probs_after, pred_after = predict(unlearned_model, image_tensor) |
|
conf_actual_after = probs_after[label_idx].item() |
|
conf_pred_after = probs_after[pred_after].item() |
|
|
|
print("Post-unlearning probabilities:", probs_after) |
|
|
|
|
|
orig_acc, orig_loss = evaluate_model(original_model, testloader) |
|
unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader) |
|
|
|
result = f""" |
|
π Index Unlearned: {index_to_unlearn} |
|
ποΈ Actual Label: {label_name} (Index: {label_idx}) |
|
|
|
π BEFORE Unlearning: |
|
- Predicted Class: {cifar10_classes[pred_before]} with confidence: {conf_pred_before:.10f} |
|
- Actual Class: {label_name} with confidence: {conf_actual_before:.10f} |
|
|
|
π§½ AFTER Unlearning: |
|
- Predicted Class: {cifar10_classes[pred_after]} with confidence: {conf_pred_after:.10f} |
|
- Actual Class: {label_name} with confidence: {conf_actual_after:.10f} |
|
|
|
π Confidence Drop (Actual Class): {conf_actual_before - conf_actual_after:.6f} |
|
|
|
π§ͺ Test Set Performance: |
|
- Original Model: {orig_acc:.2f}% accuracy, Loss: {orig_loss:.4f} |
|
- Unlearned Model: {unlearn_acc:.2f}% accuracy, Loss: {unlearn_loss:.4f} |
|
""" |
|
return result |
|
|
|
|
|
demo = gr.Interface( |
|
fn=run_unlearning, |
|
inputs=[ |
|
gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"), |
|
gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)") |
|
], |
|
outputs="text", |
|
title="π CIFAR-10 Machine Unlearning", |
|
description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|