Spaces:
Sleeping
Sleeping
sync with remote
Browse files- app.py +124 -0
- churn_prediction_model.pkl +3 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from datasets import load_dataset
|
3 |
+
import pandas as pd
|
4 |
+
import pickle
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import seaborn as sns
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
|
10 |
+
# Load the dataset
|
11 |
+
dataset = load_dataset("louiecerv/customer_churn")["train"]
|
12 |
+
|
13 |
+
# Define repository details
|
14 |
+
repo_id = "louiecerv/churn_prediction_model"
|
15 |
+
filename = "churn_prediction_model.pkl"
|
16 |
+
|
17 |
+
# Download and cache the model automatically
|
18 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
19 |
+
|
20 |
+
# Load the model
|
21 |
+
with open(model_path, "rb") as f:
|
22 |
+
model = pickle.load(f)
|
23 |
+
|
24 |
+
def main():
|
25 |
+
# Streamlit app
|
26 |
+
st.title("Customer Churn Prediction App")
|
27 |
+
st.write("This app demonstrates how to use Hugging Face Datasets and Models for a "
|
28 |
+
"real-world classification task: predicting customer churn.")
|
29 |
+
|
30 |
+
# --- Dataset Exploration ---
|
31 |
+
st.header("Dataset Exploration")
|
32 |
+
st.write("Let's explore the customer churn dataset:")
|
33 |
+
|
34 |
+
# Display dataset information
|
35 |
+
st.subheader("Dataset Sample")
|
36 |
+
st.write(dataset.to_pandas().head())
|
37 |
+
|
38 |
+
# Split data
|
39 |
+
# Convert dataset to Pandas DataFrame
|
40 |
+
df = dataset.to_pandas()
|
41 |
+
train_data, test_data = train_test_split(df, test_size=0.2, random_state=42)
|
42 |
+
|
43 |
+
# Define X_train and y_train
|
44 |
+
X_train = train_data.drop('churn', axis=1)
|
45 |
+
y_train = train_data['churn']
|
46 |
+
|
47 |
+
# Show dataset size
|
48 |
+
st.write(f"**Dataset Size:** {len(dataset)} rows")
|
49 |
+
|
50 |
+
# Visualize churn distribution
|
51 |
+
st.subheader("Churn Distribution")
|
52 |
+
fig, ax = plt.subplots()
|
53 |
+
sns.countplot(x='churn', data=dataset.to_pandas(), ax=ax)
|
54 |
+
ax.set_xticks([0, 1]) # Set tick locations
|
55 |
+
ax.set_xticklabels(["No Churn", "Churn"])
|
56 |
+
st.pyplot(fig)
|
57 |
+
|
58 |
+
# --- Feature Importance ---
|
59 |
+
st.header("Feature Importance")
|
60 |
+
st.write("Understanding which features contribute most to churn prediction:")
|
61 |
+
|
62 |
+
# Get feature importances (coefficients)
|
63 |
+
feature_importance = pd.DataFrame({'Feature': X_train.columns,
|
64 |
+
'Importance': model.coef_[0]})
|
65 |
+
feature_importance = feature_importance.sort_values(by='Importance', ascending=False)
|
66 |
+
|
67 |
+
# Plot feature importances
|
68 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
69 |
+
sns.barplot(x='Importance', y='Feature', data=feature_importance, ax=ax)
|
70 |
+
st.pyplot(fig)
|
71 |
+
|
72 |
+
# --- Churn Prediction ---
|
73 |
+
st.header("Predict Customer Churn")
|
74 |
+
st.write("Enter customer data to predict churn:")
|
75 |
+
|
76 |
+
# Input features
|
77 |
+
tenure = st.number_input("Tenure (months)", min_value=1, step=1)
|
78 |
+
monthly_charges = st.number_input("Monthly Charges", min_value=0.0, step=0.01)
|
79 |
+
total_charges = st.number_input("Total Charges", min_value=0.0, step=0.01)
|
80 |
+
churn = st.selectbox("Churn", ['Yes', 'No'])
|
81 |
+
contract = st.selectbox("Contract", ['Month-to-month', 'One year', 'Two year'])
|
82 |
+
internet_service = st.selectbox("Internet Service", ['DSL', 'Fiber optic', 'No'])
|
83 |
+
|
84 |
+
# Preprocess input features
|
85 |
+
input_features = {
|
86 |
+
'tenure': tenure,
|
87 |
+
'monthly_charges': monthly_charges,
|
88 |
+
'total_charges': total_charges,
|
89 |
+
'contract_One year': int(contract == 'One year'),
|
90 |
+
'contract_Two year': int(contract == 'Two year'),
|
91 |
+
'internet_service_Fiber optic': int(internet_service == 'Fiber optic'),
|
92 |
+
'internet_service_No': int(internet_service == 'No')
|
93 |
+
}
|
94 |
+
input_df = pd.DataFrame([input_features])
|
95 |
+
|
96 |
+
# Ensure feature names match those used during training
|
97 |
+
train_columns = X_train.columns
|
98 |
+
input_df = input_df[train_columns] # Reorder columns to match train data
|
99 |
+
|
100 |
+
# Make prediction
|
101 |
+
if st.button("Predict Churn"):
|
102 |
+
prediction = model.predict(input_df)[0]
|
103 |
+
probability = model.predict_proba(input_df)[0][1] # Probability of churn
|
104 |
+
if prediction == 1:
|
105 |
+
st.write("This customer is likely to **churn**.")
|
106 |
+
else:
|
107 |
+
st.write("This customer is likely to **stay**.")
|
108 |
+
st.write(f"Churn Probability: {probability:.2f}")
|
109 |
+
|
110 |
+
# --- Hugging Face Explanation ---
|
111 |
+
st.header("Hugging Face for Machine Learning")
|
112 |
+
st.write(
|
113 |
+
"""
|
114 |
+
This app showcases the power of Hugging Face for building ML applications.
|
115 |
+
- **Datasets:** Easily access and share datasets.
|
116 |
+
- **Models:** Download and use pre-trained models or upload your own.
|
117 |
+
- **Spaces:** Deploy your app with a simple interface.
|
118 |
+
|
119 |
+
Explore Hugging Face and build your own amazing ML projects!
|
120 |
+
"""
|
121 |
+
)
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
main()
|
churn_prediction_model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a98e018dad9148716364367ffdf9cdb9f95071a1fdae22cb0aed6439517a7305
|
3 |
+
size 996
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
datasets
|
3 |
+
pandas
|
4 |
+
matplotlib
|
5 |
+
seaborn
|
6 |
+
scikit-learn
|