|
import gradio as gr |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
|
|
|
|
model_name_pneumonia = "runaksh/chest_xray_pneumonia_detection" |
|
model_name_tuberculosis = "runaksh/chest_xray_tuberculosis_detection" |
|
model_pneumonia = ViTForImageClassification.from_pretrained(model_name_pneumonia) |
|
model_tuberculosis = ViTForImageClassification.from_pretrained(model_name_tuberculosis) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") |
|
|
|
def classify_image(image): |
|
|
|
image = np.array(image) |
|
|
|
inputs_pneumonia = feature_extractor(images=image, return_tensors="pt") |
|
inputs_tuberculosis = feature_extractor(images=image, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
outputs_pneumonia = model_pneumonia(**inputs_pneumonia) |
|
logits_pneumonia = outputs_pneumonia.logits |
|
outputs_tuberculosis = model_tuberculosis(**inputs_tuberculosis) |
|
logits_tuberculosis = outputs_tuberculosis.logits |
|
|
|
predicted_class_idx_pneumonia = logits_pneumonia.argmax(-1).item() |
|
predicted_class_idx_tuberculosis = logits_tuberculosis.argmax(-1).item() |
|
|
|
index_to_label_pneumonia = {0: "Pneumonia = NO",1: "Pneumonia = YES"} |
|
index_to_label_tuberculosis = {0: "Tuberculosis = NO",1: "Tuberculosis = YES"} |
|
|
|
label_pneumonia = index_to_label_pneumonia.get(predicted_class_idx_pneumonia, "Unknown Label") |
|
label_tuberculosis = index_to_label_tuberculosis.get(predicted_class_idx_tuberculosis, "Unknown Label") |
|
label = label_pneumonia+".................."+label_tuberculosis |
|
|
|
return label |
|
|
|
|
|
|
|
title = "Automated Classification of Pneumonia and Tuberculosis using Machine Learning" |
|
description = "Upload your lungs Radiograph to find out if you are having Pneumonia or Tuberculosis" |
|
|
|
css_code = ".gradio-container {background: url(https://www.bioworld.com/ext/resources/Stock-images/Therapeutic-topics/Respiratory/Respiratory-lungs-wireframe.png?1588285653); background-size: cover;}" |
|
|
|
|
|
iface = gr.Interface(fn=classify_image, |
|
inputs=gr.Image(), |
|
outputs=gr.Label(), |
|
title=title, |
|
description=description, |
|
css=css_code |
|
) |
|
|
|
|
|
iface.launch() |
|
|
|
css_code = f""" |
|
.gradio-container {{ |
|
background-image: url('{background_image_path}'); |
|
background-size: cover; /* Ensure image covers the container */ |
|
background-position: center; /* Center the image */ |
|
/* Add other styling options (e.g., padding, color) */ |
|
}} |
|
""" |
|
|