reab5555 commited on
Commit
aed3106
·
verified ·
1 Parent(s): 5751865

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -19
app.py CHANGED
@@ -16,17 +16,14 @@ class CustomViTModel(nn.Module):
16
  def __init__(self, dropout_rate=0.5):
17
  super(CustomViTModel, self).__init__()
18
  self.vit = ViTForImageClassification.from_pretrained(model_name)
19
-
20
- # Replace the classifier
21
  num_features = self.vit.config.hidden_size
22
- self.vit.classifier = nn.Identity() # Remove the original classifier
23
-
24
  self.classifier = nn.Sequential(
25
  nn.Dropout(dropout_rate),
26
  nn.Linear(num_features, 128),
27
  nn.ReLU(),
28
  nn.Dropout(dropout_rate),
29
- nn.Linear(128, 1) # Single output for binary classification
30
  )
31
 
32
  def forward(self, pixel_values):
@@ -43,27 +40,22 @@ model.to(device)
43
  model.eval()
44
 
45
  def predict(image):
46
- # Preprocess the image
47
  img = Image.fromarray(image.astype('uint8'), 'RGB')
48
  img = img.resize((384, 384))
49
  inputs = image_processor(images=img, return_tensors="pt")
50
  pixel_values = inputs['pixel_values'].to(device)
51
 
52
- # Make prediction
53
  with torch.no_grad():
54
  output = model(pixel_values)
55
  probability = torch.sigmoid(output).item()
56
 
57
- # Prepare results
58
  ms_prob = probability
59
  non_ms_prob = 1 - probability
60
 
61
- # Create the bar chart using Plotly
62
  fig = go.Figure(data=[
63
  go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
64
  go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
65
  ])
66
-
67
  fig.update_layout(
68
  title='Prediction Probabilities',
69
  yaxis_title='Probability (%)',
@@ -73,17 +65,36 @@ def predict(image):
73
 
74
  prediction = "MS" if ms_prob > 0.5 else "Non-MS"
75
  confidence = max(ms_prob, non_ms_prob) * 100
76
-
77
  result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
78
 
79
  return result_text, fig
80
 
81
- iface = gr.Interface(
82
- fn=predict,
83
- inputs=gr.Image(),
84
- outputs=[gr.Textbox(), gr.Plot()],
85
- title="MS Prediction",
86
- description="Upload an MRI scan image to predict MS or Non-MS patient.",
87
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- iface.launch()
 
16
  def __init__(self, dropout_rate=0.5):
17
  super(CustomViTModel, self).__init__()
18
  self.vit = ViTForImageClassification.from_pretrained(model_name)
 
 
19
  num_features = self.vit.config.hidden_size
20
+ self.vit.classifier = nn.Identity()
 
21
  self.classifier = nn.Sequential(
22
  nn.Dropout(dropout_rate),
23
  nn.Linear(num_features, 128),
24
  nn.ReLU(),
25
  nn.Dropout(dropout_rate),
26
+ nn.Linear(128, 1)
27
  )
28
 
29
  def forward(self, pixel_values):
 
40
  model.eval()
41
 
42
  def predict(image):
 
43
  img = Image.fromarray(image.astype('uint8'), 'RGB')
44
  img = img.resize((384, 384))
45
  inputs = image_processor(images=img, return_tensors="pt")
46
  pixel_values = inputs['pixel_values'].to(device)
47
 
 
48
  with torch.no_grad():
49
  output = model(pixel_values)
50
  probability = torch.sigmoid(output).item()
51
 
 
52
  ms_prob = probability
53
  non_ms_prob = 1 - probability
54
 
 
55
  fig = go.Figure(data=[
56
  go.Bar(name='Non-MS', x=['Non-MS'], y=[non_ms_prob * 100], marker_color='blue'),
57
  go.Bar(name='MS', x=['MS'], y=[ms_prob * 100], marker_color='red')
58
  ])
 
59
  fig.update_layout(
60
  title='Prediction Probabilities',
61
  yaxis_title='Probability (%)',
 
65
 
66
  prediction = "MS" if ms_prob > 0.5 else "Non-MS"
67
  confidence = max(ms_prob, non_ms_prob) * 100
 
68
  result_text = f"Prediction: {prediction}\nConfidence: {confidence:.2f}%"
69
 
70
  return result_text, fig
71
 
72
+ def load_readme():
73
+ try:
74
+ with open('README.md', 'r') as file:
75
+ return file.read()
76
+ except FileNotFoundError:
77
+ return "README.md file not found. Please make sure it exists in the same directory as this script."
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("# MS Prediction App")
81
+
82
+ with gr.Tabs():
83
+ with gr.TabItem("Prediction"):
84
+ gr.Markdown("## Upload an MRI scan image to predict MS or Non-MS patient.")
85
+ with gr.Row():
86
+ input_image = gr.Image()
87
+ output_text = gr.Textbox()
88
+ output_plot = gr.Plot()
89
+ predict_button = gr.Button("Predict")
90
+
91
+ with gr.TabItem("Description"):
92
+ readme_content = gr.Markdown(load_readme())
93
+
94
+ predict_button.click(
95
+ predict,
96
+ inputs=input_image,
97
+ outputs=[output_text, output_plot]
98
+ )
99
 
100
+ demo.launch()