File size: 7,333 Bytes
52192c2
 
 
 
 
0bf6916
55c1c2f
0bf6916
996f726
2d5dd80
 
 
 
 
996f726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c835573
 
 
 
 
 
 
 
 
 
 
 
 
 
8a144c1
c835573
1852abe
 
 
2d5dd80
 
 
52192c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241be24
 
 
74ad4c3
241be24
52192c2
241be24
74ad4c3
241be24
 
 
 
 
52192c2
61229f1
 
e473f84
 
485bcab
 
61229f1
e473f84
0c0dd4a
e473f84
 
 
8f8e201
0c0dd4a
 
e473f84
8f8e201
 
 
55c1c2f
9df995a
6c01d2d
 
 
 
 
 
 
 
 
d4ced26
cf141c6
85a7fc8
743e01a
6c01d2d
743e01a
 
d4ced26
cf141c6
85a7fc8
743e01a
6c01d2d
743e01a
 
d4ced26
cf141c6
85a7fc8
743e01a
6c01d2d
743e01a
 
241be24
9df995a
6c01d2d
41dd67e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9df995a
41dd67e
8f8e201
52192c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import streamlit as st
import torch
import torch.nn as nn
import math

# Set page config at the very beginning
st.set_page_config(page_title='Total Corneal Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide')

# Custom CSS to set background color to #000 for main app and navigation
st.markdown("""
<style>
    .stApp {
        background-color: #000;
    }
    
    .stDeployButton {
        display: none !important;
    }

    header[data-testid="stHeader"] {
        background-color: #000;
    }

    .stDecoration {
        background-color: #000 !important;
    }
    
    .stToolbar {
        background-color: #000 !important;
    }

    #MainMenu {
        background-color: #000 !important;
    }

    div[data-testid="stToolbar"] {
        background-color: #000 !important;
    }

    button[kind="headerNoPadding"], button[data-testid="baseButton-headerNoPadding"], button[aria-haspopup="menu"] {
        background-color: transparent !important;
    }

    .stApp > header {
        background-color: transparent !important;
    }




    .stTextInput > div > div > input {
        background-color: #333 !important;
        color: white !important;
    }
    .stNumberInput > div > div > input {
        background-color: #333 !important;
        color: white !important;
    }
    .stTextInput > label, .stNumberInput > label {
        color: white !important;
        font-size: 24px !important;
    }
    [data-testid="stNumberInput"] {
        margin-top: -15px;
    }
</style>
""", unsafe_allow_html=True)

class RegressionModel2(nn.Module):
    def __init__(self, input_dim2, hidden_dim2, output_dim2):
        super(RegressionModel2, self).__init__()
        self.fc1 = nn.Linear(input_dim2, hidden_dim2)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim2, output_dim2)
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim2)

    def forward(self, x2):
        out = self.fc1(x2)
        out = self.relu1(out)
        out = self.batch_norm1(out)
        out = self.fc2(out)
        return out

@st.cache_resource
def load_models():
    model_j0 = RegressionModel2(3, 32, 1)
    model_j0.load_state_dict(torch.load('j0_model-2.pt'))
    model_j0.eval()

    model_j45 = RegressionModel2(3, 32, 1)
    model_j45.load_state_dict(torch.load('j45_model-2.pt'))
    model_j45.eval()
    
    return model_j0, model_j45

model_j0, model_j45 = load_models()

def calculate_initial_j0_j45(magnitude, axis_deg):
    axis_rad = math.radians(axis_deg)
    j0 = magnitude * math.cos(2 * axis_rad)
    j45 = magnitude * math.sin(2 * axis_rad)
    return j0, j45

def predict_new_j0_j45(age, aca_magnitude, aca_axis_deg):
    initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis_deg)
    
    input_data_j0 = torch.tensor([[age, aca_axis_deg, initial_j0]], dtype=torch.float32)
    input_data_j45 = torch.tensor([[age, aca_axis_deg, initial_j45]], dtype=torch.float32)
    
    with torch.no_grad():
        new_j0 = model_j0(input_data_j0).item()
        new_j45 = model_j45(input_data_j45).item()
    
    return new_j0, new_j45

def main():
    #st.title('Total Corneal Astigmatism Prediction')

    # Initialize session state for input values if not already present
    if 'age' not in st.session_state:
        st.session_state.age = None
    if 'aca_magnitude' not in st.session_state:
        st.session_state.aca_magnitude = None
    if 'aca_axis' not in st.session_state:
        st.session_state.aca_axis = None

    # Input fields using session state
    st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter Patient Age (18-90 Years):</p>', unsafe_allow_html=True)
    #age = st.number_input('Enter Patient Age (18-90 Years):',
    age = st.number_input('',
                          min_value=18.0, max_value=90.0, 
                          value=st.session_state.age if st.session_state.age is not None else None,
                          step=0.1, 
                          key='age')
    st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Magnitude (0-10 Diopters):</p>', unsafe_allow_html=True)
    #aca_magnitude = st.number_input('Enter ACA Magnitude (0-10 Diopters):',
    aca_magnitude = st.number_input('', 
                                    min_value=0.0, max_value=10.0, 
                                    value=st.session_state.aca_magnitude if st.session_state.aca_magnitude is not None else None,
                                    step=0.01, 
                                    key='aca_magnitude')
    st.markdown('<p style="font-size: 20px; color: white; margin-bottom: 0px;">Enter ACA Axis (0-180 Degrees):</p>', unsafe_allow_html=True)
    #aca_axis = st.number_input('Enter ACA Axis (0-180 Degrees):',
    aca_axis = st.number_input('', 
                               min_value=0.0, max_value=180.0, 
                               value=st.session_state.aca_axis if st.session_state.aca_axis is not None else None,
                               step=0.1, 
                               key='aca_axis')

    if st.button('Predict!'):
        if age is not None and aca_magnitude is not None and aca_axis is not None:
            if 18 <= age <= 90 and 0 <= aca_magnitude <= 10 and 0 <= aca_axis <= 180:
                # Calculate initial J0 and J45
                initial_j0, initial_j45 = calculate_initial_j0_j45(aca_magnitude, aca_axis)
                
                # Predict new J0 and J45 using the models
                new_j0, new_j45 = predict_new_j0_j45(age, aca_magnitude, aca_axis)
                
                # Calculate predicted magnitude and axis
                predicted_magnitude = math.sqrt(new_j0**2 + new_j45**2)
                predicted_axis = 0.5 * math.degrees(math.atan2(new_j45, new_j0))
                if predicted_axis < 0:
                    predicted_axis += 180
                
                # Display results in green success boxes
                st.success(f'Predicted Total Corneal Astigmatism Magnitude: {predicted_magnitude:.2f} D')
                st.success(f'Predicted Total Corneal Astigmatism Axis: {predicted_axis:.1f}°')
                
                # Display intermediate values for verification
                #st.info(f'''
                #Input ACA - Magnitude: {aca_magnitude:.2f} D, Axis: {aca_axis:.1f}°
                #Initial J0: {initial_j0:.2f}, Initial J45: {initial_j45:.2f}
                #Predicted J0: {new_j0:.2f}, Predicted J45: {new_j45:.2f}
                #''')
                
                # Additional debugging information (optional)
                #if st.checkbox('Show detailed model inputs and outputs'):
                    #st.subheader('Debugging Information:')
                    #st.write(f"Input tensor for J0: {torch.tensor([[age, aca_axis, initial_j0]], dtype=torch.float32)}")
                    #st.write(f"Input tensor for J45: {torch.tensor([[age, aca_axis, initial_j45]], dtype=torch.float32)}")
                    #st.write(f"Raw J0 output: {new_j0}")
                    #st.write(f"Raw J45 output: {new_j45}")
            else:
                st.error('Please ensure all inputs are within the specified ranges.')
        else:
            st.error('Please fill in all input fields before predicting.')

if __name__ == '__main__':
    main()