ssyok commited on
Commit
95eab0a
·
1 Parent(s): ac0b681

first commit app.py and the content

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ .gradio/
Data_Analytics_SHE_Course_Group_Assignment_Machine_Learning.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import os
6
+
7
+ # ==============================================================================
8
+ # 1. LOAD MODELS AND SCALER (This part runs once when the script starts)
9
+ # ==============================================================================
10
+
11
+ # Dictionary to hold the loaded model objects and a list of their names
12
+ all_models = {}
13
+ model_names = [
14
+ 'Linear Regression', 'Ridge Regression', 'Lasso Regression',
15
+ 'Random Forest', 'Gradient Boosting'
16
+ ]
17
+ BEST_MODEL_NAME = 'Random Forest' # Define the best model to be highlighted
18
+
19
+ try:
20
+ # Load all the regression models
21
+ for name in model_names:
22
+ # Construct the filename, e.g., 'models/random_forest.joblib'
23
+ filename = f"models/{name.lower().replace(' ', '_')}.joblib"
24
+ if os.path.exists(filename):
25
+ all_models[name] = joblib.load(filename)
26
+ else:
27
+ raise FileNotFoundError(f"Model file not found: {filename}")
28
+
29
+ # Load the scaler
30
+ scaler_path = 'models/scaler.joblib'
31
+ if os.path.exists(scaler_path):
32
+ scaler = joblib.load(scaler_path)
33
+ else:
34
+ raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
35
+
36
+ models_loaded = True
37
+ print("✅ All models and scaler loaded successfully!")
38
+
39
+ # Get the feature names the model was trained on from the scaler
40
+ expected_columns = scaler.feature_names_in_
41
+ print(f"Models expect {len(expected_columns)} features.")
42
+
43
+ except Exception as e:
44
+ print(f"❌ ERROR: Could not load models. {e}")
45
+ print("Please ensure all '.joblib' files are in the 'models/' directory.")
46
+ models_loaded = False
47
+ all_models = {}
48
+ scaler = None
49
+ expected_columns = []
50
+
51
+ # ==============================================================================
52
+ # 2. PREDICTION FUNCTION
53
+ # ==============================================================================
54
+
55
+ def predict_shares_all_models(likes, generation_time, gpu_usage, file_size_kb,
56
+ width, height, style_accuracy_score,
57
+ is_hand_edited, ethical_concerns_flag,
58
+ day_of_week, month, hour, platform):
59
+ """
60
+ Performs feature engineering, predicts shares using all loaded models,
61
+ and returns formatted outputs for the Gradio interface.
62
+ """
63
+ if not models_loaded:
64
+ error_message = "Models are not loaded. Please check the console for errors."
65
+ return 0, error_message, error_message
66
+
67
+ # --- Step A: Perform feature engineering ---
68
+ sample_data = {
69
+ 'likes': likes,
70
+ 'style_accuracy_score': style_accuracy_score,
71
+ 'generation_time': generation_time,
72
+ 'gpu_usage': gpu_usage,
73
+ 'file_size_kb': file_size_kb,
74
+ 'is_hand_edited': int(is_hand_edited),
75
+ 'ethical_concerns_flag': int(ethical_concerns_flag),
76
+ 'width': width,
77
+ 'height': height,
78
+ 'day_of_week': day_of_week,
79
+ 'month': month,
80
+ 'hour': hour
81
+ }
82
+
83
+ sample_data['aspect_ratio'] = width / height if height > 0 else 0
84
+ sample_data['total_pixels'] = width * height
85
+ sample_data['is_square'] = int(width == height)
86
+ sample_data['is_weekend'] = int(day_of_week >= 5)
87
+
88
+ for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
89
+ sample_data[f'platform_{p}'] = 1 if platform == p else 0
90
+
91
+ sample_data['engagement_rate'] = likes / (sample_data['total_pixels'] / 1000000 + 1)
92
+ sample_data['quality_engagement'] = style_accuracy_score * likes / 100
93
+ sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
94
+ sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
95
+
96
+ for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
97
+ sample_data[f'{p.lower()}_likes'] = likes * sample_data[f'platform_{p}']
98
+
99
+ sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
100
+ sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
101
+ sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
102
+ sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
103
+
104
+ # --- Step B: Align columns and Scale ---
105
+ sample_df = pd.DataFrame([sample_data])
106
+ sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
107
+ sample_scaled = scaler.transform(sample_df)
108
+
109
+ # --- Step C: Predict with all models ---
110
+ predictions = {}
111
+ for name, model in all_models.items():
112
+ pred_value = model.predict(sample_scaled)[0]
113
+ predictions[name] = max(0, int(pred_value))
114
+
115
+ # --- Step D: Format the outputs for Gradio ---
116
+
117
+ # 1. Get the single best model prediction
118
+ best_model_prediction = predictions.get(BEST_MODEL_NAME, 0)
119
+
120
+ # 2. Create a Markdown table for all model predictions
121
+ all_results_df = pd.DataFrame(list(predictions.items()), columns=['Model', 'Predicted Shares'])
122
+ all_results_df = all_results_df.sort_values('Predicted Shares', ascending=False)
123
+ all_models_table = all_results_df.to_markdown(index=False)
124
+
125
+ # 3. Create a Markdown table for the engineered features
126
+ features_df = sample_df.T.reset_index()
127
+ features_df.columns = ['Feature', 'Value']
128
+ features_df['Value'] = features_df['Value'].apply(lambda x: f"{x:.4f}" if isinstance(x, float) else x)
129
+ features_table = features_df.to_markdown(index=False)
130
+
131
+ return best_model_prediction, all_models_table, features_table
132
+
133
+ # ==============================================================================
134
+ # 3. GRADIO INTERFACE
135
+ # ==============================================================================
136
+
137
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
138
+ gr.Markdown("# 🎨 AI Ghibli Image Virality Predictor")
139
+ gr.Markdown("Enter image features to get a virality prediction from multiple regression models.")
140
+
141
+ with gr.Row():
142
+ # --- INPUTS COLUMN ---
143
+ with gr.Column(scale=2):
144
+ gr.Markdown("### 1. Input Features")
145
+ with gr.Accordion("Core Engagement & Image Metrics", open=True):
146
+ likes = gr.Slider(minimum=0, maximum=10000, value=500, step=10, label="Likes")
147
+ style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, label="Style Accuracy Score (%)")
148
+ width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Width (px)")
149
+ height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Height (px)")
150
+ file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, label="File Size (KB)")
151
+
152
+ with gr.Accordion("Technical & Posting Details", open=True):
153
+ generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, label="Generation Time (s)")
154
+ gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, label="GPU Usage (%)")
155
+ platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"], label="Platform", value="Instagram")
156
+ day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, label="Day of Week (0=Mon, 6=Sun)")
157
+ month = gr.Slider(minimum=1, maximum=12, value=7, step=1, label="Month (1-12)")
158
+ hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Hour of Day (0-23)")
159
+ is_hand_edited = gr.Checkbox(label="Was it Hand Edited?", value=False)
160
+ ethical_concerns_flag = gr.Checkbox(label="Any Ethical Concerns?", value=False)
161
+
162
+ predict_btn = gr.Button("Predict Virality", variant="primary")
163
+
164
+ # --- OUTPUTS COLUMN ---
165
+ with gr.Column(scale=3):
166
+ gr.Markdown("### 2. Prediction Results")
167
+
168
+ # Highlighted Best Model Output
169
+ best_model_output = gr.Number(
170
+ label=f"🏆 Best Model Prediction ({BEST_MODEL_NAME})",
171
+ interactive=False
172
+ )
173
+
174
+ # Table for All Model Predictions
175
+ with gr.Accordion("Comparison of All Models", open=True):
176
+ all_models_output = gr.Markdown(label="All Model Predictions")
177
+
178
+ # Table for Feature Engineering Details
179
+ with gr.Accordion("View Engineered Features", open=False):
180
+ features_output = gr.Markdown(label="Feature Engineering Details")
181
+
182
+ # Connect the button to the function
183
+ predict_btn.click(
184
+ fn=predict_shares_all_models,
185
+ inputs=[
186
+ likes, generation_time, gpu_usage, file_size_kb,
187
+ width, height, style_accuracy_score,
188
+ is_hand_edited, ethical_concerns_flag,
189
+ day_of_week, month, hour, platform
190
+ ],
191
+ outputs=[
192
+ best_model_output,
193
+ all_models_output,
194
+ features_output
195
+ ]
196
+ )
197
+
198
+ # Launch the app
199
+ if __name__ == "__main__":
200
+ if not models_loaded:
201
+ print("\nCannot launch Gradio app because models failed to load.")
202
+ else:
203
+ demo.launch()
dataset/ai_ghibli_trend_dataset_v2.csv ADDED
The diff for this file is too large to render. See raw diff
 
