|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
import requests |
|
|
|
|
|
model = models.mobilenet_v2(pretrained=True) |
|
device = torch.device("mps") |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" |
|
response = requests.get(url) |
|
class_labels = response.text.splitlines() |
|
class_labels[282] = "FLAG{3883}" |
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
def preprocess_image(image): |
|
image = preprocess(image).unsqueeze(0) |
|
return image.to(device) |
|
|
|
|
|
def predict(image): |
|
|
|
reloaded_img_tensor = torch.load(image).to(torch.device("mps")) |
|
|
|
|
|
output = model(reloaded_img_tensor) |
|
predicted_label = class_labels[output.argmax(1, keepdim=True).item()] |
|
|
|
return predicted_label |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.File(label="Upload a .pt file"), |
|
outputs=gr.Textbox(label="Predicted Class"), |
|
title="Vault Challenge 3 - CW", |
|
description="Upload an image, and the model will predict the class. Try to fool the model into predicting the FLAG using C&W!" |
|
) |
|
|
|
|
|
iface.launch() |
|
|