Update app.py
Browse files
app.py
CHANGED
@@ -61,7 +61,7 @@ if st.button('Load Model'):
|
|
61 |
|
62 |
# Forward pass, calculate logit predictions
|
63 |
with torch.no_grad():
|
64 |
-
|
65 |
|
66 |
prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal'
|
67 |
pred = 'Predicted Class: '+ prediction
|
@@ -70,7 +70,7 @@ if st.button('Load Model'):
|
|
70 |
|
71 |
#st.write('Input', namestr(new_sentence, globals()),': \n', new_sentence)
|
72 |
with col2:
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
st.success()
|
|
|
61 |
|
62 |
# Forward pass, calculate logit predictions
|
63 |
with torch.no_grad():
|
64 |
+
output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device))
|
65 |
|
66 |
prediction = 'Spam' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Normal'
|
67 |
pred = 'Predicted Class: '+ prediction
|
|
|
70 |
|
71 |
#st.write('Input', namestr(new_sentence, globals()),': \n', new_sentence)
|
72 |
with col2:
|
73 |
+
text = st.text_input("Enter the text you'd like to analyze for spam.")
|
74 |
+
if text or st.button('Analyze'):
|
75 |
+
predict(text)
|
76 |
st.success()
|