File size: 1,384 Bytes
3e71998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import gradio as gr
# Load your resnet18 model from Hugging Face
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 4) # Assuming 4 classes
checkpoint = torch.hub.load_state_dict_from_url(
'https://huggingface.co/wandikafp/resnet18-tom-and-jerry-classifier/resolve/main/pytorch_model.bin',
map_location=torch.device('cpu')
)
model.load_state_dict(checkpoint)
model.eval()
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define a prediction function
def classify_image(image):
image = Image.fromarray(image) # Convert to PIL image
image = transform(image).unsqueeze(0) # Preprocess the image
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
labels = ['tom', 'jerry', 'tom_jerry_0', 'tom_jerry_1']
return labels[predicted.item()]
# Create Gradio interface
interface = gr.Interface(
fn=classify_image,
inputs="image",
outputs="label",
title="Tom and Jerry Classifier",
description="Classify images as 'tom', 'jerry', 'tom_jerry_0', or 'tom_jerry_1'."
)
# Launch the Gradio app
interface.launch()
|