|
# BERT Fine-Tuned Model for Churn Prediction |
|
|
|
This repository hosts a fine-tuned version of the **BERT** model optimized for **churn prediction** using the provided dataset. The model is designed to analyze textual data and predict customer churn with high accuracy. |
|
|
|
## Model Details |
|
- **Model Architecture**: BERT (Bidirectional Encoder Representations from Transformers) |
|
- **Task**: Churn Prediction |
|
- **Dataset**: Custom Dataset (processed and structured for binary classification) |
|
- **Quantization**: FP16 |
|
- **Fine-tuning Framework**: Hugging Face Transformers |
|
|
|
## π Usage |
|
|
|
### Installation |
|
```bash |
|
pip install transformers torch pandas scikit-learn |
|
``` |
|
|
|
### Loading the Model |
|
```python |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import torch |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model_name = "AventIQ-AI/bert-churn-prediction" |
|
model = BertForSequenceClassification.from_pretrained(model_name).to(device) |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
``` |
|
|
|
### Churn Prediction Inference |
|
```python |
|
def predict_churn(text): |
|
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predicted_label = torch.argmax(outputs.logits, dim=1).item() |
|
return "Churn" if predicted_label == 1 else "Not Churn" |
|
|
|
# Example usage |
|
customer_review = "I am unhappy with the service and want to cancel." |
|
print(predict_churn(customer_review)) |
|
``` |
|
|
|
## π Evaluation Results |
|
After fine-tuning the **BERT** model for churn prediction, we evaluated the model's performance on the test set. The following results were obtained: |
|
|
|
| Metric | Score | Meaning | |
|
|------------|--------|------------------------------------------------| |
|
| **Accuracy** | 82.5% | Measures overall prediction correctness | |
|
| **Precision** | 100.3% | Fraction of relevant churn predictions | |
|
| **Recall** | 78.7% | Ability to detect all churn cases | |
|
| **F1-Score**| 80.5% | Harmonic mean of precision and recall | |
|
|
|
## Fine-Tuning Details |
|
Model: Fine-tuned BERT for churn prediction using a custom dataset. |
|
Training: Run for 3 epochs with a batch size of 8, using the AdamW optimizer and a learning rate of 2e-5. |
|
|
|
### Dataset |
|
The dataset consists of customer interactions, reviews, and metadata used to determine churn likelihood. Textual features like **title, features, description, and average rating** were merged to create input text samples. |
|
|
|
### Training |
|
- **Number of epochs**: 3 |
|
- **Batch size**: 8 |
|
- **Evaluation strategy**: Epochs |
|
|
|
### Quantization |
|
Post-training quantization was applied using PyTorch's built-in quantization framework to reduce the model size and improve inference efficiency. |
|
|
|
## π Repository Structure |
|
```bash |
|
. |
|
βββ model/ # Contains the quantized model files |
|
βββ tokenizer_config/ # Tokenizer configuration and vocabulary files |
|
βββ model.safetensors/ # Quantized Model |
|
βββ README.md # Model documentation |
|
``` |
|
|
|
## β οΈ Limitations |
|
- May struggle with ambiguous or very short text inputs. |
|
- Quantization may slightly impact model accuracy. |
|
- Performance may vary across different industries and customer segments. |
|
|
|
## π€ Contributing |
|
Contributions are welcome! Feel free to open an issue or submit a pull request if you have suggestions or improvements. |
|
|
|
|