import streamlit as st import torch import torch.nn as nn import math # Set page config at the very beginning st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') # Custom CSS to set background color to #000 for main app and navigation st.markdown(""" """, unsafe_allow_html=True) class RegressionModel2(nn.Module): def __init__(self, input_dim2, hidden_dim2, output_dim2): super(RegressionModel2, self).__init__() self.fc1 = nn.Linear(input_dim2, hidden_dim2) self.relu1 = nn.ReLU() self.fc2 = nn.Linear(hidden_dim2, output_dim2) self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) def forward(self, x2): out = self.fc1(x2) out = self.relu1(out) out = self.batch_norm1(out) out = self.fc2(out) return out @st.cache_resource def load_models(): model_j0 = RegressionModel2(3, 32, 1) model_j0.load_state_dict(torch.load('j0_model-2.pt')) model_j0.eval() model_j45 = RegressionModel2(3, 32, 1) model_j45.load_state_dict(torch.load('j45_model-2.pt')) model_j45.eval() return model_j0, model_j45 model_j0, model_j45 = load_models() def calculate_initial_j0_j45(magnitude, axis_deg): axis_rad = math.radians(axis_deg) j0 = magnitude * math.cos(2 * axis_rad) j45 = magnitude * math.sin(2 * axis_rad) return j0, j45 def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg): initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg) input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32) input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32) with torch.no_grad(): new_j0 = model_j0(input_data_j0).item() new_j45 = model_j45(input_data_j45).item() return new_j0, new_j45 def main(): #st.title('Total Corneal Astigmatism Prediction') # Initialize session state for input values if not already present if 'age' not in st.session_state: st.session_state.age = None if 'aca_magnitude' not in st.session_state: st.session_state.aca_magnitude = None if 'aca_axis' not in st.session_state: st.session_state.aca_axis = None # Input fields using session state st.markdown('

Enter Patient Age (18-90 Years):

', unsafe_allow_html=True) #age = st.number_input('Enter Patient Age (18-90 Years):', age = st.number_input('', min_value=18.0, max_value=90.0, value=st.session_state.age if st.session_state.age is not None else None, step=0.1, key='age') st.markdown('

Enter ACA Magnitude (0-10 Diopters):

', unsafe_allow_html=True) #aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):', aca_magnitude = st.number_input('', min_value=0.0, max_value=10.0, value=st.session_state.aca_magnitude if st.session_state.aca_magnitude is not None else None, step=0.01, key='aca_magnitude') st.markdown('

Enter ACA Axis (0-180 Degrees):

', unsafe_allow_html=True) #aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):', aca_axis = st.number_input('', min_value=0.0, max_value=180.0, value=st.session_state.aca_axis if st.session_state.aca_axis is not None else None, step=0.1, key='aca_axis') if st.button('Predict!'): if age is not None and aca_magnitude is not None and aca_axis is not None: if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180: # Calculate initial J0 and J45 initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis) # Predict new J0 and J45 using the models new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis) # Calculate predicted magnitude and axis predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2) predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0)) if predicted_axis < 0: predicted_axis += 180 # Display results in green success boxes st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D') st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°') # Display intermediate values for verification #st.info(f''' #Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}° #Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f} #Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f} #''') # Additional debugging information (optional) #if st.checkbox('Show detailed model inputs and outputs'): #st.subheader('Debugging Information:') #st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}") #st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}") #st.write(f"Raw J0 output: {new_j0}") #st.write(f"Raw J45 output: {new_j45}") else: st.error('Please ensure all inputs are within the specified ranges.') else: st.error('Please fill in all input fields before predicting.') if __name__ == '__main__': main()