import torch from torchvision import models, transforms from PIL import Image labels = { 0: "bluebell", 1: "buttercup", 2: "colts_foot", 3: "corn_poppy", 4: "cowslip", 5: "crocus", 6: "daffodil", 7: "daisy", 8: "dandelion", 9: "foxglove", 10: "fritillary", 11: "geranium", 12: "hibiscus", 13: "iris", 14: "lily_valley", 15: "pansy", 16: "petunia", 17: "rose", 18: "snowdrop", 19: "sunflower", 20: "tigerlily", 21: "tulip", 22: "wallflower", 23: "water_lily", 24: "wild_tulip", 25: "windflower" } # Load the trained ResNet-152 model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model structure model = models.resnet152() num_classes = 26 # Update with your dataset's number of classes model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # Load trained weights model.load_state_dict(torch.load('trained_model.pth', map_location=device)) model = model.to(device) model.eval() # Set to evaluation mode # Preprocessing pipeline for incoming images preprocess = transforms.Compose([ transforms.Resize((224, 224)), # ResNet default input size transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def predict_image(image_path): # Load and preprocess the image image = Image.open(image_path).convert("RGB") input_tensor = preprocess(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = model(input_tensor) _, predicted_class = torch.max(outputs, 1) return predicted_class.item() # Return class index import gradio as gr def get_class_name(class_index): return labels[class_index] # Function to predict from an uploaded image def classify_image(image): predicted_class = predict_image(image) # Use the function from above return f"Predicted Class: {predicted_class} : {get_class_name(predicted_class)}" # Create Gradio interface interface = gr.Interface( fn=classify_image, inputs=gr.Image(type="filepath"), # Accept image uploads outputs="text", title="Image Classification with ResNet-152", description="Upload an image to classify it into one of 26 classes." ) # Launch the app interface.launch()