from __future__ import division
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2
from yolo.utils import *
import argparse
import os
import os.path as osp
from yolo.darknet import Darknet
# from preprocess import prep_image, inp_to_image
import pandas as pd
import random
import pickle as pkl
import itertools
import os
import base64
from PIL import Image
from io import BytesIO

class yolo_model():

    
    batch_size = int(1)
    confidence = float(0.5)
    nms_thesh = float(0.4)
    reso = 416
    start = 0

    CUDA = torch.cuda.is_available()

    num_classes = 80
    

    def __init__(self):

        self.classes = load_classes( os.path.join( 'yolo' , 'data', 'coco.names' ) )

        # self.colors = pkl.load( get_data_s3( "pallete" ) )

        # Set up the neural network

        self.model = Darknet( os.path.join( 'yolo' , 'yolov3-tiny.cfg' ) )
        self.model.load_weights( os.path.join(  'yolo' , 'yolov3-tiny.weights' ) )
        print(' [*] Model Loaded Successfuly')

        # set model resolution

        self.model.net_info["height"] = self.reso
        self.inp_dim = int(self.model.net_info["height"])
        
        assert self.inp_dim % 32 == 0
        assert self.inp_dim > 32

        # If there's a GPU availible, put the model on GPU
        if self.CUDA:
            self.model.cuda()

        # Set the model in evaluation mode
        self.model.eval()

    def write( self , x , batches , results , colors=[] ):
        c1 = tuple(x[1:3].int())
        c2 = tuple(x[3:5].int())
        img = results[int(x[0])]

        print( 'img' , int( x[0] ) )
        print( 'cls' , int( x[-1] ) )

        cls = int(x[-1])
        label = "{0}".format(self.classes[cls])
        color = random.choice(colors)
        cv2.rectangle(img, c1, c2,color, 1)
        t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0]
        c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4
        cv2.rectangle(img, c1, c2,color, -1)
        cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1)
        return img

    def img_to_base64_str(self,img):
        buffered = BytesIO()
        img.save(buffered, format="PNG")
        buffered.seek(0)
        img_byte = buffered.getvalue()
        img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()
        return img_str


    def predict( self , image ):

        imlist = []
        imlist.append( image )

        batches = list( map( prep_image_org , imlist , [ self.inp_dim for x in range( len(imlist) ) ] ) )
        im_batches = [x[0] for x in batches]
        orig_ims = [x[1] for x in batches]
        im_dim_list = [x[2] for x in batches]

        print( 'im_dim_list : ' , im_dim_list )

        im_dim_list = torch.FloatTensor(im_dim_list).repeat(1,2)

        if self.CUDA:
            im_dim_list = im_dim_list.cuda()

        print('im_batches' , len(im_batches))

        batch = im_batches[0]

        if self.CUDA:
            batch = batch.cuda()


        #Apply offsets to the result predictions
        #Tranform the predictions as described in the YOLO paper
        #flatten the prediction vector
        # B x (bbox cord x no. of anchors) x grid_w x grid_h --> B x bbox x (all the boxes)
        # Put every proposed box as a row.
        with torch.no_grad():
            prediction = self.model(Variable(batch), self.CUDA)

    #        prediction = prediction[:,scale_indices]


        #get the boxes with object confidence > threshold
        #Convert the cordinates to absolute coordinates
        #perform NMS on these boxes, and save the results
        #I could have done NMS and saving seperately to have a better abstraction
        #But both these operations require looping, hence
        #clubbing these ops in one loop instead of two.
        #loops are slower than vectorised operations.

        prediction = write_results(prediction, self.confidence, self.num_classes, nms = True, nms_conf = self.nms_thesh)

        end = time.time()

        # print(end - start)

        # prediction[:,0] += i*batch_size

        output = prediction

        # 1, 1, 1
        # print( 'enumerate : ' , batch_size ,  len(imlist) , min( batch_size , len(imlist) ) )

        for im_num, image in enumerate( imlist ):
            im_id = im_num
            objs = [self.classes[int(x[-1])] for x in output if int(x[0]) == im_id]
            # print("{0:20s} predicted in {1:6.3f} seconds".format(image.split("/")[-1], (end - self.start)/self.batch_size))
            print("{0:20s} {1:s}".format("Objects Detected:", " ".join(objs)))
            print("----------------------------------------------------------")

        im_dim_list = torch.index_select(im_dim_list, 0, output[:,0].long())

        scaling_factor = torch.min(self.inp_dim/im_dim_list,1)[0].view(-1,1)

        output[:,[1,3]] -= (self.inp_dim - scaling_factor*im_dim_list[:,0].view(-1,1))/2
        output[:,[2,4]] -= (self.inp_dim - scaling_factor*im_dim_list[:,1].view(-1,1))/2

        output[:,1:5] /= scaling_factor

        for i in range(output.shape[0]):
            output[i, [1,3]] = torch.clamp(output[i, [1,3]], 0.0, im_dim_list[i,0])
            output[i, [2,4]] = torch.clamp(output[i, [2,4]], 0.0, im_dim_list[i,1])

        colors = pkl.load( open( "yolo/pallete", "rb") )

        list(map(lambda x: self.write( x , im_batches , orig_ims , colors=colors ) , output ) )

        print('orig_ims : shape ',orig_ims[0].shape)
        # print('orig_ims : ',orig_ims[0])

        output_image = Image.fromarray(orig_ims[0])

        img_str = self.img_to_base64_str(output_image)

        # im_bytes = orig_ims[0].tobytes()
        # im_b64 = base64.b64encode(im_bytes)

        # im_b64 = im_b64.decode('utf-8')

        # print( 'im_b64' , im_b64 )

        payload = dict({ 'image' : img_str , 'objects' : objs })

        return payload,output_image