Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs | |
""" | |
!pip install gradio | |
import gradio as gr | |
from fastai.vision.all import load_learner | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
# Model paths for all disease types | |
model_path_skin_disease = 'multi_weight.pth' # Skin Disease Model | |
model_path_brain_tumor = 'brain_tumor_model.pkl' | |
model_path_alzheimers = 'alzheimers_model.pkl' | |
model_path_eye_disease = 'eye_disease_model.pkl' | |
# Load models | |
skin_disease_model = torch.load(model_path_skin_disease) # For Skin Disease model | |
brain_tumor_model = load_learner(model_path_brain_tumor) | |
alzheimers_model = load_learner(model_path_alzheimers) | |
eye_disease_model = load_learner(model_path_eye_disease) | |
# Diagnosis Map for Skin Disease Model | |
DIAGNOSIS_MAP = { | |
0: 'Melanoma', | |
1: 'Melanocytic nevus', | |
2: 'Basal cell carcinoma', | |
3: 'Actinic keratosis', | |
4: 'Benign keratosis', | |
5: 'Dermatofibroma', | |
6: 'Vascular lesion', | |
7: 'Squamous cell carcinoma', | |
8: 'Unknown' | |
} | |
# Image Preprocessing for Skin Disease Model | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
]) | |
# Skin Disease Prediction Function | |
def predict_skin_disease(img: Image.Image): | |
img_tensor = transform(img).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = skin_disease_model(img_tensor) | |
probs = F.softmax(outputs, dim=1) | |
top_probs, top_idxs = torch.topk(probs, 3, dim=1) # top 3 predictions | |
predictions = [] | |
for prob, idx in zip(top_probs[0], top_idxs[0]): | |
label = DIAGNOSIS_MAP.get(idx.item(), "Unknown") | |
confidence = prob.item() * 100 | |
predictions.append(f"{label}: {confidence:.2f}%") | |
return "\n".join(predictions) | |
# Brain Tumor Prediction Function | |
def predict_brain_tumor(image): | |
pred, _, prob = brain_tumor_model.predict(image) | |
return f"Prediction: {pred}, Probability: {prob.max():.2f}" | |
# Alzheimer's Prediction Function | |
def predict_alzheimers(image): | |
pred, _, prob = alzheimers_model.predict(image) | |
return f"Prediction: {pred}, Probability: {prob.max():.2f}" | |
# Eye Disease Prediction Function | |
def predict_eye_disease(image): | |
pred, _, prob = eye_disease_model.predict(image) | |
return f"Prediction: {pred}, Probability: {prob.max():.2f}" | |
# Gradio Interface Function | |
def main(): | |
# Image input component | |
image_input = gr.inputs.Image(shape=(224, 224), image_mode='RGB') | |
# Dropdown to choose disease type | |
model_choice = gr.inputs.Dropdown(choices=[ | |
"Skin Disease", "Brain Tumor", "Alzheimer's Detection", "Eye Disease"], | |
label="Select Disease Type") | |
# Gradio tabs for each category | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical Image Classifier Dashboard") | |
with gr.Tab("Skin Disease Prediction"): | |
with gr.Column(): | |
gr.Markdown("Upload a skin lesion image for diagnosis prediction.") | |
image_input_skin = gr.Image(type="pil", label="Upload Skin Lesion Image") | |
output_skin = gr.Textbox(label="Prediction Results") | |
image_input_skin.change(predict_skin_disease, inputs=image_input_skin, outputs=output_skin) | |
with gr.Tab("Brain Tumor Prediction"): | |
with gr.Column(): | |
gr.Markdown("Upload a brain scan image for tumor classification.") | |
image_input_brain = gr.Image(type="pil", label="Upload Brain Scan Image") | |
output_brain = gr.Textbox(label="Prediction Results") | |
image_input_brain.change(predict_brain_tumor, inputs=image_input_brain, outputs=output_brain) | |
with gr.Tab("Alzheimer's Prediction"): | |
with gr.Column(): | |
gr.Markdown("Upload a brain image for Alzheimer's detection.") | |
image_input_alz = gr.Image(type="pil", label="Upload Alzheimer's Image") | |
output_alz = gr.Textbox(label="Prediction Results") | |
image_input_alz.change(predict_alzheimers, inputs=image_input_alz, outputs=output_alz) | |
with gr.Tab("Eye Disease Prediction"): | |
with gr.Column(): | |
gr.Markdown("Upload an image for eye disease classification.") | |
image_input_eye = gr.Image(type="pil", label="Upload Eye Disease Image") | |
output_eye = gr.Textbox(label="Prediction Results") | |
image_input_eye.change(predict_eye_disease, inputs=image_input_eye, outputs=output_eye) | |
demo.launch() | |
# Run the Gradio app | |
if __name__ == "__main__": | |
main() |