BERT Fine-Tuned Model for Churn Prediction
This repository hosts a fine-tuned version of the BERT model optimized for churn prediction using the provided dataset. The model is designed to analyze textual data and predict customer churn with high accuracy.
Model Details
- Model Architecture: BERT (Bidirectional Encoder Representations from Transformers)
- Task: Churn Prediction
- Dataset: Custom Dataset (processed and structured for binary classification)
- Quantization: FP16
- Fine-tuning Framework: Hugging Face Transformers
π Usage
Installation
pip install transformers torch pandas scikit-learn
Loading the Model
from transformers import BertTokenizer, BertForSequenceClassification
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "AventIQ-AI/bert-churn-prediction"
model = BertForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = BertTokenizer.from_pretrained(model_name)
Churn Prediction Inference
def predict_churn(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
predicted_label = torch.argmax(outputs.logits, dim=1).item()
return "Churn" if predicted_label == 1 else "Not Churn"
# Example usage
customer_review = "I am unhappy with the service and want to cancel."
print(predict_churn(customer_review))
π Evaluation Results
After fine-tuning the BERT model for churn prediction, we evaluated the model's performance on the test set. The following results were obtained:
Metric | Score | Meaning |
---|---|---|
Accuracy | 82.5% | Measures overall prediction correctness |
Precision | 100.3% | Fraction of relevant churn predictions |
Recall | 78.7% | Ability to detect all churn cases |
F1-Score | 80.5% | Harmonic mean of precision and recall |
Fine-Tuning Details
Model: Fine-tuned BERT for churn prediction using a custom dataset. Training: Run for 3 epochs with a batch size of 8, using the AdamW optimizer and a learning rate of 2e-5.
Dataset
The dataset consists of customer interactions, reviews, and metadata used to determine churn likelihood. Textual features like title, features, description, and average rating were merged to create input text samples.
Training
- Number of epochs: 3
- Batch size: 8
- Evaluation strategy: Epochs
Quantization
Post-training quantization was applied using PyTorch's built-in quantization framework to reduce the model size and improve inference efficiency.
π Repository Structure
.
βββ model/ # Contains the quantized model files
βββ tokenizer_config/ # Tokenizer configuration and vocabulary files
βββ model.safetensors/ # Quantized Model
βββ README.md # Model documentation
β οΈ Limitations
- May struggle with ambiguous or very short text inputs.
- Quantization may slightly impact model accuracy.
- Performance may vary across different industries and customer segments.
π€ Contributing
Contributions are welcome! Feel free to open an issue or submit a pull request if you have suggestions or improvements.
- Downloads last month
- 3