import json
import logging
import math
import os
import time
from contextlib import suppress

import numpy as np
import torch
import torch.nn.functional as F

try:
    import wandb
except ImportError:
    wandb = None

from open_clip import ClipLoss, gather_features
from .distributed import is_master
from .zero_shot import zero_shot_eval


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def unwrap_model(model):
    if hasattr(model, "module"):
        return model.module
    else:
        return model


def train_one_epoch(
    model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
):
    device = torch.device(args.device)
    autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
    model.train()
    loss = ClipLoss(
        local_loss=args.local_loss,
        gather_with_grad=args.gather_with_grad,
        cache_labels=True,
        rank=args.rank,
        world_size=args.world_size,
        use_horovod=args.horovod,
        mlp_loss=args.clap_mlploss,
        weight_loss_kappa=args.kappa,
    )

    dataloader, sampler = data["train"].dataloader, data["train"].sampler
    if args.distributed and sampler is not None:
        sampler.set_epoch(epoch)
    num_batches_per_epoch = dataloader.num_batches
    sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))

    # for toy dataset
    if args.dataset_type == "toy":
        dataloader.dataset.generate_queue()

    loss_m = AverageMeter()
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    end = time.time()

    for i, batch in enumerate(dataloader):
        # logging.info(f"batch {i} of {num_batches_per_epoch}")
        step = num_batches_per_epoch * epoch + i
        if isinstance(scheduler, dict):
            for s in scheduler.values():
                s(step)
        else:
            scheduler(step)
        audios = batch  # contains mel_spec, wavform, and longer list
        texts = batch["text"]
        # audios = audios.to(device=device, non_blocking=True)
        # texts = texts.to(device=device, non_blocking=True)

        data_time_m.update(time.time() - end)
        if isinstance(optimizer, dict):
            for o_ in optimizer.values():
                o_.zero_grad()
        else:
            optimizer.zero_grad()

        with autocast():
            (
                audio_features,
                text_features,
                audio_features_mlp,
                text_features_mlp,
                logit_scale_a,
                logit_scale_t,
            ) = model(audios, texts, device)

            if args.clap_mlploss:
                total_loss = loss(
                    audio_features=audio_features,
                    text_features=text_features,
                    logit_scale_a=logit_scale_a,
                    logit_scale_t=logit_scale_t,
                    audio_features_mlp=audio_features_mlp,
                    text_features_mlp=text_features_mlp,
                )
            else:
                total_loss = loss(
                    audio_features=audio_features,
                    text_features=text_features,
                    logit_scale_a=logit_scale_a,
                )
        if isinstance(optimizer, dict):
            if scaler is not None:
                scaler.scale(total_loss).backward()
                for o_ in optimizer.values():
                    if args.horovod:
                        o_.synchronize()
                        scaler.unscale_(o_)
                        with o_.skip_synchronize():
                            scaler.step(o_)
                    else:
                        scaler.step(o_)
                scaler.update()
            else:
                total_loss.backward()
                for o_ in optimizer.values():
                    o_.step()
        else:
            if scaler is not None:
                scaler.scale(total_loss).backward()
                if args.horovod:
                    optimizer.synchronize()
                    scaler.unscale_(optimizer)
                    with optimizer.skip_synchronize():
                        scaler.step(optimizer)
                else:
                    scaler.step(optimizer)
                scaler.update()
            else:
                total_loss.backward()
                optimizer.step()

        # Note: we clamp to 4.6052 = ln(100), as in the original paper.
        with torch.no_grad():
            unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
            if args.clap_mlploss:
                unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))

        batch_time_m.update(time.time() - end)
        end = time.time()
        batch_count = i + 1
        if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
            if isinstance(audios, dict):
                batch_size = len(audios["waveform"])
            else:
                batch_size = len(audios)
            num_samples = batch_count * batch_size * args.world_size
            samples_per_epoch = dataloader.num_samples
            percent_complete = 100.0 * batch_count / num_batches_per_epoch

            # NOTE loss is coarsely sampled, just master node and per log update
            loss_m.update(total_loss.item(), batch_size)
            logit_scale_scalar_a = logit_scale_a.item()
            logit_scale_scalar_t = logit_scale_t.item()
            if isinstance(optimizer, dict):
                if args.clap_mlploss:
                    logging.info(
                        f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                        f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
                        f"Data (t): {data_time_m.avg:.3f} "
                        f"Batch (t): {batch_time_m.avg:.3f} "
                        f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
                        f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
                        f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
                    )
                    log_data = {
                        "loss": loss_m.val,
                        "data_time": data_time_m.val,
                        "batch_time": batch_time_m.val,
                        "scale_audio": logit_scale_scalar_a,
                        "scale_text": logit_scale_scalar_t,
                        "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
                    }
                else:
                    logging.info(
                        f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                        f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
                        f"Data (t): {data_time_m.avg:.3f} "
                        f"Batch (t): {batch_time_m.avg:.3f} "
                        f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
                        f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
                    )
                    log_data = {
                        "loss": loss_m.val,
                        "data_time": data_time_m.val,
                        "batch_time": batch_time_m.val,
                        "scale_audio": logit_scale_scalar_a,
                        "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
                    }

            else:
                if args.clap_mlploss:
                    logging.info(
                        f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                        f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
                        f"Data (t): {data_time_m.avg:.3f} "
                        f"Batch (t): {batch_time_m.avg:.3f} "
                        f"LR: {optimizer.param_groups[0]['lr']:5f} "
                        f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
                        f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
                    )

                    # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
                    log_data = {
                        "loss": loss_m.val,
                        "data_time": data_time_m.val,
                        "batch_time": batch_time_m.val,
                        "scale_audio": logit_scale_scalar_a,
                        "scale_text": logit_scale_scalar_t,
                        "lr": optimizer.param_groups[0]["lr"],
                    }
                else:
                    logging.info(
                        f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                        f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
                        f"Data (t): {data_time_m.avg:.3f} "
                        f"Batch (t): {batch_time_m.avg:.3f} "
                        f"LR: {optimizer.param_groups[0]['lr']:5f} "
                        f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
                    )

                    # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
                    log_data = {
                        "loss": loss_m.val,
                        "data_time": data_time_m.val,
                        "batch_time": batch_time_m.val,
                        "scale_audio": logit_scale_scalar_a,
                        "lr": optimizer.param_groups[0]["lr"],
                    }
            for name, val in log_data.items():
                name = "train/" + name
                if tb_writer is not None:
                    tb_writer.add_scalar(name, val, step)
                if args.wandb:
                    assert wandb is not None, "Please install wandb."
                    wandb.log({name: val, "step": step})

            # resetting batch / data time meters per log window
            batch_time_m.reset()
            data_time_m.reset()
    # end for


