# ------------------------------------------------------------------------------------
# Minimal DALL-E
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------------------

import os
import torch
import torch.nn as nn
import pytorch_lightning as pl
from typing import Optional, Tuple, Union
from omegaconf import OmegaConf
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.nn import functional as F
from .stage1.vqgan import VQGAN
from .stage2.transformer import Transformer1d, iGPT
from .stage2.layers import Block
from .. import utils
from ..utils.config import get_base_config
from ..utils.sampling import sampling, sampling_igpt, get_positional_encoding, sampling_prefix, sampling_conditional
from ..utils.utils import save_image
from .tokenizer import build_tokenizer
import numpy as np
from .stage2.layers import CrossAttentionLayer
from huggingface_hub import hf_hub_download

_MODELS = {
    'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
}

class Dalle(pl.LightningModule):
    def __init__(self,
                 config: OmegaConf) -> None:
        super().__init__()
        self.tokenizer = None
        self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
                            embed_dim=config.stage1.embed_dim,
                            hparams=config.stage1.hparams)
        self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
                                    vocab_size_img=config.stage2.vocab_size_img,
                                    hparams=config.stage2.hparams)
        self.config = config
        self.config_stage1 = config.stage1
        self.config_stage2 = config.stage2
        self.config_dataset = config.dataset

        # # make the parameters in stage 1 not trainable
        # self.stage1.eval()
        # for p in self.stage1.parameters():
        #     p.requires_grad = False

    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        path = args.model_name_or_path
        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        if args.do_train:
            config_base = get_base_config('finetuning')
            config_update = OmegaConf.merge(config_base, config_new)
            for key, val in vars(args).items():
                if key in config_update.optimizer.keys():
                    OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
                if key in config_update.experiment.keys():
                    OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
        else:
            config_base = get_base_config('default')
            config_update = OmegaConf.merge(config_base, config_new)

        model = cls(config_update)
        model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                          context_length=model.config_dataset.context_length,
                                          lowercase=True,
                                          dropout=None)

        print("Loading models from checkpoint %s" % path)

        if hasattr(args, 'dalle_path') and args.dalle_path and args.dalle_path.endswith('.pth'):
            model.load_state_dict(torch.load(args.dalle_path)["model_state_dict"])
        else:
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

        return model, config_update
        

    @torch.no_grad()
    def sampling(self,
                 prompt: Union[str, torch.LongTensor],
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True) -> torch.FloatTensor:
        self.stage1.eval()
        self.stage2.eval()

        if type(prompt) == str:
            tokens = self.tokenizer.encode(prompt)
            tokens = torch.LongTensor(tokens.ids)
        else:
            tokens = prompt
        tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        tokens = tokens.to(device)
        codes = sampling(self.stage2,
                         tokens,
                         top_k=top_k,
                         top_p=top_p,
                         softmax_temperature=softmax_temperature,
                         use_fp16=use_fp16)
        codes = codes.view(num_candidates, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels

    def forward(self,
                images: torch.FloatTensor,
                texts: Optional[torch.LongTensor],
                past=None
                ) -> tuple:
        B, C, H, W = images.shape
        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()
        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past)
        return logits_img, logits_txt, codes

    def training_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    def validation_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)
        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    def configure_optimizers(self):
        assert self.config.optimizer.opt_type == 'adamW'
        # assert self.config.optimizer.sched_type == 'cosine'

        opt = torch.optim.AdamW(self.parameters(),
                                lr=self.config.optimizer.learning_rate,
                                betas=self.config.optimizer.betas,
                                weight_decay=self.config.optimizer.weight_decay)
        # sched = CosineAnnealingLR(opt,
        #                           T_max=self.config.optimizer.max_steps,
        #                           eta_min=self.config.optimizer.min_lr)

        def lr_lambda(current_step: int):
            return max(
                0.0, float(self.config.optimizer.max_steps - current_step) / float(max(1, self.config.optimizer.max_steps))
            )

        sched = LambdaLR(opt, lr_lambda)
        sched = {
            'scheduler': sched,
            'name': 'linear'
        }
        return [opt], [sched]

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
                       on_tpu=False, using_native_amp=False, using_lbfgs=False):
        optimizer.step(closure=optimizer_closure)
        self.lr_schedulers().step()
        self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)

    def on_epoch_start(self):
        self.stage1.eval()


