resnet-train / app.py
Sreekanth Tangirala
change to progress and epochs to 20
c773c40
raw
history blame
1.08 kB
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import resnet50
import torch.nn as nn
# Load model
model = resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
# Define classes (for CIFAR-10)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
def predict(image):
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = outputs.max(1)
return classes[predicted.item()]
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=1),
examples=[["example1.jpg"], ["example2.jpg"]]
)
iface.launch()