|
|
|
|
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import joblib |
|
import os |
|
|
|
|
|
BEST_MODEL_LIKES = 'Random Forest' |
|
BEST_MODEL_SHARES = 'Random Forest' |
|
models_loaded = False |
|
|
|
|
|
all_models_likes = {} |
|
all_models_shares = {} |
|
model_names = [] |
|
scaler = None |
|
expected_columns = [] |
|
|
|
|
|
def predict_virality_all_models(generation_time, gpu_usage, file_size_kb, |
|
width, height, style_accuracy_score, |
|
is_hand_edited, ethical_concerns_flag, |
|
day_of_week, month, hour, platform): |
|
""" |
|
Predicts both likes and shares using all loaded models. |
|
""" |
|
global all_models_likes, all_models_shares, model_names, scaler, expected_columns |
|
|
|
|
|
sample_data = { |
|
'style_accuracy_score': style_accuracy_score, |
|
'generation_time': generation_time, |
|
'gpu_usage': gpu_usage, |
|
'file_size_kb': file_size_kb, |
|
'is_hand_edited': int(is_hand_edited), |
|
'ethical_concerns_flag': int(ethical_concerns_flag), |
|
'width': width, |
|
'height': height, |
|
'day_of_week': day_of_week, |
|
'month': month, |
|
'hour': hour |
|
} |
|
|
|
|
|
sample_data['aspect_ratio'] = width / height if height > 0 else 0 |
|
sample_data['total_pixels'] = width * height |
|
sample_data['is_square'] = int(width == height) |
|
sample_data['is_weekend'] = int(day_of_week >= 5) |
|
|
|
|
|
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']: |
|
sample_data[f'platform_{p}'] = 1 if platform == p else 0 |
|
|
|
|
|
sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1) |
|
sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1) |
|
|
|
|
|
sample_data['month_sin'] = np.sin(2 * np.pi * month / 12) |
|
sample_data['month_cos'] = np.cos(2 * np.pi * month / 12) |
|
sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7) |
|
sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7) |
|
sample_data['hour_sin'] = np.sin(2 * np.pi * hour / 24) |
|
sample_data['hour_cos'] = np.cos(2 * np.pi * hour / 24) |
|
|
|
|
|
sample_df = pd.DataFrame([sample_data]) |
|
sample_df = sample_df.reindex(columns=expected_columns, fill_value=0) |
|
|
|
|
|
try: |
|
sample_scaled = scaler.transform(sample_df) |
|
except Exception as e: |
|
return {}, {}, f"Error during scaling: {e}" |
|
|
|
|
|
predictions_likes = {} |
|
predictions_shares = {} |
|
|
|
for name in model_names: |
|
|
|
if name in all_models_likes: |
|
pred_likes = all_models_likes[name].predict(sample_scaled)[0] |
|
predictions_likes[name] = max(0, int(pred_likes)) |
|
|
|
|
|
if name in all_models_shares: |
|
pred_shares = all_models_shares[name].predict(sample_scaled)[0] |
|
predictions_shares[name] = max(0, int(pred_shares)) |
|
|
|
return predictions_likes, predictions_shares, None |
|
|
|
def load_models(): |
|
|
|
global all_models_likes, all_models_shares, model_names, scaler, expected_columns, models_loaded |
|
|
|
|
|
all_models_likes = {} |
|
all_models_shares = {} |
|
model_names = [ |
|
'Linear Regression', 'Ridge Regression', 'Lasso Regression', |
|
'Random Forest', 'Gradient Boosting' |
|
] |
|
|
|
try: |
|
|
|
for name in model_names: |
|
|
|
filename_likes = os.path.join("models", f"{name.lower().replace(' ', '_')}_likes.joblib") |
|
all_models_likes[name] = joblib.load(filename_likes) |
|
|
|
|
|
filename_shares = os.path.join("models", f"{name.lower().replace(' ', '_')}_shares.joblib") |
|
all_models_shares[name] = joblib.load(filename_shares) |
|
|
|
print(f"Loaded: {name} (both likes and shares)") |
|
|
|
|
|
scaler = joblib.load(os.path.join("models", "scaler.joblib")) |
|
print("Loaded: scaler.joblib") |
|
|
|
|
|
expected_columns = scaler.feature_names_in_ |
|
print(f"Model expects {len(expected_columns)} features.") |
|
|
|
models_loaded = True |
|
print("\nβ
All models and scaler loaded successfully!") |
|
|
|
except FileNotFoundError as e: |
|
print(f"\nβ ERROR: Could not find a model file: {e}") |
|
print("Please make sure all '.joblib' files are in the 'models/' directory.") |
|
models_loaded = False |
|
|
|
|
|
def predict_virality_gradio(generation_time, gpu_usage, file_size_kb, |
|
width, height, style_accuracy_score, |
|
is_hand_edited, ethical_concerns_flag, |
|
day_of_week, month, hour, platform): |
|
""" |
|
Gradio wrapper for the prediction function. |
|
Returns formatted outputs for both likes and shares. |
|
""" |
|
if not models_loaded: |
|
error_msg = "Models are not loaded. Please check the console for errors." |
|
return 0, 0, error_msg, error_msg, error_msg |
|
|
|
|
|
likes_preds, shares_preds, error = predict_virality_all_models( |
|
generation_time, gpu_usage, file_size_kb, |
|
width, height, style_accuracy_score, |
|
is_hand_edited, ethical_concerns_flag, |
|
day_of_week, month, hour, platform |
|
) |
|
|
|
if error: |
|
return 0, 0, error, error, error |
|
|
|
|
|
best_likes = likes_preds.get(BEST_MODEL_LIKES, 0) |
|
best_shares = shares_preds.get(BEST_MODEL_SHARES, 0) |
|
|
|
|
|
likes_df = pd.DataFrame(list(likes_preds.items()), columns=['Model', 'Predicted Likes']) |
|
likes_df = likes_df.sort_values('Predicted Likes', ascending=False) |
|
likes_table = likes_df.to_markdown(index=False) |
|
|
|
shares_df = pd.DataFrame(list(shares_preds.items()), columns=['Model', 'Predicted Shares']) |
|
shares_df = shares_df.sort_values('Predicted Shares', ascending=False) |
|
shares_table = shares_df.to_markdown(index=False) |
|
|
|
|
|
summary = f""" |
|
### Prediction Summary |
|
|
|
**Average Predictions Across All Models:** |
|
- Likes: {np.mean(list(likes_preds.values())):.0f} |
|
- Shares: {np.mean(list(shares_preds.values())):.0f} |
|
""" |
|
|
|
return best_likes, best_shares, likes_table, shares_table, summary |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo: |
|
gr.Markdown("# π¨ AI Ghibli Image Virality Predictor") |
|
gr.Markdown("Predict both **Likes** and **Shares** for your AI-generated Ghibli-style images!") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("### π Input Features") |
|
|
|
with gr.Accordion("Image Properties", open=True): |
|
width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, |
|
label="Width (px)") |
|
height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, |
|
label="Height (px)") |
|
file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, |
|
label="File Size (KB)") |
|
style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, |
|
label="Style Accuracy Score (%)") |
|
|
|
with gr.Accordion("Technical Details", open=True): |
|
generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, |
|
label="Generation Time (seconds)") |
|
gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, |
|
label="GPU Usage (%)") |
|
is_hand_edited = gr.Checkbox(label="Hand Edited?", value=False) |
|
ethical_concerns_flag = gr.Checkbox(label="Ethical Concerns?", value=False) |
|
|
|
with gr.Accordion("Posting Details", open=True): |
|
platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"], |
|
label="Platform", value="Instagram") |
|
day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, |
|
label="Day of Week (0=Mon, 6=Sun)") |
|
month = gr.Slider(minimum=1, maximum=12, value=7, step=1, |
|
label="Month (1-12)") |
|
hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, |
|
label="Hour of Day (0-23)") |
|
|
|
predict_btn = gr.Button("π Predict Virality", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("### π Prediction Results") |
|
|
|
|
|
with gr.Row(): |
|
best_likes_output = gr.Number( |
|
label=f"β€οΈ Predicted Likes ({BEST_MODEL_LIKES})", |
|
interactive=False |
|
) |
|
best_shares_output = gr.Number( |
|
label=f"π Predicted Shares ({BEST_MODEL_SHARES})", |
|
interactive=False |
|
) |
|
|
|
|
|
summary_output = gr.Markdown(label="Summary") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Accordion("All Models - Likes", open=False): |
|
likes_table_output = gr.Markdown(label="Likes Predictions") |
|
|
|
with gr.Accordion("All Models - Shares", open=False): |
|
shares_table_output = gr.Markdown(label="Shares Predictions") |
|
|
|
|
|
predict_btn.click( |
|
fn=predict_virality_gradio, |
|
inputs=[ |
|
generation_time, gpu_usage, file_size_kb, |
|
width, height, style_accuracy_score, |
|
is_hand_edited, ethical_concerns_flag, |
|
day_of_week, month, hour, platform |
|
], |
|
outputs=[ |
|
best_likes_output, |
|
best_shares_output, |
|
likes_table_output, |
|
shares_table_output, |
|
summary_output |
|
] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
load_models() |
|
if not models_loaded: |
|
print("\nCannot launch Gradio app because models failed to load.") |
|
else: |
|
demo.launch( |
|
|
|
|
|
) |