Spaces:
Sleeping
Sleeping
File size: 7,333 Bytes
52192c2 0bf6916 55c1c2f 0bf6916 996f726 2d5dd80 996f726 c835573 8a144c1 c835573 1852abe 2d5dd80 52192c2 241be24 74ad4c3 241be24 52192c2 241be24 74ad4c3 241be24 52192c2 61229f1 e473f84 485bcab 61229f1 e473f84 0c0dd4a e473f84 8f8e201 0c0dd4a e473f84 8f8e201 55c1c2f 9df995a 6c01d2d d4ced26 cf141c6 85a7fc8 743e01a 6c01d2d 743e01a d4ced26 cf141c6 85a7fc8 743e01a 6c01d2d 743e01a d4ced26 cf141c6 85a7fc8 743e01a 6c01d2d 743e01a 241be24 9df995a 6c01d2d 41dd67e 9df995a 41dd67e 8f8e201 52192c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import streamlit as st
import torch
import torch.nn as nn
import math
# Set page config at the very beginning
st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')
# Custom CSS to set background color to #000 for main app and navigation
st.markdown("""
<style>
.stApp {
background-color: #000;
}
.stDeployButton {
display: none !important;
}
header[data-testid="stHeader"] {
background-color: #000;
}
.stDecoration {
background-color: #000 !important;
}
.stToolbar {
background-color: #000 !important;
}
#MainMenu {
background-color: #000 !important;
}
div[data-testid="stToolbar"] {
background-color: #000 !important;
}
button[kind="headerNoPadding"], button[data-testid="baseButton-headerNoPadding"], button[aria-haspopup="menu"] {
background-color: transparent !important;
}
.stApp > header {
background-color: transparent !important;
}
.stTextInput > div > div > input {
background-color: #333 !important;
color: white !important;
}
.stNumberInput > div > div > input {
background-color: #333 !important;
color: white !important;
}
.stTextInput > label, .stNumberInput > label {
color: white !important;
font-size: 24px !important;
}
[data-testid="stNumberInput"] {
margin-top: -15px;
}
</style>
""", unsafe_allow_html=True)
class RegressionModel2(nn.Module):
def __init__(self, input_dim2, hidden_dim2, output_dim2):
super(RegressionModel2, self).__init__()
self.fc1 = nn.Linear(input_dim2, hidden_dim2)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim2, output_dim2)
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)
def forward(self, x2):
out = self.fc1(x2)
out = self.relu1(out)
out = self.batch_norm1(out)
out = self.fc2(out)
return out
@st.cache_resource
def load_models():
model_j0 = RegressionModel2(3, 32, 1)
model_j0.load_state_dict(torch.load('j0_model-2.pt'))
model_j0.eval()
model_j45 = RegressionModel2(3, 32, 1)
model_j45.load_state_dict(torch.load('j45_model-2.pt'))
model_j45.eval()
return model_j0, model_j45
model_j0, model_j45 = load_models()
def calculate_initial_j0_j45(magnitude, axis_deg):
axis_rad = math.radians(axis_deg)
j0 = magnitude * math.cos(2 * axis_rad)
j45 = magnitude * math.sin(2 * axis_rad)
return j0, j45
def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
with torch.no_grad():
new_j0 = model_j0(input_data_j0).item()
new_j45 = model_j45(input_data_j45).item()
return new_j0, new_j45
def main():
#st.title('Total Corneal Astigmatism Prediction')
# Initialize session state for input values if not already present
if 'age' not in st.session_state:
st.session_state.age = None
if 'aca_magnitude' not in st.session_state:
st.session_state.aca_magnitude = None
if 'aca_axis' not in st.session_state:
st.session_state.aca_axis = None
# Input fields using session state
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter Patient Age (18-90 Years):</p>', unsafe_allow_html=True)
#age = st.number_input('Enter Patient Age (18-90 Years):',
age = st.number_input('',
min_value=18.0, max_value=90.0,
value=st.session_state.age if st.session_state.age is not None else None,
step=0.1,
key='age')
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Magnitude (0-10 Diopters):</p>', unsafe_allow_html=True)
#aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):',
aca_magnitude = st.number_input('',
min_value=0.0, max_value=10.0,
value=st.session_state.aca_magnitude if st.session_state.aca_magnitude is not None else None,
step=0.01,
key='aca_magnitude')
st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Axis (0-180 Degrees):</p>', unsafe_allow_html=True)
#aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):',
aca_axis = st.number_input('',
min_value=0.0, max_value=180.0,
value=st.session_state.aca_axis if st.session_state.aca_axis is not None else None,
step=0.1,
key='aca_axis')
if st.button('Predict!'):
if age is not None and aca_magnitude is not None and aca_axis is not None:
if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
# Calculate initial J0 and J45
initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
# Predict new J0 and J45 using the models
new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
# Calculate predicted magnitude and axis
predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2)
predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
if predicted_axis < 0:
predicted_axis += 180
# Display results in green success boxes
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
# Display intermediate values for verification
#st.info(f'''
#Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°
#Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}
#Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}
#''')
# Additional debugging information (optional)
#if st.checkbox('Show detailed model inputs and outputs'):
#st.subheader('Debugging Information:')
#st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}")
#st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}")
#st.write(f"Raw J0 output: {new_j0}")
#st.write(f"Raw J45 output: {new_j45}")
else:
st.error('Please ensure all inputs are within the specified ranges.')
else:
st.error('Please fill in all input fields before predicting.')
if __name__ == '__main__':
main() |