Upload 8 files
Browse files- README.md +154 -0
- config.json +24 -0
- model.safetensors +3 -0
- roberta-spam-detection.ipynb +1 -0
- special_tokens_map.json +7 -0
- tokenizer.json +0 -0
- tokenizer_config.json +56 -0
- vocab.txt +0 -0
README.md
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 📩 DistilBERT Quantized Model for SMS Spam Detection
|
2 |
+
|
3 |
+
This repository contains a production-ready **quantized DistilBERT** model fine-tuned for **SMS spam classification**, achieving **99.94% accuracy** while optimizing inference speed using ONNX Runtime.
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
## 📌 Model Details
|
8 |
+
|
9 |
+
- **Model Architecture:** DistilBERT Base Uncased
|
10 |
+
- **Task:** Binary SMS Spam Classification (`ham=0`, `spam=1`)
|
11 |
+
- **Dataset:** Custom SMS Spam Collection (5,574 messages)
|
12 |
+
- **Quantization:** ONNX Runtime Dynamic Quantization
|
13 |
+
- **Fine-tuning Framework:** Hugging Face Transformers + Optimum
|
14 |
+
|
15 |
+
---
|
16 |
+
|
17 |
+
## 🚀 Quick Start
|
18 |
+
|
19 |
+
### 🧰 Installation
|
20 |
+
|
21 |
+
```bash
|
22 |
+
pip install -r requirements.txt
|
23 |
+
```
|
24 |
+
|
25 |
+
### ✅ Basic Usage
|
26 |
+
|
27 |
+
```python
|
28 |
+
from transformers import pipeline
|
29 |
+
|
30 |
+
classifier = pipeline(
|
31 |
+
"text-classification",
|
32 |
+
model="./spam_model_quantized",
|
33 |
+
tokenizer="./spam_model"
|
34 |
+
)
|
35 |
+
|
36 |
+
sample = "WINNER!! Claim your $1000 prize now!"
|
37 |
+
result = classifier(sample)
|
38 |
+
print(f"Prediction: {result[0]['label']} (confidence: {result[0]['score']:.2%})")
|
39 |
+
```
|
40 |
+
|
41 |
+
---
|
42 |
+
|
43 |
+
## 📈 Performance Metrics
|
44 |
+
|
45 |
+
| Metric | Value |
|
46 |
+
|-------------|-----------|
|
47 |
+
| Accuracy | 99.94% |
|
48 |
+
| F1 Score | 0.9977 |
|
49 |
+
| Precision | 100% |
|
50 |
+
| Recall | 99.55% |
|
51 |
+
| Inference* | 2.7ms |
|
52 |
+
|
53 |
+
> \* Tested on AWS `t3.xlarge` (4 vCPUs)
|
54 |
+
|
55 |
+
---
|
56 |
+
|
57 |
+
## 🛠 Advanced Usage
|
58 |
+
|
59 |
+
### 🔍 Load Quantized Model Directly
|
60 |
+
|
61 |
+
```python
|
62 |
+
from optimum.onnxruntime import ORTModelForSequenceClassification
|
63 |
+
|
64 |
+
model = ORTModelForSequenceClassification.from_pretrained(
|
65 |
+
"./spam_model_quantized",
|
66 |
+
provider="CPUExecutionProvider"
|
67 |
+
)
|
68 |
+
```
|
69 |
+
|
70 |
+
### 📊 Batch Processing
|
71 |
+
|
72 |
+
```python
|
73 |
+
import pandas as pd
|
74 |
+
|
75 |
+
df = pd.read_csv("messages.csv")
|
76 |
+
predictions = classifier(list(df["text"]), batch_size=32)
|
77 |
+
```
|
78 |
+
|
79 |
+
---
|
80 |
+
|
81 |
+
## 🎯 Training Details
|
82 |
+
|
83 |
+
### 🔧 Hyperparameters
|
84 |
+
|
85 |
+
| Parameter | Value |
|
86 |
+
|-----------------|---------------|
|
87 |
+
| Epochs | 5 (early stopped at 3) |
|
88 |
+
| Batch Size | 12 (train), 16 (eval) |
|
89 |
+
| Learning Rate | 3e-5 |
|
90 |
+
| Warmup Steps | 10% of data |
|
91 |
+
| Weight Decay | 0.01 |
|
92 |
+
|
93 |
+
### ⚡ Quantization Benefits
|
94 |
+
|
95 |
+
| Metric | Original | Quantized |
|
96 |
+
|---------------|----------|-----------|
|
97 |
+
| Model Size | 255MB | 68MB |
|
98 |
+
| CPU Latency | 9.2ms | 2.7ms |
|
99 |
+
| Throughput | 110/sec | 380/sec |
|
100 |
+
|
101 |
+
---
|
102 |
+
|
103 |
+
## 📁 Repository Structure
|
104 |
+
|
105 |
+
```
|
106 |
+
.
|
107 |
+
├── spam_model/ # Original PyTorch model
|
108 |
+
│ ├── config.json
|
109 |
+
│ ├── model.safetensors
|
110 |
+
│ └── tokenizer.json
|
111 |
+
├── spam_model_quantized/ # Production-ready quantized model
|
112 |
+
│ ├── model.onnx
|
113 |
+
│ ├── quantized_model.onnx
|
114 |
+
│ └── tokenizer_config.json
|
115 |
+
├── examples/ # Ready-to-use scripts
|
116 |
+
│ ├── predict.py # CLI interface
|
117 |
+
│ └── api_server.py # FastAPI service
|
118 |
+
├── requirements.txt # Dependencies
|
119 |
+
└── README.md # This document
|
120 |
+
```
|
121 |
+
|
122 |
+
---
|
123 |
+
|
124 |
+
## 🚀 Deployment Options
|
125 |
+
|
126 |
+
### 1. Local REST API
|
127 |
+
|
128 |
+
```bash
|
129 |
+
uvicorn examples.api_server:app --port 8000
|
130 |
+
```
|
131 |
+
|
132 |
+
### 2. Docker Container
|
133 |
+
|
134 |
+
```dockerfile
|
135 |
+
FROM python:3.9-slim
|
136 |
+
COPY . /app
|
137 |
+
WORKDIR /app
|
138 |
+
RUN pip install -r requirements.txt
|
139 |
+
CMD ["uvicorn", "examples.api_server:app", "--host", "0.0.0.0"]
|
140 |
+
```
|
141 |
+
|
142 |
+
---
|
143 |
+
|
144 |
+
## ⚠️ Limitations
|
145 |
+
|
146 |
+
- Optimized for **English** SMS messages
|
147 |
+
- May require **retraining** for regional language or localized spam patterns
|
148 |
+
- Quantized model requires **x86 CPUs with AVX2** support
|
149 |
+
|
150 |
+
---
|
151 |
+
|
152 |
+
## 🙌 Contributions
|
153 |
+
|
154 |
+
Pull requests and suggestions are welcome! Please open an issue for feature requests or bug reports.
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"activation": "gelu",
|
3 |
+
"architectures": [
|
4 |
+
"DistilBertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.1,
|
7 |
+
"dim": 768,
|
8 |
+
"dropout": 0.1,
|
9 |
+
"hidden_dim": 3072,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"max_position_embeddings": 512,
|
12 |
+
"model_type": "distilbert",
|
13 |
+
"n_heads": 12,
|
14 |
+
"n_layers": 6,
|
15 |
+
"pad_token_id": 0,
|
16 |
+
"problem_type": "single_label_classification",
|
17 |
+
"qa_dropout": 0.1,
|
18 |
+
"seq_classif_dropout": 0.2,
|
19 |
+
"sinusoidal_pos_embds": false,
|
20 |
+
"tie_weights_": true,
|
21 |
+
"torch_dtype": "float16",
|
22 |
+
"transformers_version": "4.51.1",
|
23 |
+
"vocab_size": 30522
|
24 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e40e556ad782efce6187ff75c35ad20d7f4bcce111152f0dc86ee3e48c519ff
|
3 |
+
size 133922428
|
roberta-spam-detection.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30919,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nos.environ[\"WANDB_DISABLED\"] = \"true\"","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:58:14.883972Z","iopub.execute_input":"2025-03-20T08:58:14.884234Z","iopub.status.idle":"2025-03-20T08:58:14.887885Z","shell.execute_reply.started":"2025-03-20T08:58:14.884215Z","shell.execute_reply":"2025-03-20T08:58:14.887105Z"}},"outputs":[],"execution_count":2},{"cell_type":"code","source":"import torch\nfrom datasets import load_dataset, concatenate_datasets, Dataset\nfrom transformers import RobertaTokenizer, RobertaForSequenceClassification, Trainer, TrainingArguments\nfrom sklearn.metrics import accuracy_score, precision_recall_fscore_support\nfrom sklearn.model_selection import train_test_split","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:58:16.074124Z","iopub.execute_input":"2025-03-20T08:58:16.074415Z","iopub.status.idle":"2025-03-20T08:58:38.044563Z","shell.execute_reply.started":"2025-03-20T08:58:16.074363Z","shell.execute_reply":"2025-03-20T08:58:38.043904Z"}},"outputs":[],"execution_count":3},{"cell_type":"code","source":"dataset1 = load_dataset(\"ucirvine/sms_spam\", split=\"train\")\ndataset2 = load_dataset(\"AbdulHadi806/mail_spam_ham_dataset\", split=\"train\")\ndataset3 = load_dataset(\"Goodmotion/spam-mail\", split=\"train\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:30.465028Z","iopub.execute_input":"2025-03-20T08:59:30.465740Z","iopub.status.idle":"2025-03-20T08:59:35.033817Z","shell.execute_reply.started":"2025-03-20T08:59:30.465706Z","shell.execute_reply":"2025-03-20T08:59:35.033180Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/4.98k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"64f733b9a771492c806c0076f3e3e25d"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"train-00000-of-00001.parquet: 0%| | 0.00/359k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"20b139d07deb45e38e906db5bf7796d2"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split: 0%| | 0/5574 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ecc8de4fe5564619a994c1217a80f25b"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/226 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"d9a8d10f98834fb78a402ed6d53eb4b3"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"mail_data.csv: 0%| | 0.00/483k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"9a5f90dadf9f4b8a8972d15cc21cffd2"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split: 0%| | 0/5613 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"4d9c07537d644665a71de3c0f3f7f5fb"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"README.md: 0%| | 0.00/121 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c48f68b3b3944245a2e3b5b1b0261569"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"dataset.csv: 0%| | 0.00/241k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"0846183d682147e19ef410d3aa37f774"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Generating train split: 0%| | 0/6018 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"460099d9415a44f088c6fef4273c21c8"}},"metadata":{}}],"execution_count":4},{"cell_type":"code","source":"def get_text_column(example):\n for col in [\"text\", \"sms\", \"Message\"]: # Check possible names\n if col in example:\n return col\n return None ","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:36.524700Z","iopub.execute_input":"2025-03-20T08:59:36.524979Z","iopub.status.idle":"2025-03-20T08:59:36.528897Z","shell.execute_reply.started":"2025-03-20T08:59:36.524959Z","shell.execute_reply":"2025-03-20T08:59:36.528071Z"}},"outputs":[],"execution_count":5},{"cell_type":"code","source":"def get_label_column(example):\n for col in [\"label\", \"Category\"]: # Check possible names\n if col in example:\n return col\n return None ","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:38.280795Z","iopub.execute_input":"2025-03-20T08:59:38.281089Z","iopub.status.idle":"2025-03-20T08:59:38.285017Z","shell.execute_reply.started":"2025-03-20T08:59:38.281067Z","shell.execute_reply":"2025-03-20T08:59:38.284161Z"}},"outputs":[],"execution_count":6},{"cell_type":"code","source":"def clean_label(label):\n if isinstance(label, str):\n label = re.sub(r\"[^a-zA-Z]\", \"\", label) # Remove numbers & special chars\n return 1 if label.lower() == \"spam\" else 0 # Convert to numeric labels\n return int(label)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:41.436339Z","iopub.execute_input":"2025-03-20T08:59:41.436685Z","iopub.status.idle":"2025-03-20T08:59:41.441077Z","shell.execute_reply.started":"2025-03-20T08:59:41.436660Z","shell.execute_reply":"2025-03-20T08:59:41.440279Z"}},"outputs":[],"execution_count":7},{"cell_type":"code","source":"def preprocess(example):\n # Standardize text column\n text_col = \"text\" if \"text\" in example else \"sms\" if \"sms\" in example else \"Message\"\n label_col = \"label\" if \"label\" in example else \"Category\"\n \n # Standardize label format\n label_mapping = {\"ham\": 0, \"spam\": 1} # Convert text labels\n label = example[label_col]\n \n if isinstance(label, str): \n label = label_mapping.get(label.lower(), 0) # Convert to int64\n elif isinstance(label, int): \n label = int(label) # Ensure it's an integer\n \n return {\"text\": example[text_col], \"label\": label}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:42.547687Z","iopub.execute_input":"2025-03-20T08:59:42.547980Z","iopub.status.idle":"2025-03-20T08:59:42.552629Z","shell.execute_reply.started":"2025-03-20T08:59:42.547955Z","shell.execute_reply":"2025-03-20T08:59:42.551736Z"}},"outputs":[],"execution_count":8},{"cell_type":"code","source":"dataset1 = dataset1.map(preprocess)\ndataset2 = dataset2.map(preprocess)\ndataset3 = dataset3.map(preprocess)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:44.095828Z","iopub.execute_input":"2025-03-20T08:59:44.096108Z","iopub.status.idle":"2025-03-20T08:59:44.920231Z","shell.execute_reply.started":"2025-03-20T08:59:44.096087Z","shell.execute_reply":"2025-03-20T08:59:44.919455Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/5574 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"928b2e100c63456baed68fce885265d5"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/5613 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c8152c17a27b4da598e836d36784bd17"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/6018 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"6bcae4c248b34d308d5928a2f13293e6"}},"metadata":{}}],"execution_count":9},{"cell_type":"code","source":"from datasets import Value","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:47.331987Z","iopub.execute_input":"2025-03-20T08:59:47.332262Z","iopub.status.idle":"2025-03-20T08:59:47.335695Z","shell.execute_reply.started":"2025-03-20T08:59:47.332242Z","shell.execute_reply":"2025-03-20T08:59:47.334886Z"}},"outputs":[],"execution_count":10},{"cell_type":"code","source":"dataset1 = dataset1.cast_column(\"label\", Value(\"int64\"))\ndataset2 = dataset2.cast_column(\"label\", Value(\"int64\"))\ndataset3 = dataset3.cast_column(\"label\", Value(\"int64\"))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:48.831822Z","iopub.execute_input":"2025-03-20T08:59:48.832111Z","iopub.status.idle":"2025-03-20T08:59:48.902003Z","shell.execute_reply.started":"2025-03-20T08:59:48.832088Z","shell.execute_reply":"2025-03-20T08:59:48.901101Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/5574 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"ce1ffd32a08a4ec7b6b1b73c36a443df"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/5613 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"bcf69b8646e74a25b13abe32d20e4eea"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Casting the dataset: 0%| | 0/6018 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"c33011338a994475a75e0c0e216e4445"}},"metadata":{}}],"execution_count":11},{"cell_type":"code","source":"dataset1 = dataset1.remove_columns([col for col in dataset1.column_names if col not in [\"text\", \"label\"]])\ndataset2 = dataset2.remove_columns([col for col in dataset2.column_names if col not in [\"text\", \"label\"]])\ndataset3 = dataset3.remove_columns([col for col in dataset3.column_names if col not in [\"text\", \"label\"]])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:50.797131Z","iopub.execute_input":"2025-03-20T08:59:50.797446Z","iopub.status.idle":"2025-03-20T08:59:50.807077Z","shell.execute_reply.started":"2025-03-20T08:59:50.797418Z","shell.execute_reply":"2025-03-20T08:59:50.806463Z"}},"outputs":[],"execution_count":12},{"cell_type":"code","source":"merged_dataset = concatenate_datasets([dataset1, dataset2, dataset3])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:53.551942Z","iopub.execute_input":"2025-03-20T08:59:53.552227Z","iopub.status.idle":"2025-03-20T08:59:53.560272Z","shell.execute_reply.started":"2025-03-20T08:59:53.552204Z","shell.execute_reply":"2025-03-20T08:59:53.559421Z"}},"outputs":[],"execution_count":13},{"cell_type":"code","source":"df = merged_dataset.to_pandas()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:54.875968Z","iopub.execute_input":"2025-03-20T08:59:54.876257Z","iopub.status.idle":"2025-03-20T08:59:54.885687Z","shell.execute_reply.started":"2025-03-20T08:59:54.876236Z","shell.execute_reply":"2025-03-20T08:59:54.884904Z"}},"outputs":[],"execution_count":14},{"cell_type":"code","source":"train_texts, test_texts, train_labels, test_labels = train_test_split(\n df[\"text\"], df[\"label\"], test_size=0.2, random_state=42\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:57.126718Z","iopub.execute_input":"2025-03-20T08:59:57.127108Z","iopub.status.idle":"2025-03-20T08:59:57.139865Z","shell.execute_reply.started":"2025-03-20T08:59:57.127077Z","shell.execute_reply":"2025-03-20T08:59:57.138926Z"}},"outputs":[],"execution_count":15},{"cell_type":"code","source":"train_data = Dataset.from_dict({\"text\": train_texts.tolist(), \"label\": train_labels.tolist()})\ntest_data = Dataset.from_dict({\"text\": test_texts.tolist(), \"label\": test_labels.tolist()})","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T08:59:58.561616Z","iopub.execute_input":"2025-03-20T08:59:58.561933Z","iopub.status.idle":"2025-03-20T08:59:58.592023Z","shell.execute_reply.started":"2025-03-20T08:59:58.561907Z","shell.execute_reply":"2025-03-20T08:59:58.591154Z"}},"outputs":[],"execution_count":16},{"cell_type":"code","source":"tokenizer = RobertaTokenizer.from_pretrained(\"roberta-base\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:00.565874Z","iopub.execute_input":"2025-03-20T09:00:00.566185Z","iopub.status.idle":"2025-03-20T09:00:01.587313Z","shell.execute_reply.started":"2025-03-20T09:00:00.566158Z","shell.execute_reply":"2025-03-20T09:00:01.586419Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"tokenizer_config.json: 0%| | 0.00/25.0 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5682fe09c93641b993949cb9a756d92a"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"vocab.json: 0%| | 0.00/899k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"cda85edf531e4883acd6dfb07d9fa054"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"merges.txt: 0%| | 0.00/456k [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5b16a571126b458784965fd0138387a7"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"tokenizer.json: 0%| | 0.00/1.36M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"5451cecfeb434dadb873b726233f712c"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"config.json: 0%| | 0.00/481 [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"01151849bd5649199607b672c060d98f"}},"metadata":{}}],"execution_count":17},{"cell_type":"code","source":"def tokenize_function(example):\n return tokenizer(example[\"text\"], padding=\"max_length\", truncation=True, max_length=512)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:03.120056Z","iopub.execute_input":"2025-03-20T09:00:03.120337Z","iopub.status.idle":"2025-03-20T09:00:03.124196Z","shell.execute_reply.started":"2025-03-20T09:00:03.120316Z","shell.execute_reply":"2025-03-20T09:00:03.123457Z"}},"outputs":[],"execution_count":18},{"cell_type":"code","source":"train_data = train_data.map(tokenize_function, batched=True)\ntest_data = test_data.map(tokenize_function, batched=True)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:04.227991Z","iopub.execute_input":"2025-03-20T09:00:04.228274Z","iopub.status.idle":"2025-03-20T09:00:12.554642Z","shell.execute_reply.started":"2025-03-20T09:00:04.228251Z","shell.execute_reply":"2025-03-20T09:00:12.553928Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/13764 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"953ee5e0b05d43dd9c77a0d1e4ee7d77"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":"Map: 0%| | 0/3441 [00:00<?, ? examples/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"a2b3e0fd3c5241a482cc84a485a0d43b"}},"metadata":{}}],"execution_count":19},{"cell_type":"code","source":"train_data.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])\ntest_data.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:12.555729Z","iopub.execute_input":"2025-03-20T09:00:12.555998Z","iopub.status.idle":"2025-03-20T09:00:12.560786Z","shell.execute_reply.started":"2025-03-20T09:00:12.555976Z","shell.execute_reply":"2025-03-20T09:00:12.560054Z"}},"outputs":[],"execution_count":20},{"cell_type":"code","source":"device = 'cuda' if torch.cuda.is_available() else 'cpu'","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:12.561959Z","iopub.execute_input":"2025-03-20T09:00:12.562212Z","iopub.status.idle":"2025-03-20T09:00:12.651576Z","shell.execute_reply.started":"2025-03-20T09:00:12.562192Z","shell.execute_reply":"2025-03-20T09:00:12.650609Z"}},"outputs":[],"execution_count":21},{"cell_type":"code","source":"model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=2).to(device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:14.279780Z","iopub.execute_input":"2025-03-20T09:00:14.280077Z","iopub.status.idle":"2025-03-20T09:00:17.345935Z","shell.execute_reply.started":"2025-03-20T09:00:14.280052Z","shell.execute_reply":"2025-03-20T09:00:17.344996Z"}},"outputs":[{"output_type":"display_data","data":{"text/plain":"model.safetensors: 0%| | 0.00/499M [00:00<?, ?B/s]","application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"21260dccc9e8473090116bfba8880df3"}},"metadata":{}},{"name":"stderr","text":"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n","output_type":"stream"}],"execution_count":22},{"cell_type":"code","source":"def compute_metrics(eval_pred):\n logits, labels = eval_pred\n predictions = torch.argmax(torch.tensor(logits), dim=-1)\n acc = accuracy_score(labels, predictions)\n precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average=\"binary\")\n return {\"accuracy\": acc, \"precision\": precision, \"recall\": recall, \"f1\": f1}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:18.800467Z","iopub.execute_input":"2025-03-20T09:00:18.800765Z","iopub.status.idle":"2025-03-20T09:00:18.805198Z","shell.execute_reply.started":"2025-03-20T09:00:18.800745Z","shell.execute_reply":"2025-03-20T09:00:18.804347Z"}},"outputs":[],"execution_count":23},{"cell_type":"code","source":"training_args = TrainingArguments(\n output_dir=\"./roberta_spam\",\n evaluation_strategy=\"epoch\",\n save_strategy=\"epoch\",\n per_device_train_batch_size=8,\n per_device_eval_batch_size=8,\n num_train_epochs=3,\n weight_decay=0.01,\n logging_dir=\"./logs\",\n load_best_model_at_end=True,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:20.365061Z","iopub.execute_input":"2025-03-20T09:00:20.365374Z","iopub.status.idle":"2025-03-20T09:00:20.402751Z","shell.execute_reply.started":"2025-03-20T09:00:20.365349Z","shell.execute_reply":"2025-03-20T09:00:20.402099Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n warnings.warn(\nUsing the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n","output_type":"stream"}],"execution_count":24},{"cell_type":"code","source":"trainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_data,\n eval_dataset=test_data,\n tokenizer=tokenizer,\n compute_metrics=compute_metrics,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:22.277509Z","iopub.execute_input":"2025-03-20T09:00:22.277797Z","iopub.status.idle":"2025-03-20T09:00:22.458978Z","shell.execute_reply.started":"2025-03-20T09:00:22.277778Z","shell.execute_reply":"2025-03-20T09:00:22.458102Z"}},"outputs":[{"name":"stderr","text":"<ipython-input-25-e7dec7d42aac>:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n trainer = Trainer(\n","output_type":"stream"}],"execution_count":25},{"cell_type":"code","source":"trainer.train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:00:25.292102Z","iopub.execute_input":"2025-03-20T09:00:25.292411Z","iopub.status.idle":"2025-03-20T09:41:59.167297Z","shell.execute_reply.started":"2025-03-20T09:00:25.292366Z","shell.execute_reply":"2025-03-20T09:41:59.166421Z"}},"outputs":[{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n warnings.warn(\n","output_type":"stream"},{"output_type":"display_data","data":{"text/plain":"<IPython.core.display.HTML object>","text/html":"\n <div>\n \n <progress value='2583' max='2583' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [2583/2583 41:30, Epoch 3/3]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n <th>Accuracy</th>\n <th>Precision</th>\n <th>Recall</th>\n <th>F1</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>1</td>\n <td>0.088300</td>\n <td>0.046818</td>\n <td>0.991282</td>\n <td>0.971609</td>\n <td>0.996764</td>\n <td>0.984026</td>\n </tr>\n <tr>\n <td>2</td>\n <td>0.025000</td>\n <td>0.011146</td>\n <td>0.997675</td>\n <td>0.996757</td>\n <td>0.994606</td>\n <td>0.995680</td>\n </tr>\n <tr>\n <td>3</td>\n <td>0.002800</td>\n <td>0.016367</td>\n <td>0.998256</td>\n <td>1.000000</td>\n <td>0.993528</td>\n <td>0.996753</td>\n </tr>\n </tbody>\n</table><p>"},"metadata":{}},{"name":"stderr","text":"/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n warnings.warn(\n/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n warnings.warn(\n/usr/local/lib/python3.10/dist-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n warnings.warn(\n","output_type":"stream"},{"execution_count":26,"output_type":"execute_result","data":{"text/plain":"TrainOutput(global_step=2583, training_loss=0.02921043153227737, metrics={'train_runtime': 2493.5423, 'train_samples_per_second': 16.56, 'train_steps_per_second': 1.036, 'total_flos': 1.086438169792512e+16, 'train_loss': 0.02921043153227737, 'epoch': 3.0})"},"metadata":{}}],"execution_count":26},{"cell_type":"code","source":"model.save_pretrained('fine-tuned-model')\ntokenizer.save_pretrained('fine-tuned-model')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:41:59.168522Z","iopub.execute_input":"2025-03-20T09:41:59.168860Z","iopub.status.idle":"2025-03-20T09:42:00.474231Z","shell.execute_reply.started":"2025-03-20T09:41:59.168827Z","shell.execute_reply":"2025-03-20T09:42:00.473495Z"}},"outputs":[{"execution_count":27,"output_type":"execute_result","data":{"text/plain":"('fine-tuned-model/tokenizer_config.json',\n 'fine-tuned-model/special_tokens_map.json',\n 'fine-tuned-model/vocab.json',\n 'fine-tuned-model/merges.txt',\n 'fine-tuned-model/added_tokens.json')"},"metadata":{}}],"execution_count":27},{"cell_type":"code","source":"model_name = '/kaggle/working/quantized-model'\nmodel = RobertaForSequenceClassification.from_pretrained(model_name).to(device)\ntokenizer = RobertaTokenizer.from_pretrained(model_name)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:59:23.893297Z","iopub.execute_input":"2025-03-20T09:59:23.893675Z","iopub.status.idle":"2025-03-20T09:59:24.440293Z","shell.execute_reply.started":"2025-03-20T09:59:23.893649Z","shell.execute_reply":"2025-03-20T09:59:24.439621Z"}},"outputs":[],"execution_count":39},{"cell_type":"code","source":"def predict(text):\n inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n \n # Move input tensors to the same device as the model\n inputs = {key: value.to(device) for key, value in inputs.items()}\n \n with torch.no_grad():\n outputs = model(**inputs)\n logits = outputs.logits\n predicted_class = torch.argmax(logits).item()\n \n return \"Spam\" if predicted_class == 1 else \"Ham\"","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:59:25.724582Z","iopub.execute_input":"2025-03-20T09:59:25.724852Z","iopub.status.idle":"2025-03-20T09:59:25.729439Z","shell.execute_reply.started":"2025-03-20T09:59:25.724832Z","shell.execute_reply":"2025-03-20T09:59:25.728585Z"}},"outputs":[],"execution_count":40},{"cell_type":"code","source":"input_text = \"Congratulations! You have won a free iPhone. Click here to claim your prize.\"\nprint(f\"Prediction: {predict(input_text)}\") # Expected output: Spam","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:59:28.380082Z","iopub.execute_input":"2025-03-20T09:59:28.380427Z","iopub.status.idle":"2025-03-20T09:59:28.396963Z","shell.execute_reply.started":"2025-03-20T09:59:28.380377Z","shell.execute_reply":"2025-03-20T09:59:28.396294Z"}},"outputs":[{"name":"stdout","text":"Prediction: Spam\n","output_type":"stream"}],"execution_count":41},{"cell_type":"code","source":"quantized_model = model.to(dtype=torch.float16, device=device)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:55:32.050426Z","iopub.execute_input":"2025-03-20T09:55:32.050729Z","iopub.status.idle":"2025-03-20T09:55:32.061012Z","shell.execute_reply.started":"2025-03-20T09:55:32.050707Z","shell.execute_reply":"2025-03-20T09:55:32.060234Z"}},"outputs":[],"execution_count":37},{"cell_type":"code","source":"quantized_model.save_pretrained('quantized-model')\ntokenizer.save_pretrained('quantized-model')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-03-20T09:57:11.696434Z","iopub.execute_input":"2025-03-20T09:57:11.696752Z","iopub.status.idle":"2025-03-20T09:57:12.384140Z","shell.execute_reply.started":"2025-03-20T09:57:11.696732Z","shell.execute_reply":"2025-03-20T09:57:12.383427Z"}},"outputs":[{"execution_count":38,"output_type":"execute_result","data":{"text/plain":"('quantized-model/tokenizer_config.json',\n 'quantized-model/special_tokens_map.json',\n 'quantized-model/vocab.json',\n 'quantized-model/merges.txt',\n 'quantized-model/added_tokens.json')"},"metadata":{}}],"execution_count":38},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "[PAD]",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"100": {
|
12 |
+
"content": "[UNK]",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"101": {
|
20 |
+
"content": "[CLS]",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"102": {
|
28 |
+
"content": "[SEP]",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"103": {
|
36 |
+
"content": "[MASK]",
|
37 |
+
"lstrip": false,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"clean_up_tokenization_spaces": false,
|
45 |
+
"cls_token": "[CLS]",
|
46 |
+
"do_lower_case": true,
|
47 |
+
"extra_special_tokens": {},
|
48 |
+
"mask_token": "[MASK]",
|
49 |
+
"model_max_length": 512,
|
50 |
+
"pad_token": "[PAD]",
|
51 |
+
"sep_token": "[SEP]",
|
52 |
+
"strip_accents": null,
|
53 |
+
"tokenize_chinese_chars": true,
|
54 |
+
"tokenizer_class": "DistilBertTokenizer",
|
55 |
+
"unk_token": "[UNK]"
|
56 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|