# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation utilities."""

import os
import time
from functools import partial

import torch

from megatron import get_args
from megatron import print_rank_last, is_last_rank
from megatron import mpu
from megatron.schedules import get_forward_backward_func
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
import json
import numpy as np
from tasks.label_dict import get_label_dict 

def accuracy_func_provider(single_dataset_provider):
    """Provide function that calculates accuracies."""
    args = get_args()

    # Build dataloaders.
    datapaths = [args.valid_data[0], args.test_data[0]]
    dataloaders = []
    for datapath in datapaths:
        dataset = single_dataset_provider(datapath)
        dataloader = build_data_loader(
            dataset, args.micro_batch_size, num_workers=args.num_workers,
            drop_last=(mpu.get_data_parallel_world_size() > 1))
        dataloaders.append((dataset.dataset_name, dataloader))
    
    def _generate_prediction_json(predictions, step, save_acc):
        
        probs_list = predictions[0]
        # labels_list = predictions[1]
        ids_list = predictions[2]
        min_id = min(ids_list)
        max_id = max(ids_list)
        LABELS = get_label_dict(args.task, write2file=True)
        output_submit_file = os.path.join(args.res_path[0], args.task+"_prediction_{}_{}.json".format(step, save_acc))
        with open(output_submit_file, "w") as writer:
            for i in range(min_id, max_id + 1):                    
                label_index = ids_list.index(i)
                pred_prob_list = probs_list[label_index]
                label = pred_prob_list.index(max(pred_prob_list))
                json_d = {}
                if min_id == 1:
                    json_d['id'] = i - 1
                else:
                    json_d['id'] = i
                json_d["label"] = LABELS[str(label)]
                writer.write(json.dumps(json_d) + '\n')

    def _generate_prediction_prob(predictions, step, save_acc):

        probs_list = predictions[0]
        ids_list = predictions[2]
        min_id = min(ids_list)
        max_id = max(ids_list)

        output_prob_file = os.path.join(args.res_path[0], args.task+"_prob_{}_{}".format(step, save_acc))
        prob_arr = []
        for i in range(min_id, max_id + 1):
            label_index = ids_list.index(i)
            prob_arr.append(probs_list[label_index])
        prob_arr = np.array(prob_arr)
        np.save(output_prob_file, prob_arr)

    def metrics_func(model, step):
        print_rank_last('calculating metrics ...')
        correct = 0
        total = 0

        for index, (name, dataloader) in enumerate(dataloaders):
            if index == 1:
                output_predictions = True
                assert mpu.get_data_parallel_world_size() == 1
                named_predictions = []
                names = 'predictions'
            else:
                output_predictions = False

            output = calculate_correct_answers(name, model, dataloader,
                                               step, output_predictions)
            if not output_predictions:
                correct_ans, total_count = output
            else:
                correct_ans, total_count, predictions = output
                named_predictions.append((name, predictions))
                names += '_' + name
            if not output_predictions:
                correct += correct_ans
                total += total_count
            save_acc = str(round(correct / total, 4) * 10000)[:4]

            if output_predictions:
                print_rank_last("generate prediction...")
                # import pdb;pdb.set_trace()
                _generate_prediction_json(predictions, step, save_acc)
                _generate_prediction_prob(predictions, step, save_acc)
                print_rank_last("generate done")
                # import pdb;pdb.set_trace()
        # import pdb;pdb.set_trace()
        # if is_last_rank():
        #     percent = float(correct) * 100.0 / float(total)
        #     print(' >> |step: {}| overall: correct / total = {} / {} = '
        #           '{:.4f} %'.format(step, correct, total, percent))
        # if output_predictions and is_last_rank():
        #     assert args.load is not None
        #     filename = os.path.join(args.load, names + '.pt')
        #     torch.save(named_predictions, filename)

    return metrics_func


def calculate_correct_answers(name, model, dataloader,
                              step, output_predictions):
    """Calculate correct over total answers and return prediction if the
    `output_predictions` is true."""
    args = get_args()
    forward_backward_func = get_forward_backward_func()
    start_time = time.time()
    for m in model:
        m.eval()
    saved_micro_batch_size = args.micro_batch_size
    saved_global_batch_size = args.global_batch_size

    ds = dataloader.dataset
    if hasattr(ds, 'sample_multiplier'):
        # If our dataset as a sample_multiplier attribute that means
        # each "sample" from the dataset actually has multiple samples
        # that will collapse into the batch dimension (for example in
        # the RACE dataset that has several options), we need to
        # account for that when setting the micro batch size.
        sample_multiplier = ds.sample_multiplier
    else:
        sample_multiplier = 1
    micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
    num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel

    def loss_func(output_predictions, labels, output_tensor):
        logits = output_tensor

        loss_dict = {}
        # Add output predictions.
        if output_predictions:
            # assert False
            loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
                logits.float()).data.cpu().numpy().tolist()
            loss_dict['labels'] = labels.data.cpu().numpy().tolist()
            loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
        # Compute the correct answers.
        predicted = torch.argmax(logits, dim=-1)
        corrects = (predicted == labels)
        # Add to the counters.
        loss_dict['total'] = labels.size(0)
        loss_dict['correct'] = corrects.sum().item()

        return 0, loss_dict

    # defined inside to capture output_predictions
    def correct_answers_forward_step(batch, model):
        try:
            batch_ = next(batch)
        except BaseException:
            batch_ = batch
        tokens, types, labels, attention_mask = process_batch(batch_)

        # Forward model.
        args = get_args()
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)

        return output_tensor, partial(loss_func, output_predictions, labels)

    with torch.no_grad():
        # For all the batches in the dataset.
        total = 0
        correct = 0
        if output_predictions:
            # This option is only possible when data parallel size is 1.
            assert mpu.get_data_parallel_world_size() == 1
            softmaxes = []
            labels = []
            ids = []
        for _, batch in enumerate(dataloader):
            # For evaluation only mode we use drop_last = False to get all the
            # samples, which means we might not have a full batch, so we
            # adjust batch_size here to actual batch size of data
            actual_batch_size = len(batch['label'])
            # ... applying sample_multiplier if necessary
            args.micro_batch_size = actual_batch_size * sample_multiplier
            args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches

            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
                                               optimizer=None, timers=None, forward_only=True)

            for loss_dict in loss_dicts:
                if output_predictions:
                    softmaxes.extend(loss_dict['softmaxes'])
                    labels.extend(loss_dict['labels'])
                    ids.extend(loss_dict['ids'])
                total += loss_dict['total']
                correct += loss_dict['correct']


    for m in model:
        m.train()
    args.micro_batch_size = saved_micro_batch_size
    args.global_batch_size = saved_global_batch_size

    # Reduce.
    if mpu.is_pipeline_last_stage():
        unreduced = torch.cuda.LongTensor([correct, total])
        torch.distributed.all_reduce(unreduced,
                                     group=mpu.get_data_parallel_group())

        # Print on screen.

        correct_ans = unreduced[0].item()
        total_count = unreduced[1].item()
        percent = float(correct_ans) * 100.0 / float(total_count)
        elapsed_time = time.time() - start_time
        if not output_predictions:
            print_rank_last(' > |step: {} | metrics for {}: correct / total '
                            '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
                                step, name, correct_ans, total_count,
                                percent, elapsed_time))

        if output_predictions:
            return correct_ans, total_count, (softmaxes, labels, ids)
        return correct_ans, total_count
    if output_predictions:
        return 0, 0, ()
    return 0, 0