Spaces:
Running
Running
File size: 5,965 Bytes
569f484 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import sys
sys.path.append('../XGBoost_Prediction_Model/')
import warnings
warnings.filterwarnings("ignore")
import Predict
import XGBoost_utils
import torch
import numpy as np
import os
from os.path import isfile, isdir, join
text_detection_model_path = '../XGBoost_Prediction_Model/EAST-Text-Detection/frozen_east_text_detection.pb'
LDA_model_pth = '../XGBoost_Prediction_Model/LDA_Model_trained/lda_model_best_tot.model'
training_ad_text_dictionary_path = '../XGBoost_Prediction_Model/LDA_Model_trained/object_word_dictionary'
training_lang_preposition_path = '../XGBoost_Prediction_Model/LDA_Model_trained/dutch_preposition'
Adversarial_types = ['Newspaper', 'Banner', 'Multiple Ads', 'Outdoor', 'Others']
mypath = '../XGBoost_Prediction_Model/Adversarial Samples New'
ad_locations_total = {
'Newspaper': [0,0,0,0,0,0,0,0,0,0,0],
'Banner': [2,2,2,2,2,2,2,2,2,2,2],
'Multiple Ads': [0,0,1,0,0,1,0,0,0,0,0],
'Outdoor': [2,2,2,2,2,2,2,2,2,2,2],
'Others': [1,0,1,0,2,0,1,1,0,0,0]
}
Predicted_AG = {}
Predicted_BG = {}
GT_AG = {}
GT_BG = {}
for f in os.listdir(mypath):
# if f != 'Outdoor':
# continue
if isdir(join(mypath, f)):
print('Currently processing samples of type '+f+'......')
path_temp = join(mypath, f)
surfaces = torch.load(join(path_temp,f+'_Adversarial_Surfaces'))
print(surfaces)
categories = torch.load(join(path_temp,f+'_Adversarial_Categories'))
ad_locations_curr = ad_locations_total[f]
ad_embeddings = torch.load(join(path_temp,f+'_ad_topic_embeddings'))
ctpg_embeddings = torch.load(join(path_temp,f+'_ctpg_topic_embeddings'))
GT_AG[f] = torch.load(join(path_temp,'AGs'))
GT_BG[f] = torch.load(join(path_temp,'BGs'))
#f is, e.g. Outdoor
#sub_f is, e.g. 1,2,...,11
AG_predictions_per_type = np.zeros(11)
BG_predictions_per_type = np.zeros(11)
for sub_f in os.listdir(path_temp):
if isdir(join(path_temp, sub_f)):
print('Sample Number '+sub_f+'...... ')
context_pth = None
for imgs in os.listdir(join(path_temp, sub_f)):
jpg_removed_split = imgs[:-4].split(' ')
if len(jpg_removed_split) > 1:
img_type = jpg_removed_split[-1]
if img_type == 'Ad':
ad_pth = join(join(path_temp, sub_f),imgs)
print(ad_pth)
elif img_type == 'Context':
context_pth = join(join(path_temp, sub_f),imgs)
#Start Predicting
sample_num = int(sub_f)-1
AG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768),
surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]),
obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Ad')
AG_predictions_per_type[sample_num] = AG
BG = Predict.Ad_Gaze_Prediction(input_ad_path=ad_pth, input_ctpg_path=context_pth, text_detection_model_path=text_detection_model_path, LDA_model_pth=LDA_model_pth,
training_ad_text_dictionary_path=training_ad_text_dictionary_path, training_lang_preposition_path=training_lang_preposition_path, training_language='dutch',
ad_embeddings=ad_embeddings[sample_num].reshape(1,768), ctpg_embeddings=ctpg_embeddings[sample_num].reshape(1,768),
surface_sizes=list(surfaces[sample_num]), Product_Group=list(categories[sample_num]),
obj_detection_model_pth=None, ad_location=ad_locations_curr[sample_num], num_topic=20, Gaze_Time_Type='Brand')
BG_predictions_per_type[sample_num] = BG
Predicted_AG[f] = AG_predictions_per_type
Predicted_BG[f] = BG_predictions_per_type
print("Final results: ")
diffs_rmse = {}
diffs_rmsrpd = {}
GT_AG_tot = []
GT_BG_tot = []
Pred_AG_tot = []
Pred_BG_tot = []
for key in Predicted_AG.keys():
print(key)
print('AGs', Predicted_AG[key])
print('BGs', Predicted_BG[key])
Pred_AG_tot.append(Predicted_AG[key])
Pred_BG_tot.append(Predicted_BG[key])
GT_AG_tot.append(GT_AG[key])
GT_BG_tot.append(GT_BG[key])
rmse1 = np.sqrt(np.mean((GT_AG[key]-Predicted_AG[key])**2))
rmse2 = np.sqrt(np.mean((GT_BG[key]-Predicted_BG[key])**2))
diffs_rmse[key] = (rmse1, rmse2)
rmsrpd1 = XGBoost_utils.RMSRPD(GT_AG[key],Predicted_AG[key])
rmsrpd2 = XGBoost_utils.RMSRPD(GT_BG[key],Predicted_BG[key])
diffs_rmsrpd[key] = (rmsrpd1, rmsrpd2)
print()
Pred_AG_tot = np.concatenate(Pred_AG_tot)
Pred_BG_tot = np.concatenate(Pred_BG_tot)
GT_AG_tot = np.concatenate(GT_AG_tot)
GT_BG_tot = np.concatenate(GT_BG_tot)
print("RMSE: ")
print("Total AG: ", np.sqrt(np.mean((Pred_AG_tot-GT_AG_tot)**2)))
print("Total BG: ", np.sqrt(np.mean((Pred_BG_tot-GT_BG_tot)**2)))
print()
for key in diffs_rmse.keys():
print(key, diffs_rmse[key][0], diffs_rmse[key][1])
print()
print("RMSRPD: ")
print("Total AG: ", XGBoost_utils.RMSRPD(Pred_AG_tot,GT_AG_tot))
print("Total BG: ", XGBoost_utils.RMSRPD(Pred_BG_tot,GT_BG_tot))
print()
for key in diffs_rmsrpd.keys():
print(key, diffs_rmsrpd[key][0], diffs_rmsrpd[key][1])
print()
print("Correlation: ")
print("Total AG: ", np.corrcoef(Pred_AG_tot, GT_AG_tot))
print("Total BG: ", np.corrcoef(Pred_BG_tot, GT_BG_tot))
|