first commit app.py and the content
Browse files- .gitignore +3 -0
- Data_Analytics_SHE_Course_Group_Assignment_Machine_Learning.ipynb +0 -0
- app.py +203 -0
- dataset/ai_ghibli_trend_dataset_v2.csv +0 -0
- models/lasso_regression.joblib +0 -0
- models/linear_regression.joblib +0 -0
- models/ridge_regression.joblib +0 -0
- models/scaler.joblib +0 -0
- requirements.txt +7 -0
- results/model_comparison.csv +6 -0
- results/regression_analysis_results.json +83 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
__pycache__/
|
3 |
+
.gradio/
|
Data_Analytics_SHE_Course_Group_Assignment_Machine_Learning.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import joblib
|
5 |
+
import os
|
6 |
+
|
7 |
+
# ==============================================================================
|
8 |
+
# 1. LOAD MODELS AND SCALER (This part runs once when the script starts)
|
9 |
+
# ==============================================================================
|
10 |
+
|
11 |
+
# Dictionary to hold the loaded model objects and a list of their names
|
12 |
+
all_models = {}
|
13 |
+
model_names = [
|
14 |
+
'Linear Regression', 'Ridge Regression', 'Lasso Regression',
|
15 |
+
'Random Forest', 'Gradient Boosting'
|
16 |
+
]
|
17 |
+
BEST_MODEL_NAME = 'Random Forest' # Define the best model to be highlighted
|
18 |
+
|
19 |
+
try:
|
20 |
+
# Load all the regression models
|
21 |
+
for name in model_names:
|
22 |
+
# Construct the filename, e.g., 'models/random_forest.joblib'
|
23 |
+
filename = f"models/{name.lower().replace(' ', '_')}.joblib"
|
24 |
+
if os.path.exists(filename):
|
25 |
+
all_models[name] = joblib.load(filename)
|
26 |
+
else:
|
27 |
+
raise FileNotFoundError(f"Model file not found: {filename}")
|
28 |
+
|
29 |
+
# Load the scaler
|
30 |
+
scaler_path = 'models/scaler.joblib'
|
31 |
+
if os.path.exists(scaler_path):
|
32 |
+
scaler = joblib.load(scaler_path)
|
33 |
+
else:
|
34 |
+
raise FileNotFoundError(f"Scaler file not found: {scaler_path}")
|
35 |
+
|
36 |
+
models_loaded = True
|
37 |
+
print("✅ All models and scaler loaded successfully!")
|
38 |
+
|
39 |
+
# Get the feature names the model was trained on from the scaler
|
40 |
+
expected_columns = scaler.feature_names_in_
|
41 |
+
print(f"Models expect {len(expected_columns)} features.")
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
print(f"❌ ERROR: Could not load models. {e}")
|
45 |
+
print("Please ensure all '.joblib' files are in the 'models/' directory.")
|
46 |
+
models_loaded = False
|
47 |
+
all_models = {}
|
48 |
+
scaler = None
|
49 |
+
expected_columns = []
|
50 |
+
|
51 |
+
# ==============================================================================
|
52 |
+
# 2. PREDICTION FUNCTION
|
53 |
+
# ==============================================================================
|
54 |
+
|
55 |
+
def predict_shares_all_models(likes, generation_time, gpu_usage, file_size_kb,
|
56 |
+
width, height, style_accuracy_score,
|
57 |
+
is_hand_edited, ethical_concerns_flag,
|
58 |
+
day_of_week, month, hour, platform):
|
59 |
+
"""
|
60 |
+
Performs feature engineering, predicts shares using all loaded models,
|
61 |
+
and returns formatted outputs for the Gradio interface.
|
62 |
+
"""
|
63 |
+
if not models_loaded:
|
64 |
+
error_message = "Models are not loaded. Please check the console for errors."
|
65 |
+
return 0, error_message, error_message
|
66 |
+
|
67 |
+
# --- Step A: Perform feature engineering ---
|
68 |
+
sample_data = {
|
69 |
+
'likes': likes,
|
70 |
+
'style_accuracy_score': style_accuracy_score,
|
71 |
+
'generation_time': generation_time,
|
72 |
+
'gpu_usage': gpu_usage,
|
73 |
+
'file_size_kb': file_size_kb,
|
74 |
+
'is_hand_edited': int(is_hand_edited),
|
75 |
+
'ethical_concerns_flag': int(ethical_concerns_flag),
|
76 |
+
'width': width,
|
77 |
+
'height': height,
|
78 |
+
'day_of_week': day_of_week,
|
79 |
+
'month': month,
|
80 |
+
'hour': hour
|
81 |
+
}
|
82 |
+
|
83 |
+
sample_data['aspect_ratio'] = width / height if height > 0 else 0
|
84 |
+
sample_data['total_pixels'] = width * height
|
85 |
+
sample_data['is_square'] = int(width == height)
|
86 |
+
sample_data['is_weekend'] = int(day_of_week >= 5)
|
87 |
+
|
88 |
+
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
|
89 |
+
sample_data[f'platform_{p}'] = 1 if platform == p else 0
|
90 |
+
|
91 |
+
sample_data['engagement_rate'] = likes / (sample_data['total_pixels'] / 1000000 + 1)
|
92 |
+
sample_data['quality_engagement'] = style_accuracy_score * likes / 100
|
93 |
+
sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
|
94 |
+
sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
|
95 |
+
|
96 |
+
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
|
97 |
+
sample_data[f'{p.lower()}_likes'] = likes * sample_data[f'platform_{p}']
|
98 |
+
|
99 |
+
sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
|
100 |
+
sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
|
101 |
+
sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
|
102 |
+
sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
|
103 |
+
|
104 |
+
# --- Step B: Align columns and Scale ---
|
105 |
+
sample_df = pd.DataFrame([sample_data])
|
106 |
+
sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
|
107 |
+
sample_scaled = scaler.transform(sample_df)
|
108 |
+
|
109 |
+
# --- Step C: Predict with all models ---
|
110 |
+
predictions = {}
|
111 |
+
for name, model in all_models.items():
|
112 |
+
pred_value = model.predict(sample_scaled)[0]
|
113 |
+
predictions[name] = max(0, int(pred_value))
|
114 |
+
|
115 |
+
# --- Step D: Format the outputs for Gradio ---
|
116 |
+
|
117 |
+
# 1. Get the single best model prediction
|
118 |
+
best_model_prediction = predictions.get(BEST_MODEL_NAME, 0)
|
119 |
+
|
120 |
+
# 2. Create a Markdown table for all model predictions
|
121 |
+
all_results_df = pd.DataFrame(list(predictions.items()), columns=['Model', 'Predicted Shares'])
|
122 |
+
all_results_df = all_results_df.sort_values('Predicted Shares', ascending=False)
|
123 |
+
all_models_table = all_results_df.to_markdown(index=False)
|
124 |
+
|
125 |
+
# 3. Create a Markdown table for the engineered features
|
126 |
+
features_df = sample_df.T.reset_index()
|
127 |
+
features_df.columns = ['Feature', 'Value']
|
128 |
+
features_df['Value'] = features_df['Value'].apply(lambda x: f"{x:.4f}" if isinstance(x, float) else x)
|
129 |
+
features_table = features_df.to_markdown(index=False)
|
130 |
+
|
131 |
+
return best_model_prediction, all_models_table, features_table
|
132 |
+
|
133 |
+
# ==============================================================================
|
134 |
+
# 3. GRADIO INTERFACE
|
135 |
+
# ==============================================================================
|
136 |
+
|
137 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
|
138 |
+
gr.Markdown("# 🎨 AI Ghibli Image Virality Predictor")
|
139 |
+
gr.Markdown("Enter image features to get a virality prediction from multiple regression models.")
|
140 |
+
|
141 |
+
with gr.Row():
|
142 |
+
# --- INPUTS COLUMN ---
|
143 |
+
with gr.Column(scale=2):
|
144 |
+
gr.Markdown("### 1. Input Features")
|
145 |
+
with gr.Accordion("Core Engagement & Image Metrics", open=True):
|
146 |
+
likes = gr.Slider(minimum=0, maximum=10000, value=500, step=10, label="Likes")
|
147 |
+
style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, label="Style Accuracy Score (%)")
|
148 |
+
width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Width (px)")
|
149 |
+
height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Height (px)")
|
150 |
+
file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, label="File Size (KB)")
|
151 |
+
|
152 |
+
with gr.Accordion("Technical & Posting Details", open=True):
|
153 |
+
generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, label="Generation Time (s)")
|
154 |
+
gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, label="GPU Usage (%)")
|
155 |
+
platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"], label="Platform", value="Instagram")
|
156 |
+
day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, label="Day of Week (0=Mon, 6=Sun)")
|
157 |
+
month = gr.Slider(minimum=1, maximum=12, value=7, step=1, label="Month (1-12)")
|
158 |
+
hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Hour of Day (0-23)")
|
159 |
+
is_hand_edited = gr.Checkbox(label="Was it Hand Edited?", value=False)
|
160 |
+
ethical_concerns_flag = gr.Checkbox(label="Any Ethical Concerns?", value=False)
|
161 |
+
|
162 |
+
predict_btn = gr.Button("Predict Virality", variant="primary")
|
163 |
+
|
164 |
+
# --- OUTPUTS COLUMN ---
|
165 |
+
with gr.Column(scale=3):
|
166 |
+
gr.Markdown("### 2. Prediction Results")
|
167 |
+
|
168 |
+
# Highlighted Best Model Output
|
169 |
+
best_model_output = gr.Number(
|
170 |
+
label=f"🏆 Best Model Prediction ({BEST_MODEL_NAME})",
|
171 |
+
interactive=False
|
172 |
+
)
|
173 |
+
|
174 |
+
# Table for All Model Predictions
|
175 |
+
with gr.Accordion("Comparison of All Models", open=True):
|
176 |
+
all_models_output = gr.Markdown(label="All Model Predictions")
|
177 |
+
|
178 |
+
# Table for Feature Engineering Details
|
179 |
+
with gr.Accordion("View Engineered Features", open=False):
|
180 |
+
features_output = gr.Markdown(label="Feature Engineering Details")
|
181 |
+
|
182 |
+
# Connect the button to the function
|
183 |
+
predict_btn.click(
|
184 |
+
fn=predict_shares_all_models,
|
185 |
+
inputs=[
|
186 |
+
likes, generation_time, gpu_usage, file_size_kb,
|
187 |
+
width, height, style_accuracy_score,
|
188 |
+
is_hand_edited, ethical_concerns_flag,
|
189 |
+
day_of_week, month, hour, platform
|
190 |
+
],
|
191 |
+
outputs=[
|
192 |
+
best_model_output,
|
193 |
+
all_models_output,
|
194 |
+
features_output
|
195 |
+
]
|
196 |
+
)
|
197 |
+
|
198 |
+
# Launch the app
|
199 |
+
if __name__ == "__main__":
|
200 |
+
if not models_loaded:
|
201 |
+
print("\nCannot launch Gradio app because models failed to load.")
|
202 |
+
else:
|
203 |
+
demo.launch()
|
dataset/ai_ghibli_trend_dataset_v2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/lasso_regression.joblib
ADDED
Binary file (864 Bytes). View file
|
|
models/linear_regression.joblib
ADDED
Binary file (1.03 kB). View file
|
|
models/ridge_regression.joblib
ADDED
Binary file (785 Bytes). View file
|
|
models/scaler.joblib
ADDED
Binary file (1.98 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas>=2.2.0
|
2 |
+
numpy>=1.26.0
|
3 |
+
matplotlib>=3.8.0
|
4 |
+
seaborn>=0.13.0
|
5 |
+
scikit-learn>=1.4.0
|
6 |
+
gradio>=5.3.0
|
7 |
+
tabulate
|
results/model_comparison.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Model,R² Score,MAE,RMSE
|
2 |
+
Random Forest,-0.08503617106593597,518.7684998150959,593.7506113831865
|
3 |
+
Ridge Regression,-0.08676700700361528,528.5892908496174,594.223994396894
|
4 |
+
Lasso Regression,-0.08764565917116918,528.573248905534,594.4641611976449
|
5 |
+
Linear Regression,-0.09914175984109797,531.4132783380423,597.5975604762958
|
6 |
+
Gradient Boosting,-0.2357883805937282,537.630957194942,633.6566769562837
|
results/regression_analysis_results.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_info": {
|
3 |
+
"total_samples": 500,
|
4 |
+
"features": 29,
|
5 |
+
"target_mean": 1040.182,
|
6 |
+
"target_median": 1092.0,
|
7 |
+
"target_std": 562.6687383302794
|
8 |
+
},
|
9 |
+
"model_comparison": [
|
10 |
+
{
|
11 |
+
"Model": "Random Forest",
|
12 |
+
"R\u00b2 Score": -0.08503617106593597,
|
13 |
+
"MAE": 518.7684998150959,
|
14 |
+
"RMSE": 593.7506113831865
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"Model": "Ridge Regression",
|
18 |
+
"R\u00b2 Score": -0.08676700700361528,
|
19 |
+
"MAE": 528.5892908496174,
|
20 |
+
"RMSE": 594.223994396894
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"Model": "Lasso Regression",
|
24 |
+
"R\u00b2 Score": -0.08764565917116918,
|
25 |
+
"MAE": 528.573248905534,
|
26 |
+
"RMSE": 594.4641611976449
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"Model": "Linear Regression",
|
30 |
+
"R\u00b2 Score": -0.09914175984109797,
|
31 |
+
"MAE": 531.4132783380423,
|
32 |
+
"RMSE": 597.5975604762958
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"Model": "Gradient Boosting",
|
36 |
+
"R\u00b2 Score": -0.2357883805937282,
|
37 |
+
"MAE": 537.630957194942,
|
38 |
+
"RMSE": 633.6566769562837
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"feature_correlations": [
|
42 |
+
{
|
43 |
+
"feature": "platform_Twitter",
|
44 |
+
"correlation": -0.11310486794074195
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"feature": "platform_Instagram",
|
48 |
+
"correlation": 0.07096989443791954
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"feature": "total_pixels",
|
52 |
+
"correlation": 0.05340067376167711
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"feature": "width",
|
56 |
+
"correlation": 0.050954190148673084
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"feature": "height",
|
60 |
+
"correlation": 0.050954190148673084
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"feature": "platform_Reddit",
|
64 |
+
"correlation": 0.030824709708669493
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"feature": "likes",
|
68 |
+
"correlation": -0.029318071149881914
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"feature": "is_hand_edited",
|
72 |
+
"correlation": 0.02824023580536551
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"feature": "day_of_week",
|
76 |
+
"correlation": 0.02490306263783807
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"feature": "file_size_kb",
|
80 |
+
"correlation": -0.020748477243945303
|
81 |
+
}
|
82 |
+
]
|
83 |
+
}
|