def evaluate(model, data, epoch, args, tb_writer=None):
    metrics = {}
    if not args.parallel_eval:
        if not is_master(args):
            return metrics
    device = torch.device(args.device)
    model.eval()

    # CHANGE
    # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
    # metrics.update(zero_shot_metrics)
    if is_master(args):
        print("Evaluating...")
    autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
    if args.val_dataset_names == ["Clotho", "audiocaps"]:
        # if only clotho and audiocaps are used, then we will use a different evaluation function.
        # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
        if args.parallel_eval:
            # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
            raise NotImplementedError(
                "Parallel evaluation not supported for eval only Clotho and audiocaps."
            )
        val_metrics_per_dataset = evaluate_clotho_audiocaps(
            model, data, epoch, args, autocast, device, tb_writer
        )
        for m in val_metrics_per_dataset.values():
            metrics.update(m)
        if "epoch" not in metrics.keys():
            metrics.update({"epoch": epoch})
        metrics = select_top_metric_clotho_audiocaps(
            metrics, val_metrics_per_dataset, args
        )
    elif "val" in data and (
        args.val_frequency
        and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
    ):
        dataloader = data["val"].dataloader
        num_samples = 0
        samples_per_val = dataloader.num_samples

        # FIXME this does not scale past small eval datasets
        # all_audio_features @ all_text_features will blow up memory and compute very quickly
        eval_info = {}
        if args.clap_mlploss:
            eval_info["all"] = {
                "cumulative_loss": 0.0,
                "num_samples": 0,
                "all_audio_features": [],
                "all_text_features": [],
                "all_audio_features_mlp": [],
                "all_text_features_mlp": [],
            }  # cumulative_loss = 0.0
        else:
            eval_info["all"] = {
                "cumulative_loss": 0.0,
                "num_samples": 0,
                "all_audio_features": [],
                "all_text_features": [],
            }  # cumu
        # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
        with torch.no_grad():
            for i, batch in enumerate(dataloader):
                audios = batch  # contains mel_spec, wavform, and longer list
                texts = batch["text"]
                # audios = audios.to(device=device, non_blocking=True)

                all_names = list(
                    set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
                )
                for name in all_names:
                    if name not in eval_info.keys():
                        if args.clap_mlploss:
                            eval_info[name] = {
                                "cumulative_loss": 0.0,
                                "num_samples": 0,
                                "all_audio_features": [],
                                "all_text_features": [],
                                "all_audio_features_mlp": [],
                                "all_text_features_mlp": [],
                            }
                        else:
                            eval_info[name] = {
                                "cumulative_loss": 0.0,
                                "num_samples": 0,
                                "all_audio_features": [],
                                "all_text_features": [],
                            }
                with autocast():
                    (
                        audio_features,
                        text_features,
                        audio_features_mlp,
                        text_features_mlp,
                        logit_scale_a,
                        logit_scale_t,
                    ) = model(audios, texts, device)

                    if args.parallel_eval:
                        # multi-GPU eval
                        if args.clap_mlploss:
                            (
                                audio_features,
                                text_features,
                                audio_features_mlp,
                                text_features_mlp,
                            ) = gather_features(
                                audio_features=audio_features,
                                text_features=text_features,
                                audio_features_mlp=audio_features_mlp,
                                text_features_mlp=text_features_mlp,
                                local_loss=False,
                                gather_with_grad=False,
                                rank=args.rank,
                                world_size=args.world_size,
                                use_horovod=args.horovod,
                                mlp_loss=args.clap_mlploss,
                            )
                        else:
                            (audio_features, text_features,) = gather_features(
                                audio_features=audio_features,
                                text_features=text_features,
                                local_loss=False,
                                gather_with_grad=False,
                                rank=args.rank,
                                world_size=args.world_size,
                                use_horovod=args.horovod,
                                mlp_loss=args.clap_mlploss,
                            )

                    if is_master(args):
                        num_samples += audio_features.shape[0]
                        for n in [*all_names, "all"]:
                            if n == "all":
                                eval_info[n]["all_audio_features"].append(
                                    audio_features.cpu()
                                )
                                eval_info[n]["all_text_features"].append(
                                    text_features.cpu()
                                )
                                if args.clap_mlploss:
                                    eval_info[n]["all_audio_features_mlp"].append(
                                        audio_features_mlp.cpu()
                                    )
                                    eval_info[n]["all_text_features_mlp"].append(
                                        text_features_mlp.cpu()
                                    )
                            else:
                                idx = np.where(
                                    np.array(
                                        [
                                            "-".join(b.split("/")[-3:-1])
                                            for b in batch["__url__"]
                                        ]
                                    )
                                    == n
                                )[0]
                                eval_info[n]["all_audio_features"].append(
                                    audio_features.cpu().index_select(
                                        0, torch.tensor(idx).long()
                                    )
                                )
                                eval_info[n]["all_text_features"].append(
                                    text_features.cpu().index_select(
                                        0, torch.tensor(idx).long()
                                    )
                                )
                                if args.clap_mlploss:
                                    eval_info[n]["all_audio_features_mlp"].append(
                                        audio_features_mlp.cpu().index_select(
                                            0, torch.tensor(idx).long()
                                        )
                                    )
                                    eval_info[n]["all_text_features_mlp"].append(
                                        text_features_mlp.cpu().index_select(
                                            0, torch.tensor(idx).long()
                                        )
                                    )
                        #  print(f'eval step {i}') #  (yusong): for debug

                # cumulative_loss += total_loss * batch_size
                # num_samples += batch_size
                if is_master(args) and (i % 100) == 0:  # and i != 0:
                    logging.info(
                        f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
                    )
            if is_master(args):
                val_metrics_per_dataset = {}
                for n in eval_info.keys():
                    if args.clap_mlploss:
                        metrics_single_dataset = get_metrics(
                            audio_features=torch.cat(
                                eval_info[n]["all_audio_features"]
                            ),
                            text_features=torch.cat(eval_info[n]["all_text_features"]),
                            logit_scale_a=logit_scale_a.cpu(),
                            audio_features_mlp=torch.cat(
                                eval_info[n]["all_audio_features_mlp"]
                            ),
                            text_features_mlp=torch.cat(
                                eval_info[n]["all_text_features_mlp"]
                            ),
                            logit_scale_t=logit_scale_t.cpu(),
                            mlp_loss=args.clap_mlploss,
                        )
                    else:
                        metrics_single_dataset = get_metrics(
                            audio_features=torch.cat(
                                eval_info[n]["all_audio_features"]
                            ),
                            text_features=torch.cat(eval_info[n]["all_text_features"]),
                            logit_scale_a=logit_scale_a.cpu(),
                            mlp_loss=args.clap_mlploss,
                        )
                    val_metrics_per_dataset[n] = {
                        n + "/" + k: v for k, v in metrics_single_dataset.items()
                    }
                    metrics.update(val_metrics_per_dataset[n])
                    if "epoch" not in metrics.keys():
                        metrics.update({"epoch": epoch})
    if is_master(args):
        if not metrics:
            return metrics

        logging.info(
            f"Eval Epoch: {epoch} "
            + "\n".join(
                [
                    "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
                    for m in val_metrics_per_dataset.values()
                ]
            )
        )

        if args.save_logs:
            for name, val in metrics.items():
                if tb_writer is not None:
                    tb_writer.add_scalar(f"val/{name}", val, epoch)

            with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
                f.write(json.dumps(metrics))
                f.write("\n")

        if args.wandb:
            assert wandb is not None, "Please install wandb."
            for name, val in metrics.items():
                wandb.log({f"val/{name}": val, "epoch": epoch})

        return metrics
    else:
        return metrics


