MEDIGPT / app.py
FawadHaider2's picture
Upload 16 files
2e6aaf3 verified
raw
history blame
4.82 kB
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs
"""
!pip install gradio
import gradio as gr
from fastai.vision.all import load_learner
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
# Model paths for all disease types
model_path_skin_disease = 'multi_weight.pth' # Skin Disease Model
model_path_brain_tumor = 'brain_tumor_model.pkl'
model_path_alzheimers = 'alzheimers_model.pkl'
model_path_eye_disease = 'eye_disease_model.pkl'
# Load models
skin_disease_model = torch.load(model_path_skin_disease) # For Skin Disease model
brain_tumor_model = load_learner(model_path_brain_tumor)
alzheimers_model = load_learner(model_path_alzheimers)
eye_disease_model = load_learner(model_path_eye_disease)
# Diagnosis Map for Skin Disease Model
DIAGNOSIS_MAP = {
0: 'Melanoma',
1: 'Melanocytic nevus',
2: 'Basal cell carcinoma',
3: 'Actinic keratosis',
4: 'Benign keratosis',
5: 'Dermatofibroma',
6: 'Vascular lesion',
7: 'Squamous cell carcinoma',
8: 'Unknown'
}
# Image Preprocessing for Skin Disease Model
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Skin Disease Prediction Function
def predict_skin_disease(img: Image.Image):
img_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = skin_disease_model(img_tensor)
probs = F.softmax(outputs, dim=1)
top_probs, top_idxs = torch.topk(probs, 3, dim=1) # top 3 predictions
predictions = []
for prob, idx in zip(top_probs[0], top_idxs[0]):
label = DIAGNOSIS_MAP.get(idx.item(), "Unknown")
confidence = prob.item() * 100
predictions.append(f"{label}: {confidence:.2f}%")
return "\n".join(predictions)
# Brain Tumor Prediction Function
def predict_brain_tumor(image):
pred, _, prob = brain_tumor_model.predict(image)
return f"Prediction: {pred}, Probability: {prob.max():.2f}"
# Alzheimer's Prediction Function
def predict_alzheimers(image):
pred, _, prob = alzheimers_model.predict(image)
return f"Prediction: {pred}, Probability: {prob.max():.2f}"
# Eye Disease Prediction Function
def predict_eye_disease(image):
pred, _, prob = eye_disease_model.predict(image)
return f"Prediction: {pred}, Probability: {prob.max():.2f}"
# Gradio Interface Function
def main():
# Image input component
image_input = gr.inputs.Image(shape=(224, 224), image_mode='RGB')
# Dropdown to choose disease type
model_choice = gr.inputs.Dropdown(choices=[
"Skin Disease", "Brain Tumor", "Alzheimer's Detection", "Eye Disease"],
label="Select Disease Type")
# Gradio tabs for each category
with gr.Blocks() as demo:
gr.Markdown("# Medical Image Classifier Dashboard")
with gr.Tab("Skin Disease Prediction"):
with gr.Column():
gr.Markdown("Upload a skin lesion image for diagnosis prediction.")
image_input_skin = gr.Image(type="pil", label="Upload Skin Lesion Image")
output_skin = gr.Textbox(label="Prediction Results")
image_input_skin.change(predict_skin_disease, inputs=image_input_skin, outputs=output_skin)
with gr.Tab("Brain Tumor Prediction"):
with gr.Column():
gr.Markdown("Upload a brain scan image for tumor classification.")
image_input_brain = gr.Image(type="pil", label="Upload Brain Scan Image")
output_brain = gr.Textbox(label="Prediction Results")
image_input_brain.change(predict_brain_tumor, inputs=image_input_brain, outputs=output_brain)
with gr.Tab("Alzheimer's Prediction"):
with gr.Column():
gr.Markdown("Upload a brain image for Alzheimer's detection.")
image_input_alz = gr.Image(type="pil", label="Upload Alzheimer's Image")
output_alz = gr.Textbox(label="Prediction Results")
image_input_alz.change(predict_alzheimers, inputs=image_input_alz, outputs=output_alz)
with gr.Tab("Eye Disease Prediction"):
with gr.Column():
gr.Markdown("Upload an image for eye disease classification.")
image_input_eye = gr.Image(type="pil", label="Upload Eye Disease Image")
output_eye = gr.Textbox(label="Prediction Results")
image_input_eye.change(predict_eye_disease, inputs=image_input_eye, outputs=output_eye)
demo.launch()
# Run the Gradio app
if __name__ == "__main__":
main()