import glob

import gradio as gr
import matplotlib.pyplot as plt
import timm
import torch
from timm import create_model
from timm.models.layers import PatchEmbed
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.transforms import functional as F

CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
TRANSFORM = timm.data.create_transform(
    **timm.data.resolve_data_config(CAIT_MODEL.pretrained_cfg)
)

PATCH_SIZE = 16


def create_attn_extractor(block_id=0):
    """Creates a model that produces the softmax attention scores.
    References:
        https://github.com/huggingface/pytorch-image-models/discussions/926
    """
    feature_extractor = create_feature_extractor(
        CAIT_MODEL,
        return_nodes=[f"blocks_token_only.{block_id}.attn.softmax"],
        tracer_kwargs={"leaf_modules": [PatchEmbed]},
    )
    return feature_extractor


def get_cls_attention_map(
    image, attn_score_dict, block_key="blocks_token_only.0.attn.softmax"
):
    """Prepares attention maps so that they can be visualized."""
    w_featmap = image.shape[3] // PATCH_SIZE
    h_featmap = image.shape[2] // PATCH_SIZE

    attention_scores = attn_score_dict[block_key]
    nh = attention_scores.shape[1]  # Number of attention heads.

    # Taking the representations from CLS token.
    attentions = attention_scores[0, :, 0, 1:].reshape(nh, -1)
    print(attentions.shape)

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = F.resize(
        attentions,
        size=(h_featmap * PATCH_SIZE, w_featmap * PATCH_SIZE),
        interpolation=3,
    )
    print(attentions.shape)

    return attentions


def generate_plot(processed_map):
    """Generates a class attention map plot."""
    fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(13, 13))
    img_count = 0

    for i in range(processed_map.shape[0]):
        if img_count < processed_map.shape[0]:
            axes[i].imshow(processed_map[img_count].numpy())
            axes[i].title.set_text(f"Attention head: {img_count}")
            axes[i].axis("off")
            img_count += 1

    fig.tight_layout()
    return fig


def serialize_images(processed_map):
    """Serializes attention maps."""
    print(f"Number of maps: {processed_map.shape[0]}")
    for i in range(processed_map.shape[0]):
        plt.imshow(processed_map[i].numpy())
        plt.title(f"Attention head: {i}", fontsize=14)
        plt.axis("off")
        plt.savefig(fname=f"attention_map_{i}.png")


def generate_class_attn_map(image, block_id=0):
    """Collates the above utilities together for generating
    a class attention map."""
    image_tensor = TRANSFORM(image).unsqueeze(0)
    feature_extractor = create_attn_extractor(block_id)

    with torch.no_grad():
        out = feature_extractor(image_tensor)

    block_key = f"blocks_token_only.{block_id}.attn.softmax"
    processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)

    serialize_images(processed_cls_attn_map)
    all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
    print(f"Number of images: {len(all_attn_img_paths)}")
    return all_attn_img_paths


title = "Class Attention Maps"
article = "Class attention maps as investigated in [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239) (Touvron et al.). We use the [cait_xxs24_224](https://huggingface.co/timm/cait_xxs24_224.fb_dist_in1k) variant of CaiT. One can find all the other variants [here](https://huggingface.co/models?search=cait)."

iface = gr.Interface(
    generate_class_attn_map,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
    ],
    outputs=gr.Gallery().style(columns=2, height="auto", object_fit="scale-down"),
    title=title,
    article=article,
    allow_flagging="never",
    cache_examples=True,
    examples=[["./bird.png", 0]],
)
iface.launch(debug=True)