Anuj02003 commited on
Commit
83a5a13
·
verified ·
1 Parent(s): 55635a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -1,36 +1,37 @@
1
- import streamlit as st
2
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
- import torch
4
-
5
- # Set page configuration as the very first Streamlit command
6
- st.set_page_config(page_title="Spam Detection", page_icon="📧")
7
-
8
- # Load fine-tuned model and tokenizer
9
- model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
10
- tokenizer = DistilBertTokenizerFast.from_pretrained("./fine_tuned_model")
11
-
12
- # Function to predict whether a message is spam or not
13
- def predict_spam(text):
14
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
15
- with torch.no_grad():
16
- outputs = model(**inputs)
17
- logits = outputs.logits
18
- prediction = torch.argmax(logits, dim=-1).item()
19
- return "Spam" if prediction == 1 else "Not Spam"
20
-
21
- def main():
22
- st.title("Spam Detection")
23
- st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.")
24
-
25
- # Input text box for the user
26
- message = st.text_area("Enter message to classify as spam or not:")
27
-
28
- if st.button("Predict"):
29
- if message:
30
- prediction = predict_spam(message)
31
- st.write(f"The message is: {prediction}")
32
- else:
33
- st.write("Please enter a message to classify.")
34
-
35
- if __name__ == "__main__":
36
- main()
 
 
1
+ import streamlit as st
2
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
+ import torch
4
+
5
+ # Set page configuration as the very first Streamlit command
6
+ st.set_page_config(page_title="Spam Detection", page_icon="📧")
7
+
8
+ # Load fine-tuned model and tokenizer
9
+ model = DistilBertForSequenceClassification.from_pretrained("Anuj02003/Spam-classification-using-LLM")
10
+ tokenizer = DistilBertTokenizerFast.from_pretrained("Anuj02003/Spam-classification-using-LLM")
11
+
12
+
13
+ # Function to predict whether a message is spam or not
14
+ def predict_spam(text):
15
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+ logits = outputs.logits
19
+ prediction = torch.argmax(logits, dim=-1).item()
20
+ return "Spam" if prediction == 1 else "Not Spam"
21
+
22
+ def main():
23
+ st.title("Spam Detection")
24
+ st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.")
25
+
26
+ # Input text box for the user
27
+ message = st.text_area("Enter message to classify as spam or not:")
28
+
29
+ if st.button("Predict"):
30
+ if message:
31
+ prediction = predict_spam(message)
32
+ st.write(f"The message is: {prediction}")
33
+ else:
34
+ st.write("Please enter a message to classify.")
35
+
36
+ if __name__ == "__main__":
37
+ main()