|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from sklearn.metrics import r2_score |
|
|
|
class RegressionModel2(nn.Module): |
|
def __init__(self, input_dim2, hidden_dim2, output_dim2): |
|
super(RegressionModel2, self).__init__() |
|
self.fc1 = nn.Linear(input_dim2, hidden_dim2) |
|
self.relu1 = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim2, output_dim2) |
|
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) |
|
|
|
def forward(self, x2): |
|
out = self.fc1(x2) |
|
out = self.relu1(out) |
|
out = self.batch_norm1(out) |
|
out = self.fc2(out) |
|
return out |
|
|
|
|
|
model = RegressionModel2(3, 32, 1) |
|
model.load_state_dict(torch.load('model.pt')) |
|
model.eval() |
|
|
|
|
|
def predict_astigmatism(age, axis, aca): |
|
""" |
|
This function takes three arguments (age, axis, aca) as input, |
|
converts them to a tensor, makes a prediction using the loaded model, |
|
and returns the predicted value. |
|
""" |
|
|
|
data = torch.tensor([[age, axis, aca]], dtype=torch.float32) |
|
|
|
|
|
with torch.no_grad(): |
|
prediction = model(data) |
|
|
|
|
|
return prediction.item() |
|
|
|
def main(): |
|
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') |
|
st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) |
|
st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) |
|
st.markdown( |
|
""" |
|
<style> |
|
.navbar { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
background-color: #f2f2f2; |
|
padding: 10px; |
|
} |
|
.logo img { |
|
height: 50px; |
|
} |
|
.menu { |
|
list-style-type: none; |
|
display: flex; |
|
} |
|
.menu li { |
|
margin-left: 20px; |
|
} |
|
.text-content { |
|
margin-top: 50px; |
|
text-align: center; |
|
} |
|
.button { |
|
margin-top: 20px; |
|
padding: 10px 20px; |
|
font-size: 16px; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
age = st.number_input('Enter Patient Age:', step=0.1) |
|
aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1) |
|
aca_axis = st.number_input('Enter ACA Axis:', step=0.1) |
|
|
|
if st.button('Predict!'): |
|
astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) |
|
st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') |
|
|
|
if __name__ == '__main__': |
|
main() |