def get_metrics(
    audio_features,
    text_features,
    logit_scale_a,
    audio_features_mlp=None,
    text_features_mlp=None,
    logit_scale_t=None,
    mlp_loss=False,
):
    metrics = {}
    if mlp_loss:
        # Set up audio to text & text to audio similary matrice
        a_logits_per_audio = (
            (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
        )
        a_logits_per_text = a_logits_per_audio.t().detach().cpu()
        t_logits_per_audio = (
            (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
        )
        t_logits_per_text = t_logits_per_audio.t().detach().cpu()

        labels = torch.arange(audio_features.shape[0]).long()
        # Change the loss from two terms into four terms with 2x2 combined CE loss
        total_loss = (
            F.cross_entropy(a_logits_per_audio, labels)
            + F.cross_entropy(a_logits_per_text, labels)
            + F.cross_entropy(t_logits_per_audio, labels)
            + F.cross_entropy(t_logits_per_text, labels)
        ) / 4

        metrics[f"cumulative_loss"] = total_loss.item()
        metrics[f"num_samples"] = audio_features.shape[0]

        logits = {
            "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
            "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
        }
        ground_truth = torch.arange(len(text_features)).view(-1, 1)

    else:
        # print("text_features", text_features)
        # print("text_features.shape", text_features.shape)
        logits_per_audio = (
            (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
        )
        logits_per_text = logits_per_audio.t().detach().cpu()

        labels = torch.arange(audio_features.shape[0]).long()
        # Change the loss from two terms into four terms with 2x2 combined CE loss
        total_loss = (
            F.cross_entropy(logits_per_audio, labels)
            + F.cross_entropy(logits_per_text, labels)
        ) / 2

        metrics[f"cumulative_loss"] = total_loss.item()
        metrics[f"num_samples"] = audio_features.shape[0]

        logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}

        ground_truth = torch.arange(len(text_features)).view(-1, 1)

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        preds = torch.where(ranking == ground_truth)[
            1
        ]  # (yusong) this line is slow because it uses single thread
        preds = preds.detach().cpu().numpy()
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)
        # map@10
        metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))

    return metrics


