Update app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
|
3 |
-
import tensorflow as tf
|
4 |
-
|
5 |
-
# Load the pre-trained model and tokenizer
|
6 |
-
model_path = '
|
7 |
-
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
|
8 |
-
loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
9 |
-
|
10 |
-
# Define the prediction function
|
11 |
-
def predict_with_loaded_model(in_sentences):
|
12 |
-
labels = ["non-stress", "stress"]
|
13 |
-
inputs = loaded_tokenizer(in_sentences, return_tensors="tf", padding=True, truncation=True, max_length=512)
|
14 |
-
predictions = loaded_model(inputs)
|
15 |
-
predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
|
16 |
-
predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
|
17 |
-
|
18 |
-
return [{"text": sentence, "confidence": probs.tolist(), "label": labels[label]} for sentence, label, probs in zip(in_sentences, predicted_labels, predicted_probs)]
|
19 |
-
|
20 |
-
# Streamlit interface
|
21 |
-
st.title("Stress Prediction with DistilBERT")
|
22 |
-
|
23 |
-
# Add a text input box for the user to enter a sentence
|
24 |
-
user_input = st.text_area("Enter a sentence or text:", "")
|
25 |
-
|
26 |
-
# When the user clicks "Predict", run the prediction function
|
27 |
-
if st.button("Predict"):
|
28 |
-
if user_input:
|
29 |
-
# Make the prediction using the model
|
30 |
-
prediction = predict_with_loaded_model([user_input])[0]
|
31 |
-
st.write(f"Text: {prediction['text']}")
|
32 |
-
st.write(f"Prediction: {prediction['label']}")
|
33 |
-
st.write(f"Confidence: {prediction['confidence']}")
|
34 |
-
else:
|
35 |
-
st.write("Please enter a sentence to predict.")
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import DistilBertTokenizer, TFDistilBertForSequenceClassification
|
3 |
+
import tensorflow as tf
|
4 |
+
|
5 |
+
# Load the pre-trained model and tokenizer
|
6 |
+
model_path = 'shukdevdatta123/Stress_Prediction_DistillBert/'
|
7 |
+
loaded_model = TFDistilBertForSequenceClassification.from_pretrained(model_path)
|
8 |
+
loaded_tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
9 |
+
|
10 |
+
# Define the prediction function
|
11 |
+
def predict_with_loaded_model(in_sentences):
|
12 |
+
labels = ["non-stress", "stress"]
|
13 |
+
inputs = loaded_tokenizer(in_sentences, return_tensors="tf", padding=True, truncation=True, max_length=512)
|
14 |
+
predictions = loaded_model(inputs)
|
15 |
+
predicted_labels = tf.argmax(predictions.logits, axis=-1).numpy()
|
16 |
+
predicted_probs = tf.nn.softmax(predictions.logits, axis=-1).numpy()
|
17 |
+
|
18 |
+
return [{"text": sentence, "confidence": probs.tolist(), "label": labels[label]} for sentence, label, probs in zip(in_sentences, predicted_labels, predicted_probs)]
|
19 |
+
|
20 |
+
# Streamlit interface
|
21 |
+
st.title("Stress Prediction with DistilBERT")
|
22 |
+
|
23 |
+
# Add a text input box for the user to enter a sentence
|
24 |
+
user_input = st.text_area("Enter a sentence or text:", "")
|
25 |
+
|
26 |
+
# When the user clicks "Predict", run the prediction function
|
27 |
+
if st.button("Predict"):
|
28 |
+
if user_input:
|
29 |
+
# Make the prediction using the model
|
30 |
+
prediction = predict_with_loaded_model([user_input])[0]
|
31 |
+
st.write(f"Text: {prediction['text']}")
|
32 |
+
st.write(f"Prediction: {prediction['label']}")
|
33 |
+
st.write(f"Confidence: {prediction['confidence']}")
|
34 |
+
else:
|
35 |
+
st.write("Please enter a sentence to predict.")
|