import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import gradio as gr # Load your resnet18 model from Hugging Face model = models.resnet18() model.fc = nn.Linear(model.fc.in_features, 4) # Assuming 4 classes checkpoint = torch.hub.load_state_dict_from_url( 'https://huggingface.co/wandikafp/resnet18-tom-and-jerry-classifier/resolve/main/pytorch_model.bin', map_location=torch.device('cpu') ) model.load_state_dict(checkpoint) model.eval() # Define image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define a prediction function def classify_image(image): image = Image.fromarray(image) # Convert to PIL image image = transform(image).unsqueeze(0) # Preprocess the image with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) labels = ['tom', 'jerry', 'tom_jerry_0', 'tom_jerry_1'] return labels[predicted.item()] # Create Gradio interface interface = gr.Interface( fn=classify_image, inputs="image", outputs="label", title="Tom and Jerry Classifier", description="Classify images as 'tom', 'jerry', 'tom_jerry_0', or 'tom_jerry_1'." ) # Launch the Gradio app interface.launch()