import sys
sys.path.append('XGBoost_Prediction_Model/')

import warnings
warnings.filterwarnings("ignore")
import Predict
import torch
import numpy as np
import os
from os.path import isfile, isdir, join
from Magazine_Optimization import *

mypath = 'XGBoost_Prediction_Model/Magazine_Optimization_Demo/Magazines'
results = {}

#Target Magazine folders
Target = [join(mypath, 'M1'), join(mypath, 'M3')]

for Tg in Target:
    dir_list_target = []
    for sub_f in os.listdir(Tg):
        if isdir(join(Tg, sub_f)):
            sub_path_temp = join(Tg, sub_f)
            if (sub_f.split('_')[0]) == 'Jpg':
                dir_list_target = os.listdir(sub_path_temp)
                dir_list_target.sort()
                for i in range(len(dir_list_target)):
                    dir_list_target[i] = join(sub_path_temp,dir_list_target[i])
            else:
                Slots_target = torch.load(join(sub_path_temp,'Slots')).astype('int32')
                # Sizes_target = torch.load(join(sub_path_temp,'surfaces'))
                # Product_Groups_target = torch.load(join(sub_path_temp,'Prod_Cat'))
                Textboxes_target = torch.load(join(sub_path_temp,'Textboxes'))
                Obj_and_Topics_target = torch.load(join(sub_path_temp,'Obj_and_Topics'))

    for f in os.listdir(mypath):
        if isdir(join(mypath, f)) and f != 'M1' and f != 'M3':
            print('Currently processing Target Magazine '+Tg+' with Ad Magazine '+f+'......')
            path_temp_ad = join(mypath, f)
            dir_list_ad = []
            for sub_f in os.listdir(path_temp_ad):
                if isdir(join(path_temp_ad, sub_f)):
                    sub_path_temp = join(path_temp_ad, sub_f)
                    if (sub_f.split('_')[0]) == 'Jpg':
                        dir_list_ad = os.listdir(sub_path_temp)
                        dir_list_ad.sort()
                        for i in range(len(dir_list_ad)):
                            dir_list_ad[i] = join(sub_path_temp,dir_list_ad[i])
                    else:
                        Slots_ad = torch.load(join(sub_path_temp,'Slots')).astype('int32')
                        Sizes_ad = torch.load(join(sub_path_temp,'surfaces'))
                        Product_Groups_ad = torch.load(join(sub_path_temp,'Prod_Cat'))
                        Textboxes_ad = torch.load(join(sub_path_temp,'Textboxes'))
                        Obj_and_Topics_ad = torch.load(join(sub_path_temp,'Obj_and_Topics'))

            result = Preference_Matrix_different_magazine(dir_list_target, dir_list_ad,
                                                          Slots_target, Slots_ad,
                                                          Product_Groups_ad, Sizes_ad, 
                                                          Textboxes_Target=Textboxes_target, Textboxes_Ad=Textboxes_ad,
                                                          Obj_and_Topics_Target=Obj_and_Topics_target, Obj_and_Topics_Ad=Obj_and_Topics_ad)
            if result is not None:
                Ad_Gaze, Brand_Gaze, Double_Page_Ad_Attention, Double_Page_Brand_Attention, Assign_ids_ad, Assign_ids_target = result

                #Assignement Problem
                workers = []
                jobs = []
                N = np.max(Ad_Gaze.shape)
                M_small = np.min(Ad_Gaze.shape)
                for i in range(N):
                    workers.append(i+1)
                    jobs.append(i+1)
                zeros_aux = np.zeros((N,N))
                zeros_aux[:Ad_Gaze.shape[0],:] = Ad_Gaze
                Ad_Gaze = zeros_aux
                zeros_aux = np.zeros((N,N))
                zeros_aux[:Brand_Gaze.shape[0],:] = Brand_Gaze
                Brand_Gaze = zeros_aux

                max_ad_attention = np.max(Ad_Gaze)
                max_brand_attention = np.max(Brand_Gaze)
                Ad_Gaze_cost = max_ad_attention - Ad_Gaze
                Brand_Gaze_cost = max_brand_attention - Brand_Gaze

                Prob_solved_Ad = Assignment_Problem(Ad_Gaze_cost, workers, jobs)
                Prob_solved_Brand = Assignment_Problem(Brand_Gaze_cost, workers, jobs)

                # Print the variables optimized value
                print('If based on maximizing Overall Ad Attention: ')
                strategy_AG = ''
                BG_under_AG_assignment = 0
                for v in Prob_solved_Ad.variables():
                    if v.varValue == 1:
                        curr = (v.name).split('_')
                        BG_under_AG_assignment += Brand_Gaze_cost[int(curr[1])-1,int(curr[2])-1]
                        if int(curr[1]) <= M_small:
                            temp = curr[0]+' Ad '+str(Assign_ids_ad[int(curr[1])-1])+' to Counterpage '+str(Assign_ids_target[int(curr[2])-1])
                            strategy_AG += temp+'; '
                            print(temp)
                    
                # The optimised objective function value is printed to the screen
                m_ad = N*max_ad_attention - value(Prob_solved_Ad.objective) + sum(Double_Page_Ad_Attention)
                print("Maximized Ad Attention = ", m_ad, " sec.")
                print("Maximized Average Ad attention on each Ad = ", (N*max_ad_attention - value(Prob_solved_Ad.objective) + sum(Double_Page_Ad_Attention))/(N + len(Double_Page_Ad_Attention)), " sec.")
                print()

                # Print the variables optimized value
                print('If based on maximizing Overall Brand Attention: ')
                strategy_BG = ''
                for v in Prob_solved_Brand.variables():
                    if v.varValue == 1:
                        curr = (v.name).split('_')
                        if int(curr[1]) <= M_small:
                            temp = curr[0]+' Ad '+str(Assign_ids_ad[int(curr[1])-1])+' to Counterpage '+str(Assign_ids_target[int(curr[2])-1])
                            strategy_BG += temp+'; '
                            print(temp)
                    
                # The optimised objective function value is printed to the screen
                m_brand = N*max_brand_attention - value(Prob_solved_Brand.objective) + sum(Double_Page_Brand_Attention)
                BG_under_AG_assignment = N*max_brand_attention - BG_under_AG_assignment + sum(Double_Page_Brand_Attention)
                print("Maximized Brand Attention = ", m_brand, " sec.")
                print("New Brand Gaze under AG assignment = ", BG_under_AG_assignment, " sec.")
                print("Maximized Average Brand attention on each Ad = ", (N*max_brand_attention - value(Prob_solved_Brand.objective) + sum(Double_Page_Brand_Attention))/(N + len(Double_Page_Brand_Attention)), " sec.")
                print('End of Magazine '+f+'......')

                results[Tg+' '+f] = {'AG':[strategy_AG,m_ad,np.trace(Ad_Gaze)], 'BG':[strategy_BG,m_brand,np.trace(Brand_Gaze),BG_under_AG_assignment]}
                print()
                print()
            else:
                print("Ads cannot be fully assigned!")


print()
print('Summary: ')
for f in list(results.keys()):
    print('Magazine '+f+': ')
    dict_curr = results[f]
    print('Ad Gaze: ')
    print('Strategy: '+dict_curr['AG'][0])
    print('max Attention: ',dict_curr['AG'][1])
    print('------------------------')
    print('Brand Gaze: ')
    print('Strategy: '+dict_curr['BG'][0])
    print('max Attention: ',dict_curr['BG'][1])
    print('------------------------')
    print()