Jfink09 commited on
Commit
db7295e
·
verified ·
1 Parent(s): 5df8bc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -35
app.py CHANGED
@@ -1,44 +1,52 @@
 
1
  import torch
2
- from transformers import AutoModelForSequenceClassification
3
-
4
- # Update this variable with your model name from Hugging Face Hub
5
- MODEL_NAME = "model"
6
-
7
- # Load the model (no tokenizer needed)
8
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
9
-
10
- # Function to make predictions (replace with your actual prediction logic)
11
- def predict(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
- This function takes a text input, preprocesses it using the tokenizer,
14
- makes a prediction using the loaded model, and returns the predicted output.
15
- **Replace this function with your actual prediction logic.**
16
  """
17
- # Modify this part to handle input processing if your model requires it
18
- # (assuming your model doesn't need a tokenizer in this example)
19
- inputs = text # Replace with preprocessing steps if necessary
20
 
 
21
  with torch.no_grad():
22
- outputs = model(**inputs)
23
- predictions = torch.argmax(outputs.logits, dim=-1)
24
- return predictions.item()
25
 
26
- # Function to handle user input and make predictions (modify for your UI framework)
27
- def handle_request(data):
28
- """
29
- This function takes user input data (modify based on your UI framework),
30
- extracts the relevant text, and calls the predict function to make a prediction.
31
- """
32
- text = data["text"] # Assuming "text" is the key in your data dictionary
33
- prediction = predict(text)
34
- return {"prediction": prediction}
35
 
36
- if __name__ == "__main__":
37
- from fastapi import FastAPI
 
38
 
39
- app = FastAPI()
 
 
40
 
41
- @app.post("/predict")
42
- async def predict_from_text(data: dict):
43
- response = handle_request(data)
44
- return response
 
1
+ import streamlit as st
2
  import torch
3
+ import torch.nn as nn
4
+
5
+ # Define your model architecture here (same as before)
6
+ class RegressionModel2(nn.Module):
7
+ def __init__(self, input_dim2, hidden_dim2, output_dim2):
8
+ super(RegressionModel2, self).__init__()
9
+ self.fc1 = nn.Linear(input_dim2, hidden_dim2)
10
+ self.relu1 = nn.ReLU() # ReLU activation function
11
+ self.fc2 = nn.Linear(hidden_dim2, output_dim2)
12
+ self.batch_norm1 = nn.BatchNorm1d(hidden_dim2) # Batch normalization
13
+
14
+ def forward(self, x2):
15
+ out = self.fc1(x2)
16
+ out = self.relu1(out)
17
+ out = self.batch_norm1(out)
18
+ out = self.fc2(out)
19
+ return out
20
+
21
+ # Load your saved model state dictionary (assuming 'model.pt' is uploaded)
22
+ model2 = RegressionModel2(input_dim2, hidden_dim2, output_dim2)
23
+ model2.load_state_dict(torch.load('model.pt'))
24
+ model2.eval() # Set the model to evaluation mode
25
+
26
+ def predict(age, aca, axis):
27
  """
28
+ This function takes three arguments (age, axis, aca) as input,
29
+ prepares the data, makes a prediction using the loaded model,
30
+ and returns the predicted value.
31
  """
32
+ # Prepare the input data
33
+ data = torch.tensor([[age, aca, axis]], dtype=torch.float32)
 
34
 
35
+ # Make prediction
36
  with torch.no_grad():
37
+ prediction = model2(data)
 
 
38
 
39
+ # Return the predicted value
40
+ return prediction.item()
 
 
 
 
 
 
 
41
 
42
+ # Streamlit App
43
+ st.title("Astigmatism Prediction App")
44
+ st.write("Enter the patient's information:")
45
 
46
+ age = st.number_input("Age", min_value=0)
47
+ aca = st.number_input("ACA Magnitude", min_value=0)
48
+ axis = st.number_input("ACA Axis", min_value=0)
49
 
50
+ if st.button("Predict"):
51
+ predicted_value = predict(age, aca, axis)
52
+ st.write(f"Predicted Astigmatism Value: {predicted_value}")