blacksw0rd's picture
Upload 6 files
cfd24e0 verified
raw
history blame
2.23 kB
import streamlit as st
import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import plotly.express as px
st.title("Customer Segmentation")
kmeans = joblib.load("kmeans.pkl")
scaler = joblib.load("scaler.pkl")
rfm = pd.read_csv("transformation.csv")
cluster_label = {0: 'Loyal Customers', 1: 'At Risk', 2: 'Champions', 3: 'New Customers'}
def customer_segmentation(num1,num2,num3):
print("Customer Segmentation")
data_recency = np.log1p(num1)
data_frequency = np.log1p(num2)
data_monetary = np.log1p(num3)
data = pd.DataFrame({'Recency': [data_recency], 'Frequency': [data_frequency], 'Monetary': [data_monetary]})
X_data = scaler.transform(data)
pred = kmeans.predict(X_data)
return cluster_label[pred[0]]
col1,col2,col3 = st.columns(3)
num1 = col1.number_input("Enter Recency",min_value=1,max_value=400,step=1)
num2 = col2.number_input("Enter Frequency",min_value=1,max_value=6000,step=1)
num3 = col3.number_input("Enter Monetary",min_value=1,step=10)
value = ""
if st.button(label="Predict"):
value = customer_segmentation(num1,num2,num3)
st.markdown(f"<span style='font-size:20px; font-weight:bold; font-style:italic'>{value}</span>",unsafe_allow_html=True)
custom_colors = {
'Loyal Customers': '#99ff99',
'Champions': '#66b3ff',
'At Risk': '#ff9999',
'New Customers': '#ffcc99'
}
figx = px.scatter_3d(
rfm,
x='Recency',
y='Frequency',
z='Monetary',
color='Cluster Labels',
color_discrete_map=custom_colors,
labels={'Recency': 'Recency', 'Frequency': 'Frequency', 'Monetary': 'Monetary'},
title='Customer Segmentation Visualization'
)
st.plotly_chart(figx)
customers = rfm.shape[0]
labels = ['Loyal Customers','At Risk','Champions','New Customers']
sizes = (rfm["Cluster"].value_counts()/customers)*100
colors = ['#99ff99', '#ff9999', '#66b3ff', '#ffcc99']
fig,ax = plt.subplots(figsize=(8,6))
ax.pie(
sizes, labels=labels, colors=colors, autopct='%1.1f%%',
startangle=120, wedgeprops={'edgecolor': 'black'}
)
ax.set_title('Customer Segmentation', fontsize=14)
ax.legend([0,1,2,3],title='Clusters',loc='best',)
st.pyplot(fig)