class ImageGPT(pl.LightningModule):
    def __init__(self,
                 config: OmegaConf) -> None:
        super().__init__()
        self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
                            embed_dim=config.stage1.embed_dim,
                            hparams=config.stage1.hparams)
        self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
                           use_cls_cond=config.stage2.use_cls_cond,
                           hparams=config.stage2.hparams)
        self.config = config
        self.use_cls_cond = config.stage2.use_cls_cond

        # make the parameters in stage 1 not trainable
        self.stage1.eval()
        for p in self.stage1.parameters():
            p.requires_grad = False

    @classmethod
    def from_pretrained(cls,
                        path_upstream: str,
                        path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
        config_base = get_base_config(use_default=False)
        config_down = OmegaConf.load(path_downstream)
        config_down = OmegaConf.merge(config_base, config_down)

        model = cls(config_down)
        model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
        model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
        return model, config_down

    def sample(self,
               cls_idx: Optional[int] = None,
               top_k: int = 256,
               top_p: Optional[float] = None,
               softmax_temperature: float = 1.0,
               num_candidates: int = 16,
               device: str = 'cuda:0',
               use_fp16: bool = True,
               is_tqdm: bool = True) -> torch.FloatTensor:
        self.stage1.eval()
        self.stage2.eval()

        if cls_idx is None:
            sos = self.stage2.sos.repeat(num_candidates, 1, 1)
        else:
            sos = torch.LongTensor([cls_idx]).to(device=device)
            sos = sos.repeat(num_candidates)
            sos = self.stage2.sos(sos).unsqueeze(1)

        codes = sampling_igpt(self.stage2,
                              sos=sos,
                              top_k=top_k,
                              top_p=top_p,
                              softmax_temperature=softmax_temperature,
                              use_fp16=use_fp16,
                              is_tqdm=is_tqdm)
        codes = codes.view(num_candidates, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels

    def forward(self,
                images: torch.FloatTensor,
                labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
        B, C, H, W = images.shape
        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()
        logits = self.stage2(codes, labels)
        return logits, codes

    def training_step(self, batch, batch_idx):
        images, labels = batch
        logits, codes = self(images, labels=labels if self.use_cls_cond else None)
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
        self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits, codes = self(images, labels=labels if self.use_cls_cond else None)
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss

    def configure_optimizers(self):
        assert self.config.optimizer.opt_type == 'adamW'
        assert self.config.optimizer.sched_type == 'cosine'

        opt = torch.optim.AdamW(self.parameters(),
                                lr=self.config.optimizer.base_lr,
                                betas=self.config.optimizer.betas,
                                weight_decay=self.config.optimizer.weight_decay)
        sched = CosineAnnealingLR(opt,
                                  T_max=self.config.optimizer.max_steps,
                                  eta_min=self.config.optimizer.min_lr)
        sched = {
            'scheduler': sched,
            'name': 'cosine'
        }
        return [opt], [sched]

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
                       on_tpu=False, using_native_amp=False, using_lbfgs=False):
        optimizer.step(closure=optimizer_closure)
        self.lr_schedulers().step()
        self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)

    def on_epoch_start(self):
        self.stage1.eval()


class PromptDalle(Dalle):
    """Classification Head for  transformer encoders"""
    def __init__(self, config):
        super().__init__(config)
        print('Initializing the PromptTuning model')

        self.config = config
        self.n_embd = config.stage2.hparams.embed_dim
        self.preseqlen = config.prompt.preseqlen
        self.prefix_dropout = config.prompt.prefix_dropout

        # DIFFERENT PARAMETRIZATION:

        print('[Full prompt-tuning Setting :) ]')
        self.input_tokens = torch.arange(self.preseqlen).long()
        self.wte = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans = nn.Sequential(
            nn.Linear(self.n_embd, self.n_embd),
            nn.Tanh(),
            nn.Linear(self.n_embd, self.n_embd))
        self.get_prompt = self.get_prompt_p5
        self.dropout = nn.Dropout(self.prefix_dropout)

        ###### NUM PARAMS #########
        total_param = 0
        for name, param in self.named_parameters():
            # print(param.shape)
            total_param += param.numel()
        print('Total parameters is {}'.format(total_param))


    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        # if not args.model_name_or_path:
        #     args.model_name_or_path = args.prefix_model_name_or_path

        path = args.prefix_model_name_or_path
        path = _MODELS[path] if path in _MODELS else path
        path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))

        config_base = get_base_config('prompt_tuning')
        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        config_update = OmegaConf.merge(config_base, config_new)

        for key, val in vars(args).items():
            if key in config_update.prompt.keys():
                OmegaConf.update(config_update, "prompt.%s" % key, val, merge=False)
            if key in config_update.optimizer.keys():
                OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
            if key in config_update.experiment.keys():
                OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)

        model = cls(config_update)
        model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                          context_length=model.config_dataset.context_length,
                                          lowercase=True,
                                          dropout=None)

        if args.model_name_or_path:
            print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
            # model.from_ckpt(args.model_name_or_path)
            try:
                model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
            except KeyError:
                model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])

        else:
            print("Loading models from checkpoint %s" % path)
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

        return model, config_update

    def get_prompt_p5(self, bsz=None, eval=False):
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
        temp_control = self.wte(input_tokens)
        past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
        if not eval:
            past_key_values = self.dropout(past_key_values)
        return past_key_values

    def forward(self,
                images: torch.FloatTensor,
                texts: Optional[torch.LongTensor],
                **kwargs,
        ):

        #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}

        B, C, H, W = images.shape
        prompt = self.get_prompt(bsz=B)
        pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')

        # if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
        #     attention_mask = torch.cat([src_attn, tgt_attn], dim=1)


        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()

        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        # print(images.shape, codes.shape, texts.shape)
        logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt, pos_prompt=pos_enc_prompt)
        return logits_img, logits_txt, codes


    @torch.no_grad()
    def sampling(self,
                 tokens: torch.LongTensor,
                 prompt: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True,
                 labels = None) -> torch.FloatTensor:
        self.stage1.eval()
        self.stage2.eval()

        # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        tokens = tokens.to(device)
        pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')

        codes = sampling(self.stage2,
                         tokens,
                         top_k=top_k,
                         top_p=top_p,
                         softmax_temperature=softmax_temperature,
                         use_fp16=use_fp16,
                         prompt=prompt,
                         pos_prompt=pos_enc_prompt)

        codes = codes.view(-1, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels


    @torch.no_grad()
    def predict_step(self, batch, batch_idx, return_images=False):
        orig_images, texts = batch

        # extra for checks
        logits_img, logits_txt, codes = self(orig_images, texts)
        pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
        bs = orig_images.shape[0]
        pred = pred.view(bs, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy()  # [B, 256, 256]
        pixels = np.transpose(pixels, (0, 2, 3, 1))

        # print(texts.shape, orig_images.shape)
        prompt = self.get_prompt(bsz=5, eval=True)

        images = []
        for i, t in enumerate(texts):
            pixels = self.sampling(t, prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
            pixels = np.transpose(pixels, (0, 2, 3, 1))
            images.append(pixels)

        if return_images:
            return images
        else:
            save_image(orig_images, pixels, './out/images/pororo_prompt', batch_idx+10)
            save_image(orig_images, images, './out/images/pororo_prompt', batch_idx)


class PrefixTuningDalle(Dalle):
    """Classification Head for  transformer encoders"""
    def __init__(self, config):
        super().__init__(config)
        print('Initializing the PrefixTuning model')

        self.config = config

        self.match_n_layer = config.stage2.hparams.n_layers
        self.match_n_head = config.stage2.hparams.n_heads
        self.match_n_embd = config.stage2.hparams.embed_dim // config.stage2.hparams.n_heads
        self.n_embd = config.stage2.hparams.embed_dim

        self.optim_prefix = config.prefix.optim_prefix
        self.preseqlen = config.prefix.preseqlen
        self.prefix_dropout = config.prefix.prefix_dropout
        self.init_random = config.prefix.init_random
        self.hidden_dim_prefix = config.prefix.hidden_dim_prefix

        self.lowdata_token = config.prefix.lowdata_token
        self.init_shallow = config.prefix.init_shallow
        self.init_shallow_word = config.prefix.init_shallow_word
        self.mode_para = 0

        print('PrefixTuning')
        print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen))

        # DIFFERENT PARAMETRIZATION:

        print('[Full prefix-tuning Setting :) ]')
        self.input_tokens = torch.arange(self.preseqlen).long()
        self.wte = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans = nn.Sequential(
            nn.Linear(self.n_embd, self.hidden_dim_prefix),
            nn.Tanh(),
            nn.Linear(self.hidden_dim_prefix, self.match_n_layer * 2 * self.n_embd))
        self.get_prompt = self.get_prompt_p5
        self.dropout = nn.Dropout(self.prefix_dropout)

        ###### NUM PARAMS #########
        total_param = 0
        for name, param in self.named_parameters():
            # print(param.shape)
            total_param += param.numel()
        print('Total parameters is {}'.format(total_param))


    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        # if not args.model_name_or_path:
        #     args.model_name_or_path = args.prefix_model_name_or_path

        path = args.prefix_model_name_or_path
        path = _MODELS[path] if path in _MODELS else path
        path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))

        config_base = get_base_config('prefixtuning')
        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        config_update = OmegaConf.merge(config_base, config_new)

        for key, val in vars(args).items():
            if key in config_update.prefix.keys():
                OmegaConf.update(config_update, "prefix.%s" % key, val, merge=False)
            if key in config_update.optimizer.keys():
                OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
            if key in config_update.experiment.keys():
                OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)

        model = cls(config_update)
        model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                          context_length=model.config_dataset.context_length,
                                          lowercase=True,
                                          dropout=None)

        if args.model_name_or_path:
            print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
            # model.from_ckpt(args.model_name_or_path)
            try:
                model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
            except KeyError:
                model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])

        else:
            print("Loading models from checkpoint %s" % path)
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

        return model, config_update

    def get_prompt_p5(self, bsz=None, eval=False):
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
        temp_control = self.wte(input_tokens)
        past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
        bsz, seqlen, _ = past_key_values.shape
        past_key_values = past_key_values.view(bsz, seqlen, self.match_n_layer * 2, self.match_n_head,
                                               self.match_n_embd)
        if not eval:
            past_key_values = self.dropout(past_key_values)
        # past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
        past_key_values = past_key_values.permute([2, 0, 3, 1, 4])
        # print(past_key_values.shape)
        return past_key_values.split(2)

    def forward(self,
                images: torch.FloatTensor,
                texts: Optional[torch.LongTensor],
                **kwargs,
        ):

        #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}

        B, C, H, W = images.shape

        if self.mode_para == 2:
            past_key_values_prompt = self.get_prompt(bsz=B)
        else:
            past_key_values_prompt = self.get_prompt(bsz=B)

        # if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
        #     attention_mask = torch.cat([src_attn, tgt_attn], dim=1)


        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()

        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        # print(images.shape, codes.shape, texts.shape)
        logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past_key_values_prompt)
        return logits_img, logits_txt, codes

    @torch.no_grad()
    def sampling(self,
                 tokens: torch.LongTensor,
                 past: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True,
                 labels = None) -> torch.FloatTensor:
        self.stage1.eval()
        self.stage2.eval()

        if len(past.shape) == 6:
            n_layers, temp, bs, n_heads, seq_len, n_dim = past.shape
            past = past.view(n_layers, temp, bs*n_heads, seq_len, n_dim)

        tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        tokens = tokens.to(device)
        codes = sampling_prefix(self.stage2,
                                tokens,
                                past,
                                top_k=top_k,
                                top_p=top_p,
                                softmax_temperature=softmax_temperature,
                                use_fp16=use_fp16,
                                labels = None if labels is None else labels.view(-1))

        # codes = sampling(self.stage2,
        #                  tokens,
        #                  top_k=top_k,
        #                  top_p=top_p,
        #                  softmax_temperature=softmax_temperature,
        #                  use_fp16=use_fp16)

        codes = codes.view(num_candidates, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels

    def training_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    def validation_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)
        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    @torch.no_grad()
    def predict_step(self, batch, batch_idx, return_images=False):
        orig_images, texts = batch

        # extra for checks
        logits_img, logits_txt, codes = self(orig_images, texts)
        pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
        bs = orig_images.shape[0]
        pred = pred.view(bs, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy()  # [B, 256, 256]
        pixels = np.transpose(pixels, (0, 2, 3, 1))


        # print(texts.shape, orig_images.shape)
        # concatenate the list of prompts (split by n_head) for better downstream processing
        past_key_values_prompt = self.get_prompt(bsz=5, eval=True)
        # print(past_key_values_prompt[0].shape, past_key_values_prompt[1].shape, len(past_key_values_prompt))
        past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
        n_layers, temp, bs, n_heads, seq_len, n_dim = past_key_values_prompt.shape
        past_key_values_prompt = past_key_values_prompt.view(n_layers, temp, bs*n_heads, seq_len, n_dim)
        # print(past_key_values_prompt.shape)
        images = []
        for i, t in enumerate(texts):
            pixels = self.sampling(t, past_key_values_prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
            pixels = np.transpose(pixels, (0, 2, 3, 1))
            images.append(pixels)
            # images.extend([p for p in pixels])
            # print([i.shape for i in images])


        if return_images:
            return images
        else:
            save_image(orig_images, pixels, './out/images/pororo_prefix', batch_idx+10)
            save_image(orig_images, images, './out/images/pororo_prefix', batch_idx)


class ConditionalDalle(Dalle):
    """Classification Head for  transformer encoders"""
    def __init__(self, config):
        super().__init__(config)
        print('Initializing the Conditional Dalle model')

        self.config = config

        print('Setting up Cross-attention Layers')
        self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)

        ###### NUM PARAMS #########
        total_param = 0
        for name, param in self.named_parameters():
            # print(param.shape)
            total_param += param.numel()
        print('Total parameters is {}'.format(total_param))

    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        # if not args.model_name_or_path:
        #     args.model_name_or_path = args.prefix_model_name_or_path

        path = args.model_name_or_path
        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        if args.do_train:
            config_base = get_base_config('finetuning')
            config_update = OmegaConf.merge(config_base, config_new)
            for key, val in vars(args).items():
                if key in config_update.optimizer.keys():
                    OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
                if key in config_update.experiment.keys():
                    OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
        else:
            config_base = get_base_config('default')
            config_update = OmegaConf.merge(config_base, config_new)

        model = cls(config_update)
        model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                          context_length=model.config_dataset.context_length,
                                          lowercase=True,
                                          dropout=None)
        print(model.cross_attention_idxs)
        # print(next(model.cross_attention_layers[0].parameters()).is_cuda)

        if args.dalle_path:
            print("Loading model from pretrained checkpoint %s" % args.dalle_path)
            # model.from_ckpt(args.model_name_or_path)
            model.load_state_dict(torch.load(args.dalle_path)['model_state_dict'])
        else:
            print("Loading models from checkpoint %s" % path)
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

        return model, config_update


    def init_cross_attention(self, cross_attention_layers, hparams):
        self.cross_attention_idxs = cross_attention_layers
        self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
                                         embed_dim=hparams.embed_dim,
                                         n_heads=hparams.n_heads,
                                         attn_bias=hparams.attn_bias,
                                         resid_pdrop=hparams.resid_pdrop,
                                         attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]


    def forward(self,
                images: torch.FloatTensor,
                src_images: Optional[torch.FloatTensor],
                texts: Optional[torch.LongTensor],
                **kwargs,
        ):

        #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}

        # print(images.shape, src_images.shape, texts.shape)
        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()
                src_codes = self.stage1.get_codes(src_images).detach()

        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        src_codes = src_codes.clone().detach()
        src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        # print(images.shape, codes.shape, texts.shape)
        logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
                                                                  pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
                                                                  self.cross_attention_idxs, self.cross_attention_layers)
        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
        return logits_img, logits_txt, codes

    @torch.no_grad()
    def sampling(self,
                 prompt: torch.LongTensor,
                 source: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True) -> torch.FloatTensor:
        self.stage1.eval()
        self.stage2.eval()

        if type(prompt) == str:
            tokens = self.tokenizer.encode(prompt)
            tokens = torch.LongTensor(tokens.ids)
        else:
            tokens = prompt

        tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        tokens = tokens.to(device)
        source = source.to(device)

        with autocast(enabled=False):
            src_codes = self.stage1.get_codes(source).detach()
        src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)

        codes = sampling_conditional(self.stage2,
                                     self.cross_attention_idxs,
                                     self.cross_attention_layers,
                                     tokens,
                                     src_codes,
                                     top_k=top_k,
                                     top_p=top_p,
                                     softmax_temperature=softmax_temperature,
                                     use_fp16=use_fp16)
        codes = codes.view(num_candidates, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels

    def training_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    def validation_step(self, batch, batch_idx):
        images, texts = batch
        logits_img, logits_txt, codes = self(images, texts)
        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)

        loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
        loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
        self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss_img + loss_txt

    @torch.no_grad()
    def predict_step(self, batch, batch_idx):
        orig_images, texts = batch
        # concatenate the list of prompts (split by n_head) for better downstream processing
        past_key_values_prompt = self.get_prompt(bsz=5)
        past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
        images = []
        for t in texts:
            pixels = self.sampling(t, past_key_values_prompt, top_k=64, num_candidates=5).cpu().numpy()
            pixels = np.transpose(pixels, (0, 2, 3, 1))
            images.append(pixels)
            # images.extend([p for p in pixels])
            # print([i.shape for i in images])

        save_image(orig_images, images, './out/images/', batch_idx)


