Update app.py
Browse files
app.py
CHANGED
@@ -16,17 +16,14 @@ class CustomViTModel(nn.Module):
|
|
16 |
def __init__(self, dropout_rate=0.5):
|
17 |
super(CustomViTModel, self).__init__()
|
18 |
self.vit = ViTForImageClassification.from_pretrained(model_name)
|
19 |
-
|
20 |
-
# Replace the classifier
|
21 |
num_features = self.vit.config.hidden_size
|
22 |
-
self.vit.classifier = nn.Identity()
|
23 |
-
|
24 |
self.classifier = nn.Sequential(
|
25 |
nn.Dropout(dropout_rate),
|
26 |
nn.Linear(num_features, 128),
|
27 |
nn.ReLU(),
|
28 |
nn.Dropout(dropout_rate),
|
29 |
-
nn.Linear(128, 1)
|
30 |
)
|
31 |
|
32 |
def forward(self, pixel_values):
|
@@ -43,27 +40,22 @@ model.to(device)
|
|
43 |
model.eval()
|
44 |
|
45 |
def predict(image):
|
46 |
-
# Preprocess the image
|
47 |
img = Image.fromarray(image.astype('uint8'), 'RGB')
|
48 |
img = img.resize((384, 384))
|
49 |
inputs = image_processor(images=img, return_tensors="pt")
|
50 |
pixel_values = inputs['pixel_values'].to(device)
|
51 |
|
52 |
-
# Make prediction
|
53 |
with torch.no_grad():
|
54 |
output = model(pixel_values)
|
55 |
probability = torch.sigmoid(output).item()
|
56 |
|
57 |
-
# Prepare results
|
58 |
ms_prob = probability
|
59 |
non_ms_prob = 1 - probability
|
60 |
|
61 |
-
# Create the bar chart using Plotly
|
62 |
fig = go.Figure(data=[
|
63 |
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
|
64 |
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
|
65 |
])
|
66 |
-
|
67 |
fig.update_layout(
|
68 |
title='Prediction Probabilities',
|
69 |
yaxis_title='Probability (%)',
|
@@ -73,17 +65,36 @@ def predict(image):
|
|
73 |
|
74 |
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
|
75 |
confidence = max(ms_prob, non_ms_prob) * 100
|
76 |
-
|
77 |
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
|
78 |
|
79 |
return result_text, fig
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
|
|
16 |
def __init__(self, dropout_rate=0.5):
|
17 |
super(CustomViTModel, self).__init__()
|
18 |
self.vit = ViTForImageClassification.from_pretrained(model_name)
|
|
|
|
|
19 |
num_features = self.vit.config.hidden_size
|
20 |
+
self.vit.classifier = nn.Identity()
|
|
|
21 |
self.classifier = nn.Sequential(
|
22 |
nn.Dropout(dropout_rate),
|
23 |
nn.Linear(num_features, 128),
|
24 |
nn.ReLU(),
|
25 |
nn.Dropout(dropout_rate),
|
26 |
+
nn.Linear(128, 1)
|
27 |
)
|
28 |
|
29 |
def forward(self, pixel_values):
|
|
|
40 |
model.eval()
|
41 |
|
42 |
def predict(image):
|
|
|
43 |
img = Image.fromarray(image.astype('uint8'), 'RGB')
|
44 |
img = img.resize((384, 384))
|
45 |
inputs = image_processor(images=img, return_tensors="pt")
|
46 |
pixel_values = inputs['pixel_values'].to(device)
|
47 |
|
|
|
48 |
with torch.no_grad():
|
49 |
output = model(pixel_values)
|
50 |
probability = torch.sigmoid(output).item()
|
51 |
|
|
|
52 |
ms_prob = probability
|
53 |
non_ms_prob = 1 - probability
|
54 |
|
|
|
55 |
fig = go.Figure(data=[
|
56 |
go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
|
57 |
go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
|
58 |
])
|
|
|
59 |
fig.update_layout(
|
60 |
title='Prediction Probabilities',
|
61 |
yaxis_title='Probability (%)',
|
|
|
65 |
|
66 |
prediction = "MS" if ms_prob > 0.5 else "Non-MS"
|
67 |
confidence = max(ms_prob, non_ms_prob) * 100
|
|
|
68 |
result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
|
69 |
|
70 |
return result_text, fig
|
71 |
|
72 |
+
def load_readme():
|
73 |
+
try:
|
74 |
+
with open('README.md', 'r') as file:
|
75 |
+
return file.read()
|
76 |
+
except FileNotFoundError:
|
77 |
+
return "README.md file not found. Please make sure it exists in the same directory as this script."
|
78 |
+
|
79 |
+
with gr.Blocks() as demo:
|
80 |
+
gr.Markdown("# MS Prediction App")
|
81 |
+
|
82 |
+
with gr.Tabs():
|
83 |
+
with gr.TabItem("Prediction"):
|
84 |
+
gr.Markdown("## Upload an MRI scan image to predict MS or Non-MS patient.")
|
85 |
+
with gr.Row():
|
86 |
+
input_image = gr.Image()
|
87 |
+
output_text = gr.Textbox()
|
88 |
+
output_plot = gr.Plot()
|
89 |
+
predict_button = gr.Button("Predict")
|
90 |
+
|
91 |
+
with gr.TabItem("Description"):
|
92 |
+
readme_content = gr.Markdown(load_readme())
|
93 |
+
|
94 |
+
predict_button.click(
|
95 |
+
predict,
|
96 |
+
inputs=input_image,
|
97 |
+
outputs=[output_text, output_plot]
|
98 |
+
)
|
99 |
|
100 |
+
demo.launch()
|