# import streamlit as st | |
# import pandas as pd | |
# import torch | |
# import torch.nn as nn | |
# import torch.optim as optim | |
# from sklearn.metrics import r2_score | |
# class RegressionModel2(nn.Module): | |
# def __init__(self, input_dim2, hidden_dim2, output_dim2): | |
# super(RegressionModel2, self).__init__() | |
# self.fc1 = nn.Linear(input_dim2, hidden_dim2) | |
# self.relu1 = nn.ReLU() | |
# self.fc2 = nn.Linear(hidden_dim2, output_dim2) | |
# self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) | |
# def forward(self, x2): | |
# out = self.fc1(x2) | |
# out = self.relu1(out) | |
# out = self.batch_norm1(out) | |
# out = self.fc2(out) | |
# return out | |
# # Load the saved model state dictionary | |
# model = RegressionModel2(3, 32, 1) | |
# model.load_state_dict(torch.load('model.pt')) | |
# model.eval() # Set the model to evaluation mode | |
# # Define a function to make predictions | |
# def predict_astigmatism(age, axis, aca): | |
# """ | |
# This function takes three arguments (age, axis, aca) as input, | |
# converts them to a tensor, makes a prediction using the loaded model, | |
# and returns the predicted value. | |
# """ | |
# # Prepare the input data | |
# data = torch.tensor([[age, axis, aca]], dtype=torch.float32) | |
# # Make prediction | |
# with torch.no_grad(): | |
# prediction = model(data) | |
# # Return the predicted value | |
# return prediction.item() | |
# def main(): | |
# st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') | |
# st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) | |
# st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) | |
# st.markdown( | |
# """ | |
# <style> | |
# .navbar { | |
# display: flex; | |
# justify-content: space-between; | |
# align-items: center; | |
# background-color: #f2f2f2; | |
# padding: 10px; | |
# } | |
# .logo img { | |
# height: 50px; | |
# } | |
# .menu { | |
# list-style-type: none; | |
# display: flex; | |
# } | |
# .menu li { | |
# margin-left: 20px; | |
# } | |
# .text-content { | |
# margin-top: 50px; | |
# text-align: center; | |
# } | |
# .button { | |
# margin-top: 20px; | |
# padding: 10px 20px; | |
# font-size: 16px; | |
# } | |
# </style> | |
# """, | |
# unsafe_allow_html=True | |
# ) | |
# # st.markdown( | |
# # """ | |
# # <body> | |
# # <header> | |
# # <nav class="navbar"> | |
# # <div class="logo"><img src="iol.png" alt="Image description"></div> | |
# # <ul class="menu"> | |
# # <li><a href="#">Home</a></li> | |
# # <li><a href="#">About</a></li> | |
# # <li><a href="#">Contact</a></li> | |
# # </ul> | |
# # </nav> | |
# # <div class="text-content"> | |
# # <h2>Enter Variables</h2> | |
# # <br> | |
# # </div> | |
# # </header> | |
# # </body> | |
# # """, | |
# # unsafe_allow_html=True | |
# # ) | |
# age = st.number_input('Enter Patient Age:', step=0.1) | |
# aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1) | |
# aca_axis = st.number_input('Enter ACA Axis:', step=0.1) | |
# if st.button('Predict!'): | |
# astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) | |
# st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') | |
# if __name__ == '__main__': | |
# main() | |
def main(): | |
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') | |
st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) | |
st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<style> | |
.navbar { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
background-color: #f2f2f2; | |
padding: 10px; | |
} | |
.logo img { | |
height: 50px; | |
} | |
.menu { | |
list-style-type: none; | |
display: flex; | |
} | |
.menu li { | |
margin-left: 20px; | |
} | |
.text-content { | |
margin-top: 50px; | |
text-align: center; | |
} | |
.button { | |
margin-top: 20px; | |
padding: 10px 20px; | |
font-size: 16px; | |
} | |
.error { | |
color: red; | |
font-weight: bold; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
age = st.number_input('Enter Patient Age:', min_value=0.0, max_value=120.0, step=0.1) | |
if age < 18 or age > 90: | |
st.markdown('<p class="error">Error: Age must be between 18 and 90.</p>', unsafe_allow_html=True) | |
aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, max_value=20.0, step=0.1) | |
if aca_magnitude < 0 or aca_magnitude > 10: | |
st.markdown('<p class="error">Error: ACA Magnitude must be between 0 and 10.</p>', unsafe_allow_html=True) | |
aca_axis = st.number_input('Enter ACA Axis:', min_value=0.0, max_value=180.0, step=0.1) | |
if aca_axis < 0 or aca_axis > 180: | |
st.markdown('<p class="error">Error: ACA Axis must be between 0 and 180.</p>', unsafe_allow_html=True) | |
if st.button('Predict!'): | |
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180: | |
astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) | |
st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') | |
else: | |
st.error('Please correct the input errors before predicting.') | |
if __name__ == '__main__': | |
main() |