class PromptConditionalDalle(Dalle):
    """Classification Head for  transformer encoders"""
    def __init__(self, config):
        super().__init__(config)
        print('Initializing the Conditional Dalle model')

        self.config = config

        print('Setting up Cross-attention Layers')
        self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)

        self.n_embd = config.stage2.hparams.embed_dim
        self.preseqlen = config.story.preseqlen
        self.prefix_dropout = config.story.prefix_dropout

        # DIFFERENT PARAMETRIZATION:

        print('[Full prompt-tuning Setting :) ]')
        self.input_tokens = torch.arange(self.preseqlen).long()
        self.wte = nn.Embedding(self.preseqlen, self.n_embd)
        self.control_trans = nn.Sequential(
            nn.Linear(self.n_embd, self.n_embd),
            nn.Tanh(),
            nn.Linear(self.n_embd, self.n_embd))
        self.get_prompt = self.get_prompt_p5
        self.dropout = nn.Dropout(self.prefix_dropout)

        ###### NUM PARAMS #########
        total_param = 0
        for name, param in self.named_parameters():
            # print(param.shape)
            total_param += param.numel()
        print('Total parameters is {}'.format(total_param))

    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        # if not args.model_name_or_path:
        #     args.model_name_or_path = args.prefix_model_name_or_path

        path = args.prefix_model_name_or_path
        path = _MODELS[path] if path in _MODELS else path
        path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))

        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        if args.do_train:
            config_base = get_base_config('story')
            config_update = OmegaConf.merge(config_base, config_new)
            for key, val in vars(args).items():
                if key in config_update.story.keys():
                    OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
                if key in config_update.optimizer.keys():
                    OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
                if key in config_update.experiment.keys():
                    OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
        else:
            config_base = get_base_config('default')
            config_update = OmegaConf.merge(config_base, config_new)

        model = cls(config_update)
        model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                          context_length=model.config_dataset.context_length,
                                          lowercase=True,
                                          dropout=None)
        print(model.cross_attention_idxs)
        # print(next(model.cross_attention_layers[0].parameters()).is_cuda)

        if args.model_name_or_path:
            print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
            # model.from_ckpt(args.model_name_or_path)
            try:
                model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
            except KeyError:
                model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])

        else:
            print("Loading models from checkpoint %s" % path)
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

        return model, config_update


    def init_cross_attention(self, cross_attention_layers, hparams):
        self.cross_attention_idxs = cross_attention_layers
        self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
                                         embed_dim=hparams.embed_dim,
                                         n_heads=hparams.n_heads,
                                         attn_bias=hparams.attn_bias,
                                         resid_pdrop=hparams.resid_pdrop,
                                         attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]

    def get_prompt_p5(self, bsz=None, eval=False):
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
        temp_control = self.wte(input_tokens)
        past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
        if not eval:
            past_key_values = self.dropout(past_key_values)
        return past_key_values

    def forward(self,
                images: torch.FloatTensor,
                src_images: Optional[torch.FloatTensor],
                texts: Optional[torch.LongTensor],
                **kwargs,
        ):

        #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}

        # print(images.shape, src_images.shape, texts.shape)
        with torch.no_grad():
            with autocast(enabled=False):
                codes = self.stage1.get_codes(images).detach()
                src_codes = self.stage1.get_codes(src_images).detach()

        B, C, H, W = images.shape
        prompt = self.get_prompt(bsz=B)
        pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')

        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        src_codes = src_codes.clone().detach()
        src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        # print(images.shape, codes.shape, texts.shape)
        logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
                                                                  pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
                                                                  self.cross_attention_idxs, self.cross_attention_layers,
                                                                  prompt=prompt, pos_prompt=pos_enc_prompt)
        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
        return logits_img, logits_txt, codes

    @torch.no_grad()
    def sampling(self,
                 tokens: torch.LongTensor,
                 prompt: torch.LongTensor,
                 source: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True,
                 labels=None) -> torch.FloatTensor:

        self.stage1.eval()
        self.stage2.eval()

        if type(tokens) == str:
            tokens = self.tokenizer.encode(prompt)
            tokens = torch.LongTensor(tokens.ids)
        else:
            pass

        tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        tokens = tokens.to(device)
        source = source.to(device)

        pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')

        with autocast(enabled=False):
            src_codes = self.stage1.get_codes(source).detach()
        src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)

        codes = sampling_conditional(self.stage2,
                                     self.cross_attention_idxs,
                                     self.cross_attention_layers,
                                     tokens,
                                     src_codes,
                                     top_k=top_k,
                                     top_p=top_p,
                                     softmax_temperature=softmax_temperature,
                                     use_fp16=use_fp16,
                                     prompt=prompt,
                                     pos_prompt=pos_enc_prompt)

        codes = codes.view(num_candidates, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels


    @torch.no_grad()
    def predict_step(self, batch, batch_idx, return_images=False):
        orig_images, texts = batch
        # concatenate the list of prompts (split by n_head) for better downstream processing

        # extra for checks
        logits_img, logits_txt, codes = self(orig_images, texts)
        pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
        bs = orig_images.shape[0]
        pred = pred.view(bs, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy()  # [B, 256, 256]
        pixels = np.transpose(pixels, (0, 2, 3, 1))

        prompt = self.get_prompt(bsz=5, eval=True)

        images = []
        for i, t in enumerate(texts):
            pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
            pixels = np.transpose(pixels, (0, 2, 3, 1))
            images.append(pixels)
            # images.extend([p for p in pixels])
            # print([i.shape for i in images])

        if return_images:
            return images
        else:
            save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
            save_image(orig_images, images, './out/images/pororo_story', batch_idx)


class StoryDalle(Dalle):
    """Base model with story block"""
    def __init__(self, config):
        super().__init__(config)
        print('Initializing the Conditional Dalle model')

        self.config = config

        self.story_linear = nn.Linear(config.story.sent_embed, config.stage2.hparams.embed_dim)
        self.story_block = Block(ctx_len=config.story.story_len,
                             embed_dim=config.stage2.hparams.embed_dim,
                             n_heads=config.stage2.hparams.n_heads,
                             mlp_bias=config.stage2.hparams.mlp_bias,
                             attn_bias=config.stage2.hparams.attn_bias,
                             resid_pdrop=config.stage2.hparams.resid_pdrop,
                             attn_pdrop=config.stage2.hparams.attn_pdrop,
                             gelu_use_approx=config.stage2.hparams.gelu_use_approx)

        if self.config.story.prompt:
            self.n_embd = config.stage2.hparams.embed_dim
            self.preseqlen = config.story.preseqlen
            self.prefix_dropout = config.story.prefix_dropout

            # DIFFERENT PARAMETRIZATION:

            print('[Full prompt-tuning Setting :) ]')
            self.input_tokens = torch.arange(self.preseqlen).long()
            self.wte = nn.Embedding(self.preseqlen, self.n_embd)
            self.control_trans = nn.Sequential(
                nn.Linear(self.n_embd, self.n_embd),
                nn.Tanh(),
                nn.Linear(self.n_embd, self.n_embd))
            self.get_prompt = self.get_prompt_p5
            self.dropout = nn.Dropout(self.prefix_dropout)

        if self.config.story.condition:
            print('Setting up Cross-attention Layers')
            self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)

        ###### NUM PARAMS #########
        total_param = 0
        for name, param in self.named_parameters():
            # print(param.shape)
            total_param += param.numel()
        print('Total parameters is {}'.format(total_param))

    @classmethod
    def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:

        # if not args.model_name_or_path:
        #     args.model_name_or_path = args.prefix_model_name_or_path

        path = args.prefix_model_name_or_path
        path = _MODELS[path] if path in _MODELS else path
        path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))

        config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
        # if args.do_train:
        config_base = get_base_config('story')
        config_update = OmegaConf.merge(config_base, config_new)
        for key, val in vars(args).items():
            if key in config_update.story.keys():
                OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
            if key in config_update.optimizer.keys():
                OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
            if key in config_update.experiment.keys():
                OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
        # else:
        #     config_base = get_base_config('story')
        #     config_update = OmegaConf.merge(config_base, config_new)
        # print(next(model.cross_attention_layers[0].parameters()).is_cuda)

        if args.model_name_or_path:

            model = cls(config_update)
            model_dir = os.path.dirname(args.model_name_or_path)
            print(model_dir)
            model.tokenizer = build_tokenizer(model_dir,
                                              context_length=model.config_dataset.context_length,
                                              lowercase=True,
                                              dropout=None)
            print("Loaded tokenizer from finetuned checkpoint")
            print(model.cross_attention_idxs)
            print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)

            # model.from_ckpt(args.model_name_or_path)

            try:
                model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict'])
            except KeyError:
                model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['model_state_dict'])
        else:
            model = cls(config_update)
            print(model.cross_attention_idxs)
            print("Loading models from checkpoint %s" % path)
            model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
            model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))

            model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
                                              context_length=model.config_dataset.context_length,
                                              lowercase=True,
                                              dropout=None)

        return model, config_update


    def init_cross_attention(self, cross_attention_layers, hparams):
        self.cross_attention_idxs = cross_attention_layers
        self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
                                         embed_dim=hparams.embed_dim,
                                         n_heads=hparams.n_heads,
                                         attn_bias=hparams.attn_bias,
                                         resid_pdrop=hparams.resid_pdrop,
                                         attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]


    def get_prompt_p5(self, bsz=None, eval=False):
        input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
        temp_control = self.wte(input_tokens)
        past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
        if not eval:
            past_key_values = self.dropout(past_key_values)
        return past_key_values


    def forward(self,
                images: torch.FloatTensor,
                src_images: Optional[torch.FloatTensor],
                texts: Optional[torch.LongTensor],
                sent_embeds: Optional[torch.FloatTensor],
                **kwargs,
        ):

        # print(images.shape, src_images.shape, texts.shape, sent_embeds.shape)

        B, L, C, H, W = images.shape
        images = images.view(B*L, C, H, W)
        src_images = src_images.unsqueeze(1).expand(-1, L, -1, -1, -1).reshape(B*L, C, H, W)
        sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
        texts = texts.view(B * L, -1)

        #{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}

        with torch.no_grad():
            #with autocast(enabled=False):
            codes = self.stage1.get_codes(images).detach()
            src_codes = self.stage1.get_codes(src_images).detach()

        B, C, H, W = images.shape

        if self.config.story.prompt:
            prompt = self.get_prompt(bsz=B)
            prompt = torch.cat([prompt, sent_embeds], dim=1)
        else:
            prompt = sent_embeds

        # dim = 0 for full-model finetuning??
        pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B, -1).to(self.device),
                                                 mode='1d')

        pos_enc_tokens = get_positional_encoding(texts, mode='1d')
        codes = codes.clone().detach()
        pos_enc_code = get_positional_encoding(codes, mode='1d')
        src_codes = src_codes.clone().detach()
        src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
        # codes = codes.unsqueeze(-1)
        # pos_enc_code = pos_enc_code.unsqueeze(-1)
        # print(images.shape, codes.shape, texts.shape)
        if self.config.story.condition:
            logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
                                                                      pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
                                                                      self.cross_attention_idxs, self.cross_attention_layers,
                                                                      prompt=prompt, pos_prompt=pos_enc_prompt)
        else:
            logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt,
                                                 pos_prompt=pos_enc_prompt)

        # print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
        return logits_img, logits_txt, codes


    @torch.no_grad()
    def sampling(self,
                 tokens: torch.LongTensor,
                 source: torch.FloatTensor,
                 sent_embeds: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 num_candidates: int = 96,
                 device: str = 'cuda:0',
                 use_fp16: bool = True,
                 labels=None,
                 prompt = None) -> torch.FloatTensor:

        self.stage1.eval()
        self.stage2.eval()

        if type(tokens) == str:
            tokens = self.tokenizer.encode(tokens)
            tokens = torch.LongTensor(tokens.ids)

        # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        #tokens = tokens.to(device)
        #source = source.to(device)

        # print(tokens.shape, sent_embeds.shape, prompt.shape)
        B, L, _ = sent_embeds.shape
        sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
        if prompt is not None:
            prompt = torch.cat([prompt, sent_embeds], dim=1)
        else:
            prompt = sent_embeds
        pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')

        #with autocast(enabled=False):
        src_codes = self.stage1.get_codes(source).detach()

        src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
        print(tokens.shape, src_codes.shape, prompt.shape)
        if self.config.story.condition:
            codes = sampling_conditional(self.stage2,
                                         self.cross_attention_idxs,
                                         self.cross_attention_layers,
                                         tokens,
                                         src_codes,
                                         top_k=top_k,
                                         top_p=top_p,
                                         softmax_temperature=softmax_temperature,
                                         use_fp16=use_fp16,
                                         prompt=prompt,
                                         pos_prompt=pos_enc_prompt)
        else:
            codes = sampling(self.stage2,
                             tokens,
                             top_k=top_k,
                             top_p=top_p,
                             softmax_temperature=softmax_temperature,
                             use_fp16=use_fp16,
                             prompt=prompt,
                             pos_prompt=pos_enc_prompt)

        codes = codes.view(self.config.story.story_len, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 256, 256]
        return pixels


    @torch.no_grad()
    def sampling_batch(self,
                 tokens: torch.LongTensor,
                 source: torch.FloatTensor,
                 sent_embeds: torch.FloatTensor,
                 top_k: int = 256,
                 top_p: Optional[float] = None,
                 softmax_temperature: float = 1.0,
                 device: str = 'cuda:0',
                 use_fp16: bool = True,
                 prompt=None, n_candidates=1) -> torch.FloatTensor:

        self.stage1.eval()
        self.stage2.eval()

        if type(tokens) == str:
            tokens = self.tokenizer.encode(tokens)
            tokens = torch.LongTensor(tokens.ids)

        # tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)

        # Check if the encoding works as intended
        # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])

        #tokens = tokens.to(device)
        #source = source.to(device)

        # print(tokens.shape, sent_embeds.shape, prompt.shape)
        B, L, _ = sent_embeds.shape
        sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
        if prompt is not None:
            prompt = torch.cat([prompt, sent_embeds], dim=1)
        else:
            prompt = sent_embeds
        pos_enc_prompt = get_positional_encoding(
            torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(tokens.device), mode='1d')

        #with autocast(enabled=False):
        src_codes = self.stage1.get_codes(source).detach()
        # src_codes = src_codes.to(device=device) #ensure that src_codes is moved to GPU in case VQGAN was kept in CPU

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            # repeat inputs to adjust to n_candidates and story length
            src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
            prompt = prompt.repeat(n_candidates, 1, 1)
            pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
            tokens = tokens.repeat(n_candidates, 1)
            print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
            if self.config.story.condition:
                codes = sampling_conditional(self.stage2,
                                             self.cross_attention_idxs,
                                             self.cross_attention_layers,
                                             tokens,
                                             src_codes,
                                             top_k=top_k,
                                             top_p=top_p,
                                             softmax_temperature=softmax_temperature,
                                             use_fp16=use_fp16,
                                             prompt=prompt,
                                             pos_prompt=pos_enc_prompt)
            else:
                codes = sampling(self.stage2,
                                 tokens,
                                 top_k=top_k,
                                 top_p=top_p,
                                 softmax_temperature=softmax_temperature,
                                 use_fp16=use_fp16,
                                 prompt=prompt,
                                 pos_prompt=pos_enc_prompt)

            codes = codes.view(self.config.story.story_len * n_candidates, 16, 16)  # [B, 16, 16]
            print(codes.shape)

        pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1)  # [B, 3, 256, 256]
        print(pixels.shape)
        return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])


    @torch.no_grad()
    def predict_step(self, batch, batch_idx, return_images=False):
        orig_images, texts = batch
        # concatenate the list of prompts (split by n_head) for better downstream processing

        # extra for checks
        logits_img, logits_txt, codes = self(orig_images, texts)
        pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
        bs = orig_images.shape[0]
        pred = pred.view(bs, 16, 16)  # [B, 16, 16]
        pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy()  # [B, 256, 256]
        pixels = np.transpose(pixels, (0, 2, 3, 1))
        prompt = self.get_prompt(bsz=5, eval=True)

        images = []
        for i, t in enumerate(texts):
            pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
            pixels = np.transpose(pixels, (0, 2, 3, 1))
            images.append(pixels)
            # images.extend([p for p in pixels])
            # print([i.shape for i in images])

        if return_images:
            return images
        else:
            save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
            save_image(orig_images, images, './out/images/pororo_story', batch_idx)