File size: 10,841 Bytes
713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae 95eab0a 713c6ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
# Gradio Demo App for Predicting Both Likes and Shares
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import os
# Best models (update based on your results)
BEST_MODEL_LIKES = 'Random Forest'
BEST_MODEL_SHARES = 'Random Forest'
models_loaded = False
# Global variables for models and scaler
all_models_likes = {}
all_models_shares = {}
model_names = []
scaler = None
expected_columns = []
# Prediction Function for Both Likes and Shares
def predict_virality_all_models(generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform):
"""
Predicts both likes and shares using all loaded models.
"""
global all_models_likes, all_models_shares, model_names, scaler, expected_columns
# Create feature dictionary (WITHOUT likes)
sample_data = {
'style_accuracy_score': style_accuracy_score,
'generation_time': generation_time,
'gpu_usage': gpu_usage,
'file_size_kb': file_size_kb,
'is_hand_edited': int(is_hand_edited),
'ethical_concerns_flag': int(ethical_concerns_flag),
'width': width,
'height': height,
'day_of_week': day_of_week,
'month': month,
'hour': hour
}
# Perform feature engineering
sample_data['aspect_ratio'] = width / height if height > 0 else 0
sample_data['total_pixels'] = width * height
sample_data['is_square'] = int(width == height)
sample_data['is_weekend'] = int(day_of_week >= 5)
# One-hot encode platform
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
sample_data[f'platform_{p}'] = 1 if platform == p else 0
# Technical features
sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
# Temporal cyclical features (continued)
sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
sample_data['hour_sin'] = np.sin(2 * np.pi * hour / 24)
sample_data['hour_cos'] = np.cos(2 * np.pi * hour / 24)
# Create DataFrame and align columns
sample_df = pd.DataFrame([sample_data])
sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
# Scale features
try:
sample_scaled = scaler.transform(sample_df)
except Exception as e:
return {}, {}, f"Error during scaling: {e}"
# Predict with all models
predictions_likes = {}
predictions_shares = {}
for name in model_names:
# Predict likes
if name in all_models_likes:
pred_likes = all_models_likes[name].predict(sample_scaled)[0]
predictions_likes[name] = max(0, int(pred_likes))
# Predict shares
if name in all_models_shares:
pred_shares = all_models_shares[name].predict(sample_scaled)[0]
predictions_shares[name] = max(0, int(pred_shares))
return predictions_likes, predictions_shares, None
def load_models():
# Load Models for Both Likes and Shares
global all_models_likes, all_models_shares, model_names, scaler, expected_columns, models_loaded
# Dictionaries to hold the loaded model objects
all_models_likes = {}
all_models_shares = {}
model_names = [
'Linear Regression', 'Ridge Regression', 'Lasso Regression',
'Random Forest', 'Gradient Boosting'
]
try:
# Load all the regression models for both targets
for name in model_names:
# Load likes model
filename_likes = os.path.join("models", f"{name.lower().replace(' ', '_')}_likes.joblib")
all_models_likes[name] = joblib.load(filename_likes)
# Load shares model
filename_shares = os.path.join("models", f"{name.lower().replace(' ', '_')}_shares.joblib")
all_models_shares[name] = joblib.load(filename_shares)
print(f"Loaded: {name} (both likes and shares)")
# Load the scaler
scaler = joblib.load(os.path.join("models", "scaler.joblib"))
print("Loaded: scaler.joblib")
# Get the feature names
expected_columns = scaler.feature_names_in_
print(f"Model expects {len(expected_columns)} features.")
models_loaded = True
print("\nβ
All models and scaler loaded successfully!")
except FileNotFoundError as e:
print(f"\nβ ERROR: Could not find a model file: {e}")
print("Please make sure all '.joblib' files are in the 'models/' directory.")
models_loaded = False
def predict_virality_gradio(generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform):
"""
Gradio wrapper for the prediction function.
Returns formatted outputs for both likes and shares.
"""
if not models_loaded:
error_msg = "Models are not loaded. Please check the console for errors."
return 0, 0, error_msg, error_msg, error_msg
# Get predictions
likes_preds, shares_preds, error = predict_virality_all_models(
generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform
)
if error:
return 0, 0, error, error, error
# Get best model predictions
best_likes = likes_preds.get(BEST_MODEL_LIKES, 0)
best_shares = shares_preds.get(BEST_MODEL_SHARES, 0)
# Create comparison tables
likes_df = pd.DataFrame(list(likes_preds.items()), columns=['Model', 'Predicted Likes'])
likes_df = likes_df.sort_values('Predicted Likes', ascending=False)
likes_table = likes_df.to_markdown(index=False)
shares_df = pd.DataFrame(list(shares_preds.items()), columns=['Model', 'Predicted Shares'])
shares_df = shares_df.sort_values('Predicted Shares', ascending=False)
shares_table = shares_df.to_markdown(index=False)
# Create summary statistics
summary = f"""
### Prediction Summary
**Average Predictions Across All Models:**
- Likes: {np.mean(list(likes_preds.values())):.0f}
- Shares: {np.mean(list(shares_preds.values())):.0f}
"""
return best_likes, best_shares, likes_table, shares_table, summary
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
gr.Markdown("# π¨ AI Ghibli Image Virality Predictor")
gr.Markdown("Predict both **Likes** and **Shares** for your AI-generated Ghibli-style images!")
with gr.Row():
# Input Column
with gr.Column(scale=2):
gr.Markdown("### π Input Features")
with gr.Accordion("Image Properties", open=True):
width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
label="Width (px)")
height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
label="Height (px)")
file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100,
label="File Size (KB)")
style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1,
label="Style Accuracy Score (%)")
with gr.Accordion("Technical Details", open=True):
generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5,
label="Generation Time (seconds)")
gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5,
label="GPU Usage (%)")
is_hand_edited = gr.Checkbox(label="Hand Edited?", value=False)
ethical_concerns_flag = gr.Checkbox(label="Ethical Concerns?", value=False)
with gr.Accordion("Posting Details", open=True):
platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"],
label="Platform", value="Instagram")
day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1,
label="Day of Week (0=Mon, 6=Sun)")
month = gr.Slider(minimum=1, maximum=12, value=7, step=1,
label="Month (1-12)")
hour = gr.Slider(minimum=0, maximum=23, value=18, step=1,
label="Hour of Day (0-23)")
predict_btn = gr.Button("π Predict Virality", variant="primary", size="lg")
# Output Column
with gr.Column(scale=3):
gr.Markdown("### π Prediction Results")
# Main predictions
with gr.Row():
best_likes_output = gr.Number(
label=f"β€οΈ Predicted Likes ({BEST_MODEL_LIKES})",
interactive=False
)
best_shares_output = gr.Number(
label=f"π Predicted Shares ({BEST_MODEL_SHARES})",
interactive=False
)
# Summary
summary_output = gr.Markdown(label="Summary")
# Detailed predictions (continued)
with gr.Row():
with gr.Accordion("All Models - Likes", open=False):
likes_table_output = gr.Markdown(label="Likes Predictions")
with gr.Accordion("All Models - Shares", open=False):
shares_table_output = gr.Markdown(label="Shares Predictions")
# Connect the button to the function
predict_btn.click(
fn=predict_virality_gradio,
inputs=[
generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform
],
outputs=[
best_likes_output,
best_shares_output,
likes_table_output,
shares_table_output,
summary_output
]
)
# Launch the app
if __name__ == "__main__":
load_models()
if not models_loaded:
print("\nCannot launch Gradio app because models failed to load.")
else:
demo.launch(
# share=True,
# debug=True
) |