File size: 4,215 Bytes
f2740f6 6bb0701 f2740f6 aed3106 f2740f6 aed3106 f2740f6 740fe04 f2740f6 6bb0701 3c02efe f2740f6 aed3106 a2acf73 aed3106 6bb0701 3c02efe 6bb0701 bd68e16 6bb0701 aed3106 3c02efe aed3106 6bb0701 aed3106 6bb0701 aed3106 1a17e3f 3c02efe 6bb0701 3c02efe 6bb0701 aed3106 3c02efe aed3106 f2740f6 aed3106 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import plotly.graph_objects as go
from transformers import ViTImageProcessor, ViTForImageClassification
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the ViT model and image processor
model_name = "google/vit-base-patch16-384"
image_processor = ViTImageProcessor.from_pretrained(model_name)
class CustomViTModel(nn.Module):
def __init__(self, dropout_rate=0.5):
super(CustomViTModel, self).__init__()
self.vit = ViTForImageClassification.from_pretrained(model_name)
num_features = self.vit.config.hidden_size
self.vit.classifier = nn.Identity()
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, 1)
)
def forward(self, pixel_values):
outputs = self.vit(pixel_values)
x = outputs.logits
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
x = self.classifier(x)
return x.squeeze()
# Load the trained model
model = CustomViTModel()
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_384_bce.pth', map_location=device))
model.to(device)
model.eval()
def predict(image):
if image is None:
return "No image provided", None
if isinstance(image, str): # If image is a file path
img = Image.open(image).convert('RGB')
else: # If image is a numpy array
img = Image.fromarray(image.astype('uint8'), 'RGB')
img = img.resize((384, 384))
inputs = image_processor(images=img, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
with torch.no_grad():
output = model(pixel_values)
probability = torch.sigmoid(output).item()
ms_prob = probability
non_ms_prob = 1 - probability
fig = go.Figure(data=[
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
])
fig.update_layout(
title='Prediction Probabilities',
yaxis_title='Probability (%)',
barmode='group',
yaxis=dict(range=[0, 100])
)
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
confidence = max(ms_prob, non_ms_prob) * 100
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
return result_text, fig
def load_readme():
try:
with open('README_DESC.md', 'r') as file:
return file.read()
except FileNotFoundError:
return "README_DESC.md file not found. Please make sure it exists in the same directory as this script."
# Example images
example_images = [
"examples/C-A (44).png",
"examples/C-S (362).png",
"examples/MS-A (9).png",
"examples/MS-S (19).png"
]
with gr.Blocks() as demo:
gr.Markdown("# MS Prediction App")
with gr.Tabs():
with gr.TabItem("Prediction"):
gr.Markdown("## Upload an MRI scan image or use an example to predict MS or Non-MS patient.")
with gr.Row():
input_image = gr.Image(type="numpy")
predict_button = gr.Button("Predict")
output_text = gr.Textbox()
output_plot = gr.Plot()
gr.Markdown("## Or choose one of the sample images below:")
with gr.Row():
for i, img_path in enumerate(example_images):
with gr.Column():
gr.Image(img_path, show_label=False)
sample_button = gr.Button(f"Use Sample {i+1}")
sample_button.click(
lambda x=img_path: predict(x),
outputs=[output_text, output_plot]
)
with gr.TabItem("Description"):
readme_content = gr.Markdown(load_readme())
predict_button.click(
predict,
inputs=input_image,
outputs=[output_text, output_plot]
)
demo.launch() |