|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import gradio as gr |
|
import plotly.graph_objects as go |
|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_name = "google/vit-base-patch16-384" |
|
image_processor = ViTImageProcessor.from_pretrained(model_name) |
|
|
|
class CustomViTModel(nn.Module): |
|
def __init__(self, dropout_rate=0.5): |
|
super(CustomViTModel, self).__init__() |
|
self.vit = ViTForImageClassification.from_pretrained(model_name) |
|
|
|
|
|
num_features = self.vit.config.hidden_size |
|
self.vit.classifier = nn.Identity() |
|
|
|
self.classifier = nn.Sequential( |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(num_features, 128), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(128, 1) |
|
) |
|
|
|
def forward(self, pixel_values): |
|
outputs = self.vit(pixel_values) |
|
x = outputs.logits |
|
x = nn.functional.adaptive_avg_pool2d(x.unsqueeze(-1).unsqueeze(-1), (1, 1)).squeeze(-1).squeeze(-1) |
|
x = self.classifier(x) |
|
return x.squeeze() |
|
|
|
|
|
model = CustomViTModel() |
|
model.load_state_dict(torch.load('final_ms_model_2classes_vit_base_224_bce.pth', map_location=device)) |
|
model.to(device) |
|
model.eval() |
|
|
|
def predict(image): |
|
|
|
img = Image.fromarray(image.astype('uint8'), 'RGB') |
|
img = img.resize((384, 384)) |
|
inputs = image_processor(images=img, return_tensors="pt") |
|
pixel_values = inputs['pixel_values'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(pixel_values) |
|
probability = torch.sigmoid(output).item() |
|
|
|
|
|
ms_prob = probability |
|
non_ms_prob = 1 - probability |
|
|
|
|
|
fig = go.Figure(data=[ |
|
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'), |
|
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red') |
|
]) |
|
|
|
fig.update_layout( |
|
title='Prediction Probabilities', |
|
yaxis_title='Probability (%)', |
|
barmode='group', |
|
yaxis=dict(range=[0, 100]) |
|
) |
|
|
|
prediction = "MS" if ms_prob > 0.5 else "Non-MS" |
|
confidence = max(ms_prob, non_ms_prob) * 100 |
|
|
|
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%" |
|
|
|
return result_text, fig |
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(), |
|
outputs=[gr.Textbox(), gr.Plot()], |
|
title="MS Prediction", |
|
description="Upload an image to predict whether it shows MS or Non-MS.", |
|
) |
|
|
|
iface.launch() |