File size: 10,841 Bytes
713c6ae
 
95eab0a
 
 
 
 
 
713c6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95eab0a
713c6ae
95eab0a
713c6ae
 
 
95eab0a
 
 
 
 
 
 
 
 
 
 
 
 
713c6ae
 
95eab0a
 
 
 
713c6ae
 
95eab0a
 
713c6ae
 
95eab0a
 
713c6ae
 
95eab0a
 
 
 
713c6ae
 
 
 
95eab0a
 
713c6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95eab0a
713c6ae
 
95eab0a
713c6ae
 
 
95eab0a
713c6ae
 
 
 
95eab0a
713c6ae
 
 
95eab0a
713c6ae
 
 
95eab0a
713c6ae
 
 
 
95eab0a
713c6ae
95eab0a
713c6ae
95eab0a
 
713c6ae
95eab0a
 
713c6ae
95eab0a
713c6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95eab0a
713c6ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95eab0a
 
 
713c6ae
95eab0a
713c6ae
95eab0a
 
 
 
 
713c6ae
 
 
 
 
95eab0a
 
 
 
 
713c6ae
95eab0a
 
 
713c6ae
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# Gradio Demo App for Predicting Both Likes and Shares

import gradio as gr
import pandas as pd
import numpy as np
import joblib
import os

# Best models (update based on your results)
BEST_MODEL_LIKES = 'Random Forest'
BEST_MODEL_SHARES = 'Random Forest'
models_loaded = False

# Global variables for models and scaler
all_models_likes = {}
all_models_shares = {}
model_names = []
scaler = None
expected_columns = []

# Prediction Function for Both Likes and Shares
def predict_virality_all_models(generation_time, gpu_usage, file_size_kb,
                               width, height, style_accuracy_score,
                               is_hand_edited, ethical_concerns_flag,
                               day_of_week, month, hour, platform):
    """
    Predicts both likes and shares using all loaded models.
    """
    global all_models_likes, all_models_shares, model_names, scaler, expected_columns
    
    # Create feature dictionary (WITHOUT likes)
    sample_data = {
        'style_accuracy_score': style_accuracy_score,
        'generation_time': generation_time,
        'gpu_usage': gpu_usage,
        'file_size_kb': file_size_kb,
        'is_hand_edited': int(is_hand_edited),
        'ethical_concerns_flag': int(ethical_concerns_flag),
        'width': width,
        'height': height,
        'day_of_week': day_of_week,
        'month': month,
        'hour': hour
    }
    
    # Perform feature engineering
    sample_data['aspect_ratio'] = width / height if height > 0 else 0
    sample_data['total_pixels'] = width * height
    sample_data['is_square'] = int(width == height)
    sample_data['is_weekend'] = int(day_of_week >= 5)
    
    # One-hot encode platform
    for p in ['Twitter', 'TikTok', 'Reddit', 'Instagram']:
        sample_data[f'platform_{p}'] = 1 if platform == p else 0
    
    # Technical features
    sample_data['file_density'] = file_size_kb / (sample_data['total_pixels'] / 1000 + 1)
    sample_data['gpu_efficiency'] = generation_time / (gpu_usage + 1)
    
    # Temporal cyclical features (continued)
    sample_data['month_sin'] = np.sin(2 * np.pi * month / 12)
    sample_data['month_cos'] = np.cos(2 * np.pi * month / 12)
    sample_data['day_sin'] = np.sin(2 * np.pi * day_of_week / 7)
    sample_data['day_cos'] = np.cos(2 * np.pi * day_of_week / 7)
    sample_data['hour_sin'] = np.sin(2 * np.pi * hour / 24)
    sample_data['hour_cos'] = np.cos(2 * np.pi * hour / 24)
    
    # Create DataFrame and align columns
    sample_df = pd.DataFrame([sample_data])
    sample_df = sample_df.reindex(columns=expected_columns, fill_value=0)
    
    # Scale features
    try:
        sample_scaled = scaler.transform(sample_df)
    except Exception as e:
        return {}, {}, f"Error during scaling: {e}"
    
    # Predict with all models
    predictions_likes = {}
    predictions_shares = {}
    
    for name in model_names:
        # Predict likes
        if name in all_models_likes:
            pred_likes = all_models_likes[name].predict(sample_scaled)[0]
            predictions_likes[name] = max(0, int(pred_likes))
        
        # Predict shares
        if name in all_models_shares:
            pred_shares = all_models_shares[name].predict(sample_scaled)[0]
            predictions_shares[name] = max(0, int(pred_shares))
    
    return predictions_likes, predictions_shares, None

def load_models():
    # Load Models for Both Likes and Shares
    global all_models_likes, all_models_shares, model_names, scaler, expected_columns, models_loaded

    # Dictionaries to hold the loaded model objects
    all_models_likes = {}
    all_models_shares = {}
    model_names = [
        'Linear Regression', 'Ridge Regression', 'Lasso Regression',
        'Random Forest', 'Gradient Boosting'
    ]

    try:
        # Load all the regression models for both targets
        for name in model_names:
            # Load likes model
            filename_likes = os.path.join("models", f"{name.lower().replace(' ', '_')}_likes.joblib")
            all_models_likes[name] = joblib.load(filename_likes)

            # Load shares model
            filename_shares = os.path.join("models", f"{name.lower().replace(' ', '_')}_shares.joblib")
            all_models_shares[name] = joblib.load(filename_shares)

            print(f"Loaded: {name} (both likes and shares)")

        # Load the scaler
        scaler = joblib.load(os.path.join("models", "scaler.joblib"))
        print("Loaded: scaler.joblib")

        # Get the feature names
        expected_columns = scaler.feature_names_in_
        print(f"Model expects {len(expected_columns)} features.")

        models_loaded = True
        print("\nβœ… All models and scaler loaded successfully!")

    except FileNotFoundError as e:
        print(f"\n❌ ERROR: Could not find a model file: {e}")
        print("Please make sure all '.joblib' files are in the 'models/' directory.")
        models_loaded = False


