anjikum's picture
added params to app
4ec6f83 verified
raw
history blame
2.47 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import gradio as gr
# Define custom class 'Params' if used
class Params:
def __init__(self):
self.batch_size = 128 # You can adjust this based on your GPU memory
self.name = "resnet_50_sgd" # Rename to reflect ResNet-50
self.workers = 4 # Number of DataLoader workers
self.lr = 0.1 # Learning rate for SGD optimizer
self.momentum = 0.9 # Momentum for SGD
self.weight_decay = 1e-4 # Weight decay for regularization
self.lr_step_size = 30 # Step size for learning rate decay
self.lr_gamma = 0.1 # Gamma factor for learning rate decay
def __repr__(self):
return str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# Force CPU usage
device = torch.device('cpu')
# Load your trained ResNet-50 model (or any custom architecture)
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
model.to(device) # Move model to CPU (even if you have a GPU)
model.eval() # Set model to evaluation mode
# Define the transformation required for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the labels for ImageNet (or your specific dataset labels)
LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
"class_6", "class_7", "class_8", "class_9", "class_10"]
# Define the prediction function
def predict(image):
image = Image.open(image).convert("RGB") # Open the image and convert to RGB
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
# Move the image tensor to CPU as well
image = image.to(device)
with torch.no_grad():
outputs = model(image) # Get model predictions
_, predicted = torch.max(outputs, 1) # Get the class with highest probability
return LABELS[predicted.item()] # Return the predicted class label
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
# Launch the interface
interface.launch()