import logging
from bisect import bisect_left
from collections import OrderedDict

import cv2
import numpy as np
import torch

from m4.training.utils import FAKE_TOKEN_AROUND_IMAGE_V2, IMAGE_TOKEN, _convert_to_rgb


logger = logging.getLogger(__name__)


# Hyper-parameters
_IMAGE_BONUS_VALUE = 2  # The bonus value for tokens preceding the image token
_MIN_LENGTH_DOCUMENTS_TO_PACK = (
    5  # Minimum lengths of documents to pack together (lenghts is measures in number of tokens)
)


def incremental_to_binary_attention_mask(incremental_mask, num_classes=-1):
    # This function converts: [-1, 0, 1] => [[0, 0], [1, 0], [0, 1]]

    # If any of images index are more than num_classes, set them to -1.
    # Words after the max number of images allowed have been seen don't attend on anything
    if num_classes != -1:
        incremental_mask[incremental_mask >= num_classes] = -1

    negatives = incremental_mask == -1
    incremental_mask[negatives] = 0
    attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
    attn_mask[negatives, :] = 0
    return attn_mask


def image_attention_mask_for_packed_input_ids(input_ids, tokenizer):
    image_attention_mask = torch.full_like(input_ids, fill_value=-1)
    next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
    image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
    eod_token_id = tokenizer.eos_token_id
    for batch_idx in range(input_ids.size(0)):
        count = -1
        seen_eod = False
        for idx, token_id in enumerate(input_ids[batch_idx]):
            if token_id == image_token_id:
                count += 1
                image_attention_mask[batch_idx][idx] = count
                seen_eod = False
            else:
                image_attention_mask[batch_idx][idx] = count

            if seen_eod:
                image_attention_mask[batch_idx][idx] = -1

            if token_id == eod_token_id:
                seen_eod = True

    for batch_idx in range(input_ids.size(0)):
        count = -1
        seen_eod = False
        for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
            token_id = input_ids[batch_idx][idx]
            if token_id == image_token_id:
                count += 1
                next_image_attention_mask[batch_idx][idx] = count
                seen_eod = False
            else:
                next_image_attention_mask[batch_idx][idx] = count

            if token_id == eod_token_id:
                seen_eod = True

            if seen_eod:
                next_image_attention_mask[batch_idx][idx] = -1

        non_negative_indices = next_image_attention_mask[batch_idx] != -1
        next_image_attention_mask[batch_idx][non_negative_indices] -= count
        next_image_attention_mask[batch_idx][non_negative_indices] *= -1

    return image_attention_mask, next_image_attention_mask


def laplacian_blur_detection(image, threshold=0.0):
    # compute the Laplacian of the image and then return the focus
    # measure, which is simply the variance of the Laplacian
    if threshold == 0.0:
        return False

    image = np.array(image)

    if len(image.shape) == 3 and image.shape[2] == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        return cv2.Laplacian(gray, cv2.CV_64F).var() < threshold
    else:
        # Don't remove grayscale images
        return False


def fft_blur_detection(image, size=50, threshold=0.0):
    if threshold == 0.0:
        return False
    (h, w) = image.shape
    (cX, cY) = (int(w / 2.0), int(h / 2.0))
    fft = np.fft.fft2(image)
    fftShift = np.fft.fftshift(fft)
    fftShift[cY - size : cY + size, cX - size : cX + size] = 0
    fftShift = np.fft.ifftshift(fftShift)
    recon = np.fft.ifft2(fftShift)
    magnitude = 20 * np.log(np.abs(recon))
    mean = np.mean(magnitude)
    return mean < threshold


