|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
import numpy as np |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = "google/vit-base-patch16-384" |
|
image_processor = ViTImageProcessor.from_pretrained(model_name) |
|
|
|
class CustomViTModel(nn.Module): |
|
def __init__(self, dropout_rate=0.5): |
|
super(CustomViTModel, self).__init__() |
|
self.vit = ViTForImageClassification.from_pretrained(model_name) |
|
num_features = self.vit.config.hidden_size |
|
self.vit.classifier = nn.Identity() |
|
self.classifier = nn.Sequential( |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(num_features, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(128, 1) |
|
) |
|
|
|
def forward(self, pixel_values): |
|
outputs = self.vit(pixel_values) |
|
x = outputs.logits |
|
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1) |
|
x = self.classifier(x) |
|
return x.squeeze() |
|
|
|
|
|
model = CustomViTModel() |
|
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_384_bce.pth', map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
def predict(image): |
|
if image is None: |
|
return "No image provided", None |
|
|
|
if isinstance(image, str): |
|
img = Image.open(image).convert('RGB') |
|
else: |
|
img = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
img = img.resize((384, 384)) |
|
inputs = image_processor(images=img, return_tensors="pt") |
|
pixel_values = inputs['pixel_values'].to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(pixel_values) |
|
probability = torch.sigmoid(output).item() |
|
|
|
ms_prob = probability |
|
non_ms_prob = 1 - probability |
|
|
|
fig = go.Figure(data=[ |
|
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'), |
|
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red') |
|
]) |
|
fig.update_layout( |
|
title='Prediction Probabilities', |
|
yaxis_title='Probability (%)', |
|
barmode='group', |
|
yaxis=dict(range=[0, 100]) |
|
) |
|
|
|
prediction = "MS" if ms_prob > 0.5 else "Non-MS" |
|
confidence = max(ms_prob, non_ms_prob) * 100 |
|
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%" |
|
|
|
return result_text, fig |
|
|
|
def load_readme(): |
|
try: |
|
with open('README_DESC.md', 'r') as file: |
|
return file.read() |
|
except FileNotFoundError: |
|
return "README_DESC.md file not found. Please make sure it exists in the same directory as this script." |
|
|
|
|
|
example_images = [ |
|
"examples/C-A (44).png", |
|
"examples/C-S (362).png", |
|
"examples/MS-A (9).png", |
|
"examples/MS-S (19).png" |
|
] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# MS Prediction App") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Prediction"): |
|
gr.Markdown("## Upload an MRI scan image or use an example to predict MS or Non-MS patient.") |
|
with gr.Row(): |
|
input_image = gr.Image(type="numpy") |
|
predict_button = gr.Button("Predict") |
|
output_text = gr.Textbox() |
|
output_plot = gr.Plot() |
|
|
|
gr.Markdown("## Or choose one of the sample images below:") |
|
with gr.Row(): |
|
for i, img_path in enumerate(example_images): |
|
with gr.Column(): |
|
gr.Image(img_path, show_label=False) |
|
sample_button = gr.Button(f"Use Sample {i+1}") |
|
sample_button.click( |
|
lambda x=img_path: predict(x), |
|
outputs=[output_text, output_plot] |
|
) |
|
|
|
with gr.TabItem("Description"): |
|
readme_content = gr.Markdown(load_readme()) |
|
|
|
predict_button.click( |
|
predict, |
|
inputs=input_image, |
|
outputs=[output_text, output_plot] |
|
) |
|
|
|
demo.launch() |