Spaces:
Running
Running
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
labels = { | |
0: "bluebell", | |
1: "buttercup", | |
2: "colts_foot", | |
3: "corn_poppy", | |
4: "cowslip", | |
5: "crocus", | |
6: "daffodil", | |
7: "daisy", | |
8: "dandelion", | |
9: "foxglove", | |
10: "fritillary", | |
11: "geranium", | |
12: "hibiscus", | |
13: "iris", | |
14: "lily_valley", | |
15: "pansy", | |
16: "petunia", | |
17: "rose", | |
18: "snowdrop", | |
19: "sunflower", | |
20: "tigerlily", | |
21: "tulip", | |
22: "wallflower", | |
23: "water_lily", | |
24: "wild_tulip", | |
25: "windflower" | |
} | |
# Load the trained ResNet-152 model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load model structure | |
model = models.resnet152() | |
num_classes = 26 # Update with your dataset's number of classes | |
model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
# Load trained weights | |
model.load_state_dict(torch.load('trained_model.pth', map_location=device)) | |
model = model.to(device) | |
model.eval() # Set to evaluation mode | |
# Preprocessing pipeline for incoming images | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), # ResNet default input size | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def predict_image(image_path): | |
# Load and preprocess the image | |
image = Image.open(image_path).convert("RGB") | |
input_tensor = preprocess(image).unsqueeze(0).to(device) | |
# Predict | |
with torch.no_grad(): | |
outputs = model(input_tensor) | |
_, predicted_class = torch.max(outputs, 1) | |
return predicted_class.item() # Return class index | |
import gradio as gr | |
def get_class_name(class_index): | |
return labels[class_index] | |
# Function to predict from an uploaded image | |
def classify_image(image): | |
predicted_class = predict_image(image) # Use the function from above | |
return f"Predicted Class: {predicted_class} : {get_class_name(predicted_class)}" | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(type="filepath"), # Accept image uploads | |
outputs="text", | |
title="Image Classification with ResNet-152", | |
description="Upload an image to classify it into one of 26 classes." | |
) | |
# Launch the app | |
interface.launch() | |