import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from PIL import Image | |
import gradio as gr | |
# Define custom class 'Params' if used | |
class Params: | |
def __init__(self): | |
self.batch_size = 128 # You can adjust this based on your GPU memory | |
self.name = "resnet_50_sgd" # Rename to reflect ResNet-50 | |
self.workers = 4 # Number of DataLoader workers | |
self.lr = 0.1 # Learning rate for SGD optimizer | |
self.momentum = 0.9 # Momentum for SGD | |
self.weight_decay = 1e-4 # Weight decay for regularization | |
self.lr_step_size = 30 # Step size for learning rate decay | |
self.lr_gamma = 0.1 # Gamma factor for learning rate decay | |
def __repr__(self): | |
return str(self.__dict__) | |
def __eq__(self, other): | |
return self.__dict__ == other.__dict__ | |
# Force CPU usage | |
device = torch.device('cpu') | |
# Load your trained ResNet-50 model (or any custom architecture) | |
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture | |
model.fc = nn.Linear(model.fc.in_features, 1000) | |
# model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth) | |
model.to(device) # Move model to CPU (even if you have a GPU) | |
checkpoint = torch.load('model.pth', map_location='cpu') | |
print(checkpoint.keys()) | |
# Load the model weights | |
model.load_state_dict(checkpoint['model'], strict=False) | |
# If you need to resume training, load optimizer and scheduler states | |
# optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
# scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
# If you want to resume from a specific epoch | |
epoch = checkpoint['epoch'] | |
# Set the model to evaluation mode (for inference) | |
model.eval() | |
# model.eval() # Set model to evaluation mode | |
# Define the transformation required for the input image | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Define the labels for ImageNet (or your specific dataset labels) | |
LABELS = [f"class_{k}" for k in range(1,1001)] | |
# Define the prediction function | |
def predict(image): | |
image = Image.open(image).convert("RGB") # Open the image and convert to RGB | |
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension | |
# Move the image tensor to CPU as well | |
image = image.to(device) | |
with torch.no_grad(): | |
outputs = model(image) # Get model predictions | |
_, predicted = torch.max(outputs, 1) # Get the class with highest probability | |
return LABELS[predicted.item()] # Return the predicted class label | |
# Set up the Gradio interface | |
interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text") | |
# interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text") | |
# Launch the interface | |
interface.launch() | |