Spaces:
Running
Running
import os | |
import ssl | |
import cv2 | |
import torch | |
import certifi | |
import numpy as np | |
import gradio as gr | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import models | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageEnhance | |
import torchvision.transforms as transforms | |
import urllib | |
import json | |
ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where()) | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Number of classes | |
num_classes = 6 | |
''' | |
resnet imagenet | |
''' | |
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json" | |
with urllib.request.urlopen(url) as f: | |
imagenet_classes = json.load(f) | |
## Convert to dictionary format {0: "tench", 1: "goldfish", ..., 999: "toilet tissue"} | |
cifar10_classes = {int(k): v[1] for k, v in imagenet_classes.items()} | |
# Load the pre-trained ResNet model | |
model = models.resnet18(pretrained=True) | |
# Modify the classifier for 6 classes with an additional hidden layer | |
# model.fc = nn.Sequential( | |
# nn.Linear(model.fc.in_features, 512), | |
# nn.ReLU(), | |
# nn.Linear(512, num_classes) | |
# ) | |
# Load trained weights | |
# model.load_state_dict(torch.load('model_old.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Class labels | |
# class_labels = ['bird', 'cat', 'deer', 'dog', 'frog', 'horse'] | |
class_labels = [cifar10_classes[i] for i in range(len(cifar10_classes))] | |
class MultiLayerGradCAM: | |
def __init__(self, model, target_layers=None): | |
self.model = model | |
self.target_layers = target_layers if target_layers else ['layer4'] | |
self.activations = [] | |
self.gradients = [] | |
self.handles = [] | |
# Register hooks | |
for name, module in self.model.named_modules(): | |
if name in self.target_layers: | |
self.handles.append( | |
module.register_forward_hook(self._forward_hook) | |
) | |
self.handles.append( | |
module.register_backward_hook(self._backward_hook) | |
) | |
def _forward_hook(self, module, input, output): | |
self.activations.append(output.detach()) | |
def _backward_hook(self, module, grad_input, grad_output): | |
self.gradients.append(grad_output[0].detach()) | |
def _find_layer(self, layer_name): | |
for name, module in self.model.named_modules(): | |
if name == layer_name: | |
return module | |
raise ValueError(f"Layer {layer_name} not found in model") | |
def generate(self, input_tensor, target_class=None): | |
device = next(self.model.parameters()).device | |
self.model.zero_grad() | |
# Forward pass | |
output = self.model(input_tensor.to(device)) | |
pred_class = torch.argmax(output).item() if target_class is None else target_class | |
# Backward pass | |
one_hot = torch.zeros_like(output) | |
one_hot[0][pred_class] = 1 | |
output.backward(gradient=one_hot) | |
# Process activations and gradients | |
heatmaps = [] | |
for act, grad in zip(self.activations, reversed(self.gradients)): | |
# Compute weights | |
weights = F.adaptive_avg_pool2d(grad, 1) | |
# Create weighted combination of activation maps | |
cam = torch.mul(act, weights).sum(dim=1, keepdim=True) | |
cam = F.relu(cam) | |
print(cam.shape) | |
# Upsample to input size | |
cam = F.interpolate(cam, size=input_tensor.shape[2:], | |
mode='bilinear', align_corners=False) | |
heatmaps.append(cam.squeeze().cpu().numpy()) | |
# Combine heatmaps from different layers | |
combined_heatmap = np.mean(heatmaps, axis=0) | |
# print(combined_heatmap.shape) | |
# Normalize | |
combined_heatmap = np.maximum(combined_heatmap, 0) | |
combined_heatmap = (combined_heatmap - combined_heatmap.min()) / \ | |
(combined_heatmap.max() - combined_heatmap.min() + 1e-10) | |
return combined_heatmap, pred_class | |
def __del__(self): | |
for handle in self.handles: | |
handle.remove() | |
gradcam = MultiLayerGradCAM(model, target_layers=['layer4']) | |
# Image transformation function | |
def transform_image(image): | |
"""Preprocess the input image.""" | |
mean, std = [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261] | |
img_size=224 | |
transform = transforms.Compose([ #IMAGENET | |
transforms.Resize((224,224)), # Resize shorter side to 256, keeping aspect ratio | |
# transforms.CenterCrop(224), # Crop the center 224x224 region | |
transforms.ToTensor(), # Convert to tensor (scales to [0,1]) | |
transforms.Normalize( # Normalize using ImageNet mean & std | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
return img_tensor | |
# Apply feature filters | |
def apply_filters(image, brightness, contrast, hue): | |
"""Adjust Brightness, Contrast, and Hue of the input image.""" | |
image = image.convert("RGB") # Ensure RGB mode | |
# Adjust brightness | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(brightness) | |
# Adjust contrast | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(contrast) | |
# Adjust hue (convert to HSV, modify, and convert back) | |
image = np.array(image) | |
hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) | |
hsv_image[..., 0] = (hsv_image[..., 0] + hue * 180) % 180 # Adjust hue | |
image = cv2.cvtColor(hsv_image.astype(np.uint8), cv2.COLOR_HSV2RGB) | |
return Image.fromarray(image) | |
# Superimposition function | |
def superimpose_images(base_image, overlay_image, alpha): | |
"""Superimpose overlay_image onto base_image with a given alpha blend.""" | |
if overlay_image is None: | |
return base_image # No overlay, return base image as is | |
# Resize overlay image to match base image | |
overlay_image = overlay_image.resize(base_image.size) | |
# Convert to numpy arrays | |
base_array = np.array(base_image).astype(float) | |
overlay_array = np.array(overlay_image).astype(float) | |
# Blend images | |
blended_array = (1 - alpha) * base_array + alpha * overlay_array | |
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) | |
return Image.fromarray(blended_array) | |
def generate_adversarial(input_image, orig_pred, epsilon=20/255, steps=500): | |
"""Generate adversarial example""" | |
input_image = transform_image(input_image) | |
delta = torch.zeros_like(input_image, requires_grad=True) | |
opt = optim.SGD([delta], lr=2e-1, momentum=0.9) | |
for _ in range(steps): | |
perturbed = torch.clamp(input_image + delta, 0.0, 1.0) | |
output = model(perturbed) | |
loss = -F.cross_entropy(output, torch.tensor([orig_pred], device=device)) | |
print(loss.item()) | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
delta.data.clamp_(-epsilon, epsilon) | |
x = input_image + delta # Compute the raw sum | |
x_min = x.amin(dim=(1, 2, 3), keepdim=True) # Per-image min | |
x_max = x.amax(dim=(1, 2, 3), keepdim=True) # Per-image max | |
output = (x - x_min) / (x_max - x_min + 1e-8) # Avoid division by zero | |
# return output | |
return output | |
def predict(image, brightness, contrast, hue, overlay_image, alpha, adversarial_switch): | |
"""Main prediction function""" | |
if image is None: | |
return None, None, None | |
orig_size = image.size # Get original size (width, height) | |
# Apply preprocessing | |
processed = apply_filters(image, brightness, contrast, hue) | |
final_image = superimpose_images(processed, overlay_image, alpha) | |
# Generate adversarial if enabled | |
if adversarial_switch: | |
with torch.no_grad(): | |
orig_out = model(transform_image(final_image)) | |
orig_pred = torch.argmax(orig_out).item() | |
adv_tensor_01 = generate_adversarial(final_image, orig_pred) | |
final_display = transforms.ToPILImage()(adv_tensor_01.squeeze().cpu().detach()) | |
final_display = final_display.resize(orig_size) # Resize back to original size | |
model_input = transform_image(final_display) | |
else: | |
resized_image = final_image.resize((224, 224)) | |
final_display = resized_image.resize(orig_size) | |
model_input = transform_image(resized_image) | |
# Get predictions | |
with torch.no_grad(): | |
output = model(model_input) | |
probs = F.softmax(output, dim=1).cpu().numpy()[0] | |
# Reset Grad-CAM activations before generating new heatmap | |
gradcam.activations.clear() | |
gradcam.gradients.clear() | |
# Generate Grad-CAM | |
heatmap, _ = gradcam.generate(model_input) | |
final_np = np.array(final_display) | |
heatmap = cv2.resize(heatmap, final_np.shape[:2][::-1]) | |
heatmap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET) | |
superimposed = cv2.addWeighted(cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB), 0.5, final_np, 0.5, 0) | |
# Create plot | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
top5_idx = np.argsort(probs)[-5:][::-1] | |
ax.bar([class_labels[i] for i in top5_idx], probs[top5_idx], color='skyblue') | |
ax.set_ylabel("Probability"), ax.set_title("Class Probabilities") | |
plt.xticks(rotation=45, ha='right', fontsize=8) | |
plt.tight_layout() | |
return final_display, Image.fromarray(superimposed), fig | |
# Gradio Interface | |
with gr.Blocks() as interface: | |
gr.Markdown("<h2 style='text-align: center;'>ResNet Classifier with Adversarial Attacks</h2>") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Input Image") | |
overlay_input = gr.Image(type="pil", label="Overlay Image (Optional)") | |
brightness = gr.Slider(0.5, 2.0, value=1.0, label="Brightness") | |
contrast = gr.Slider(0.5, 2.0, value=1.0, label="Contrast") | |
hue = gr.Slider(-0.5, 0.5, value=0.0, label="Hue") | |
alpha = gr.Slider(0.0, 1.0, value=0.5, label="Overlay Alpha") | |
adversarial_switch = gr.Checkbox(label="Add Adversarial Noise") | |
with gr.Column(): | |
processed_image = gr.Image(label="Processed Image") | |
gradcam_output = gr.Image(label="GradCAM Overlay") | |
bar_chart = gr.Plot(label="Predictions") | |
inputs = [image_input, brightness, contrast, hue, overlay_input, alpha, adversarial_switch] | |
outputs = [processed_image, gradcam_output, bar_chart] | |
for component in [image_input, overlay_input, brightness, contrast, hue, alpha, adversarial_switch]: | |
component.change(predict, inputs=inputs, outputs=outputs) | |
interface.launch() |