ssyok's picture
change logic to predict like and share. Before that my like is a feature which is not correct
713c6ae
# Gradio Demo App for Predicting Both Likes and Shares
import gradio as gr
import pandas as pd
import numpy as np
import joblib
import os
# Best models (update based on your results)
BEST_MODEL_LIKES = 'Random Forest'
BEST_MODEL_SHARES = 'Random Forest'
models_loaded = False
# Global variables for models and scaler
all_models_likes = {}
all_models_shares = {}
model_names = []
scaler = None
expected_columns = []
# Prediction Function for Both Likes and Shares
def predict_virality_all_models(generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform):
"""
Predicts both likes and shares using all loaded models.
"""
global all_models_likes, all_models_shares, model_names, scaler, expected_columns
# Create feature dictionary (WITHOUT likes)
sample_data = {
'style_accuracy_score': style_accuracy_score,
'generation_time': generation_time,
'gpu_usage': gpu_usage,
'file_size_kb': file_size_kb,
'is_hand_edited': int(is_hand_edited),
'ethical_concerns_flag': int(ethical_concerns_flag),
'width': width,
'height': height,
'day_of_week': day_of_week,
'month': month,
'hour': hour
}
# Perform feature engineering
sample_data['aspect_ratio'] = width / height if height > 0 else 0
sample_data['total_pixels'] = width * height
sample_data['is_square'] = int(width == height)
sample_data['is_weekend'] = int(day_of_week >= 5)
# One-hot encode platform
for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
sample_data[f'platform_{p}'] = 1 if platform == p else 0
# Technical features
sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
# Temporal cyclical features (continued)
sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
sample_data['hour_sin'] = np.sin(2 * np.pi * hour / 24)
sample_data['hour_cos'] = np.cos(2 * np.pi * hour / 24)
# Create DataFrame and align columns
sample_df = pd.DataFrame([sample_data])
sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
# Scale features
try:
sample_scaled = scaler.transform(sample_df)
except Exception as e:
return {}, {}, f"Error during scaling: {e}"
# Predict with all models
predictions_likes = {}
predictions_shares = {}
for name in model_names:
# Predict likes
if name in all_models_likes:
pred_likes = all_models_likes[name].predict(sample_scaled)[0]
predictions_likes[name] = max(0, int(pred_likes))
# Predict shares
if name in all_models_shares:
pred_shares = all_models_shares[name].predict(sample_scaled)[0]
predictions_shares[name] = max(0, int(pred_shares))
return predictions_likes, predictions_shares, None
def load_models():
# Load Models for Both Likes and Shares
global all_models_likes, all_models_shares, model_names, scaler, expected_columns, models_loaded
# Dictionaries to hold the loaded model objects
all_models_likes = {}
all_models_shares = {}
model_names = [
'Linear Regression', 'Ridge Regression', 'Lasso Regression',
'Random Forest', 'Gradient Boosting'
]
try:
# Load all the regression models for both targets
for name in model_names:
# Load likes model
filename_likes = os.path.join("models", f"{name.lower().replace(' ', '_')}_likes.joblib")
all_models_likes[name] = joblib.load(filename_likes)
# Load shares model
filename_shares = os.path.join("models", f"{name.lower().replace(' ', '_')}_shares.joblib")
all_models_shares[name] = joblib.load(filename_shares)
print(f"Loaded: {name} (both likes and shares)")
# Load the scaler
scaler = joblib.load(os.path.join("models", "scaler.joblib"))
print("Loaded: scaler.joblib")
# Get the feature names
expected_columns = scaler.feature_names_in_
print(f"Model expects {len(expected_columns)} features.")
models_loaded = True
print("\nβœ… All models and scaler loaded successfully!")
except FileNotFoundError as e:
print(f"\n❌ ERROR: Could not find a model file: {e}")
print("Please make sure all '.joblib' files are in the 'models/' directory.")
models_loaded = False
def predict_virality_gradio(generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform):
"""
Gradio wrapper for the prediction function.
Returns formatted outputs for both likes and shares.
"""
if not models_loaded:
error_msg = "Models are not loaded. Please check the console for errors."
return 0, 0, error_msg, error_msg, error_msg
# Get predictions
likes_preds, shares_preds, error = predict_virality_all_models(
generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform
)
if error:
return 0, 0, error, error, error
# Get best model predictions
best_likes = likes_preds.get(BEST_MODEL_LIKES, 0)
best_shares = shares_preds.get(BEST_MODEL_SHARES, 0)
# Create comparison tables
likes_df = pd.DataFrame(list(likes_preds.items()), columns=['Model', 'Predicted Likes'])
likes_df = likes_df.sort_values('Predicted Likes', ascending=False)
likes_table = likes_df.to_markdown(index=False)
shares_df = pd.DataFrame(list(shares_preds.items()), columns=['Model', 'Predicted Shares'])
shares_df = shares_df.sort_values('Predicted Shares', ascending=False)
shares_table = shares_df.to_markdown(index=False)
# Create summary statistics
summary = f"""
### Prediction Summary
**Average Predictions Across All Models:**
- Likes: {np.mean(list(likes_preds.values())):.0f}
- Shares: {np.mean(list(shares_preds.values())):.0f}
"""
return best_likes, best_shares, likes_table, shares_table, summary
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
gr.Markdown("# 🎨 AI Ghibli Image Virality Predictor")
gr.Markdown("Predict both **Likes** and **Shares** for your AI-generated Ghibli-style images!")
with gr.Row():
# Input Column
with gr.Column(scale=2):
gr.Markdown("### πŸ“ Input Features")
with gr.Accordion("Image Properties", open=True):
width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
label="Width (px)")
height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
label="Height (px)")
file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100,
label="File Size (KB)")
style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1,
label="Style Accuracy Score (%)")
with gr.Accordion("Technical Details", open=True):
generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5,
label="Generation Time (seconds)")
gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5,
label="GPU Usage (%)")
is_hand_edited = gr.Checkbox(label="Hand Edited?", value=False)
ethical_concerns_flag = gr.Checkbox(label="Ethical Concerns?", value=False)
with gr.Accordion("Posting Details", open=True):
platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"],
label="Platform", value="Instagram")
day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1,
label="Day of Week (0=Mon, 6=Sun)")
month = gr.Slider(minimum=1, maximum=12, value=7, step=1,
label="Month (1-12)")
hour = gr.Slider(minimum=0, maximum=23, value=18, step=1,
label="Hour of Day (0-23)")
predict_btn = gr.Button("πŸš€ Predict Virality", variant="primary", size="lg")
# Output Column
with gr.Column(scale=3):
gr.Markdown("### πŸ“Š Prediction Results")
# Main predictions
with gr.Row():
best_likes_output = gr.Number(
label=f"❀️ Predicted Likes ({BEST_MODEL_LIKES})",
interactive=False
)
best_shares_output = gr.Number(
label=f"πŸ”„ Predicted Shares ({BEST_MODEL_SHARES})",
interactive=False
)
# Summary
summary_output = gr.Markdown(label="Summary")
# Detailed predictions (continued)
with gr.Row():
with gr.Accordion("All Models - Likes", open=False):
likes_table_output = gr.Markdown(label="Likes Predictions")
with gr.Accordion("All Models - Shares", open=False):
shares_table_output = gr.Markdown(label="Shares Predictions")
# Connect the button to the function
predict_btn.click(
fn=predict_virality_gradio,
inputs=[
generation_time, gpu_usage, file_size_kb,
width, height, style_accuracy_score,
is_hand_edited, ethical_concerns_flag,
day_of_week, month, hour, platform
],
outputs=[
best_likes_output,
best_shares_output,
likes_table_output,
shares_table_output,
summary_output
]
)
# Launch the app
if __name__ == "__main__":
load_models()
if not models_loaded:
print("\nCannot launch Gradio app because models failed to load.")
else:
demo.launch(
# share=True,
# debug=True
)