saranimje commited on
Commit
a0a2c37
·
verified ·
1 Parent(s): fa0dd39

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import joblib
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+ best_model = joblib.load("best_model.pkl")
7
+ roc_img = Image.open("roc_curve_rf_tuned.png")
8
+
9
+ def churn_prediction(age, gender, tenure, usage_frequency, support_calls,
10
+ payment_delay, last_interaction, total_spend,
11
+ subscription_type, contract_length):
12
+ try:
13
+
14
+ input_data = {
15
+ "Age": age,
16
+ "Gender_Male": 1 if gender == "Male" else 0,
17
+ "Gender_Female": 1 if gender == "Female" else 0,
18
+ "Usage Frequency": usage_frequency,
19
+ "Support Calls": support_calls,
20
+ "Contract Length_Monthly": 1 if contract_length == "Monthly" else 0,
21
+ "Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0,
22
+ "Contract Length_Annual": 1 if contract_length == "Annual" else 0,
23
+ "Payment Delay": payment_delay,
24
+ "Last Interaction": last_interaction,
25
+ "Total Spend": total_spend,
26
+ "Tenure": tenure,
27
+ "Subscription Type_Basic": 1 if subscription_type == "Basic" else 0,
28
+ "Subscription Type_Premium": 1 if subscription_type == "Premium" else 0,
29
+ "Subscription Type_Standard": 1 if subscription_type == "Standard" else 0,
30
+ }
31
+
32
+ input_df = pd.DataFrame([input_data])
33
+
34
+ # Predict churn and probability
35
+ prediction = best_model.predict(input_df)
36
+ prediction_proba = best_model.predict_proba(input_df)[:, 1]
37
+
38
+ churn_result = "Yes" if prediction[0] == 1 else "No"
39
+ churn_probability = f"{prediction_proba[0]:.2f}"
40
+
41
+ return churn_result, churn_probability, roc_img
42
+
43
+ except Exception as e:
44
+ return f"Error: {str(e)}", "N/A", None
45
+
46
+ inputs = [
47
+ gr.Slider(18, 100, value=40, label="Age"),
48
+ gr.Dropdown(["Female", "Male"], value="Male", label="Gender"),
49
+ gr.Slider(1, 60, value=30, label="Tenure (months)"),
50
+ gr.Slider(1, 30, value=15, label="Usage Frequency"),
51
+ gr.Slider(0, 10, value=4, label="Support Calls"),
52
+ gr.Slider(0, 30, value=15, label="Payment Delay"),
53
+ gr.Slider(1, 30, value=15, label="Last Interaction (days ago)"),
54
+ gr.Slider(100, 1000, value=620, label="Total Spend"),
55
+ gr.Dropdown(["Premium", "Standard", "Basic"], value="Standard", label="Subscription Type"),
56
+ gr.Dropdown(["Monthly", "Quarterly", "Annual"], value="Annual", label="Contract Length")
57
+ ]
58
+
59
+ outputs = [
60
+ gr.Textbox(label="Churn Prediction"),
61
+ gr.Textbox(label="Churn Probability"),
62
+ gr.Image(label="ROC Curve")
63
+ ]
64
+
65
+ gr.Interface(
66
+ fn=churn_prediction,
67
+ inputs=inputs,
68
+ outputs=outputs,
69
+ title="Customer Churn Prediction"
70
+ ).launch()