Anuj02003 commited on
Commit
a980cfd
·
verified ·
1 Parent(s): d838e51

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +36 -0
  2. fine_tune.py +84 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
+ import torch
4
+
5
+ # Set page configuration as the very first Streamlit command
6
+ st.set_page_config(page_title="Spam Detection", page_icon="📧")
7
+
8
+ # Load fine-tuned model and tokenizer
9
+ model = DistilBertForSequenceClassification.from_pretrained("./fine_tuned_model")
10
+ tokenizer = DistilBertTokenizerFast.from_pretrained("./fine_tuned_model")
11
+
12
+ # Function to predict whether a message is spam or not
13
+ def predict_spam(text):
14
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ logits = outputs.logits
18
+ prediction = torch.argmax(logits, dim=-1).item()
19
+ return "Spam" if prediction == 1 else "Not Spam"
20
+
21
+ def main():
22
+ st.title("Spam Detection")
23
+ st.write("This is a Spam Detection App using a fine-tuned DistilBERT model.")
24
+
25
+ # Input text box for the user
26
+ message = st.text_area("Enter message to classify as spam or not:")
27
+
28
+ if st.button("Predict"):
29
+ if message:
30
+ prediction = predict_spam(message)
31
+ st.write(f"The message is: {prediction}")
32
+ else:
33
+ st.write("Please enter a message to classify.")
34
+
35
+ if __name__ == "__main__":
36
+ main()
fine_tune.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, Trainer, TrainingArguments
3
+ import torch
4
+ from sklearn.metrics import accuracy_score
5
+
6
+ # Load the dataset
7
+ dataset = load_dataset("sms_spam")
8
+
9
+ # Print the dataset structure and inspect the columns
10
+ print(dataset)
11
+ print(dataset['train'][0]) # Print the first row of the 'train' split
12
+
13
+ # Initialize the tokenizer
14
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
15
+
16
+ # Initialize the model
17
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
18
+
19
+ # Tokenize the dataset using the correct column
20
+ def tokenize_function(examples):
21
+ return tokenizer(examples["sms"], padding="max_length", truncation=True)
22
+
23
+ # Apply the tokenization to the dataset
24
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
25
+
26
+ # Check if 'test' split exists, else use 'validation' or create your own split
27
+ train_dataset = tokenized_datasets["train"]
28
+
29
+ # If there is no 'test' split, you can use 'validation' or manually split the dataset
30
+ eval_dataset = tokenized_datasets.get("test", tokenized_datasets.get("validation"))
31
+
32
+ # If neither 'test' nor 'validation' exists, manually split the dataset
33
+ if eval_dataset is None:
34
+ eval_dataset = train_dataset.shuffle(seed=42).select([i for i in range(len(train_dataset)//10)]) # Take 10% as eval dataset
35
+ train_dataset = train_dataset.select([i for i in range(len(train_dataset)//10, len(train_dataset))]) # Take the remaining 90% as train dataset
36
+
37
+ # Set up training arguments
38
+ # Set up training arguments
39
+ training_args = TrainingArguments(
40
+ output_dir="./results",
41
+ evaluation_strategy="steps", # Evaluate every 'eval_steps'
42
+ save_strategy="steps", # Save every 'save_steps'
43
+ eval_steps=500, # Evaluate every 500 steps
44
+ save_steps=500, # Save every 500 steps
45
+ learning_rate=2e-5,
46
+ per_device_train_batch_size=16,
47
+ per_device_eval_batch_size=64,
48
+ num_train_epochs=3,
49
+ weight_decay=0.01,
50
+ logging_dir="./logs",
51
+ logging_steps=10,
52
+ load_best_model_at_end=True,
53
+ metric_for_best_model="accuracy",
54
+ )
55
+
56
+
57
+ # Define compute_metrics function (optional, if you want to track metrics)
58
+ def compute_metrics(p):
59
+ predictions, labels = p
60
+ preds = predictions.argmax(axis=1)
61
+ return {"accuracy": accuracy_score(labels, preds)}
62
+
63
+ # Initialize the Trainer
64
+ trainer = Trainer(
65
+ model=model,
66
+ args=training_args,
67
+ train_dataset=train_dataset,
68
+ eval_dataset=eval_dataset,
69
+ compute_metrics=compute_metrics, # Optional: to compute accuracy
70
+ )
71
+
72
+ # Train the model
73
+ trainer.train()
74
+
75
+ # Save the model after training
76
+ model.save_pretrained("./fine_tuned_model")
77
+ tokenizer.save_pretrained("./fine_tuned_model")
78
+
79
+
80
+ # Optionally, push the model to Hugging Face Hub
81
+ # from huggingface_hub import HfApi, HfFolder
82
+
83
+ # model.push_to_hub("Anuj02003/Spam-classification-using-LLM")
84
+ # tokenizer.push_to_hub("Anuj02003/Spam-classification-using-LLM")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision
3
+ torchaudio
4
+ transformers==4.33.0
5
+ datasets==2.14.0
6
+ streamlit==1.24.0
7
+ huggingface_hub