|
# π© DistilBERT Quantized Model for SMS Spam Detection |
|
|
|
This repository contains a production-ready **quantized DistilBERT** model fine-tuned for **SMS spam classification**, achieving **99.94% accuracy** while optimizing inference speed using ONNX Runtime. |
|
|
|
--- |
|
|
|
## π Model Details |
|
|
|
- **Model Architecture:** DistilBERT Base Uncased |
|
- **Task:** Binary SMS Spam Classification (`ham=0`, `spam=1`) |
|
- **Dataset:** Custom SMS Spam Collection (5,574 messages) |
|
- **Quantization:** ONNX Runtime Dynamic Quantization |
|
- **Fine-tuning Framework:** Hugging Face Transformers + Optimum |
|
|
|
--- |
|
|
|
## π Quick Start |
|
|
|
### π§° Installation |
|
|
|
```bash |
|
pip install -r requirements.txt |
|
``` |
|
|
|
### β
Basic Usage |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
classifier = pipeline( |
|
"text-classification", |
|
model="./spam_model_quantized", |
|
tokenizer="./spam_model" |
|
) |
|
|
|
sample = "WINNER!! Claim your $1000 prize now!" |
|
result = classifier(sample) |
|
print(f"Prediction: {result[0]['label']} (confidence: {result[0]['score']:.2%})") |
|
``` |
|
|
|
--- |
|
|
|
## π Performance Metrics |
|
|
|
| Metric | Value | |
|
|-------------|-----------| |
|
| Accuracy | 99.94% | |
|
| F1 Score | 0.9977 | |
|
| Precision | 100% | |
|
| Recall | 99.55% | |
|
| Inference* | 2.7ms | |
|
|
|
> \* Tested on AWS `t3.xlarge` (4 vCPUs) |
|
|
|
--- |
|
|
|
## π Advanced Usage |
|
|
|
### π Load Quantized Model Directly |
|
|
|
```python |
|
from optimum.onnxruntime import ORTModelForSequenceClassification |
|
|
|
model = ORTModelForSequenceClassification.from_pretrained( |
|
"./spam_model_quantized", |
|
provider="CPUExecutionProvider" |
|
) |
|
``` |
|
|
|
### π Batch Processing |
|
|
|
```python |
|
import pandas as pd |
|
|
|
df = pd.read_csv("messages.csv") |
|
predictions = classifier(list(df["text"]), batch_size=32) |
|
``` |
|
|
|
--- |
|
|
|
## π― Training Details |
|
|
|
### π§ Hyperparameters |
|
|
|
| Parameter | Value | |
|
|-----------------|---------------| |
|
| Epochs | 5 (early stopped at 3) | |
|
| Batch Size | 12 (train), 16 (eval) | |
|
| Learning Rate | 3e-5 | |
|
| Warmup Steps | 10% of data | |
|
| Weight Decay | 0.01 | |
|
|
|
### β‘ Quantization Benefits |
|
|
|
| Metric | Original | Quantized | |
|
|---------------|----------|-----------| |
|
| Model Size | 255MB | 68MB | |
|
| CPU Latency | 9.2ms | 2.7ms | |
|
| Throughput | 110/sec | 380/sec | |
|
|
|
--- |
|
|
|
## π Repository Structure |
|
|
|
``` |
|
. |
|
βββ spam_model/ # Original PyTorch model |
|
β βββ config.json |
|
β βββ model.safetensors |
|
β βββ tokenizer.json |
|
βββ spam_model_quantized/ # Production-ready quantized model |
|
β βββ model.onnx |
|
β βββ quantized_model.onnx |
|
β βββ tokenizer_config.json |
|
βββ examples/ # Ready-to-use scripts |
|
β βββ predict.py # CLI interface |
|
β βββ api_server.py # FastAPI service |
|
βββ requirements.txt # Dependencies |
|
βββ README.md # This document |
|
``` |
|
|
|
--- |
|
|
|
## π Deployment Options |
|
|
|
### 1. Local REST API |
|
|
|
```bash |
|
uvicorn examples.api_server:app --port 8000 |
|
``` |
|
|
|
### 2. Docker Container |
|
|
|
```dockerfile |
|
FROM python:3.9-slim |
|
COPY . /app |
|
WORKDIR /app |
|
RUN pip install -r requirements.txt |
|
CMD ["uvicorn", "examples.api_server:app", "--host", "0.0.0.0"] |
|
``` |
|
|
|
--- |
|
|
|
## β οΈ Limitations |
|
|
|
- Optimized for **English** SMS messages |
|
- May require **retraining** for regional language or localized spam patterns |
|
- Quantized model requires **x86 CPUs with AVX2** support |
|
|
|
--- |
|
|
|
## π Contributions |
|
|
|
Pull requests and suggestions are welcome! Please open an issue for feature requests or bug reports. |