# Gradio Demo App for Predicting Both Likes and Shares import gradio as gr import pandas as pd import numpy as np import joblib import os # Best models (update based on your results) BEST_MODEL_LIKES = 'Random Forest' BEST_MODEL_SHARES = 'Random Forest' models_loaded = False # Global variables for models and scaler all_models_likes = {} all_models_shares = {} model_names = [] scaler = None expected_columns = [] # Prediction Function for Both Likes and Shares def predict_virality_all_models(generation_time, gpu_usage, file_size_kb, width, height, style_accuracy_score, is_hand_edited, ethical_concerns_flag, day_of_week, month, hour, platform): """ Predicts both likes and shares using all loaded models. """ global all_models_likes, all_models_shares, model_names, scaler, expected_columns # Create feature dictionary (WITHOUT likes) sample_data = { 'style_accuracy_score': style_accuracy_score, 'generation_time': generation_time, 'gpu_usage': gpu_usage, 'file_size_kb': file_size_kb, 'is_hand_edited': int(is_hand_edited), 'ethical_concerns_flag': int(ethical_concerns_flag), 'width': width, 'height': height, 'day_of_week': day_of_week, 'month': month, 'hour': hour } # Perform feature engineering sample_data['aspect_ratio'] = width / height if height > 0 else 0 sample_data['total_pixels'] = width * height sample_data['is_square'] = int(width == height) sample_data['is_weekend'] = int(day_of_week >= 5) # One-hot encode platform for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']: sample_data[f'platform_{p}'] = 1 if platform == p else 0 # Technical features sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1) sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1) # Temporal cyclical features (continued) sample_data['month_sin'] = np.sin(2 * np.pi * month / 12) sample_data['month_cos'] = np.cos(2 * np.pi * month / 12) sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7) sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7) sample_data['hour_sin'] = np.sin(2 * np.pi * hour / 24) sample_data['hour_cos'] = np.cos(2 * np.pi * hour / 24) # Create DataFrame and align columns sample_df = pd.DataFrame([sample_data]) sample_df = sample_df.reindex(columns=expected_columns, fill_value=0) # Scale features try: sample_scaled = scaler.transform(sample_df) except Exception as e: return {}, {}, f"Error during scaling: {e}" # Predict with all models predictions_likes = {} predictions_shares = {} for name in model_names: # Predict likes if name in all_models_likes: pred_likes = all_models_likes[name].predict(sample_scaled)[0] predictions_likes[name] = max(0, int(pred_likes)) # Predict shares if name in all_models_shares: pred_shares = all_models_shares[name].predict(sample_scaled)[0] predictions_shares[name] = max(0, int(pred_shares)) return predictions_likes, predictions_shares, None def load_models(): # Load Models for Both Likes and Shares global all_models_likes, all_models_shares, model_names, scaler, expected_columns, models_loaded # Dictionaries to hold the loaded model objects all_models_likes = {} all_models_shares = {} model_names = [ 'Linear Regression', 'Ridge Regression', 'Lasso Regression', 'Random Forest', 'Gradient Boosting' ] try: # Load all the regression models for both targets for name in model_names: # Load likes model filename_likes = os.path.join("models", f"{name.lower().replace(' ', '_')}_likes.joblib") all_models_likes[name] = joblib.load(filename_likes) # Load shares model filename_shares = os.path.join("models", f"{name.lower().replace(' ', '_')}_shares.joblib") all_models_shares[name] = joblib.load(filename_shares) print(f"Loaded: {name} (both likes and shares)") # Load the scaler scaler = joblib.load(os.path.join("models", "scaler.joblib")) print("Loaded: scaler.joblib") # Get the feature names expected_columns = scaler.feature_names_in_ print(f"Model expects {len(expected_columns)} features.") models_loaded = True print("\nāœ… All models and scaler loaded successfully!") except FileNotFoundError as e: print(f"\nāŒ ERROR: Could not find a model file: {e}") print("Please make sure all '.joblib' files are in the 'models/' directory.") models_loaded = False def predict_virality_gradio(generation_time, gpu_usage, file_size_kb, width, height, style_accuracy_score, is_hand_edited, ethical_concerns_flag, day_of_week, month, hour, platform): """ Gradio wrapper for the prediction function. Returns formatted outputs for both likes and shares. """ if not models_loaded: error_msg = "Models are not loaded. Please check the console for errors." return 0, 0, error_msg, error_msg, error_msg # Get predictions likes_preds, shares_preds, error = predict_virality_all_models( generation_time, gpu_usage, file_size_kb, width, height, style_accuracy_score, is_hand_edited, ethical_concerns_flag, day_of_week, month, hour, platform ) if error: return 0, 0, error, error, error # Get best model predictions best_likes = likes_preds.get(BEST_MODEL_LIKES, 0) best_shares = shares_preds.get(BEST_MODEL_SHARES, 0) # Create comparison tables likes_df = pd.DataFrame(list(likes_preds.items()), columns=['Model', 'Predicted Likes']) likes_df = likes_df.sort_values('Predicted Likes', ascending=False) likes_table = likes_df.to_markdown(index=False) shares_df = pd.DataFrame(list(shares_preds.items()), columns=['Model', 'Predicted Shares']) shares_df = shares_df.sort_values('Predicted Shares', ascending=False) shares_table = shares_df.to_markdown(index=False) # Create summary statistics summary = f""" ### Prediction Summary **Average Predictions Across All Models:** - Likes: {np.mean(list(likes_preds.values())):.0f} - Shares: {np.mean(list(shares_preds.values())):.0f} """ return best_likes, best_shares, likes_table, shares_table, summary # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo: gr.Markdown("# šŸŽØ AI Ghibli Image Virality Predictor") gr.Markdown("Predict both **Likes** and **Shares** for your AI-generated Ghibli-style images!") with gr.Row(): # Input Column with gr.Column(scale=2): gr.Markdown("### šŸ“ Input Features") with gr.Accordion("Image Properties", open=True): width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Width (px)") height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64, label="Height (px)") file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100, label="File Size (KB)") style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1, label="Style Accuracy Score (%)") with gr.Accordion("Technical Details", open=True): generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5, label="Generation Time (seconds)") gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5, label="GPU Usage (%)") is_hand_edited = gr.Checkbox(label="Hand Edited?", value=False) ethical_concerns_flag = gr.Checkbox(label="Ethical Concerns?", value=False) with gr.Accordion("Posting Details", open=True): platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"], label="Platform", value="Instagram") day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1, label="Day of Week (0=Mon, 6=Sun)") month = gr.Slider(minimum=1, maximum=12, value=7, step=1, label="Month (1-12)") hour = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Hour of Day (0-23)") predict_btn = gr.Button("šŸš€ Predict Virality", variant="primary", size="lg") # Output Column with gr.Column(scale=3): gr.Markdown("### šŸ“Š Prediction Results") # Main predictions with gr.Row(): best_likes_output = gr.Number( label=f"ā¤ļø Predicted Likes ({BEST_MODEL_LIKES})", interactive=False ) best_shares_output = gr.Number( label=f"šŸ”„ Predicted Shares ({BEST_MODEL_SHARES})", interactive=False ) # Summary summary_output = gr.Markdown(label="Summary") # Detailed predictions (continued) with gr.Row(): with gr.Accordion("All Models - Likes", open=False): likes_table_output = gr.Markdown(label="Likes Predictions") with gr.Accordion("All Models - Shares", open=False): shares_table_output = gr.Markdown(label="Shares Predictions") # Connect the button to the function predict_btn.click( fn=predict_virality_gradio, inputs=[ generation_time, gpu_usage, file_size_kb, width, height, style_accuracy_score, is_hand_edited, ethical_concerns_flag, day_of_week, month, hour, platform ], outputs=[ best_likes_output, best_shares_output, likes_table_output, shares_table_output, summary_output ] ) # Launch the app if __name__ == "__main__": load_models() if not models_loaded: print("\nCannot launch Gradio app because models failed to load.") else: demo.launch( # share=True, # debug=True )