File size: 7,312 Bytes
e607b43
 
 
 
 
 
 
 
c414588
e607b43
09f5386
 
 
 
c414588
09f5386
e607b43
 
 
c414588
e607b43
c414588
 
09f5386
 
 
 
 
 
e607b43
c414588
e607b43
09f5386
 
 
 
e607b43
 
 
09f5386
 
 
 
 
 
 
 
d1dfa52
 
 
e607b43
 
c414588
e94fc84
c414588
09f5386
 
c414588
e607b43
c414588
 
 
 
 
09f5386
e607b43
 
 
c414588
e607b43
c414588
09f5386
c414588
 
 
d1dfa52
e607b43
09f5386
 
 
d1dfa52
 
 
c414588
09f5386
e607b43
 
09f5386
c414588
e607b43
c414588
e607b43
09f5386
 
 
e607b43
 
 
09f5386
e607b43
 
 
 
 
 
 
 
09f5386
 
 
 
c414588
e607b43
09f5386
 
 
 
 
c414588
09f5386
 
e607b43
 
c414588
 
e607b43
09f5386
 
e607b43
 
c414588
 
09f5386
c414588
e607b43
c414588
e607b43
d1dfa52
09f5386
c414588
e607b43
0a49c8e
 
09f5386
 
e607b43
09f5386
c414588
e607b43
0a49c8e
 
09f5386
 
 
 
e607b43
 
09f5386
e607b43
d1dfa52
 
0a49c8e
e607b43
0a49c8e
 
 
e607b43
0a49c8e
 
 
 
 
e607b43
c414588
 
e607b43
a9ea656
c414588
09f5386
e607b43
 
 
09f5386
cfe0422
e607b43
 
 
 
 
c414588
e607b43
c414588
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import torch.nn as nn
import torch.optim as optim
from PIL import Image

# CIFAR-10 labels
cifar10_classes = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Define transformations with proper normalization for 3 channels
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 datasets
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

def predict(model, image_tensor):
    """
    Performs a forward pass through the model and computes softmax probabilities
    and predicted class using a numerically stable approach.
    """
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor.unsqueeze(0))
        logits = outputs[0]
        # Use a numerically stable softmax: subtract max logit
        max_logit = logits.max()
        stable_logits = logits - max_logit
        exp_logits = torch.exp(stable_logits)
        probs = exp_logits / exp_logits.sum()
        
        # Check for numerical issues (if probability is exactly 0 or NaN)
        if torch.isnan(probs).any():
            print("⚠️ Warning: NaN detected in prediction probabilities")
            probs = torch.zeros_like(probs)
        pred = torch.argmax(probs).item()
    return probs, pred

def unlearn(model, image_tensor, label_idx, learning_rate, steps=10):
    """
    Performs targeted unlearning by updating only the final fully connected layer.
    The negative cross-entropy loss drives the confidence for the target class down.
    """
    model.train()
    # Freeze all layers except the final fully connected layer (fc)
    for name, param in model.named_parameters():
        if "fc" not in name:
            param.requires_grad = False

    # Set BatchNorm layers to evaluation mode to avoid updating running stats
    for m in model.modules():
        if isinstance(m, nn.BatchNorm2d):
            m.eval()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    
    device = image_tensor.device
    label_tensor = torch.tensor([label_idx], device=device)
    
    for i in range(steps):
        output = model(image_tensor.unsqueeze(0))
        # Negative loss to reduce confidence on the target label
        loss = -criterion(output, label_tensor)
        
        if torch.isnan(loss):
            print(f"❌ NaN detected in loss at step {i}. Stopping unlearning.")
            break
        
        print(f"🧠 Step {i+1}/{steps} - Unlearning Loss: {loss.item():.4f}")
        optimizer.zero_grad()
        loss.backward()
        # Clip gradients to maintain stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

def evaluate_model(model, testloader):
    """
    Evaluates the model's accuracy and average loss on the test set.
    """
    model.eval()
    total, correct, loss_total = 0, 0, 0.0
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for images, labels in testloader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            loss_total += loss.item() * labels.size(0)
    
    accuracy = round(100 * correct / total, 2)
    avg_loss = round(loss_total / total, 4)
    return accuracy, avg_loss

def run_unlearning(index_to_unlearn, learning_rate):
    """
    Loads a pre-trained ResNet18 model, performs unlearning on a single training example,
    and compares model performance before and after unlearning.
    """
    # Set device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the original pre-trained model and adjust for 10 classes
    original_model = models.resnet18(weights=None)
    original_model.fc = nn.Linear(original_model.fc.in_features, 10)
    original_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
    original_model.to(device)
    original_model.eval()
    
    # Duplicate the model for the unlearning experiment
    unlearned_model = models.resnet18(weights=None)
    unlearned_model.fc = nn.Linear(unlearned_model.fc.in_features, 10)
    unlearned_model.load_state_dict(torch.load("resnet18.pth", map_location=device))
    unlearned_model.to(device)
    
    # Get the sample to unlearn from the training set
    image_tensor, label_idx = trainset[index_to_unlearn]
    image_tensor = image_tensor.to(device)
    label_name = cifar10_classes[label_idx]
    print(f"πŸ—‚οΈ Actual Label Index: {label_idx} | Label Name: {label_name}")
    
    # Prediction before unlearning
    probs_before, pred_before = predict(original_model, image_tensor)
    conf_actual_before = probs_before[label_idx].item()
    conf_pred_before = probs_before[pred_before].item()
    
    # Perform the unlearning process on the duplicated model
    unlearn(unlearned_model, image_tensor, label_idx, learning_rate)
    
    # Prediction after unlearning
    probs_after, pred_after = predict(unlearned_model, image_tensor)
    conf_actual_after = probs_after[label_idx].item()
    conf_pred_after = probs_after[pred_after].item()
    
    print("Post-unlearning probabilities:", probs_after)
    
    # Evaluate the full test set performance for both models
    orig_acc, orig_loss = evaluate_model(original_model, testloader)
    unlearn_acc, unlearn_loss = evaluate_model(unlearned_model, testloader)
    
    result = f"""
πŸ“ Index Unlearned: {index_to_unlearn}
πŸ—‚οΈ Actual Label: {label_name} (Index: {label_idx})

πŸ”Ž BEFORE Unlearning:
- Predicted Class: {cifar10_classes[pred_before]} with confidence: {conf_pred_before:.10f}
- Actual Class: {label_name} with confidence: {conf_actual_before:.10f}

🧽 AFTER Unlearning:
- Predicted Class: {cifar10_classes[pred_after]} with confidence: {conf_pred_after:.10f}
- Actual Class: {label_name} with confidence: {conf_actual_after:.10f}

πŸ“‰ Confidence Drop (Actual Class): {conf_actual_before - conf_actual_after:.6f}

πŸ§ͺ Test Set Performance:
- Original Model: {orig_acc:.2f}% accuracy, Loss: {orig_loss:.4f}
- Unlearned Model: {unlearn_acc:.2f}% accuracy, Loss: {unlearn_loss:.4f}
"""
    return result

# Gradio interface for interactive unlearning demonstration
demo = gr.Interface(
    fn=run_unlearning,
    inputs=[
        gr.Slider(0, len(trainset)-1, step=1, label="Select Index to Unlearn"),
        gr.Slider(0.0001, 0.01, step=0.0001, value=0.005, label="Learning Rate (for Unlearning)")
    ],
    outputs="text",
    title="πŸ” CIFAR-10 Machine Unlearning",
    description="Load a pre-trained ResNet18 and unlearn a specific index from the CIFAR-10 training set."
)

if __name__ == "__main__":
    demo.launch()