def split_pack_and_pad(
    sample,
    tokenizer,
    max_seq_len,
    image_transform,
    max_num_images,
    max_num_samples_per_document=10,
    prefix_seed=(0, 0),
    is_blurred_fn=None,
    blur_threshold=0.0,
    add_begin_of_doc_token=False,
    add_end_of_doc_token=True,
    max_num_images_per_document=None,
):
    """
    Return a batch of samples in the format expected by the model which
    includes `input_ids`, `pixel_values`, `attention_mask`, `image_attention_mask`,
    and `next_image_attention_mask`. The `input_ids` are sampled from the document to
    ensure it has `max_seq_len` tokens otherwise, the shorter documents are packed together.
    For each document, we sample a maximum of `max_num_samples_per_document` or `max_num_samples_for_curr_document`
    (where the latter is proportional to the length of the document and inversely proportional to the length of subsequences)
    `input_ids` with sequence length `max_seq_len` from the document. This means that
    each sample sampled can have different start index. Based on the start index of sample that
    has been sampled, we also sample a maximum of `max_num_images` images from the document.
    If there are less than `max_num_images` images in the document, we pad the images with zeros.
    The start indexes are skewed towards subsequences that contain images.

    Args:
        sample (Dict): A sample object containing the document with images and text.
        tokenizer (PretrainedTokenizer): Text tokenizer to be used.
        max_seq_len (int): Maximum sequence length of the returned text tokens.
        image_transform (Callable): Transform to be applied on the images
        max_num_images (int): Maximum number of images to be sampled per sample. If less, they are padded with zeros.
        max_num_samples_per_document (int, optional): Maximum number of samples per document to be sampled. Defaults to 10.
        prefix_seed: Prefix seed sequence for "reproducible randomness" in calls to `np.random.choice`

    Returns:
        _type_: _description_
    """
    text_batch = sample["texts"]

    image_batch = sample.get("image_embeddings", None)
    is_raw_images = False
    if image_batch is None:
        image_batch = sample.get("images", None)
        is_raw_images = True
    if image_batch is None:
        raise ValueError("Either image_embeddings or images must be present in the sample")

    image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
    last_was_image = False

    if is_blurred_fn is None:
        is_blurred_fn = fft_blur_detection

    all_images = []
    all_texts = []
    for raw_images, raw_texts in zip(image_batch, text_batch):
        # Filter ones that don't have either one image and one text word
        if not any(raw_images) or not any(raw_texts):
            continue

        if max_num_images_per_document:
            num_images = sum([1 if image is not None else 0 for image in raw_images])
            if num_images > max_num_images_per_document:
                continue

        any_blurred = False

        if is_raw_images and blur_threshold > 0.0:
            for image in raw_images:
                if image is not None:
                    image = _convert_to_rgb(image)
                    any_blurred = any_blurred or is_blurred_fn(image, threshold=blur_threshold)
                    if any_blurred:
                        break

        if any_blurred:
            continue

        inds_of_texts_to_split = [
            i
            for i, text in enumerate(raw_texts)
            if text is not None and isinstance(text, str) and "END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED" in text
        ]
        if inds_of_texts_to_split:
            splitted_raw_images, splitted_raw_texts = [], []
            previous_i = 0
            for i in inds_of_texts_to_split:
                splitting = raw_texts[i].split("END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED")
                part1, part2 = splitting[0], splitting[-1]

                sub_doc_images = raw_images[previous_i:i] + [None]
                sub_doc_texts = raw_texts[previous_i:i] + [part1.strip()]
                if not any(sub_doc_images):  # This can happen if all images in raw_images[0:i] are all None
                    continue

                splitted_raw_images.append(sub_doc_images)
                splitted_raw_texts.append(sub_doc_texts)

                if part2.strip() == "":
                    previous_i = i + 1
                else:
                    raw_texts[i] = part2.strip()
                    previous_i = i

            if previous_i < len(raw_images) and any(raw_images[previous_i:]):
                splitted_raw_images.append(raw_images[previous_i:])
                splitted_raw_texts.append(raw_texts[previous_i:])

        else:
            splitted_raw_images, splitted_raw_texts = [raw_images], [raw_texts]

        # Sanity check
        if [len(ims) for ims in splitted_raw_images] != [len(txts) for txts in splitted_raw_texts]:
            raise ValueError(
                "Number of images and texts don't match after splitting on `END_OF_DOCUMENT_TOKEN_TO_BE_REPLACED`."
                " Something core went wrong during the splitting and needs to be fixed."
            )

        for s_r_ims, s_r_txts in zip(splitted_raw_images, splitted_raw_texts):
            images, web_text = [], ""
            for image, text in zip(s_r_ims, s_r_txts):
                if text is None and image is None:
                    continue

                if image is not None:
                    web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}"
                    if is_raw_images:
                        images.append(image_transform(image))
                    else:
                        images.append(torch.tensor(image))
                    last_was_image = True
                elif text is not None:
                    if last_was_image:
                        web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}{text}"
                        last_was_image = False
                    else:
                        web_text += f" {text}" if web_text != "" else text

            if last_was_image:
                web_text += f"{FAKE_TOKEN_AROUND_IMAGE_V2}"

            web_text = web_text.strip(" ")

            # This is mostly a sanity check. Cases like that should not happen at that point.
            if web_text == "" or len(images) == 0:
                continue

            images = torch.stack(images)
            all_images.append(images)

            web_text_ids = tokenizer.encode(web_text, add_special_tokens=False)
            if add_end_of_doc_token:
                web_text_ids += [tokenizer.eos_token_id]

            if add_begin_of_doc_token:
                web_text_ids = [tokenizer.bos_token_id] + web_text_ids
            all_texts.append(web_text_ids)

    output_input_ids = []
    output_images = []
    output_attention_masks = []
    output_num_images = []
    output_num_text_tokens = []

    input_ids_to_pack = []
    images_to_pack = []
    for images, text in zip(all_images, all_texts):
        # We save all the documents which are shorter than the max_seq_len to pack them together.
        if len(text) <= max_seq_len:
            if len(text) < _MIN_LENGTH_DOCUMENTS_TO_PACK:  # Filter out extremely short sequences
                continue
            input_ids_to_pack.extend(text)
            images_to_pack.extend(images)
        else:
            # Computing the bonus scores for tokens near images to skew the sampling towards them
            # The main idea is to give a bonus to tokens that are closely before an image token, so that these tokens have more chance to be sampled.
            # Bonuses are computed for each image, which means a given token can receive bonuses from multiple images if this token is closely preceding multiple images.
            # We sum all the bonuses and L1 normalized along the seq_len axis to get a probability distribution.
            # Each token start with a regular bonus of 1, which corresponds to the uniform distribution over the sequence when there are no bonuses added.

            # Now the remaining question is which precedding tokens do we distribue bonuses to.
            # We first observe that for the sampled sub-sequence to be considered valid (i.e. sub-sequence contains an image), the start index can only be among [image_idx - max_seq_len + 1, image_idx].
            # For the sake of the explanation, let's split the [image_idx - max_seq_len + 1, image_idx] interval in 3 parts: left, middle and right (in increasing order).
            # If we give bonuses to the tokens just before the image (right part), then we are favoring p_next=0 because only the tokens after the image have an image to attend to.
            # In practice, images will tend to be at the beginning of the sampled sub-sequence.
            # If we give bonuses very far before the image (left part), then we are favoring p_next=1 because only the tokens before the image gave an image to attend to.
            # In practice, images will tend to be at the end of the sampled sub-sequence.
            # To avoid choosing favoring p_next=0 or p_next=1, we can give bonuses to the tokens in the middle part.
            # In practise, images will tend to be in the middle of the sampled sequence.

            # Ultimately, we don't want to skew the distribution fed to model in that way (i.e. whether images are in the beginning, middle or end of the sampled sub-sequence),
            # and have all these cases represented equally in the data. So the easiest is to distribute a bonus to all of the max_seq_len tokens preceding the image.
            all_scores = np.array([1] * len(text))
            for img_token_idx in np.where(np.array(text) == image_token_id)[0]:
                all_scores[max(0, img_token_idx - max_seq_len) : img_token_idx + 1] += _IMAGE_BONUS_VALUE
            # all_scores = np.clip(all_scores, a_min=1, a_max=3 * _IMAGE_BONUS_VALUE * max_num_images + 1) # We can optionally clip the bonuses to avoid having too high values (i.e. outliers documents)
            all_scores = all_scores[:-_MIN_LENGTH_DOCUMENTS_TO_PACK]

            # The number of samples is proportional to the length of the text and inversely proportional to the maximum sequence length
            max_num_samples_for_curr_document = len(text) // max_seq_len
            # Set "reproducible randomness" by creating an np.default_rng seeded by (main seed, epoch, rank_idx, worker_idx, mapped_batch_index, text len)
            choices = np.random.default_rng(seed=list(prefix_seed) + [len(text)]).choice(
                range(len(text) - _MIN_LENGTH_DOCUMENTS_TO_PACK),  # shorter sub-sequences are reserved for packing
                min(
                    len(text) - max_seq_len, 2 * max_num_samples_per_document
                ),  # Sampling more than necessary and then breaking out of the for loop once we have enough samples
                p=all_scores / np.linalg.norm(all_scores, ord=1),
                replace=False,
            )

            nb_effective_sequences_out_of_sampling = 0
            for start_index in choices:
                image_start_index = text[:start_index].count(image_token_id)
                text_sub_sequence = text[start_index : start_index + max_seq_len]
                image_count = text_sub_sequence.count(image_token_id)
                if image_count == 0:
                    # Skip if there are no images in the sequence
                    continue

                if len(text_sub_sequence) < max_seq_len:
                    # If the sub-sequence is shorter than max_seq_len, we reserve it for packing
                    # It necessarily mean that the sub-sequence was sampled towards the end of the document,
                    # which implies that we only need the `image_start_index` and not the `image_end_index`
                    if text_sub_sequence.count(image_token_id) != len(images[image_start_index:]):
                        # A safeguard for this
                        logger.warning(
                            "Skipping this sample because of mismatch in actual number of images and "
                            "the '<image>' tokens in the text"
                        )
                        continue
                    input_ids_to_pack.extend(text_sub_sequence)
                    images_to_pack.extend(images[image_start_index:])
                    continue

                current_images = images[image_start_index : image_start_index + min(max_num_images, image_count)]
                if len(current_images) != min(max_num_images, image_count):
                    # A safeguard for something off about this document, maybe `<image>` tag that
                    # by there from before or some issue in parsing the image?
                    logger.warning(
                        "Skipping this sample because of mismatch in actual number of images and "
                        "the '<image>' tokens in the text"
                    )
                    break
                padded_image_tensor = torch.zeros(max_num_images, *images.size()[1:])
                padded_image_tensor[: min(max_num_images, image_count)] = current_images
                output_images.append(padded_image_tensor)
                output_num_images.append(min(max_num_images, image_count))

                output_input_ids.append(torch.tensor(text_sub_sequence))
                output_num_text_tokens.append(len(text_sub_sequence))

                attention_mask = torch.ones((max_seq_len,), dtype=torch.long)
                output_attention_masks.append(attention_mask)

                nb_effective_sequences_out_of_sampling += 1
                if nb_effective_sequences_out_of_sampling >= min(
                    max_num_samples_for_curr_document, max_num_samples_per_document
                ):
                    # We got all the samples we need for this document, so breaking out
                    break

    # Pack the remaining sequences from `input_ids_to_pack` x `images_to_pack`
    if input_ids_to_pack:
        image_counter = 0
        for i in range(0, len(input_ids_to_pack), max_seq_len):
            current_input_ids = input_ids_to_pack[i : i + max_seq_len]
            unpadded_seq_len = len(current_input_ids)
            num_images = current_input_ids.count(image_token_id)
            if num_images == 0:
                continue
            current_images = images_to_pack[image_counter : image_counter + num_images]
            image_counter += num_images
            if unpadded_seq_len < max_seq_len:
                padded_input_ids = [tokenizer.pad_token_id] * max_seq_len
                padded_input_ids[:unpadded_seq_len] = current_input_ids
                current_input_ids = padded_input_ids
            elif unpadded_seq_len > max_seq_len:
                # This case has no purpose other than safeguard
                continue
            try:
                current_images = torch.stack(current_images)[:max_num_images]
            except Exception:
                continue
            padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
            padded_image_tensor[: current_images.size(0)] = current_images
            attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
            attention_mask[:unpadded_seq_len] = 1

            output_images.append(padded_image_tensor)
            output_input_ids.append(torch.tensor(current_input_ids))
            output_num_text_tokens.append(unpadded_seq_len)
            output_num_images.append(min(max_num_images, num_images))

            output_attention_masks.append(attention_mask)

    if len(output_images) == 0 or len(output_input_ids) == 0:
        result = {
            "input_ids": torch.tensor([], dtype=torch.long),
            "attention_mask": torch.tensor([], dtype=torch.bool),
            "image_attention_mask": torch.tensor([], dtype=torch.bool),
            "next_image_attention_mask": torch.tensor([], dtype=torch.bool),
            "num_images": torch.tensor([], dtype=torch.long),
            "num_text_tokens": torch.tensor([], dtype=torch.long),
        }
        if is_raw_images:
            result["pixel_values"] = torch.tensor([], dtype=torch.float32)
        else:
            result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
        return result

    output_input_ids = torch.stack(output_input_ids)
    output_images = torch.stack(output_images)
    output_attention_masks = torch.stack(output_attention_masks)

    image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(
        output_input_ids, tokenizer
    )
    image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
    next_image_attention_mask = incremental_to_binary_attention_mask(
        next_image_attention_mask, num_classes=max_num_images
    )

    result = {
        "input_ids": output_input_ids,
        "attention_mask": output_attention_masks,
        "image_attention_mask": image_attention_mask,
        "next_image_attention_mask": next_image_attention_mask,
        "num_images": torch.tensor(output_num_images),
        "num_text_tokens": torch.tensor(output_num_text_tokens),
    }
    if is_raw_images:
        result["pixel_values"] = output_images
    else:
        result["image_embeddings"] = output_images
    return result


