File size: 4,215 Bytes
f2740f6
 
 
 
 
 
 
6bb0701
f2740f6
 
 
 
 
 
 
 
 
 
 
 
aed3106
f2740f6
 
 
 
 
aed3106
f2740f6
 
 
 
 
 
 
 
 
 
 
740fe04
f2740f6
 
 
 
6bb0701
 
 
 
 
 
 
3c02efe
f2740f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aed3106
 
a2acf73
aed3106
 
6bb0701
 
3c02efe
6bb0701
bd68e16
 
 
 
6bb0701
aed3106
 
 
3c02efe
aed3106
 
6bb0701
aed3106
6bb0701
aed3106
1a17e3f
 
3c02efe
6bb0701
 
3c02efe
 
 
 
 
 
 
 
6bb0701
aed3106
 
3c02efe
aed3106
 
 
 
 
f2740f6
aed3106
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import plotly.graph_objects as go
from transformers import ViTImageProcessor, ViTForImageClassification
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the ViT model and image processor
model_name = "google/vit-base-patch16-384"
image_processor = ViTImageProcessor.from_pretrained(model_name)

class CustomViTModel(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(CustomViTModel, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained(model_name)
        num_features = self.vit.config.hidden_size
        self.vit.classifier = nn.Identity()
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 1)
        )

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values)
        x = outputs.logits
        x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
        x = self.classifier(x)
        return x.squeeze()

# Load the trained model
model = CustomViTModel()
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_384_bce.pth', map_location=device))
model.to(device)
model.eval()

def predict(image):
    if image is None:
        return "No image provided", None

    if isinstance(image, str):  # If image is a file path
        img = Image.open(image).convert('RGB')
    else:  # If image is a numpy array
        img = Image.fromarray(image.astype('uint8'), 'RGB')

    img = img.resize((384, 384))
    inputs = image_processor(images=img, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(device)

    with torch.no_grad():
        output = model(pixel_values)
        probability = torch.sigmoid(output).item()

    ms_prob = probability
    non_ms_prob = 1 - probability

    fig = go.Figure(data=[
        go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
        go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
    ])
    fig.update_layout(
        title='Prediction Probabilities',
        yaxis_title='Probability (%)',
        barmode='group',
        yaxis=dict(range=[0, 100])
    )

    prediction = "MS" if ms_prob > 0.5 else "Non-MS"
    confidence = max(ms_prob, non_ms_prob) * 100
    result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"

    return result_text, fig

def load_readme():
    try:
        with open('README_DESC.md', 'r') as file:
            return file.read()
    except FileNotFoundError:
        return "README_DESC.md file not found. Please make sure it exists in the same directory as this script."

# Example images
example_images = [
    "examples/C-A (44).png",
    "examples/C-S (362).png",
    "examples/MS-A (9).png",
    "examples/MS-S (19).png"
]

with gr.Blocks() as demo:
    gr.Markdown("# MS Prediction App")

    with gr.Tabs():
        with gr.TabItem("Prediction"):
            gr.Markdown("## Upload an MRI scan image or use an example to predict MS or Non-MS patient.")
            with gr.Row():
                input_image = gr.Image(type="numpy")
            predict_button = gr.Button("Predict")
            output_text = gr.Textbox()
            output_plot = gr.Plot()

            gr.Markdown("## Or choose one of the sample images below:")
            with gr.Row():
                for i, img_path in enumerate(example_images):
                    with gr.Column():
                        gr.Image(img_path, show_label=False)
                        sample_button = gr.Button(f"Use Sample {i+1}")
                        sample_button.click(
                            lambda x=img_path: predict(x),
                            outputs=[output_text, output_plot]
                        )

        with gr.TabItem("Description"):
            readme_content = gr.Markdown(load_readme())

    predict_button.click(
        predict,
        inputs=input_image,
        outputs=[output_text, output_plot]
    )

demo.launch()