|
import gradio as gr |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
|
|
|
|
model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection" |
|
model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") |
|
|
|
def classify_image(image): |
|
|
|
image_pneumonia = np.array(image) |
|
|
|
inputs_pneumonia = feature_extractor(images=image, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs_pneumonia = model_pneumonia(**inputs_pneumonia) |
|
logits_pneumonia = outputs.logits |
|
|
|
predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item() |
|
|
|
index_to_label_pneumonia = { |
|
0: "NORMAL", |
|
1: "PNEUMONIA" |
|
} |
|
|
|
|
|
label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label") |
|
|
|
return label_pneumonia |
|
|
|
|
|
title = "Classification Demo" |
|
description = "XRay classification" |
|
|
|
|
|
iface = gr.Interface(fn=classify_image, |
|
inputs=gr.Image(), |
|
outputs=gr.Label(), |
|
title=title, |
|
description=description) |
|
|
|
|
|
iface.launch() |
|
|
|
|