import gradio as gr from transformers import ViTForImageClassification, ViTFeatureExtractor from PIL import Image import torch import numpy as np # Load the pre-trained model and preprocessor (feature extractor) model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection" model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia) feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") def classify_image(image): # Convert the PIL Image to a format compatible with the feature extractor image_pneumonia = np.array(image) # Preprocess the image and prepare it for the model inputs_pneumonia = feature_extractor(images=image, return_tensors="pt") # Make prediction with torch.no_grad(): outputs_pneumonia = model_pneumonia(**inputs_pneumonia) logits_pneumonia = outputs.logits # Retrieve the highest probability class label index predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item() # Define a manual mapping of label indices to human-readable labels index_to_label_pneumonia = { 0: "NORMAL", 1: "PNEUMONIA" } # Convert the index to the model's class label label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label") return label_pneumonia # Create title, description and article strings title = "Classification Demo" description = "XRay classification" # Create Gradio interface iface = gr.Interface(fn=classify_image, inputs=gr.Image(), # Accepts image of any size outputs=gr.Label(), title=title, description=description) # Launch the app iface.launch()