# -*- coding: utf-8 -*-
'''
This script computes the recall scores given the ground-truth annotations and predictions.
'''

import json
import sys
import os
import string
import numpy as np
import time

NUM_K = 10

def read_submission(submit_path, reference, k=5):
    # check whether the path of submitted file exists
    if not os.path.exists(submit_path):
        raise Exception("The submission file is not found!")

    submission_dict = {}
    ref_qids = set(reference.keys())

    with open(submit_path, encoding="utf-8") as fin:
        for line in fin:
            line = line.strip()
            try:
                pred_obj = json.loads(line)
            except:
                raise Exception('Cannot parse this line into json object: {}'.format(line))
            if "text_id" not in pred_obj:
                raise Exception('There exists one line not containing text_id: {}'.format(line))
            if not isinstance(pred_obj['text_id'], int):
                raise Exception('Found an invalid text_id {}, it should be an integer (not string), please check your schema'.format(qid))
            qid = pred_obj["text_id"]
            if "image_ids" not in pred_obj:
                raise Exception('There exists one line not containing the predicted image_ids: {}'.format(line))
            image_ids = pred_obj["image_ids"]
            if not isinstance(image_ids, list):
                raise Exception('The image_ids field of text_id {} is not a list, please check your schema'.format(qid))
            # check whether there are K products for each text
            if len(image_ids) != k:
                raise Exception('Text_id {} has wrong number of predicted image_ids! Require {}, but {} founded.'.format(qid, k, len(image_ids)))           
            # check whether there exist an invalid prediction for any text
            for rank, image_id in enumerate(image_ids):
                if not isinstance(image_id, int):
                    raise Exception('Text_id {} has an invalid predicted image_id {} at rank {}, it should be an integer (not string), please check your schema'.format(qid, image_id, rank + 1))
            # check whether there are duplicate predicted products for a single text
            if len(set(image_ids)) != k:
                raise Exception('Text_id {} has duplicate products in your prediction. Pleace check again!'.format(qid))
            submission_dict[qid] = image_ids # here we save the list of product ids
    
    # check if any text is missing in the submission
    pred_qids = set(submission_dict.keys())
    nopred_qids = ref_qids - pred_qids
    if len(nopred_qids) != 0:
        raise Exception('The following text_ids have no prediction in your submission, please check again: {}'.format(", ".join([str(idx) for idx in nopred_qids])))

    return submission_dict


def dump_2_json(info, path):
    with open(path, 'w', encoding="utf-8") as output_json_file:
        json.dump(info, output_json_file)


def report_error_msg(detail, showMsg, out_p):
    error_dict=dict()
    error_dict['errorDetail']=detail
    error_dict['errorMsg']=showMsg
    error_dict['score']=0
    error_dict['scoreJson']={}
    error_dict['success']=False
    dump_2_json(error_dict,out_p)


def report_score(r1, r5, r10, out_p):
    result = dict()
    result['success']=True
    mean_recall = (r1 + r5 + r10) / 3.0
    result['score'] = mean_recall * 100
    result['scoreJson'] = {'score': mean_recall * 100, 'mean_recall': mean_recall * 100, 'r1': r1 * 100, 'r5': r5 * 100, 'r10': r10 * 100}
    dump_2_json(result,out_p)


def read_reference(path):
    fin = open(path, encoding="utf-8")
    reference = dict()
    for line in fin:
        line = line.strip()
        obj = json.loads(line)
        reference[obj['text_id']] = obj['image_ids']
    return reference

def compute_score(golden_file, predict_file):
    # read ground-truth
    reference = read_reference(golden_file)

    # read predictions
    k = 10
    predictions = read_submission(predict_file, reference, k)

    # compute score for each text
    r1_stat, r5_stat, r10_stat = 0, 0, 0
    for qid in reference.keys():
        ground_truth_ids = set(reference[qid])
        top10_pred_ids = predictions[qid]
        if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
            r1_stat += 1
        if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
            r5_stat += 1
        if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
            r10_stat += 1
    # the higher score, the better
    r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
    mean_recall = (r1 + r5 + r10) / 3.0
    result = [mean_recall, r1, r5, r10]
    result = [score * 100 for score in result]
    return result


if __name__=="__main__":
    # the path of answer json file (eg. test_queries_answers.jsonl)
    standard_path = sys.argv[1]
    # the path of prediction file (eg. example_pred.jsonl)
    submit_path = sys.argv[2]
    # the score will be dumped into this output json file
    out_path = sys.argv[3]

    print("Read standard from %s" % standard_path)
    print("Read user submit file from %s" % submit_path)

    try:
        # read ground-truth
        reference = read_reference(standard_path)
        
        # read predictions
        k = 10
        predictions = read_submission(submit_path, reference, k)

        # compute score for each text
        r1_stat, r5_stat, r10_stat = 0, 0, 0
        for qid in reference.keys():
            ground_truth_ids = set(reference[qid])
            top10_pred_ids = predictions[qid]
            if any([idx in top10_pred_ids[:1] for idx in ground_truth_ids]):
                r1_stat += 1
            if any([idx in top10_pred_ids[:5] for idx in ground_truth_ids]):
                r5_stat += 1
            if any([idx in top10_pred_ids[:10] for idx in ground_truth_ids]):
                r10_stat += 1
        # the higher score, the better
        r1, r5, r10 = r1_stat * 1.0 / len(reference), r5_stat * 1.0 / len(reference), r10_stat * 1.0 / len(reference)
        report_score(r1, r5, r10, out_path)
        print("The evaluation finished successfully.")
    except Exception as e:
        report_error_msg(e.args[0], e.args[0], out_path)
        print("The evaluation failed: {}".format(e.args[0]))