models/lasso_regression.joblib ADDED
Binary file (864 Bytes). View file
 
models/linear_regression.joblib ADDED
Binary file (1.03 kB). View file
 
models/ridge_regression.joblib ADDED
Binary file (785 Bytes). View file
 
models/scaler.joblib ADDED
Binary file (1.98 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pandas>=2.2.0
2
+ numpy>=1.26.0
3
+ matplotlib>=3.8.0
4
+ seaborn>=0.13.0
5
+ scikit-learn>=1.4.0
6
+ gradio>=5.3.0
7
+ tabulate
results/model_comparison.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Model,R² Score,MAE,RMSE
2
+ Random Forest,-0.08503617106593597,518.7684998150959,593.7506113831865
3
+ Ridge Regression,-0.08676700700361528,528.5892908496174,594.223994396894
4
+ Lasso Regression,-0.08764565917116918,528.573248905534,594.4641611976449
5
+ Linear Regression,-0.09914175984109797,531.4132783380423,597.5975604762958
6
+ Gradient Boosting,-0.2357883805937282,537.630957194942,633.6566769562837
results/regression_analysis_results.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_info": {
3
+ "total_samples": 500,
4
+ "features": 29,
5
+ "target_mean": 1040.182,
6
+ "target_median": 1092.0,
7
+ "target_std": 562.6687383302794
8
+ },
9
+ "model_comparison": [
10
+ {
11
+ "Model": "Random Forest",
12
+ "R\u00b2 Score": -0.08503617106593597,
13
+ "MAE": 518.7684998150959,
14
+ "RMSE": 593.7506113831865
15
+ },
16
+ {
17
+ "Model": "Ridge Regression",
18
+ "R\u00b2 Score": -0.08676700700361528,
19
+ "MAE": 528.5892908496174,
20
+ "RMSE": 594.223994396894
21
+ },
22
+ {
23
+ "Model": "Lasso Regression",
24
+ "R\u00b2 Score": -0.08764565917116918,
25
+ "MAE": 528.573248905534,
26
+ "RMSE": 594.4641611976449
27
+ },
28
+ {
29
+ "Model": "Linear Regression",
30
+ "R\u00b2 Score": -0.09914175984109797,
31
+ "MAE": 531.4132783380423,
32
+ "RMSE": 597.5975604762958
33
+ },
34
+ {
35
+ "Model": "Gradient Boosting",
36
+ "R\u00b2 Score": -0.2357883805937282,
37
+ "MAE": 537.630957194942,
38
+ "RMSE": 633.6566769562837
39
+ }
40
+ ],
41
+ "feature_correlations": [
42
+ {
43
+ "feature": "platform_Twitter",
44
+ "correlation": -0.11310486794074195
45
+ },
46
+ {
47
+ "feature": "platform_Instagram",
48
+ "correlation": 0.07096989443791954
49
+ },
50
+ {
51
+ "feature": "total_pixels",
52
+ "correlation": 0.05340067376167711
53
+ },
54
+ {
55
+ "feature": "width",
56
+ "correlation": 0.050954190148673084
57
+ },
58
+ {
59
+ "feature": "height",
60
+ "correlation": 0.050954190148673084
61
+ },
62
+ {
63
+ "feature": "platform_Reddit",
64
+ "correlation": 0.030824709708669493
65
+ },
66
+ {
67
+ "feature": "likes",
68
+ "correlation": -0.029318071149881914
69
+ },
70
+ {
71
+ "feature": "is_hand_edited",
72
+ "correlation": 0.02824023580536551
73
+ },
74
+ {
75
+ "feature": "day_of_week",
76
+ "correlation": 0.02490306263783807
77
+ },
78
+ {
79
+ "feature": "file_size_kb",
80
+ "correlation": -0.020748477243945303
81
+ }
82
+ ]
83
+ }