runaksh's picture
Update app.py
6db0f0f verified
raw
history blame
1.75 kB
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np
# Load the pre-trained model and preprocessor (feature extractor)
model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
def classify_image(image):
# Convert the PIL Image to a format compatible with the feature extractor
image_pneumonia = np.array(image)
# Preprocess the image and prepare it for the model
inputs_pneumonia = feature_extractor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs_pneumonia = model_pneumonia(**inputs_pneumonia)
logits_pneumonia = outputs.logits
# Retrieve the highest probability class label index
predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
# Define a manual mapping of label indices to human-readable labels
index_to_label_pneumonia = {
0: "NORMAL",
1: "PNEUMONIA"
}
# Convert the index to the model's class label
label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label")
return label_pneumonia
# Create title, description and article strings
title = "Classification Demo"
description = "XRay classification"
# Create Gradio interface
iface = gr.Interface(fn=classify_image,
inputs=gr.Image(), # Accepts image of any size
outputs=gr.Label(),
title=title,
description=description)
# Launch the app
iface.launch()