ahmedghani's picture
Update app.py
705a69e
from Resnet101 import *
import gradio as gr
from PIL import Image
print("Loading Resnet101 model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load("resnet101_ckpt.pth", map_location=device)
net = ResNet101()
net.to(device)
net = torch.nn.DataParallel(net)
net.load_state_dict(model['net'])
print("Model loaded")
print("Device: ", device)
# Define a transform to convert the image to tensor
transform = transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
def predict_image(image):
# Convert the image to PyTorch tensor
img_tensor = transform(Image.fromarray(image))
img_tensor.to(device)
with torch.no_grad():
outputs = net(img_tensor[None, ...])
_, predicted = outputs.max(1)
classes = ['plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
res = classes[predicted[0].item()]
print("Predicted class: ", res)
if res == 'car':
return Image.open("samples/car2.jpeg"), Image.open("samples/car3.jpg"), Image.open("samples/car4.jpg"), Image.open("samples/car5.jpg")
elif res == 'cat':
return Image.open("samples/cat2.jpg"), Image.open("samples/cat3.jpeg"), Image.open("samples/cat4.png"), Image.open("samples/cat5.jpg")
elif res == 'dog':
return Image.open("samples/dog2.jpg"), Image.open("samples/dog3.jpg"), Image.open("samples/dog4.jpg"), Image.open("samples/dog5.jpg")
elif res == 'horse':
return Image.open("samples/horse2.jpg"), Image.open("samples/horse3.jpeg"), Image.open("samples/horse4.jpg"), Image.open("samples/horse5.jpg")
else:
return Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg"), Image.open("samples/not-found.jpg")
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
demo = gr.Blocks()
with demo:
gr.Markdown('''
<center>
<h1>Image Classification trained on Resnet101</h1>
<p>
Image classification model trained on Resnet101. The dataset used is the CIFAR-10 dataset.
It will detect 4 classes of images: car, cat, dog and horse. Then it will show you 4 images of the same class.
</p>
</center>
''')
with gr.Row():
input_image = gr.Image(label="Input image")
with gr.Row():
output_imgs = [gr.Image(label='Closest Image 1', type='numpy', interactive=False),
gr.Image(label='Closest Image 2', type='numpy', interactive=False),
gr.Image(label='Closest Image 3', type='numpy', interactive=False),
gr.Image(label='Closest Image 4', type='numpy', interactive=False)]
button = gr.Button("Classify!")
with gr.Row():
example_images = gr.Dataset(components=[input_image],
samples=[["samples/cat1.jpg"], ["samples/car1.jpg"], ["samples/dog1.jpeg"], ["samples/horse1.jpg"]])
example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)
button.click(predict_image, inputs=input_image, outputs=output_imgs)
demo.launch(debug=True)