File size: 1,745 Bytes
3e586d9
 
d6f9e72
3e586d9
 
d6f9e72
3e586d9
a55adf1
 
3e586d9
95716e3
a55adf1
f9b41fc
a55adf1
f9b41fc
a55adf1
f9b41fc
 
a55adf1
 
f9b41fc
a55adf1
f9b41fc
a55adf1
f9b41fc
 
 
 
 
a55adf1
f9b41fc
a55adf1
f9b41fc
a55adf1
 
 
3e586d9
a55adf1
 
 
6db0f0f
a55adf1
 
3e586d9
a55adf1
 
95716e3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import numpy as np

# Load the pre-trained model and preprocessor (feature extractor)
model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection"
model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

def classify_image(image):
    # Convert the PIL Image to a format compatible with the feature extractor
    image_pneumonia = np.array(image)
    # Preprocess the image and prepare it for the model
    inputs_pneumonia = feature_extractor(images=image, return_tensors="pt")
    # Make prediction
    with torch.no_grad():
        outputs_pneumonia = model_pneumonia(**inputs_pneumonia)
        logits_pneumonia = outputs.logits
    # Retrieve the highest probability class label index
    predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item()
    # Define a manual mapping of label indices to human-readable labels
    index_to_label_pneumonia = {
        0: "NORMAL",
        1: "PNEUMONIA"
    }

    # Convert the index to the model's class label
    label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label")

    return label_pneumonia

# Create title, description and article strings
title = "Classification Demo"
description = "XRay classification"

# Create Gradio interface
iface = gr.Interface(fn=classify_image, 
                     inputs=gr.Image(),  # Accepts image of any size
                     outputs=gr.Label(),
                     title=title,
                     description=description)

# Launch the app
iface.launch()