Jfink09 commited on
Commit
e9fb974
·
verified ·
1 Parent(s): dccf1b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -38
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import streamlit as st
 
2
  import torch
3
  import torch.nn as nn
 
 
4
 
5
- # Define your model architecture here (same as before)
6
  class RegressionModel2(nn.Module):
7
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
8
  super(RegressionModel2, self).__init__()
9
  self.fc1 = nn.Linear(input_dim2, hidden_dim2)
10
- self.relu1 = nn.ReLU() # ReLU activation function
11
  self.fc2 = nn.Linear(hidden_dim2, output_dim2)
12
- self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) # Batch normalization
13
 
14
  def forward(self, x2):
15
  out = self.fc1(x2)
@@ -18,39 +20,81 @@ class RegressionModel2(nn.Module):
18
  out = self.fc2(out)
19
  return out
20
 
21
- input_dim2 = X2_train.shape[1] # change to [1] for pentacam dataset X_train.shape[1]
22
- hidden_dim2 = 32 # Was 16
23
- output_dim2 = 1
24
-
25
- # Load your saved model state dictionary (assuming 'model.pt' is uploaded)
26
- model2 = RegressionModel2(input_dim2, hidden_dim2, output_dim2)
27
  model2.load_state_dict(torch.load('model.pt'))
28
- model2.eval() # Set the model to evaluation mode
29
-
30
- def predict(age, aca, axis):
31
- """
32
- This function takes three arguments (age, axis, aca) as input,
33
- prepares the data, makes a prediction using the loaded model,
34
- and returns the predicted value.
35
- """
36
- # Prepare the input data
37
- data = torch.tensor([[age, aca, axis]], dtype=torch.float32)
38
-
39
- # Make prediction
40
- with torch.no_grad():
41
- prediction = model2(data)
42
-
43
- # Return the predicted value
44
- return prediction.item()
45
-
46
- # Streamlit App
47
- st.title("Astigmatism Prediction App")
48
- st.write("Enter the patient's information:")
49
-
50
- age = st.number_input("Age", min_value=0)
51
- aca = st.number_input("ACA Magnitude", min_value=0)
52
- axis = st.number_input("ACA Axis", min_value=0)
53
-
54
- if st.button("Predict"):
55
- predicted_value = predict(age, aca, axis)
56
- st.write(f"Predicted Astigmatism Value: {predicted_value}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  import torch
4
  import torch.nn as nn
5
+ import torch.optim as optim
6
+ from sklearn.metrics import r2_score
7
 
 
8
  class RegressionModel2(nn.Module):
9
  def __init__(self, input_dim2, hidden_dim2, output_dim2):
10
  super(RegressionModel2, self).__init__()
11
  self.fc1 = nn.Linear(input_dim2, hidden_dim2)
12
+ self.relu1 = nn.ReLU()
13
  self.fc2 = nn.Linear(hidden_dim2, output_dim2)
14
+ self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
15
 
16
  def forward(self, x2):
17
  out = self.fc1(x2)
 
20
  out = self.fc2(out)
21
  return out
22
 
23
+ # Load the trained model
24
+ model2 = RegressionModel2(3, 32, 1)
 
 
 
 
25
  model2.load_state_dict(torch.load('model.pt'))
26
+ model2.eval()
27
+
28
+ def predict_astigmatism(age, aca_magnitude, aca_axis):
29
+ input_data = torch.tensor([[age, aca_magnitude, aca_axis]], dtype=torch.float32)
30
+ output = model2(input_data)
31
+ return output.item()
32
+
33
+ def main():
34
+ st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
35
+ st.markdown(
36
+ """
37
+ <style>
38
+ .navbar {
39
+ display: flex;
40
+ justify-content: space-between;
41
+ align-items: center;
42
+ background-color: #f2f2f2;
43
+ padding: 10px;
44
+ }
45
+ .logo img {
46
+ height: 50px;
47
+ }
48
+ .menu {
49
+ list-style-type: none;
50
+ display: flex;
51
+ }
52
+ .menu li {
53
+ margin-left: 20px;
54
+ }
55
+ .text-content {
56
+ margin-top: 50px;
57
+ text-align: center;
58
+ }
59
+ .button {
60
+ margin-top: 20px;
61
+ padding: 10px 20px;
62
+ font-size: 16px;
63
+ }
64
+ </style>
65
+ """,
66
+ unsafe_allow_html=True
67
+ )
68
+
69
+ st.markdown(
70
+ """
71
+ <body>
72
+ <header>
73
+ <nav class="navbar">
74
+ <div class="logo"><img src="iol.png" alt="Image description"></div>
75
+ <ul class="menu">
76
+ <li><a href="#">Home</a></li>
77
+ <li><a href="#">About</a></li>
78
+ <li><a href="#">Contact</a></li>
79
+ </ul>
80
+ </nav>
81
+ <div class="text-content">
82
+ <h2>Enter Variables</h2>
83
+ <br>
84
+ </div>
85
+ </header>
86
+ </body>
87
+ """,
88
+ unsafe_allow_html=True
89
+ )
90
+
91
+ age = st.number_input('Enter Patient Age:', min_value=0, step=1)
92
+ aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, step=0.1)
93
+ aca_axis = st.number_input('Enter ACA Axis:', min_value=0, max_value=180, step=1)
94
+
95
+ if st.button('Predict!'):
96
+ astigmatism = predict_astigmatism(age, aca_magnitude, aca_axis)
97
+ st.success(f'Predicted Astigmatism: {astigmatism:.4f}')
98
+
99
+ if __name__ == '__main__':
100
+ main()