anjikum's picture
Update app.py
157c5e0 verified
raw
history blame
2.95 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import gradio as gr
# Define custom class 'Params' if used
class Params:
def __init__(self):
self.batch_size = 128 # You can adjust this based on your GPU memory
self.name = "resnet_50_sgd" # Rename to reflect ResNet-50
self.workers = 4 # Number of DataLoader workers
self.lr = 0.1 # Learning rate for SGD optimizer
self.momentum = 0.9 # Momentum for SGD
self.weight_decay = 1e-4 # Weight decay for regularization
self.lr_step_size = 30 # Step size for learning rate decay
self.lr_gamma = 0.1 # Gamma factor for learning rate decay
def __repr__(self):
return str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# Force CPU usage
device = torch.device('cpu')
# Load your trained ResNet-50 model (or any custom architecture)
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
model.fc = nn.Linear(model.fc.in_features, 1000)
# model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
model.to(device) # Move model to CPU (even if you have a GPU)
checkpoint = torch.load('model.pth', map_location='cpu')
print(checkpoint.keys())
# Load the model weights
model.load_state_dict(checkpoint['model_state_dict'], strict=False, map_location=device)
# If you need to resume training, load optimizer and scheduler states
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# If you want to resume from a specific epoch
epoch = checkpoint['epoch']
# Set the model to evaluation mode (for inference)
model.eval()
# model.eval() # Set model to evaluation mode
# Define the transformation required for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the labels for ImageNet (or your specific dataset labels)
LABELS = [f"class_{k}" for k in range(1,1001)]
# Define the prediction function
def predict(image):
image = Image.open(image).convert("RGB") # Open the image and convert to RGB
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
# Move the image tensor to CPU as well
image = image.to(device)
with torch.no_grad():
outputs = model(image) # Get model predictions
_, predicted = torch.max(outputs, 1) # Get the class with highest probability
return LABELS[predicted.item()] # Return the predicted class label
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
# Launch the interface
interface.launch()