import argparse
import os, sys
import torch
import cv2
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

# Import files from the local folder
root_path = os.path.abspath('.')
sys.path.append(root_path)
from opt import opt
from dataset_curation_pipeline.IC9600.ICNet import ICNet



inference_transform = transforms.Compose([
                        transforms.Resize((512,512)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ])

def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")):
    cm_ic_map = cm(ic_img) 
    heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8))
    ori_img = Image.fromarray(ori_img)
    blend = Image.blend(ori_img,heatmap,alpha=alpha)
    blend = np.array(blend)
    return blend

        
def infer_one_image(model, img_path):
    with torch.no_grad():
        ori_img = Image.open(img_path).convert("RGB")
        ori_height = ori_img.height
        ori_width = ori_img.width
        img = inference_transform(ori_img)
        img = img.cuda()
        img = img.unsqueeze(0)
        ic_score, ic_map = model(img)
        ic_score = ic_score.item()

        
        # ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear')
        
        ## gene ic map
        # ic_map_np = ic_map.squeeze().detach().cpu().numpy()
        # out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy'
        # out_ic_map_path = os.path.join(args.output, out_ic_map_name)
        # np.save(out_ic_map_path, ic_map_np)
        
        ## gene blend map
        # ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8')
        # blend_img = blend(np.array(ori_img), ic_map_img)
        # out_blend_img_name = os.path.basename(img_path).split('.')[0]  + '.png'
        # out_blend_img_path = os.path.join(args.output, out_blend_img_name)
        # cv2.imwrite(out_blend_img_path, blend_img)
        return ic_score

    
    
def infer_directory(img_dir):
    imgs = sorted(os.listdir(img_dir))
    scores = []
    for img in tqdm(imgs):
        img_path = os.path.join(img_dir, img)
        score = infer_one_image(img_path)

        scores.append((score, img_path))
        print(img_path, score)
    
    scores = sorted(scores, key=lambda x: x[0])
    scores = scores[::-1]

    for score in scores[:50]:
        print(score)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', type = str, default = './example')
    parser.add_argument('-o', '--output', type = str, default = './out')
    parser.add_argument('-d', '--device', type = int, default=0)

    args  = parser.parse_args()
    
    model = ICNet()
    model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu')))
    model.eval()
    device = torch.device(args.device)
    model.to(device)
    
    inference_transform = transforms.Compose([
        transforms.Resize((512,512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    if os.path.isfile(args.input):
        infer_one_image(args.input)
    else:
        infer_directory(args.input)