File size: 4,801 Bytes
2e6aaf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-
"""app.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1sjyLFLqBccpUzaUi4eyyP3NYE3gDtHfs
"""

import gradio as gr
from fastai.vision.all import load_learner
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

# Model paths for all disease types
model_path_skin_disease = 'multi_weight.pth'  # Skin Disease Model
model_path_brain_tumor = 'brain_tumor_model.pkl'
model_path_alzheimers = 'alzheimers_model.pkl'
model_path_eye_disease = 'eye_disease_model.pkl'

# Load models
skin_disease_model = torch.load(model_path_skin_disease)  # For Skin Disease model
brain_tumor_model = load_learner(model_path_brain_tumor)
alzheimers_model = load_learner(model_path_alzheimers)
eye_disease_model = load_learner(model_path_eye_disease)

# Diagnosis Map for Skin Disease Model
DIAGNOSIS_MAP = {
    0: 'Melanoma',
    1: 'Melanocytic nevus',
    2: 'Basal cell carcinoma',
    3: 'Actinic keratosis',
    4: 'Benign keratosis',
    5: 'Dermatofibroma',
    6: 'Vascular lesion',
    7: 'Squamous cell carcinoma',
    8: 'Unknown'
}

# Image Preprocessing for Skin Disease Model
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Skin Disease Prediction Function
def predict_skin_disease(img: Image.Image):
    img_tensor = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = skin_disease_model(img_tensor)
        probs = F.softmax(outputs, dim=1)
        top_probs, top_idxs = torch.topk(probs, 3, dim=1)  # top 3 predictions

    predictions = []
    for prob, idx in zip(top_probs[0], top_idxs[0]):
        label = DIAGNOSIS_MAP.get(idx.item(), "Unknown")
        confidence = prob.item() * 100
        predictions.append(f"{label}: {confidence:.2f}%")

    return "\n".join(predictions)

# Brain Tumor Prediction Function
def predict_brain_tumor(image):
    pred, _, prob = brain_tumor_model.predict(image)
    return f"Prediction: {pred}, Probability: {prob.max():.2f}"

# Alzheimer's Prediction Function
def predict_alzheimers(image):
    pred, _, prob = alzheimers_model.predict(image)
    return f"Prediction: {pred}, Probability: {prob.max():.2f}"

# Eye Disease Prediction Function
def predict_eye_disease(image):
    pred, _, prob = eye_disease_model.predict(image)
    return f"Prediction: {pred}, Probability: {prob.max():.2f}"

# Gradio Interface Function
def main():
    # Image input component
    image_input = gr.inputs.Image(shape=(224, 224), image_mode='RGB')

    # Dropdown to choose disease type
    model_choice = gr.inputs.Dropdown(choices=[
        "Skin Disease", "Brain Tumor", "Alzheimer's Detection", "Eye Disease"],
        label="Select Disease Type")

    # Gradio tabs for each category
    with gr.Blocks() as demo:
        gr.Markdown("# Medical Image Classifier Dashboard")

        with gr.Tab("Skin Disease Prediction"):
            with gr.Column():
                gr.Markdown("Upload a skin lesion image for diagnosis prediction.")
                image_input_skin = gr.Image(type="pil", label="Upload Skin Lesion Image")
                output_skin = gr.Textbox(label="Prediction Results")
                image_input_skin.change(predict_skin_disease, inputs=image_input_skin, outputs=output_skin)

        with gr.Tab("Brain Tumor Prediction"):
            with gr.Column():
                gr.Markdown("Upload a brain scan image for tumor classification.")
                image_input_brain = gr.Image(type="pil", label="Upload Brain Scan Image")
                output_brain = gr.Textbox(label="Prediction Results")
                image_input_brain.change(predict_brain_tumor, inputs=image_input_brain, outputs=output_brain)

        with gr.Tab("Alzheimer's Prediction"):
            with gr.Column():
                gr.Markdown("Upload a brain image for Alzheimer's detection.")
                image_input_alz = gr.Image(type="pil", label="Upload Alzheimer's Image")
                output_alz = gr.Textbox(label="Prediction Results")
                image_input_alz.change(predict_alzheimers, inputs=image_input_alz, outputs=output_alz)

        with gr.Tab("Eye Disease Prediction"):
            with gr.Column():
                gr.Markdown("Upload an image for eye disease classification.")
                image_input_eye = gr.Image(type="pil", label="Upload Eye Disease Image")
                output_eye = gr.Textbox(label="Prediction Results")
                image_input_eye.change(predict_eye_disease, inputs=image_input_eye, outputs=output_eye)

    demo.launch()

# Run the Gradio app
if __name__ == "__main__":
    main()