from typing import Optional, Union, Dict, Any

import torch
import math
import PIL.Image
import PIL.ImageSequence
import numpy as np
import PIL
from PIL import Image

from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers import AutoImageProcessor
from transformers.image_transforms import to_channel_dimension_format
from transformers.image_utils import (
    ImageInput, 
    make_list_of_images, 
    valid_images, 
    is_torch_tensor, 
    to_numpy_array, 
    infer_channel_dimension_format,
    ChannelDimension
)


def recursive_converter(converter, value):
    if isinstance(value, list):
        new_value = []
        for v in value:
            new_value += [recursive_converter(converter, v)]
        return new_value
    else:
        return converter(value)


class MiniCPMVBatchFeature(BatchFeature):
    r"""
    Extend from BatchFeature for supporting various image size
    """
    def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
        super().__init__(data)
        self.convert_to_tensors(tensor_type=tensor_type)

    def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
        if tensor_type is None:
            return self
        
        is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)

        def converter(value):
            try:
                if not is_tensor(value):
                    tensor = as_tensor(value)
                    return tensor
            except:  # noqa E722
                if key == "overflowing_values":
                    raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
                raise ValueError(
                    "Unable to create tensor, you should probably activate padding "
                    "with 'padding=True' to have batched tensors with the same length."
                )


        for key, value in self.items():
            self[key] = recursive_converter(converter, value)
        return self
            
    def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature":
        requires_backends(self, ["torch"])
        import torch

        def cast_tensor(v):
            # check if v is a floating point
            if torch.is_floating_point(v):
                # cast and send to device
                return v.to(*args, **kwargs)
            elif device is not None:
                return v.to(device=device)
            else:
                return v

        new_data = {}
        device = kwargs.get("device")
        # Check if the args are a device or a dtype
        if device is None and len(args) > 0:
            # device should be always the first argument
            arg = args[0]
            if is_torch_dtype(arg):
                # The first argument is a dtype
                pass
            elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
                device = arg
            else:
                # it's something else
                raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
        # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
        for k, v in self.items():
            new_data[k] = recursive_converter(cast_tensor, v)
        self.data = new_data
        return self


