import torch
from torch import Tensor

from transformer import Transformer
from tokenizers import Tokenizer
from dataset import causal_mask


def greedy_decode(
    model: Transformer, 
    src: Tensor, 
    src_mask: Tensor, 
    src_tokenizer: Tokenizer, 
    tgt_tokenizer: Tokenizer, 
    tgt_max_seq_len: int, 
    device,
    give_attn: bool = False,
):
    """
    Decodes greedily.
    """
    sos_idx = src_tokenizer.token_to_id('<sos>')
    eos_idx = src_tokenizer.token_to_id('<eos>')

    encoder_output = model.encode(src, src_mask)

    attn = None
    decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)
    
    while True:
        if decoder_input.size(1) == tgt_max_seq_len:
            break

        # build target mask
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(src).to(device)

        # get decoder output
        decoder_output, attn = model.decode(encoder_output, src_mask, decoder_input, decoder_mask)

        prob = model.project(decoder_output[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1,1).type_as(src).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break
    if give_attn:
        return (decoder_input.squeeze(0), attn)
    return decoder_input.squeeze(0)

def beam_search_decode(
    model: Transformer, 
    src: Tensor, 
    src_mask: Tensor, 
    src_tokenizer: Tokenizer, 
    tgt_tokenizer: Tokenizer, 
    tgt_max_seq_len: int, 
    device,
    beam_size: int = 3,
):
    sos_idx = src_tokenizer.token_to_id('<sos>')
    eos_idx = src_tokenizer.token_to_id('<eos>')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(src, src_mask)
    # Initialize the decoder input with the sos token
    decoder_initial_input = torch.empty(1,1).fill_(sos_idx).type_as(src).to(device)

    # Create a candidate list
    candidates = [(decoder_initial_input, 1)]

    while True:

        # If a candidate has reached the maximum length, it means we have run the decoding for at least max_len iterations, so stop the search
        if any([cand.size(1) == tgt_max_seq_len for cand, _ in candidates]):
            break

        # Create a new list of candidates
        new_candidates = []

        for candidate, score in candidates:

            # Do not expand candidates that have reached the eos token
            if candidate[0][-1].item() == eos_idx:
                continue

            # Build the candidate's mask
            candidate_mask = causal_mask(candidate.size(1)).type_as(src_mask).to(device)
            # calculate output
            out, attn = model.decode(encoder_output, src_mask, candidate, candidate_mask)
            # get next token probabilities
            prob = model.project(out[:, -1])
            # get the top k candidates
            topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)
            for i in range(beam_size):
                # for each of the top k candidates, get the token and its probability
                token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
                token_prob = topk_prob[0][i].item()
                # create a new candidate by appending the token to the current candidate
                new_candidate = torch.cat([candidate, token], dim=1)
                # We sum the log probabilities because the probabilities are in log space
                new_candidates.append((new_candidate, score + token_prob))

        # Sort the new candidates by their score
        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
        # Keep only the top k candidates
        candidates = candidates[:beam_size]

        # If all the candidates have reached the eos token, stop
        if all([cand[0][-1].item() == eos_idx for cand, _ in candidates]):
            break

    # Return the best candidate
    return candidates[0][0].squeeze()