import cv2
import torch
from model import U2NET
from torch.autograd import Variable
import numpy as np
from glob import glob
import os

def detect_single_face(face_cascade,img):
    # Convert into grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Detect faces
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)
    if(len(faces)==0):
        print("Warming: no face detection, the portrait u2net will run on the whole image!")
        return None

    # filter to keep the largest face
    wh = 0
    idx = 0
    for i in range(0,len(faces)):
        (x,y,w,h) = faces[i]
        if(wh<w*h):
            idx = i
            wh = w*h

    return faces[idx]

# crop, pad and resize face region to 512x512 resolution
def crop_face(img, face):

    # no face detected, return the whole image and the inference will run on the whole image
    if(face is None):
        return img
    (x, y, w, h) = face

    height,width = img.shape[0:2]

    # crop the face with a bigger bbox
    hmw = h - w
    # hpad = int(h/2)+1
    # wpad = int(w/2)+1

    l,r,t,b = 0,0,0,0
    lpad = int(float(w)*0.4)
    left = x-lpad
    if(left<0):
        l = lpad-x
        left = 0

    rpad = int(float(w)*0.4)
    right = x+w+rpad
    if(right>width):
        r = right-width
        right = width

    tpad = int(float(h)*0.6)
    top = y - tpad
    if(top<0):
        t = tpad-y
        top = 0

    bpad  = int(float(h)*0.2)
    bottom = y+h+bpad
    if(bottom>height):
        b = bottom-height
        bottom = height


    im_face = img[top:bottom,left:right]
    if(len(im_face.shape)==2):
        im_face = np.repeat(im_face[:,:,np.newaxis],(1,1,3))

    im_face = np.pad(im_face,((t,b),(l,r),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))

    # pad to achieve image with square shape for avoding face deformation after resizing
    hf,wf = im_face.shape[0:2]
    if(hf-2>wf):
        wfp = int((hf-wf)/2)
        im_face = np.pad(im_face,((0,0),(wfp,wfp),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
    elif(wf-2>hf):
        hfp = int((wf-hf)/2)
        im_face = np.pad(im_face,((hfp,hfp),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))

    # resize to have 512x512 resolution
    im_face = cv2.resize(im_face, (512,512), interpolation = cv2.INTER_AREA)

    return im_face

def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def inference(net,input):

    # normalize the input
    tmpImg = np.zeros((input.shape[0],input.shape[1],3))
    input = input/np.max(input)

    tmpImg[:,:,0] = (input[:,:,2]-0.406)/0.225
    tmpImg[:,:,1] = (input[:,:,1]-0.456)/0.224
    tmpImg[:,:,2] = (input[:,:,0]-0.485)/0.229

    # convert BGR to RGB
    tmpImg = tmpImg.transpose((2, 0, 1))
    tmpImg = tmpImg[np.newaxis,:,:,:]
    tmpImg = torch.from_numpy(tmpImg)

    # convert numpy array to torch tensor
    tmpImg = tmpImg.type(torch.FloatTensor)

    if torch.cuda.is_available():
        tmpImg = Variable(tmpImg.cuda())
    else:
        tmpImg = Variable(tmpImg)

    # inference
    d1,d2,d3,d4,d5,d6,d7= net(tmpImg)

    # normalization
    pred = 1.0 - d1[:,0,:,:]
    pred = normPRED(pred)

    # convert torch tensor to numpy array
    pred = pred.squeeze()
    pred = pred.cpu().data.numpy()

    del d1,d2,d3,d4,d5,d6,d7

    return pred

def main():

    # get the image path list for inference
    im_list = glob('./test_data/test_portrait_images/your_portrait_im/*')
    print("Number of images: ",len(im_list))
    # indicate the output directory
    out_dir = './test_data/test_portrait_images/your_portrait_results'
    if(not os.path.exists(out_dir)):
        os.mkdir(out_dir)

    # Load the cascade face detection model
    face_cascade = cv2.CascadeClassifier('./saved_models/face_detection_cv2/haarcascade_frontalface_default.xml')
    # u2net_portrait path
    model_dir = './saved_models/u2net_portrait/u2net_portrait.pth'

    # load u2net_portrait model
    net = U2NET(3,1)
    net.load_state_dict(torch.load(model_dir))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()

    # do the inference one-by-one
    for i in range(0,len(im_list)):
        print("--------------------------")
        print("inferencing ", i, "/", len(im_list), im_list[i])

        # load each image
        img = cv2.imread(im_list[i])
        height,width = img.shape[0:2]
        face = detect_single_face(face_cascade,img)
        im_face = crop_face(img, face)
        im_portrait = inference(net,im_face)

        # save the output
        cv2.imwrite(out_dir+"/"+im_list[i].split('/')[-1][0:-4]+'.png',(im_portrait*255).astype(np.uint8))

if __name__ == '__main__':
    main()