import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import timm

from PIL import Image

import matplotlib.pyplot as plt

import os

# Thanks to ( ), proxy can be essentail :)
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808'

IMG_FILE_LIST = [
    './testcases/14.jpg',
    './testcases/15.jpg',
    './testcases/16.jpg',
    './testcases/17.jpg',
    './testcases/18.jpg',
    './testcases/19.jpg'
]

TANH_SCALE = 1


class Scorer(nn.Module):
    def __init__(
            self,
            model_name,
            pretrained=False,
            features_only=True,
            embedding_dim=128
    ):
        super(Scorer, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only)
        pooled_dim = 128 + 256 + 512 + 1024
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(128),
            nn.LayerNorm(256),
            nn.LayerNorm(512),
            nn.LayerNorm(1024)
        ])
        self.mlp = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim),
            nn.BatchNorm1d(pooled_dim),
            nn.GELU(),
        )
        # Probably a BYOL-accidental BatchNorm could help ?
        self.mlp_1 = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim // 4),
            nn.BatchNorm1d(pooled_dim // 4),
            nn.GELU(),
            nn.Linear(pooled_dim // 4, 3),
            nn.Tanh()
        )
        self.mlp_2 = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim // 4),
            nn.GELU(),
            nn.Linear(pooled_dim // 4, 1),
        )

    def forward(self, x, upload_date=None, freeze_backbone=False):
        if freeze_backbone:
            with torch.no_grad():
                out_features = self.model(x)
        else:
            out_features = self.model(x)
        #  out_features: List [
        #  torch.Size([1, 128, x, x])
        #  torch.Size([1, 256, x, x])
        #  torch.Size([1, 512, x, x])
        #  torch.Size([1, 1024, x, x])
        #  ]
        # Pool the output features from each layer on the channel dimension
        pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features]
        # Normalize the pooled features
        pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)]
        # Embed the upload date
        # date_embedding_features = self.embedding(upload_date)
        # Concatenate the pooled features
        out = torch.cat(pooled_features, dim=-1)
        # Concatenate the date embedding features
        # out = torch.cat([out, date_embedding_features], dim=-1)
        out = self.mlp(out)
        rl_out = self.mlp_1(out) * TANH_SCALE
        ai_out = self.mlp_2(out).squeeze(-1)
        return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2]


BACKBONE = 'convnextv2_base.fcmae'
RESOLUTION = 640
SHOW_GRAD = False
GRAD_SCALE = 50

MORE_LIKE = False
MORE_COLLECTION = False
LESS_AI = False
MORE_RELATIVE_POP = True

WEIGHT_PATH = './scorer.pt'

DECIVE = 'cuda'


def main():
    model = Scorer(BACKBONE)
    transform = transforms.Compose([
        transforms.Resize((RESOLUTION, RESOLUTION)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    model.load_state_dict(torch.load(WEIGHT_PATH))
    model.eval()
    model.to(DECIVE)

    # Show all the images in pyplot horizontally, and mark the predicted values under each image
    fig = plt.figure(figsize=(20, 20))
    for i, img_file in enumerate(IMG_FILE_LIST):
        img = Image.open(img_file, 'r').convert('RGB')
        transformed_img = transform(img).unsqueeze(0).to(DECIVE)
        transformed_img.requires_grad = True
        liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False)
        ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1)

        backwardee = 0
        if MORE_LIKE:
            backwardee -= liking_pred
        if MORE_COLLECTION:
            backwardee -= collection_pred
        if LESS_AI:
            backwardee += ai_pred
        if MORE_RELATIVE_POP:
            backwardee -= relative_pop
        if SHOW_GRAD:
            model.zero_grad()
            # Figure out which part of the image is the most important to popularity
            backwardee.backward()
            # Get the gradients of the image, and normalize them
            gradients = transformed_img.grad
            # squeeze the batch dimension
            gradients = gradients.squeeze(0).detach()
            # resize the gradients to the same size as the image
            gradients = transforms.Resize((img.height, img.width))(gradients)
            # add the gradients to the image
            img = transforms.ToTensor()(img)
            img = img + gradients.cpu() * GRAD_SCALE
            img = transforms.ToPILImage()(img.cpu())
        ax.imshow(img)
        del img
        ax.set_title(
            f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}')
    plt.show()
    pass


if __name__ == '__main__':
    main()