def split_and_pad_pmd(
    sample,
    tokenizer,
    max_seq_len,
    image_transform,
    max_num_images,
    prefix_seed=(0, 0),
    is_blurred_fn=None,
    blur_threshold=0.0,
    prob_image_at_end=0.5,  # If 1, the <image> token is always added at the end of the text
    # If set to -1, all padding will be tolerated. If set to 0, no padding will be tolerated.
    padding_tolerance=-1,
    add_begin_of_doc_token=False,
    add_end_of_doc_token=True,
):
    if is_blurred_fn is None:
        is_blurred_fn = fft_blur_detection

    text_batch = sample["text"]
    image_batch = sample.get("image_embedding", None)
    is_raw_images = False
    if image_batch is None:
        image_batch = sample.get("image", None)
        is_raw_images = True

    filtered_image_batch = []
    filtered_input_ids = []

    # Define whether for the current PMD batch whether the images will be at the start or at the end.
    rng = np.random.default_rng(seed=list(prefix_seed))
    is_image_at_end = False

    # rng.random is between 0 and 1, so if prob_image_at_end is 1, random value will
    # always be less than `prob_image_at_end` and `is_image_at_end` will always be True.
    # This means that images will always be at the end of the text.
    if rng.random() < prob_image_at_end:
        is_image_at_end = True

    for image, text in zip(image_batch, text_batch):
        if text is None or image is None:
            continue

        if is_raw_images and is_blurred_fn(image, threshold=blur_threshold):
            continue

        sample_text = f"{FAKE_TOKEN_AROUND_IMAGE_V2}{IMAGE_TOKEN}{FAKE_TOKEN_AROUND_IMAGE_V2}"

        # Remove trailing and leading whitespaces, including newlines and tabs
        text = text.strip()

        if is_image_at_end:
            sample_text = f"{text}{sample_text}"
        else:
            sample_text = f"{sample_text}{text}"

        sample_input_ids = tokenizer.encode(sample_text, add_special_tokens=False)
        if add_end_of_doc_token:
            sample_input_ids += [tokenizer.eos_token_id]

        if add_begin_of_doc_token:
            sample_input_ids = [tokenizer.bos_token_id] + sample_input_ids

        filtered_image_batch.append(image)
        filtered_input_ids.append(sample_input_ids)

    # sort by length of text and save same length elements in a mapping so we
    # can retrieve candidates later.
    filtered_image_batch, filtered_input_ids = zip(
        *sorted(zip(filtered_image_batch, filtered_input_ids), key=lambda x: len(x[1]))
    )
    mapping_by_len = OrderedDict()
    for i, sample_input_ids in enumerate(filtered_input_ids):
        if len(sample_input_ids) not in mapping_by_len:
            mapping_by_len[len(sample_input_ids)] = []
        mapping_by_len[len(sample_input_ids)].append((filtered_image_batch[i], sample_input_ids))

    all_images = []
    all_texts = []
    all_attention_masks = []
    all_num_images = []
    all_num_text_tokens = []
    current_text = []
    current_images = []

    while True:
        current_lens = list(mapping_by_len.keys())
        if len(current_text) > 0:
            # Now we try to do a binary search to find the biggest sequence that
            # we can fit into the current sequence.
            # This will eventually use up bigger sequences faster which is good
            # and leave smaller sequences to pack with each other later.
            diff = max_seq_len - len(current_text)
            if len(current_lens) == 0:
                possible_index = -1
            else:
                possible_index = bisect_left(current_lens, diff)
                if possible_index == len(current_lens) or current_lens[possible_index] != diff:
                    possible_index -= 1

            if possible_index >= 0:
                best_possible_length = current_lens[possible_index]
                image, sample_input_ids = mapping_by_len[best_possible_length].pop(0)

                # If we have used up all the samples of a certain length, remove
                # that length from the mapping.
                if len(mapping_by_len[best_possible_length]) == 0:
                    del mapping_by_len[best_possible_length]
                current_text.extend(sample_input_ids)
                if is_raw_images:
                    current_images.append(image_transform(image))
                else:
                    current_images.append(torch.tensor(image))
            elif diff > padding_tolerance and padding_tolerance != -1:
                # If we are here, it means that we still have padding left
                # and we have exhausted our current unique options that will allow us to
                # fill this sequence completely.
                # So, we will try to fill the sequence with whatever we get from the unchanged
                # copy of all sequences.
                while diff > padding_tolerance:
                    # Find a random sequence to fit
                    # Why we need to add more stuff to prefix seed?
                    # prefix_seed will be same in the same batch which means that it might sample
                    # same thing again and again if there are multiple cases of padding in the
                    # same batch which means we need to make this part as random as possible.
                    rng = np.random.default_rng(
                        prefix_seed
                        + (
                            diff,
                            len(current_text),
                            len(all_texts),
                            all_num_images,
                        )
                    )
                    choice = rng.choice(range(len(filtered_input_ids)))
                    image, sample_input_ids = filtered_image_batch[choice], filtered_input_ids[choice]
                    current_text.extend(sample_input_ids)
                    if is_raw_images:
                        current_images.append(image_transform(image))
                    else:
                        current_images.append(torch.tensor(image))
                    diff = max_seq_len - len(current_text)
                # In the next top-level while loop iteration, this should go into the else
                # clause which should also handle the sequences longer than max_seq_len
            else:
                current_images = torch.stack(current_images)
                padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
                padded_image_tensor[: current_images.size(0)] = current_images[
                    : min(max_num_images, current_images.size(0))
                ]
                all_num_images.append(min(max_num_images, current_images.size(0)))
                all_images.append(padded_image_tensor)

                padded_input_ids = torch.full((max_seq_len,), tokenizer.pad_token_id)
                current_max_len = min(max_seq_len, len(current_text))
                padded_input_ids[:current_max_len] = torch.tensor(current_text)[:current_max_len]
                all_num_text_tokens.append(current_max_len)
                all_texts.append(padded_input_ids)

                attention_mask = torch.zeros((max_seq_len,), dtype=torch.long)
                attention_mask[: len(current_text)] = 1
                all_attention_masks.append(attention_mask)

                # Make sure to reset the current text and images.
                current_images = []
                current_text = []
                if len(current_lens) == 0:
                    break
        else:
            # A case where we might not have any samples left over after the initial filtering step.
            if len(current_lens) == 0:
                break
            image, sample_input_ids = mapping_by_len[current_lens[-1]].pop(0)
            if len(mapping_by_len[current_lens[-1]]) == 0:
                del mapping_by_len[current_lens[-1]]
            current_text = sample_input_ids[:max_seq_len]
            if is_raw_images:
                current_images = [image_transform(image)]
            else:
                current_images = [torch.tensor(image)]

    if len(all_images) == 0 or len(all_texts) == 0:
        result = {
            "input_ids": torch.tensor([], dtype=torch.long),
            "attention_mask": torch.tensor([], dtype=torch.bool),
            "image_attention_mask": torch.tensor([], dtype=torch.bool),
            "num_images": torch.tensor([], dtype=torch.long),
            "num_text_tokens": torch.tensor([], dtype=torch.long),
        }
        if is_raw_images:
            result["pixel_values"] = torch.tensor([], dtype=torch.float32)
        else:
            result["image_embeddings"] = torch.tensor([], dtype=torch.float32)
        return result

    all_texts = torch.stack(all_texts)
    all_images = torch.stack(all_images)
    all_attention_masks = torch.stack(all_attention_masks)

    image_attention_mask, next_image_attention_mask = image_attention_mask_for_packed_input_ids(all_texts, tokenizer)
    image_attention_mask = incremental_to_binary_attention_mask(image_attention_mask, num_classes=max_num_images)
    next_image_attention_mask = incremental_to_binary_attention_mask(
        next_image_attention_mask, num_classes=max_num_images
    )

    output = {
        "input_ids": all_texts,
        "attention_mask": all_attention_masks,
        "image_attention_mask": image_attention_mask,
        "num_images": torch.tensor(all_num_images),
        "num_text_tokens": torch.tensor(all_num_text_tokens),
    }
    if is_raw_images:
        output["pixel_values"] = all_images
    else:
        output["image_embeddings"] = all_images

    if is_image_at_end:
        # Set the correct attention mask based on whether the image is at the start
        # or not. When it is at the end, we need next image attention mask.
        output["image_attention_mask"] = next_image_attention_mask

    return output


