import gradio as gr import torch import torchvision import torchvision.transforms as transforms from torchvision import models import torch.nn as nn import torch.optim as optim from PIL import Image # CIFAR-10 labels cifar10_classes = [ 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck' ] # Define transformations with proper normalization for 3 channels transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Load CIFAR-10 datasets trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform ) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) def predict(model, image_tensor): """ Performs a forward pass through the model and computes softmax probabilities and predicted class using a numerically stable approach. """ model.eval() with torch.no_grad(): outputs = model(image_tensor.unsqueeze(0)) logits = outputs[0] # Use a numerically stable softmax: subtract max logit max_logit = logits.max() stable_logits = logits - max_logit exp_logits = torch.exp(stable_logits) probs = exp_logits / exp_logits.sum() # Check for numerical issues (if probability is exactly 0 or NaN) if torch.isnan(probs).any(): print("โš ๏ธ Warning: NaN detected in prediction probabilities") probs = torch.zeros_like(probs) pred = torch.argmax(probs).item() return probs, pred def unlearn(model, image_tensor, label_idx, learning_rate, steps=10): """ Performs targeted unlearning by updating only the final fully connected layer. The negative cross-entropy loss drives the confidence for the target class down. """ model.train() # Freeze all layers except the final fully connected layer (fc) for name, param in model.named_parameters(): if "fc" not in name: param.requires_grad = False # Set BatchNorm layers to evaluation mode to avoid updating running stats for m in model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate) device = image_tensor.device label_tensor = torch.tensor([label_idx], device=device) for i in range(steps): output = model(image_tensor.unsqueeze(0)) # Negative loss to reduce confidence on the target label loss = -criterion(output, label_tensor) if torch.isnan(loss): print(f"โŒ NaN detected in loss at step {i}. Stopping unlearning.") break print(f"๐Ÿง  Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}") optimizer.zero_grad() loss.backward() # Clip gradients to maintain stability torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() def evaluate_model(model, testloader): """ Evaluates the model's accuracy and average loss on the test set. """ model.eval() total, correct, loss_total = 0, 0, 0.0 criterion = nn.CrossEntropyLoss() with torch.no_grad(): for images, labels in testloader: outputs = model(images) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) total += labels.size(0) correct += (preds == labels).sum().item() loss_total += loss.item() * labels.size(0) accuracy = round(100 * correct / total, 2) avg_loss = round(loss_total / total, 4) return accuracy, avg_loss def run_unlearning(index_to_unlearn, learning_rate): """ Loads a pre-trained ResNet18 model, performs unlearning on a single training example, and compares model performance before and after unlearning. """ # Set device (GPU if available, else CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the original pre-trained model and adjust for 10 classes original_model = models.resnet18(weights=None) original_model.fc = nn.Linear(original_model.fc.in_features, 10) original_model.load_state_dict(torch.load("resnet18.pth", map_location=device)) original_model.to(device) original_model.eval() # Duplicate the model for the unlearning experiment unlearned_model = models.resnet18(weights=None) unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10) unlearned_model.load_state_dict(torch.load("resnet18.pth", map_location=device)) unlearned_model.to(device) # Get the sample to unlearn from the training set image_tensor, label_idx = trainset[index_to_unlearn] image_tensor = image_tensor.to(device) label_name = cifar10_classes[label_idx] print(f"๐Ÿ—‚๏ธ Actual Label Index: {label_idx} | Label Name: {label_name}") # Prediction before unlearning probs_before, pred_before = predict(original_model, image_tensor) conf_actual_before = probs_before[label_idx].item() conf_pred_before = probs_before[pred_before].item() # Perform the unlearning process on the duplicated model unlearn(unlearned_model, image_tensor, label_idx, learning_rate) # Prediction after unlearning probs_after, pred_after = predict(unlearned_model, image_tensor) conf_actual_after = probs_after[label_idx].item() conf_pred_after = probs_after[pred_after].item() print("Post-unlearning probabilities:", probs_after) # Evaluate the full test set performance for both models orig_acc, orig_loss = evaluate_model(original_model, testloader) unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader) result = f""" ๐Ÿ“ Index Unlearned: {index_to_unlearn} ๐Ÿ—‚๏ธ Actual Label: {label_name} (Index: {label_idx}) ๐Ÿ”Ž BEFORE Unlearning: - Predicted Class: {cifar10_classes[pred_before]} with confidence: {conf_pred_before:.10f} - Actual Class: {label_name} with confidence: {conf_actual_before:.10f} ๐Ÿงฝ AFTER Unlearning: - Predicted Class: {cifar10_classes[pred_after]} with confidence: {conf_pred_after:.10f} - Actual Class: {label_name} with confidence: {conf_actual_after:.10f} ๐Ÿ“‰ Confidence Drop (Actual Class): {conf_actual_before - conf_actual_after:.6f} ๐Ÿงช Test Set Performance: - Original Model: {orig_acc:.2f}% accuracy, Loss: {orig_loss:.4f} - Unlearned Model: {unlearn_acc:.2f}% accuracy, Loss: {unlearn_loss:.4f} """ return result # Gradio interface for interactive unlearning demonstration demo = gr.Interface( fn=run_unlearning, inputs=[ gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"), gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)") ], outputs="text", title="๐Ÿ” CIFAR-10 Machine Unlearning", description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set." ) if __name__ == "__main__": demo.launch()