reab5555's picture
Update app.py
740fe04 verified
raw
history blame
3.28 kB
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import plotly.graph_objects as go
from transformers import ViTImageProcessor, ViTForImageClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the ViT model and image processor
model_name = "google/vit-base-patch16-384"
image_processor = ViTImageProcessor.from_pretrained(model_name)
class CustomViTModel(nn.Module):
def __init__(self, dropout_rate=0.5):
super(CustomViTModel, self).__init__()
self.vit = ViTForImageClassification.from_pretrained(model_name)
num_features = self.vit.config.hidden_size
self.vit.classifier = nn.Identity()
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(num_features, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, 1)
)
def forward(self, pixel_values):
outputs = self.vit(pixel_values)
x = outputs.logits
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1)
x = self.classifier(x)
return x.squeeze()
# Load the trained model
model = CustomViTModel()
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_384_bce.pth', map_location=device))
model.to(device)
model.eval()
def predict(image):
img = Image.fromarray(image.astype('uint8'), 'RGB')
img = img.resize((384, 384))
inputs = image_processor(images=img, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(device)
with torch.no_grad():
output = model(pixel_values)
probability = torch.sigmoid(output).item()
ms_prob = probability
non_ms_prob = 1 - probability
fig = go.Figure(data=[
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
])
fig.update_layout(
title='Prediction Probabilities',
yaxis_title='Probability (%)',
barmode='group',
yaxis=dict(range=[0, 100])
)
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
confidence = max(ms_prob, non_ms_prob) * 100
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
return result_text, fig
def load_readme():
try:
with open('README.md', 'r') as file:
return file.read()
except FileNotFoundError:
return "README.md file not found. Please make sure it exists in the same directory as this script."
with gr.Blocks() as demo:
gr.Markdown("# MS Prediction App")
with gr.Tabs():
with gr.TabItem("Prediction"):
gr.Markdown("## Upload an MRI scan image to predict MS or Non-MS patient.")
with gr.Row():
input_image = gr.Image()
output_text = gr.Textbox()
output_plot = gr.Plot()
predict_button = gr.Button("Predict")
with gr.TabItem("Description"):
readme_content = gr.Markdown(load_readme())
predict_button.click(
predict,
inputs=input_image,
outputs=[output_text, output_plot]
)
demo.launch()