|
import streamlit as st |
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from sklearn.metrics import r2_score |
|
|
|
class RegressionModel2(nn.Module): |
|
def __init__(self, input_dim2, hidden_dim2, output_dim2): |
|
super(RegressionModel2, self).__init__() |
|
self.fc1 = nn.Linear(input_dim2, hidden_dim2) |
|
self.relu1 = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim2, output_dim2) |
|
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) |
|
|
|
def forward(self, x2): |
|
out = self.fc1(x2) |
|
out = self.relu1(out) |
|
out = self.batch_norm1(out) |
|
out = self.fc2(out) |
|
return out |
|
|
|
|
|
model = RegressionModel2(3, 32, 1) |
|
model.load_state_dict(torch.load('model.pt')) |
|
model.eval() |
|
|
|
|
|
def predict_astigmatism(age, axis, aca): |
|
""" |
|
This function takes three arguments (age, axis, aca) as input, |
|
converts them to a tensor, makes a prediction using the loaded model, |
|
and returns the predicted value. |
|
""" |
|
|
|
data = torch.tensor([[age, axis, aca]], dtype=torch.float32) |
|
|
|
|
|
with torch.no_grad(): |
|
prediction = model(data) |
|
|
|
|
|
return prediction.item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') |
|
st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) |
|
st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) |
|
st.markdown( |
|
""" |
|
<style> |
|
.navbar { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
background-color: #f2f2f2; |
|
padding: 10px; |
|
} |
|
.logo img { |
|
height: 50px; |
|
} |
|
.menu { |
|
list-style-type: none; |
|
display: flex; |
|
} |
|
.menu li { |
|
margin-left: 20px; |
|
} |
|
.text-content { |
|
margin-top: 50px; |
|
text-align: center; |
|
} |
|
.button { |
|
margin-top: 20px; |
|
padding: 10px 20px; |
|
font-size: 16px; |
|
} |
|
.error { |
|
color: red; |
|
font-weight: bold; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
if 'age' not in st.session_state: |
|
st.session_state.age = None |
|
if 'aca_magnitude' not in st.session_state: |
|
st.session_state.aca_magnitude = None |
|
if 'aca_axis' not in st.session_state: |
|
st.session_state.aca_axis = None |
|
|
|
|
|
age = st.number_input('Enter Patient Age (15-90 Years):', min_value=18.0, max_value=90.0, step=0.1, value=st.session_state.age) |
|
if age != st.session_state.age: |
|
st.session_state.age = age |
|
if age is not None and (age < 18 or age > 90): |
|
st.markdown('<p class="error">Error: Age must be between 18 and 90.</p>', unsafe_allow_html=True) |
|
|
|
|
|
aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.1, value=st.session_state.aca_magnitude) |
|
if aca_magnitude != st.session_state.aca_magnitude: |
|
st.session_state.aca_magnitude = aca_magnitude |
|
if aca_magnitude is not None and (aca_magnitude < 0 or aca_magnitude > 10): |
|
st.markdown('<p class="error">Error: ACA Magnitude must be between 0 and 10.</p>', unsafe_allow_html=True) |
|
|
|
|
|
aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1, value=st.session_state.aca_axis) |
|
if aca_axis != st.session_state.aca_axis: |
|
st.session_state.aca_axis = aca_axis |
|
if aca_axis is not None and (aca_axis < 0 or aca_axis > 180): |
|
st.markdown('<p class="error">Error: ACA Axis must be between 0 and 180.</p>', unsafe_allow_html=True) |
|
|
|
if st.button('Predict!'): |
|
if age is not None and aca_magnitude is not None and aca_axis is not None: |
|
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180: |
|
astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) |
|
st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') |
|
else: |
|
st.error('Please correct the input errors before predicting.') |
|
else: |
|
st.error('Please fill in all fields before predicting.') |
|
|
|
if __name__ == '__main__': |
|
main() |