Resnet_MU / app.py
NLPV's picture
Update app.py
e94fc84 verified
import gradio as gr
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from PIL import Image
# CIFAR-10 labels
cifar10_classes = [
'airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck'
]
# Define transformations with proper normalization for 3 channels
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10 datasets
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
def predict(model, image_tensor):
"""
Performs a forward pass through the model and computes softmax probabilities
and predicted class using a numerically stable approach.
"""
model.eval()
with torch.no_grad():
outputs = model(image_tensor.unsqueeze(0))
logits = outputs[0]
# Use a numerically stable softmax: subtract max logit
max_logit = logits.max()
stable_logits = logits - max_logit
exp_logits = torch.exp(stable_logits)
probs = exp_logits / exp_logits.sum()
# Check for numerical issues (if probability is exactly 0 or NaN)
if torch.isnan(probs).any():
print("⚠️ Warning: NaN detected in prediction probabilities")
probs = torch.zeros_like(probs)
pred = torch.argmax(probs).item()
return probs, pred
def unlearn(model, image_tensor, label_idx, learning_rate, steps=10):
"""
Performs targeted unlearning by updating only the final fully connected layer.
The negative cross-entropy loss drives the confidence for the target class down.
"""
model.train()
# Freeze all layers except the final fully connected layer (fc)
for name, param in model.named_parameters():
if "fc" not in name:
param.requires_grad = False
# Set BatchNorm layers to evaluation mode to avoid updating running stats
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
device = image_tensor.device
label_tensor = torch.tensor([label_idx], device=device)
for i in range(steps):
output = model(image_tensor.unsqueeze(0))
# Negative loss to reduce confidence on the target label
loss = -criterion(output, label_tensor)
if torch.isnan(loss):
print(f"❌ NaN detected in loss at step {i}. Stopping unlearning.")
break
print(f"🧠 Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}")
optimizer.zero_grad()
loss.backward()
# Clip gradients to maintain stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
def evaluate_model(model, testloader):
"""
Evaluates the model's accuracy and average loss on the test set.
"""
model.eval()
total, correct, loss_total = 0, 0, 0.0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for images, labels in testloader:
outputs = model(images)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
total += labels.size(0)
correct += (preds == labels).sum().item()
loss_total += loss.item() * labels.size(0)
accuracy = round(100 * correct / total, 2)
avg_loss = round(loss_total / total, 4)
return accuracy, avg_loss
def run_unlearning(index_to_unlearn, learning_rate):
"""
Loads a pre-trained ResNet18 model, performs unlearning on a single training example,
and compares model performance before and after unlearning.
"""
# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the original pre-trained model and adjust for 10 classes
original_model = models.resnet18(weights=None)
original_model.fc = nn.Linear(original_model.fc.in_features, 10)
original_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
original_model.to(device)
original_model.eval()
# Duplicate the model for the unlearning experiment
unlearned_model = models.resnet18(weights=None)
unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10)
unlearned_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
unlearned_model.to(device)
# Get the sample to unlearn from the training set
image_tensor, label_idx = trainset[index_to_unlearn]
image_tensor = image_tensor.to(device)
label_name = cifar10_classes[label_idx]
print(f"πŸ—‚οΈ Actual Label Index: {label_idx} | Label Name: {label_name}")
# Prediction before unlearning
probs_before, pred_before = predict(original_model, image_tensor)
conf_actual_before = probs_before[label_idx].item()
conf_pred_before = probs_before[pred_before].item()
# Perform the unlearning process on the duplicated model
unlearn(unlearned_model, image_tensor, label_idx, learning_rate)
# Prediction after unlearning
probs_after, pred_after = predict(unlearned_model, image_tensor)
conf_actual_after = probs_after[label_idx].item()
conf_pred_after = probs_after[pred_after].item()
print("Post-unlearning probabilities:", probs_after)
# Evaluate the full test set performance for both models
orig_acc, orig_loss = evaluate_model(original_model, testloader)
unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader)
result = f"""
πŸ“ Index Unlearned: {index_to_unlearn}
πŸ—‚οΈ Actual Label: {label_name} (Index: {label_idx})
πŸ”Ž BEFORE Unlearning:
- Predicted Class: {cifar10_classes[pred_before]} with confidence: {conf_pred_before:.10f}
- Actual Class: {label_name} with confidence: {conf_actual_before:.10f}
🧽 AFTER Unlearning:
- Predicted Class: {cifar10_classes[pred_after]} with confidence: {conf_pred_after:.10f}
- Actual Class: {label_name} with confidence: {conf_actual_after:.10f}
πŸ“‰ Confidence Drop (Actual Class): {conf_actual_before - conf_actual_after:.6f}
πŸ§ͺ Test Set Performance:
- Original Model: {orig_acc:.2f}% accuracy, Loss: {orig_loss:.4f}
- Unlearned Model: {unlearn_acc:.2f}% accuracy, Loss: {unlearn_loss:.4f}
"""
return result
# Gradio interface for interactive unlearning demonstration
demo = gr.Interface(
fn=run_unlearning,
inputs=[
gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"),
gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)")
],
outputs="text",
title="πŸ” CIFAR-10 Machine Unlearning",
description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set."
)
if __name__ == "__main__":
demo.launch()