class MiniCPMVImageProcessor(BaseImageProcessor):
    model_input_names = ["pixel_values"]

    def __init__(
            self, 
            max_slice_nums=9,
            scale_resolution=448,
            patch_size=14,
            **kwargs):
        super().__init__(**kwargs)
        self.max_slice_nums = max_slice_nums
        self.scale_resolution = scale_resolution
        self.patch_size = patch_size
        self.image_feature_size = kwargs.pop("image_feature_size", 64)
        self.im_start_token = kwargs.pop("im_start", "<image>")
        self.im_end_token = kwargs.pop("im_end", "</image>")
        self.slice_start_token = kwargs.pop("slice_start", "<slice>")
        self.slice_end_token = kwargs.pop("slice_end", "</slice>")
        self.unk_token = kwargs.pop("unk", "<unk>")
        self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5]))
        self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5]))
        self.version = kwargs.pop("version", 2.0)

    def ensure_divide(self, length, patch_size):
        return max(round(length / patch_size) * patch_size, patch_size)

    def find_best_resize(self,
                         original_size,
                         scale_resolution,
                         patch_size,
                         allow_upscale=False):
        width, height = original_size
        if (width * height >
                scale_resolution * scale_resolution) or allow_upscale:
            r = width / height
            height = int(scale_resolution / math.sqrt(r))
            width = int(height * r)
        best_width = self.ensure_divide(width, patch_size)
        best_height = self.ensure_divide(height, patch_size)
        return (best_width, best_height)

    def get_refine_size(self,
                        original_size,
                        grid,
                        scale_resolution,
                        patch_size,
                        allow_upscale=False):
        width, height = original_size
        grid_x, grid_y = grid

        refine_width = self.ensure_divide(width, grid_x)
        refine_height = self.ensure_divide(height, grid_y)

        grid_width = refine_width / grid_x
        grid_height = refine_height / grid_y

        best_grid_size = self.find_best_resize((grid_width, grid_height),
                                               scale_resolution,
                                               patch_size,
                                               allow_upscale=allow_upscale)
        refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
        return refine_size

    def split_to_patches(self, image, grid):
        patches = []
        width, height = image.size
        grid_x = int(width / grid[0])
        grid_y = int(height / grid[1])
        for i in range(0, height, grid_y):
            images = []
            for j in range(0, width, grid_x):
                box = (j, i, j + grid_x, i + grid_y)
                patch = image.crop(box)
                images.append(patch)
            patches.append(images)
        return patches

    def slice_image(
        self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
    ):
        original_size = image.size
        original_width, original_height = original_size
        log_ratio = math.log(original_width / original_height)
        ratio = original_width * original_height / (scale_resolution * scale_resolution)
        multiple = min(math.ceil(ratio), max_slice_nums)

        source_image = None
        best_grid = None
        patches = []

        if multiple <= 1 or never_split:
            # dont need to slice, upsample
            best_size = self.find_best_resize(
                original_size, scale_resolution, patch_size, allow_upscale=True
            )
            source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC)
        else:
            candidate_split_grids_nums = []
            for i in [multiple - 1, multiple, multiple + 1]:
                if i == 1 or i > max_slice_nums:
                    continue
                candidate_split_grids_nums.append(i)

            # source image, down-sampling and ensure divided by patch_size
            best_resize = self.find_best_resize(original_size, scale_resolution, patch_size)
            source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC)
            candidate_grids = []

            # find best grid
            for split_grids_nums in candidate_split_grids_nums:
                m = 1
                while m <= split_grids_nums:
                    if split_grids_nums % m == 0:
                        candidate_grids.append([m, split_grids_nums // m])
                    m += 1

            best_grid = [1, 1]
            min_error = float("inf")
            for grid in candidate_grids:
                error = abs(log_ratio - math.log(grid[0] / grid[1]))
                if error < min_error:
                    best_grid = grid
                    min_error = error

            refine_size = self.get_refine_size(
                original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
            )

            refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC)
            patches = self.split_to_patches(refine_image, best_grid)

        return source_image, patches, best_grid

    def get_grid_placeholder(self, grid):
        if grid is None:
            return ""
        image_placeholder = (
            self.im_start_token 
            + self.unk_token * self.image_feature_size
            + self.im_end_token
        )

        cols = grid[0]
        rows = grid[1]
        slices = []
        for i in range(rows):
            lines = []
            for j in range(cols):
                lines.append(image_placeholder)
            slices.append("".join(lines))
            
        slice_placeholder = self.slice_start_token + "\n".join(slices) + self.slice_end_token
        return slice_placeholder

    def get_sliced_images(self, image):
        slice_images = []

        source_image, patches, sliced_grid = self.slice_image(
            image,
            self.max_slice_nums,  # default: 9
            self.scale_resolution,  # default: 448
            self.patch_size  # default: 14
        )
        slice_images.append(source_image)

        if len(patches) > 0:
            for i in range(len(patches)):
                for j in range(len(patches[0])):
                    slice_images.append(patches[i][j])
        return slice_images

    def get_sliced_grid(self, image_size):
        original_width, original_height = image_size
        log_ratio = math.log(original_width / original_height)
        ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution)
        multiple = min(math.ceil(ratio), self.max_slice_nums)
        if multiple <= 1:
            return None
        candidate_split_grids_nums = []
        for i in [multiple - 1, multiple, multiple + 1]:
            if i == 1 or i > self.max_slice_nums:
                continue
            candidate_split_grids_nums.append(i)
        
        candidate_grids = []
        for split_grids_nums in candidate_split_grids_nums:
            m = 1
            while m <= split_grids_nums:
                if split_grids_nums % m == 0:
                    candidate_grids.append([m, split_grids_nums // m])
                m += 1

        best_grid = [1, 1]
        min_error = float("inf")
        for grid in candidate_grids:
            error = abs(log_ratio - math.log(grid[0] / grid[1]))
            if error < min_error:
                best_grid = grid
                min_error = error
        
        return best_grid

    def get_slice_image_placeholder(self, image_size):
        grid = self.get_sliced_grid(image_size=image_size)
        return (
            self.im_start_token 
            + self.unk_token * self.image_feature_size 
            + self.im_end_token
        ) + self.get_grid_placeholder(grid=grid)

    def to_pil_image(self, image, rescale=None) -> PIL.Image.Image:
        """
        Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
        needed.

        Args:
            image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
                The image to convert to the PIL Image format.
            rescale (`bool`, *optional*):
                Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
                default to `True` if the image type is a floating type, `False` otherwise.
        """
        if isinstance(image, PIL.Image.Image):
            return image
        if is_torch_tensor(image):
            image = image.numpy()

        if isinstance(image, np.ndarray):
            if rescale is None:
                # rescale default to the array being of floating type.
                rescale = isinstance(image.flat[0], np.floating)
            # If the channel as been moved to first dim, we put it back at the end.
            if image.ndim == 3 and image.shape[0] in [1, 3]:
                image = image.transpose(1, 2, 0)
            if rescale:
                image = image * 255
            image = image.astype(np.uint8)
            return PIL.Image.fromarray(image)
        return image

    def reshape_by_patch(self, image):
        """
        :param image: shape [3, H, W]
        :param patch_size:
        :return: [3, patch_size, HW/patch_size]
        """
        image = torch.from_numpy(image)
        patch_size = self.patch_size
        patches = torch.nn.functional.unfold(
            image,
            (patch_size, patch_size),
            stride=(patch_size, patch_size)
        )

        patches = patches.reshape(image.size(0), patch_size, patch_size, -1)
        patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1)
        return patches.numpy()

    def preprocess(
            self, 
            images: ImageInput,
            do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5
            return_tensors: Optional[Union[str, TensorType]] = None
        ) -> MiniCPMVBatchFeature:
        images = make_list_of_images(images)

        if not valid_images(images):
            raise ValueError(
                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )
        
        images = [self.to_pil_image(image).convert("RGB") for image in images]
        input_data_format = infer_channel_dimension_format(np.array(images[0]))

        new_images = []
        image_sizes = [image.size for image in images]
        tgt_sizes = []
        for image in images:
            image_patches = self.get_sliced_images(image)
            image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
            image_patches = [
                self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
                    for image in image_patches
            ]
            image_patches = [
                to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) 
                    for image in image_patches
            ]
            for slice_image in image_patches:
                new_images.append(self.reshape_by_patch(slice_image))
                tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size)))

        if tgt_sizes:
            tgt_sizes = np.vstack(tgt_sizes)
        return MiniCPMVBatchFeature(
            data={"pixel_values": [new_images], "image_sizes": [image_sizes], "tgt_sizes": [tgt_sizes]}, tensor_type=return_tensors
        )

AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor)