# import streamlit as st | |
# import pandas as pd | |
# import torch | |
# import torch.nn as nn | |
# import torch.optim as optim | |
# from sklearn.metrics import r2_score | |
# class RegressionModel2(nn.Module): | |
# def __init__(self, input_dim2, hidden_dim2, output_dim2): | |
# super(RegressionModel2, self).__init__() | |
# self.fc1 = nn.Linear(input_dim2, hidden_dim2) | |
# self.relu1 = nn.ReLU() | |
# self.fc2 = nn.Linear(hidden_dim2, output_dim2) | |
# self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) | |
# def forward(self, x2): | |
# out = self.fc1(x2) | |
# out = self.relu1(out) | |
# out = self.batch_norm1(out) | |
# out = self.fc2(out) | |
# return out | |
# # Load the saved model state dictionary | |
# model = RegressionModel2(3, 32, 1) | |
# model.load_state_dict(torch.load('model.pt')) | |
# model.eval() # Set the model to evaluation mode | |
# # Define a function to make predictions | |
# def predict_astigmatism(age, axis, aca): | |
# """ | |
# This function takes three arguments (age, axis, aca) as input, | |
# converts them to a tensor, makes a prediction using the loaded model, | |
# and returns the predicted value. | |
# """ | |
# # Prepare the input data | |
# data = torch.tensor([[age, axis, aca]], dtype=torch.float32) | |
# # Make prediction | |
# with torch.no_grad(): | |
# prediction = model(data) | |
# # Return the predicted value | |
# return prediction.item() | |
# def main(): | |
# st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') | |
# st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) | |
# st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) | |
# st.markdown( | |
# """ | |
# <style> | |
# .navbar { | |
# display: flex; | |
# justify-content: space-between; | |
# align-items: center; | |
# background-color: #f2f2f2; | |
# padding: 10px; | |
# } | |
# .logo img { | |
# height: 50px; | |
# } | |
# .menu { | |
# list-style-type: none; | |
# display: flex; | |
# } | |
# .menu li { | |
# margin-left: 20px; | |
# } | |
# .text-content { | |
# margin-top: 50px; | |
# text-align: center; | |
# } | |
# .button { | |
# margin-top: 20px; | |
# padding: 10px 20px; | |
# font-size: 16px; | |
# } | |
# </style> | |
# """, | |
# unsafe_allow_html=True | |
# ) | |
# # st.markdown( | |
# # """ | |
# # <body> | |
# # <header> | |
# # <nav class="navbar"> | |
# # <div class="logo"><img src="iol.png" alt="Image description"></div> | |
# # <ul class="menu"> | |
# # <li><a href="#">Home</a></li> | |
# # <li><a href="#">About</a></li> | |
# # <li><a href="#">Contact</a></li> | |
# # </ul> | |
# # </nav> | |
# # <div class="text-content"> | |
# # <h2>Enter Variables</h2> | |
# # <br> | |
# # </div> | |
# # </header> | |
# # </body> | |
# # """, | |
# # unsafe_allow_html=True | |
# # ) | |
# age = st.number_input('Enter Patient Age:', step=0.1) | |
# aca_magnitude = st.number_input('Enter ACA Magnitude:', step=0.1) | |
# aca_axis = st.number_input('Enter ACA Axis:', step=0.1) | |
# if st.button('Predict!'): | |
# astigmatism = predict_astigmatism(age, aca_axis, aca_magnitude) | |
# st.success(f'Predicted Total Corneal Astigmatism: {astigmatism:.4f}') | |
# if __name__ == '__main__': | |
# main() | |
import streamlit as st | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from sklearn.metrics import r2_score | |
import math | |
class RegressionModel2(nn.Module): | |
def __init__(self, input_dim2, hidden_dim2, output_dim2): | |
super(RegressionModel2, self).__init__() | |
self.fc1 = nn.Linear(input_dim2, hidden_dim2) | |
self.relu1 = nn.ReLU() | |
self.fc2 = nn.Linear(hidden_dim2, output_dim2) | |
self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) | |
def forward(self, x2): | |
out = self.fc1(x2) | |
out = self.relu1(out) | |
out = self.batch_norm1(out) | |
out = self.fc2(out) | |
return out | |
# Load the saved model state dictionary | |
model = RegressionModel2(3, 32, 1) | |
model.load_state_dict(torch.load('model.pt')) | |
model.eval() # Set the model to evaluation mode | |
def predict_astigmatism(age, axis, aca): | |
""" | |
This function takes three arguments (age, axis, aca) as input, | |
converts them to a tensor, makes a prediction using the loaded model, | |
and returns the predicted value. | |
""" | |
# Prepare the input data | |
data = torch.tensor([[age, axis, aca]], dtype=torch.float32) | |
# Make prediction | |
with torch.no_grad(): | |
prediction = model(data) | |
# Return the predicted value | |
return prediction.item() | |
def predict_axis(aca_magnitude, aca_axis): | |
# Convert axis to radians | |
aca_axis_rad = math.radians(aca_axis) | |
# Calculate X and Y components | |
X = aca_magnitude * math.cos(2 * aca_axis_rad) | |
Y = aca_magnitude * math.sin(2 * aca_axis_rad) | |
# Calculate intermediate axis prediction | |
Z = math.degrees(0.5 * math.atan2(Y, X)) | |
# Determine final predicted axis | |
if X > 0: | |
if Y > 0: | |
predicted_axis = Z | |
else: | |
predicted_axis = Z + 180 | |
else: | |
predicted_axis = Z + 90 | |
# Ensure the axis is between 0 and 180 degrees | |
predicted_axis = predicted_axis % 180 | |
return predicted_axis | |
def main(): | |
st.set_page_config(page_title='Astigmatism Prediction', page_icon=':eyeglasses:', layout='wide') | |
st.write('<style>.st-emotion-cache-1dp5vir.ezrtsby1 { display: none; }</style>', unsafe_allow_html=True) | |
st.write("""<style>.st-emotion-cache-czk5ss.e16jpq800 {display: none;}</style>""", unsafe_allow_html=True) | |
st.markdown( | |
""" | |
<style> | |
.navbar { | |
display: flex; | |
justify-content: space-between; | |
align-items: center; | |
background-color: #f2f2f2; | |
padding: 10px; | |
} | |
.logo img { | |
height: 50px; | |
} | |
.menu { | |
list-style-type: none; | |
display: flex; | |
} | |
.menu li { | |
margin-left: 20px; | |
} | |
.text-content { | |
margin-top: 50px; | |
text-align: center; | |
} | |
.button { | |
margin-top: 20px; | |
padding: 10px 20px; | |
font-size: 16px; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.title('Total Corneal Astigmatism Prediction') | |
age = st.number_input('Enter Patient Age:', min_value=0.0, step=0.1) | |
aca_magnitude = st.number_input('Enter ACA Magnitude:', min_value=0.0, step=0.1) | |
aca_axis = st.number_input('Enter ACA Axis:', min_value=0.0, max_value=180.0, step=0.1) | |
if st.button('Predict!'): | |
# Predict magnitude | |
tca_magnitude = predict_astigmatism(age, aca_axis, aca_magnitude) | |
# Predict axis | |
tca_axis = predict_axis(aca_magnitude, aca_axis) | |
st.success(f'Predicted Total Corneal Astigmatism Magnitude: {tca_magnitude:.4f} D') | |
st.success(f'Predicted Total Corneal Astigmatism Axis: {tca_axis:.2f}°') | |
if __name__ == '__main__': | |
main() |