File size: 1,745 Bytes
3e586d9 d6f9e72 3e586d9 d6f9e72 3e586d9 a55adf1 3e586d9 95716e3 a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 f9b41fc a55adf1 3e586d9 a55adf1 6db0f0f a55adf1 3e586d9 a55adf1 95716e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np
# Load the pre-trained model and preprocessor (feature extractor)
model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
def classify_image(image):
# Convert the PIL Image to a format compatible with the feature extractor
image_pneumonia = np.array(image)
# Preprocess the image and prepare it for the model
inputs_pneumonia = feature_extractor(images=image, return_tensors="pt")
# Make prediction
with torch.no_grad():
outputs_pneumonia = model_pneumonia(**inputs_pneumonia)
logits_pneumonia = outputs.logits
# Retrieve the highest probability class label index
predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
# Define a manual mapping of label indices to human-readable labels
index_to_label_pneumonia = {
0: "NORMAL",
1: "PNEUMONIA"
}
# Convert the index to the model's class label
label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label")
return label_pneumonia
# Create title, description and article strings
title = "Classification Demo"
description = "XRay classification"
# Create Gradio interface
iface = gr.Interface(fn=classify_image,
inputs=gr.Image(), # Accepts image of any size
outputs=gr.Label(),
title=title,
description=description)
# Launch the app
iface.launch()
|