saranimje commited on
Commit
c593dca
·
verified ·
1 Parent(s): 0c09a1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -58
app.py CHANGED
@@ -1,70 +1,97 @@
1
- import gradio as gr
2
  import joblib
3
  import pandas as pd
4
  from PIL import Image
5
 
6
- best_model = joblib.load("best_model.pkl")
7
- roc_img = Image.open("roc_curve_rf_tuned.png")
 
 
8
 
9
- def churn_prediction(age, gender, tenure, usage_frequency, support_calls,
10
- payment_delay, last_interaction, total_spend,
11
- subscription_type, contract_length):
12
- try:
13
-
14
- input_data = {
15
- "Age": age,
16
- "Gender_Male": 1 if gender == "Male" else 0,
17
- "Gender_Female": 1 if gender == "Female" else 0,
18
- "Usage Frequency": usage_frequency,
19
- "Support Calls": support_calls,
20
- "Contract Length_Monthly": 1 if contract_length == "Monthly" else 0,
21
- "Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0,
22
- "Contract Length_Annual": 1 if contract_length == "Annual" else 0,
23
- "Payment Delay": payment_delay,
24
- "Last Interaction": last_interaction,
25
- "Total Spend": total_spend,
26
- "Tenure": tenure,
27
- "Subscription Type_Basic": 1 if subscription_type == "Basic" else 0,
28
- "Subscription Type_Premium": 1 if subscription_type == "Premium" else 0,
29
- "Subscription Type_Standard": 1 if subscription_type == "Standard" else 0,
30
- }
31
-
32
- input_df = pd.DataFrame([input_data])
33
 
34
- # Predict churn and probability
35
- prediction = best_model.predict(input_df)
36
- prediction_proba = best_model.predict_proba(input_df)[:, 1]
 
 
 
37
 
38
- churn_result = "Yes" if prediction[0] == 1 else "No"
39
- churn_probability = f"{prediction_proba[0]:.2f}"
 
40
 
41
- return churn_result, churn_probability, roc_img
 
42
 
43
- except Exception as e:
44
- return f"Error: {str(e)}", "N/A", None
 
 
 
 
45
 
46
- inputs = [
47
- gr.Slider(18, 100, value=40, label="Age"),
48
- gr.Dropdown(["Female", "Male"], value="Male", label="Gender"),
49
- gr.Slider(1, 60, value=30, label="Tenure (months)"),
50
- gr.Slider(1, 30, value=15, label="Usage Frequency"),
51
- gr.Slider(0, 10, value=4, label="Support Calls"),
52
- gr.Slider(0, 30, value=15, label="Payment Delay"),
53
- gr.Slider(1, 30, value=15, label="Last Interaction (days ago)"),
54
- gr.Slider(100, 1000, value=620, label="Total Spend"),
55
- gr.Dropdown(["Premium", "Standard", "Basic"], value="Standard", label="Subscription Type"),
56
- gr.Dropdown(["Monthly", "Quarterly", "Annual"], value="Annual", label="Contract Length")
57
- ]
58
 
59
- outputs = [
60
- gr.Textbox(label="Churn Prediction"),
61
- gr.Textbox(label="Churn Probability"),
62
- gr.Image(label="ROC Curve")
63
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- gr.Interface(
66
- fn=churn_prediction,
67
- inputs=inputs,
68
- outputs=outputs,
69
- title="Customer Churn Prediction"
70
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import joblib
3
  import pandas as pd
4
  from PIL import Image
5
 
6
+ # Load the model and image
7
+ @st.cache_resource
8
+ def load_model():
9
+ return joblib.load("best_model.pkl")
10
 
11
+ @st.cache_data
12
+ def load_roc_image():
13
+ return Image.open("roc_curve_rf_tuned.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ try:
16
+ best_model = load_model()
17
+ roc_img = load_roc_image()
18
+ except Exception as e:
19
+ st.error(f"Error loading model or image: {str(e)}")
20
+ st.stop()
21
 
22
+ # App title and description
23
+ st.title("Customer Churn Prediction")
24
+ st.write("Enter customer information to predict likelihood of churn")
25
 
26
+ # Create two columns for inputs
27
+ col1, col2 = st.columns(2)
28
 
29
+ with col1:
30
+ age = st.slider("Age", min_value=18, max_value=100, value=40)
31
+ gender = st.selectbox("Gender", options=["Male", "Female"])
32
+ tenure = st.slider("Tenure (months)", min_value=1, max_value=60, value=30)
33
+ usage_frequency = st.slider("Usage Frequency", min_value=1, max_value=30, value=15)
34
+ support_calls = st.slider("Support Calls", min_value=0, max_value=10, value=4)
35
 
36
+ with col2:
37
+ payment_delay = st.slider("Payment Delay", min_value=0, max_value=30, value=15)
38
+ last_interaction = st.slider("Last Interaction (days ago)", min_value=1, max_value=30, value=15)
39
+ total_spend = st.slider("Total Spend", min_value=100, max_value=1000, value=620)
40
+ subscription_type = st.selectbox("Subscription Type", options=["Premium", "Standard", "Basic"])
41
+ contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Annual"])
 
 
 
 
 
 
42
 
43
+ # Prediction function
44
+ def make_prediction():
45
+ input_data = {
46
+ "Age": age,
47
+ "Gender_Male": 1 if gender == "Male" else 0,
48
+ "Gender_Female": 1 if gender == "Female" else 0,
49
+ "Usage Frequency": usage_frequency,
50
+ "Support Calls": support_calls,
51
+ "Contract Length_Monthly": 1 if contract_length == "Monthly" else 0,
52
+ "Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0,
53
+ "Contract Length_Annual": 1 if contract_length == "Annual" else 0,
54
+ "Payment Delay": payment_delay,
55
+ "Last Interaction": last_interaction,
56
+ "Total Spend": total_spend,
57
+ "Tenure": tenure,
58
+ "Subscription Type_Basic": 1 if subscription_type == "Basic" else 0,
59
+ "Subscription Type_Premium": 1 if subscription_type == "Premium" else 0,
60
+ "Subscription Type_Standard": 1 if subscription_type == "Standard" else 0,
61
+ }
62
+
63
+ input_df = pd.DataFrame([input_data])
64
+
65
+ # Predict churn and probability
66
+ prediction = best_model.predict(input_df)
67
+ prediction_proba = best_model.predict_proba(input_df)[:, 1]
68
+
69
+ return prediction[0], prediction_proba[0]
70
 
71
+ # Make prediction when button is clicked
72
+ if st.button("Predict Churn"):
73
+ try:
74
+ prediction, probability = make_prediction()
75
+
76
+ # Display results
77
+ st.header("Prediction Results")
78
+
79
+ # Create three columns for results
80
+ col1, col2, col3 = st.columns(3)
81
+
82
+ with col1:
83
+ st.metric("Churn Prediction", "Yes" if prediction == 1 else "No")
84
+
85
+ with col2:
86
+ st.metric("Churn Probability", f"{probability:.2f}")
87
+
88
+ with col3:
89
+ risk_level = "High" if probability > 0.7 else ("Medium" if probability > 0.4 else "Low")
90
+ st.metric("Risk Level", risk_level)
91
+
92
+ # Display ROC curve
93
+ st.subheader("Model ROC Curve")
94
+ st.image(roc_img, caption="ROC Curve for Random Forest Model")
95
+
96
+ except Exception as e:
97
+ st.error(f"Error making prediction: {str(e)}")