File size: 3,589 Bytes
10a4c5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# π© DistilBERT Quantized Model for SMS Spam Detection
This repository contains a production-ready **quantized DistilBERT** model fine-tuned for **SMS spam classification**, achieving **99.94% accuracy** while optimizing inference speed using ONNX Runtime.
---
## π Model Details
- **Model Architecture:** DistilBERT Base Uncased
- **Task:** Binary SMS Spam Classification (`ham=0`, `spam=1`)
- **Dataset:** Custom SMS Spam Collection (5,574 messages)
- **Quantization:** ONNX Runtime Dynamic Quantization
- **Fine-tuning Framework:** Hugging Face Transformers + Optimum
---
## π Quick Start
### π§° Installation
```bash
pip install -r requirements.txt
```
### β
Basic Usage
```python
from transformers import pipeline
classifier = pipeline(
"text-classification",
model="./spam_model_quantized",
tokenizer="./spam_model"
)
sample = "WINNER!! Claim your $1000 prize now!"
result = classifier(sample)
print(f"Prediction: {result[0]['label']} (confidence: {result[0]['score']:.2%})")
```
---
## π Performance Metrics
| Metric | Value |
|-------------|-----------|
| Accuracy | 99.94% |
| F1 Score | 0.9977 |
| Precision | 100% |
| Recall | 99.55% |
| Inference* | 2.7ms |
> \* Tested on AWS `t3.xlarge` (4 vCPUs)
---
## π Advanced Usage
### π Load Quantized Model Directly
```python
from optimum.onnxruntime import ORTModelForSequenceClassification
model = ORTModelForSequenceClassification.from_pretrained(
"./spam_model_quantized",
provider="CPUExecutionProvider"
)
```
### π Batch Processing
```python
import pandas as pd
df = pd.read_csv("messages.csv")
predictions = classifier(list(df["text"]), batch_size=32)
```
---
## π― Training Details
### π§ Hyperparameters
| Parameter | Value |
|-----------------|---------------|
| Epochs | 5 (early stopped at 3) |
| Batch Size | 12 (train), 16 (eval) |
| Learning Rate | 3e-5 |
| Warmup Steps | 10% of data |
| Weight Decay | 0.01 |
### β‘ Quantization Benefits
| Metric | Original | Quantized |
|---------------|----------|-----------|
| Model Size | 255MB | 68MB |
| CPU Latency | 9.2ms | 2.7ms |
| Throughput | 110/sec | 380/sec |
---
## π Repository Structure
```
.
βββ spam_model/ # Original PyTorch model
β βββ config.json
β βββ model.safetensors
β βββ tokenizer.json
βββ spam_model_quantized/ # Production-ready quantized model
β βββ model.onnx
β βββ quantized_model.onnx
β βββ tokenizer_config.json
βββ examples/ # Ready-to-use scripts
β βββ predict.py # CLI interface
β βββ api_server.py # FastAPI service
βββ requirements.txt # Dependencies
βββ README.md # This document
```
---
## π Deployment Options
### 1. Local REST API
```bash
uvicorn examples.api_server:app --port 8000
```
### 2. Docker Container
```dockerfile
FROM python:3.9-slim
COPY . /app
WORKDIR /app
RUN pip install -r requirements.txt
CMD ["uvicorn", "examples.api_server:app", "--host", "0.0.0.0"]
```
---
## β οΈ Limitations
- Optimized for **English** SMS messages
- May require **retraining** for regional language or localized spam patterns
- Quantized model requires **x86 CPUs with AVX2** support
---
## π Contributions
Pull requests and suggestions are welcome! Please open an issue for feature requests or bug reports. |