anjikum's picture
Adjusting CPU for inferencing
a90f86d verified
raw
history blame
1.7 kB
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import gradio as gr
# Force CPU usage
device = torch.device('cpu')
# Load your trained ResNet-50 model
model = models.resnet50(pretrained=False) # Load the ResNet-50 architecture
model.load_state_dict(torch.load("model.pth", map_location=device)) # Load the trained weights (.pth)
model.to(device) # Move model to CPU (even if you have a GPU)
model.eval() # Set model to evaluation mode
# Define the transformation required for the input image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the labels for ImageNet (or your specific dataset labels)
LABELS = ["class_1", "class_2", "class_3", "class_4", "class_5", # Replace with your classes
"class_6", "class_7", "class_8", "class_9", "class_10"]
# Define the prediction function
def predict(image):
image = Image.open(image).convert("RGB") # Open the image and convert to RGB
image = transform(image).unsqueeze(0) # Apply transformations and add batch dimension
# Move the image tensor to CPU as well
image = image.to(device)
with torch.no_grad():
outputs = model(image) # Get model predictions
_, predicted = torch.max(outputs, 1) # Get the class with highest probability
return LABELS[predicted.item()] # Return the predicted class label
# Set up the Gradio interface
interface = gr.Interface(fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="text")
# Launch the interface
interface.launch()