Jfink09's picture
Update app.py
d4ced26 verified
import streamlit as st
import torch
import torch.nn as nn
import math
# Set page config at the very beginning
st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
# Custom CSS to set background color to #000 for main app and navigation
st.markdown("""
<style>
.stApp {
background-color: #000;
}
.stDeployButton {
display: none !important;
}
header[data-testid="stHeader"] {
background-color: #000;
}
.stDecoration {
background-color: #000 !important;
}
.stToolbar {
background-color: #000 !important;
}
#MainMenu {
background-color: #000 !important;
}
div[data-testid="stToolbar"] {
background-color: #000 !important;
}
button[kind="headerNoPadding"], button[data-testid="baseButton-headerNoPadding"], button[aria-haspopup="menu"] {
background-color: transparent !important;
}
.stApp > header {
background-color: transparent !important;
}
.stTextInput > div > div > input {
background-color: #333 !important;
color: white !important;
}
.stNumberInput > div > div > input {
background-color: #333 !important;
color: white !important;
}
.stTextInput > label, .stNumberInput > label {
color: white !important;
font-size: 24px !important;
}
[data-testid="stNumberInput"] {
margin-top: -15px;
}
</style>
""", unsafe_allow_html=True)
class RegressionModel2(nn.Module):
def __init__(self, input_dim2, hidden_dim2, output_dim2):
super(RegressionModel2, self).__init__()
self.fc1 = nn.Linear(input_dim2, hidden_dim2)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim2, output_dim2)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x2):
out = self.fc1(x2)
out = self.relu1(out)
out = self.batch_norm1(out)
out = self.fc2(out)
return out
@st.cache_resource
def load_models():
model_j0 = RegressionModel2(3, 32, 1)
model_j0.load_state_dict(torch.load('j0_model-2.pt'))
model_j0.eval()
model_j45 = RegressionModel2(3, 32, 1)
model_j45.load_state_dict(torch.load('j45_model-2.pt'))
model_j45.eval()
return model_j0, model_j45
model_j0, model_j45 = load_models()
def calculate_initial_j0_j45(magnitude, axis_deg):
axis_rad = math.radians(axis_deg)
j0 = magnitude * math.cos(2 * axis_rad)
j45 = magnitude * math.sin(2 * axis_rad)
return j0, j45
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
with torch.no_grad():
new_j0 = model_j0(input_data_j0).item()
new_j45 = model_j45(input_data_j45).item()
return new_j0, new_j45
def main():
#st.title('Total Corneal Astigmatism Prediction')
# Initialize session state for input values if not already present
if 'age' not in st.session_state:
st.session_state.age = None
if 'aca_magnitude' not in st.session_state:
st.session_state.aca_magnitude = None
if 'aca_axis' not in st.session_state:
st.session_state.aca_axis = None
# Input fields using session state
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter Patient Age (18-90 Years):</p>', unsafe_allow_html=True)
#age = st.number_input('Enter Patient Age (18-90 Years):',
age = st.number_input('',
min_value=18.0, max_value=90.0,
value=st.session_state.age if st.session_state.age is not None else None,
step=0.1,
key='age')
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Magnitude (0-10 Diopters):</p>', unsafe_allow_html=True)
#aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):',
aca_magnitude = st.number_input('',
min_value=0.0, max_value=10.0,
value=st.session_state.aca_magnitude if st.session_state.aca_magnitude is not None else None,
step=0.01,
key='aca_magnitude')
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Axis (0-180 Degrees):</p>', unsafe_allow_html=True)
#aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):',
aca_axis = st.number_input('',
min_value=0.0, max_value=180.0,
value=st.session_state.aca_axis if st.session_state.aca_axis is not None else None,
step=0.1,
key='aca_axis')
if st.button('Predict!'):
if age is not None and aca_magnitude is not None and aca_axis is not None:
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
# Calculate initial J0 and J45
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
# Predict new J0 and J45 using the models
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
# Calculate predicted magnitude and axis
predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2)
predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
if predicted_axis < 0:
predicted_axis += 180
# Display results in green success boxes
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
# Display intermediate values for verification
#st.info(f'''
#Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°
#Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}
#Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}
#''')
# Additional debugging information (optional)
#if st.checkbox('Show detailed model inputs and outputs'):
#st.subheader('Debugging Information:')
#st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}")
#st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}")
#st.write(f"Raw J0 output: {new_j0}")
#st.write(f"Raw J45 output: {new_j45}")
else:
st.error('Please ensure all inputs are within the specified ranges.')
else:
st.error('Please fill in all input fields before predicting.')
if __name__ == '__main__':
main()