import streamlit as st import pandas as pd import torch import torch.nn as nn import torch.optim as optim from sklearn.metrics import r2_score class RegressionModel2(nn.Module): def __init__(self, input_dim2, hidden_dim2, output_dim2): super(RegressionModel2, self).__init__() self.fc1 = nn.Linear(input_dim2, hidden_dim2) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim2, output_dim2) self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out # Load the saved model state dictionary model = RegressionModel2(3, 32, 1) model.load_state_dict(torch.load('model.pt')) model.eval() # Set the model to evaluation mode # Define a function to make predictions def predict_astigmatism(age, axis, aca): """ This function takes three arguments (age, axis, aca) as input, converts them to a tensor, makes a prediction using the loaded model, and returns the predicted value. """ # Prepare the input data data = torch.tensor([[age, axis, aca]], dtype=torch.float32) # Make prediction with torch.no_grad(): prediction = model(data) # Return the predicted value return prediction.item() # def main(): # st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') # st.write('', unsafe_allow_html=True) # st.write("""""", unsafe_allow_html=True) # st.markdown( # """ # # """, # unsafe_allow_html=True # ) # # st.markdown( # # """ # # # #
# # # #
# #

Enter Variables

# #
# #
# #
# # # # """, # # unsafe_allow_html=True # # ) # age = st.number_input('Enter Patient Age:', step=0.1) # aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1) # aca_axis = st.number_input('Enter ACA Axis:', step=0.1) # if st.button('Predict!'): # astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) # st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') # if __name__ == '__main__': # main() def main(): st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') st.write('', unsafe_allow_html=True) st.write("""""", unsafe_allow_html=True) st.markdown( """ """, unsafe_allow_html=True ) # Use session state to store input values if 'age' not in st.session_state: st.session_state.age = None if 'aca_magnitude' not in st.session_state: st.session_state.aca_magnitude = None if 'aca_axis' not in st.session_state: st.session_state.aca_axis = None # Age input age = st.number_input('Enter Patient Age (15-90 Years):', min_value=18.0, max_value=90.0, step=0.1, value=st.session_state.age) if age != st.session_state.age: st.session_state.age = age if age is not None and (age < 18 or age > 90): st.markdown('

Error: Age must be between 18 and 90.

', unsafe_allow_html=True) # ACA Magnitude input aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', min_value=0.0, max_value=10.0, step=0.1, value=st.session_state.aca_magnitude) if aca_magnitude != st.session_state.aca_magnitude: st.session_state.aca_magnitude = aca_magnitude if aca_magnitude is not None and (aca_magnitude < 0 or aca_magnitude > 10): st.markdown('

Error: ACA Magnitude must be between 0 and 10.

', unsafe_allow_html=True) # ACA Axis input aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', min_value=0.0, max_value=180.0, step=0.1, value=st.session_state.aca_axis) if aca_axis != st.session_state.aca_axis: st.session_state.aca_axis = aca_axis if aca_axis is not None and (aca_axis < 0 or aca_axis > 180): st.markdown('

Error: ACA Axis must be between 0 and 180.

', unsafe_allow_html=True) if st.button('Predict!'): if age is not None and aca_magnitude is not None and aca_axis is not None: if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180: astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') else: st.error('Please correct the input errors before predicting.') else: st.error('Please fill in all fields before predicting.') if __name__ == '__main__': main()