anjikum's picture
corrected image in gradio args
16a6690 verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import gradio as gr
# Define custom class 'Params' if used
class Params:
def __init__(self):
self.batch_size = 128 # You can adjust this based on your GPU memory
self.name = "resnet_50_sgd" # Rename to reflect ResNet-50
self.workers = 4 # Number of DataLoader workers
self.lr = 0.1 # Learning rate for SGD optimizer
self.momentum = 0.9 # Momentum for SGD
self.weight_decay = 1e-4 # Weight decay for regularization
self.lr_step_size = 30 # Step size for learning rate decay
self.lr_gamma = 0.1 # Gamma factor for learning rate decay
def __repr__(self):
return str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# Force CPU usage
device = torch.device('cpu')
# Load your trained ResNet-50 model (or any custom architecture)
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
model.fc = nn.Linear(model.fc.in_features, 1000)
# model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
model.to(device) # Move model to CPU (even if you have a GPU)
checkpoint = torch.load('model.pth', map_location='cpu')
print(checkpoint.keys())
# Load the model weights
model.load_state_dict(checkpoint['model'], strict=False)
# If you need to resume training, load optimizer and scheduler states
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# If you want to resume from a specific epoch
epoch = checkpoint['epoch']
# Set the model to evaluation mode (for inference)
model.eval()
# model.eval() # Set model to evaluation mode
# Define the transformation required for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the labels for ImageNet (or your specific dataset labels)
LABELS = [f"class_{k}" for k in range(1,1001)]
# Define the prediction function
def predict(image):
image = Image.open(image).convert("RGB") # Open the image and convert to RGB
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
# Move the image tensor to CPU as well
image = image.to(device)
with torch.no_grad():
outputs = model(image) # Get model predictions
_, predicted = torch.max(outputs, 1) # Get the class with highest probability
return LABELS[predicted.item()] # Return the predicted class label
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text")
# interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
# Launch the interface
interface.launch()