Anuj02003 commited on
Commit
261e9be
·
verified ·
1 Parent(s): 9a79179

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -2,15 +2,15 @@ import streamlit as st
2
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
  import torch
4
 
5
- # Set page configuration as the very first Streamlit command
6
  st.set_page_config(page_title="Spam Detection", page_icon="📧")
7
 
8
- # Load fine-tuned model and tokenizer
9
- model = DistilBertForSequenceClassification.from_pretrained("Anuj02003/Spam-classification-using-LLM")
10
- tokenizer = DistilBertTokenizerFast.from_pretrained("Anuj02003/Spam-classification-using-LLM")
 
11
 
12
-
13
- # Function to predict whether a message is spam or not
14
  def predict_spam(text):
15
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
16
  with torch.no_grad():
@@ -19,17 +19,18 @@ def predict_spam(text):
19
  prediction = torch.argmax(logits, dim=-1).item()
20
  return "Spam" if prediction == 1 else "Not Spam"
21
 
 
22
  def main():
23
  st.title("Spam Detection")
24
  st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.")
25
 
26
- # Input text box for the user
27
  message = st.text_area("Enter message to classify as spam or not:")
28
 
29
  if st.button("Predict"):
30
  if message:
31
  prediction = predict_spam(message)
32
- st.write(f"The message is: {prediction}")
33
  else:
34
  st.write("Please enter a message to classify.")
35
 
 
2
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
  import torch
4
 
5
+ # Set page configuration
6
  st.set_page_config(page_title="Spam Detection", page_icon="📧")
7
 
8
+ # Load the local fine-tuned model and tokenizer
9
+ model_path = "./fine_tuned_model"
10
+ model = DistilBertForSequenceClassification.from_pretrained(model_path)
11
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
12
 
13
+ # Function to predict spam
 
14
  def predict_spam(text):
15
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
16
  with torch.no_grad():
 
19
  prediction = torch.argmax(logits, dim=-1).item()
20
  return "Spam" if prediction == 1 else "Not Spam"
21
 
22
+ # Streamlit app
23
  def main():
24
  st.title("Spam Detection")
25
  st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.")
26
 
27
+ # Input text box
28
  message = st.text_area("Enter message to classify as spam or not:")
29
 
30
  if st.button("Predict"):
31
  if message:
32
  prediction = predict_spam(message)
33
+ st.write(f"The message is: **{prediction}**")
34
  else:
35
  st.write("Please enter a message to classify.")
36