def evaluate_clotho_audiocaps(
    model, data, epoch, args, autocast, device, tb_writer=None
):
    """
    Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
    1. for text-to-audio retrieval, do 5 times and average the results
    2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
    3. for map@10 in audio-to-text retrieval:
        3.1: sort the rank of 5 text
        3.2: exclude the rank >=10 (0-index)
        3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
        (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
        (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
    """
    # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
    dataloader = data["val"].dataloader
    with torch.no_grad():
        eval_info = {}
        for i, batch in enumerate(dataloader):
            audios = batch  # contains mel_spec, wavform, and longer list

            # each item in the list has 5 texts
            if args.tmodel == "transformer":
                from open_clip import tokenize

                texts = [tokenize(t) for t in batch["full_text"]]
                texts = torch.cat(texts)
            else:
                from .data import tokenizer

                texts = [
                    tokenizer(t) for t in batch["full_text"]
                ]  # 5 texts for each audio
                texts = {
                    k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()
                }  # 5 x batch

            # audios = audios.to(device=device, non_blocking=True)

            all_names = list(
                set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
            )
            for name in all_names:
                if name not in eval_info.keys():
                    # we will not use mlp outputs even if args.clap_mlploss=True
                    eval_info[name] = {
                        "cumulative_loss": 0.0,
                        "num_samples": 0,
                        "all_audio_features": [],
                        "all_text_features": [],
                    }
            with autocast():
                audio_features = model(audios, None, device)
                text_features = model(None, texts, device)
                audio_features = F.normalize(audio_features, dim=-1)
                text_features = F.normalize(text_features, dim=-1)

                all_names = list(
                    set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
                )
                for n in all_names:
                    idx = np.where(
                        np.array(
                            ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]
                        )
                        == n
                    )[0]
                    eval_info[n]["all_audio_features"].append(
                        audio_features.cpu().index_select(0, torch.tensor(idx).long())
                    )
                    # (yusong) please double-check. This is for selecting 5 text features at once.
                    # because idx is a list of indices in size of num_samples,
                    # and text_features is a tensor of size (5*num_samples, dim)
                    # so we need to select 5 consecutive indices at once for a single index in idx.
                    eval_info[n]["all_text_features"].append(
                        text_features.cpu()
                        .reshape([-1, 5, text_features.shape[1]])
                        .index_select(0, torch.tensor(idx).long())
                        .reshape([-1, text_features.shape[1]])
                    )

        val_metrics_all = {}

        for n in eval_info.keys():
            logit_scale_a, logit_scale_t = model(None, None, device)
            logit_scale_a = logit_scale_a.cpu()

            audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
            text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)

            logits_per_audio = (
                (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
            )
            logits_per_text = logits_per_audio.t().detach().cpu()

            # logits_per_audio shape: [num_samples, num_samples*5]
            # logits_per_text shape: [num_samples*5, num_samples]

            logging.info(
                f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
                f"logits_per_text shape: {logits_per_text.shape}"
            )

            metrics = {}
            num_samples = audio_features.shape[0]
            metrics[f"num_samples"] = num_samples

            # (yusong) the following code is very important, please double-check:
            # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
            # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
            # Those two are retrieving one of the 5 text for each audio.
            labels = torch.arange(audio_features.shape[0]).long()
            audio_to_text_loss = [
                F.cross_entropy(
                    logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d],
                    labels,
                )
                for d in range(5)
            ]
            text_to_audio_loss = [
                F.cross_entropy(
                    logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :],
                    labels,
                )
                for d in range(5)
            ]
            total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2

            metrics[f"cumulative_loss"] = total_loss.item()

            # text to audio: do 5 times
            pred_text = []
            for d in range(5):
                logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
                ground_truth = torch.arange(len(logit)).view(-1, 1)
                ranking = torch.argsort(
                    logit, descending=True
                )  # [num_samples, num_samples]
                preds = torch.where(ranking == ground_truth)[1]
                pred_text.append(preds.detach().cpu().numpy())
            pred_text_concat = np.concatenate(pred_text, axis=0)  # [5*num_samples]
            metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
            metrics[f"text_to_audio_median_rank"] = (
                np.floor(np.median(pred_text_concat)) + 1
            )
            for k in [1, 5, 10]:
                metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
            # map@10
            metrics[f"text_to_audio_mAP@10"] = np.mean(
                np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)
            )

            # audio to text: take the best result
            # for audio to text map 10, sort and assign descending ground truth.
            # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
            # map@10
            map_all = []
            pred_audio_all = []
            for d in range(num_samples):
                # logits_per_audio: [num_samples, num_samples*5]
                logit_single = logits_per_audio[d, :]  # [5*num_samples]
                # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
                ranking = torch.argsort(
                    logit_single, descending=True
                )  # [5*num_samples]
                # ranking: the index of first match, second match, ...
                ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
                all_pred = torch.where(
                    torch.stack([ranking] * 5) == ground_truth.view(-1, 1)
                )[1]
                min_pred = torch.min(all_pred)
                pred_audio_all.append(min_pred.detach().cpu().numpy())
                all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
                # /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
                map_single = (
                    np.sum(
                        (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))
                    )
                    / 5
                )
                map_all.append(map_single)
            metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
            for k in [1, 5, 10]:
                metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)

            val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()}
    return val_metrics_all