# Copied from https://github.com/google-research/text-to-text-transfer-transformer/blob/main/t5/data/preprocessors.py
def random_spans_helper(
    inputs_length,
    noise_density,
    mean_noise_span_length,
    extra_tokens_per_span_inputs,
    extra_tokens_per_span_targets,
    verbose=False,
):
    """Training parameters to avoid padding with random_spans_noise_mask.

    When training a model with random_spans_noise_mask, we would like to set the
    other training hyperparmeters in a way that avoids padding.  This function
    helps us compute these hyperparameters.

    We assume that each noise span in the input is replaced by
    extra_tokens_per_span_inputs sentinel tokens, and each non-noise span in the
    targets is replaced by extra_tokens_per_span_targets sentinel tokens.

    This function tells us the required number of tokens in the raw example (for
    split_tokens()) as well as the length of the encoded targets.

    Note that this function assumes the inputs and targets will have EOS appended
    and includes that in the reported length.

    Args:
      inputs_length: an integer - desired length of the tokenized inputs sequence
      noise_density: a float
      mean_noise_span_length: a float
      extra_tokens_per_span_inputs: an integer
      extra_tokens_per_span_targets: an integer
      verbose: a bool indicating whether to log sequence lengths
    Returns:
      tokens_length: length of original text in tokens
      targets_length: an integer - length in tokens of encoded targets sequence
    """

    if extra_tokens_per_span_inputs != 1:
        raise NotImplementedError(
            "extra_tokens_per_span_inputs != 1 not supported yet. You need to check"
            " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
        )
    if extra_tokens_per_span_targets != 1:
        raise NotImplementedError(
            "extra_tokens_per_span_targets != 1 not supported yet. You need to check"
            " `get_model_tflops_per_batch_per_gpu` of `VT5ForConditionalGeneration` if you change it."
        )

    def _tokens_length_to_inputs_length_targets_length(tokens_length):
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
        # inputs contain all nonnoise tokens, sentinels for all noise spans
        # and one EOS token.
        return (
            num_nonnoise_tokens + num_noise_spans * extra_tokens_per_span_inputs + 1,
            num_noise_tokens + num_noise_spans * extra_tokens_per_span_targets + 1,
        )

    tokens_length = inputs_length - 1
    while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
        tokens_length += 1
    inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
    # minor hack to get the targets length to be equal to inputs length
    # which is more likely to have been set to a nice round number.
    if noise_density == 0.5 and targets_length > inputs_length:
        tokens_length -= 1
        targets_length -= 1
    if verbose:
        logging.info(
            "tokens_length=%s inputs_length=%s targets_length=%s noise_density=%s mean_noise_span_length=%s ",
            tokens_length,
            inputs_length,
            targets_length,
            noise_density,
            mean_noise_span_length,
        )
    return tokens_length, targets_length