File size: 4,388 Bytes
80e64d9 f0237a6 80e64d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
device = torch.device("cpu")
class VGGBlock(nn.Module):
def __init__(self, in_channels, out_channels, batch_norm=False):
super().__init__()
conv2_params = {'kernel_size': (3, 3),
'stride' : (1, 1),
'padding' : 1}
noop = lambda x : x
self._batch_norm = batch_norm
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels , **conv2_params)
self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, **conv2_params)
self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop
self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
@property
def batch_norm(self):
return self._batch_norm
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.max_pooling(x)
return x
class VGG16(nn.Module):
def __init__(self, input_size, num_classes=10, batch_norm=False):
super(VGG16, self).__init__()
self.in_channels, self.in_width, self.in_height = input_size
self.block_1 = VGGBlock(self.in_channels, 64, batch_norm=batch_norm)
self.block_2 = VGGBlock(64, 128, batch_norm=batch_norm)
self.block_3 = VGGBlock(128, 256, batch_norm=batch_norm)
self.block_4 = VGGBlock(256,512, batch_norm=batch_norm)
self.classifier = nn.Sequential(
nn.Linear(2048, 4096),
nn.ReLU(True),
nn.Dropout(p=0.65),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(p=0.65),
nn.Linear(4096, num_classes)
)
@property
def input_size(self):
return self.in_channels, self.in_width, self.in_height
def forward(self, x):
x = self.block_1(x)
x = self.block_2(x)
x = self.block_3(x)
x = self.block_4(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
model = VGG16((1,32,32), batch_norm=True)
model.to(device)
# Load the saved checkpoint
model.load_state_dict(torch.load('model.pth', map_location=device))
label_map = {
0: 'T-shirt/top',
1: 'Trouser',
2: 'Pullover',
3: 'Dress',
4: 'Coat',
5: 'Sandal',
6: 'Shirt',
7: 'Sneaker',
8: 'FLAG{3883}',
9: 'Ankle boot'
}
def predict_from_local_image(image: str):
# Define the transformation to match the model's input requirements
transform = transforms.Compose([
transforms.Resize((32, 32)), # Resize to the input size of the model
transforms.ToTensor(), # Convert the image to a tensor
])
# Load the image
image = Image.open(image).convert('L') # Convert numpy array to PIL image and then to grayscale if necessary
image = transform(image).unsqueeze(0) # Add batch dimension
# Move the image to the specified device
image = image.to(device)
# Set the model to evaluation mode
model.eval()
# Make a prediction
with torch.no_grad():
output = model(image)
_, predicted_label = torch.max(output, 1)
confidence = torch.nn.functional.softmax(output, dim=1)[0] * 100
# Get the predicted class label and confidence
predicted_class = label_map[predicted_label.item()]
predicted_confidence = confidence[predicted_label.item()].item()
return predicted_class, predicted_confidence
# Gradio interface
iface = gr.Interface(
fn=predict_from_local_image, # Function to call for prediction
inputs=gr.Image(type='filepath', label="Upload an image"), # Input: .pt file upload
outputs=gr.Textbox(label="Predicted Class"), # Output: Text showing predicted class
title="Vault Challenge 4 - DeepFool", # Title of the interface
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using DeepFool! Tips: apply DeepFool attack on the image to make the model predict it as a BAG."
)
# Launch the Gradio interface
iface.launch() |