def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
    """
    Calculate performance for Clotho+AudioCaps for model selection.
    """
    selection_performance_all = []
    for n in val_metrics_per_dataset.keys():
        selection_performance = (
            val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"]
            + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]
        ) / 2
        selection_performance_all.append(selection_performance)
    return np.mean(selection_performance_all)


def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
    # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
    # metrics: dict, key: metric name, value: metric value
    # Hack: use args to save the top performance
    if not hasattr(args, "top_selection_performance"):
        selection_performance = calculate_selection_performance_clotho_audiocaps(
            val_metrics_per_dataset
        )
        # TODO: write the if and else together
        metric_update = {}
        for n in val_metrics_per_dataset.keys():
            for k in val_metrics_per_dataset[n].keys():
                metric_update[
                    k.split("/")[0] + "-top" + "/" + k.split("/")[1]
                ] = val_metrics_per_dataset[n][k]
        metric_update["top_selection_performance"] = selection_performance
        metric_update["top-selection-epoch"] = metrics["epoch"]
        metrics.update(metric_update)
        args.top_metric = metric_update
        args.top_selection_performance = selection_performance
    else:
        selection_performance_new = calculate_selection_performance_clotho_audiocaps(
            val_metrics_per_dataset
        )
        selection_performance_old = args.top_selection_performance
        if selection_performance_new > selection_performance_old:
            metric_update = {}
            for n in val_metrics_per_dataset.keys():
                for k in val_metrics_per_dataset[n].keys():
                    metric_update[
                        k.split("/")[0] + "-top" + "/" + k.split("/")[1]
                    ] = val_metrics_per_dataset[n][k]
            metric_update["top_selection_performance"] = selection_performance_new
            metric_update["top-selection-epoch"] = metrics["epoch"]
            metrics.update(metric_update)
            args.top_metric = metric_update
            args.top_selection_performance = selection_performance_new
        else:
            metrics.update(args.top_metric)
    return metrics