def predict_virality_gradio(generation_time, gpu_usage, file_size_kb,
                           width, height, style_accuracy_score,
                           is_hand_edited, ethical_concerns_flag,
                           day_of_week, month, hour, platform):
    """
    Gradio wrapper for the prediction function.
    Returns formatted outputs for both likes and shares.
    """
    if not models_loaded:
        error_msg = "Models are not loaded. Please check the console for errors."
        return 0, 0, error_msg, error_msg, error_msg

    # Get predictions
    likes_preds, shares_preds, error = predict_virality_all_models(
        generation_time, gpu_usage, file_size_kb,
        width, height, style_accuracy_score,
        is_hand_edited, ethical_concerns_flag,
        day_of_week, month, hour, platform
    )

    if error:
        return 0, 0, error, error, error

    # Get best model predictions
    best_likes = likes_preds.get(BEST_MODEL_LIKES, 0)
    best_shares = shares_preds.get(BEST_MODEL_SHARES, 0)

    # Create comparison tables
    likes_df = pd.DataFrame(list(likes_preds.items()), columns=['Model', 'Predicted Likes'])
    likes_df = likes_df.sort_values('Predicted Likes', ascending=False)
    likes_table = likes_df.to_markdown(index=False)

    shares_df = pd.DataFrame(list(shares_preds.items()), columns=['Model', 'Predicted Shares'])
    shares_df = shares_df.sort_values('Predicted Shares', ascending=False)
    shares_table = shares_df.to_markdown(index=False)

    # Create summary statistics
    summary = f"""
    ### Prediction Summary

    **Average Predictions Across All Models:**
    - Likes: {np.mean(list(likes_preds.values())):.0f}
    - Shares: {np.mean(list(shares_preds.values())):.0f}
    """

    return best_likes, best_shares, likes_table, shares_table, summary

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), title="AI Image Virality Predictor") as demo:
    gr.Markdown("# 🎨 AI Ghibli Image Virality Predictor")
    gr.Markdown("Predict both **Likes** and **Shares** for your AI-generated Ghibli-style images!")

    with gr.Row():
        # Input Column
        with gr.Column(scale=2):
            gr.Markdown("### πŸ“ Input Features")

            with gr.Accordion("Image Properties", open=True):
                width = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
                                 label="Width (px)")
                height = gr.Slider(minimum=256, maximum=2048, value=1024, step=64,
                                  label="Height (px)")
                file_size_kb = gr.Slider(minimum=100, maximum=5000, value=1500, step=100,
                                        label="File Size (KB)")
                style_accuracy_score = gr.Slider(minimum=0, maximum=100, value=85, step=1,
                                               label="Style Accuracy Score (%)")

            with gr.Accordion("Technical Details", open=True):
                generation_time = gr.Slider(minimum=1, maximum=30, value=8, step=0.5,
                                          label="Generation Time (seconds)")
                gpu_usage = gr.Slider(minimum=10, maximum=100, value=70, step=5,
                                     label="GPU Usage (%)")
                is_hand_edited = gr.Checkbox(label="Hand Edited?", value=False)
                ethical_concerns_flag = gr.Checkbox(label="Ethical Concerns?", value=False)

            with gr.Accordion("Posting Details", open=True):
                platform = gr.Radio(["Instagram", "Twitter", "TikTok", "Reddit"],
                                   label="Platform", value="Instagram")
                day_of_week = gr.Slider(minimum=0, maximum=6, value=4, step=1,
                                       label="Day of Week (0=Mon, 6=Sun)")
                month = gr.Slider(minimum=1, maximum=12, value=7, step=1,
                                 label="Month (1-12)")
                hour = gr.Slider(minimum=0, maximum=23, value=18, step=1,
                                label="Hour of Day (0-23)")

            predict_btn = gr.Button("πŸš€ Predict Virality", variant="primary", size="lg")

        # Output Column
        with gr.Column(scale=3):
            gr.Markdown("### πŸ“Š Prediction Results")

            # Main predictions
            with gr.Row():
                best_likes_output = gr.Number(
                    label=f"❀️ Predicted Likes ({BEST_MODEL_LIKES})",
                    interactive=False
                )
                best_shares_output = gr.Number(
                    label=f"πŸ”„ Predicted Shares ({BEST_MODEL_SHARES})",
                    interactive=False
                )

            # Summary
            summary_output = gr.Markdown(label="Summary")

            # Detailed predictions (continued)
            with gr.Row():
                with gr.Accordion("All Models - Likes", open=False):
                    likes_table_output = gr.Markdown(label="Likes Predictions")

                with gr.Accordion("All Models - Shares", open=False):
                    shares_table_output = gr.Markdown(label="Shares Predictions")

    # Connect the button to the function
    predict_btn.click(
        fn=predict_virality_gradio,
        inputs=[
            generation_time, gpu_usage, file_size_kb,
            width, height, style_accuracy_score,
            is_hand_edited, ethical_concerns_flag,
            day_of_week, month, hour, platform
        ],
        outputs=[
            best_likes_output,
            best_shares_output,
            likes_table_output,
            shares_table_output,
            summary_output
        ]
    )

# Launch the app
if __name__ == "__main__":
    load_models()
    if not models_loaded:
        print("\nCannot launch Gradio app because models failed to load.")
    else:
        demo.launch(
            # share=True,
            # debug=True
        )