import os.path as osp
import random

import cv2
import decord
import numpy as np
import skvideo.io
import torch
import torchvision
from decord import VideoReader, cpu, gpu
from tqdm import tqdm

random.seed(42)

decord.bridge.set_bridge("torch")


def get_spatial_fragments(
    video,
    fragments_h=7,
    fragments_w=7,
    fsize_h=32,
    fsize_w=32,
    aligned=32,
    nfrags=1,
    random=False,
    fallback_type="upsample",
):
    size_h = fragments_h * fsize_h
    size_w = fragments_w * fsize_w

    ## situation for images
    if video.shape[1] == 1:
        aligned = 1

    dur_t, res_h, res_w = video.shape[-3:]
    ratio = min(res_h / size_h, res_w / size_w)
    if fallback_type == "upsample" and ratio < 1:

        ovideo = video
        video = torch.nn.functional.interpolate(
            video / 255.0, scale_factor=1 / ratio, mode="bilinear"
        )
        video = (video * 255.0).type_as(ovideo)

    assert dur_t % aligned == 0, "Please provide match vclip and align index"
    size = size_h, size_w

    ## make sure that sampling will not run out of the picture
    hgrids = torch.LongTensor(
        [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)]
    )
    wgrids = torch.LongTensor(
        [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)]
    )
    hlength, wlength = res_h // fragments_h, res_w // fragments_w

    if random:
        print("This part is deprecated. Please remind that.")
        if res_h > fsize_h:
            rnd_h = torch.randint(
                res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
            )
        else:
            rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
        if res_w > fsize_w:
            rnd_w = torch.randint(
                res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
            )
        else:
            rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
    else:
        if hlength > fsize_h:
            rnd_h = torch.randint(
                hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
            )
        else:
            rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
        if wlength > fsize_w:
            rnd_w = torch.randint(
                wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
            )
        else:
            rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()

    target_video = torch.zeros(video.shape[:-2] + size).to(video.device)
    # target_videos = []

    for i, hs in enumerate(hgrids):
        for j, ws in enumerate(wgrids):
            for t in range(dur_t // aligned):
                t_s, t_e = t * aligned, (t + 1) * aligned
                h_s, h_e = i * fsize_h, (i + 1) * fsize_h
                w_s, w_e = j * fsize_w, (j + 1) * fsize_w
                if random:
                    h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
                    w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
                else:
                    h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
                    w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
                target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[
                    :, t_s:t_e, h_so:h_eo, w_so:w_eo
                ]
    # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo])
    # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6)
    # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments
    return target_video


class FragmentSampleFrames:
    def __init__(self, fsize_t, fragments_t, frame_interval=1, num_clips=1):

        self.fragments_t = fragments_t
        self.fsize_t = fsize_t
        self.size_t = fragments_t * fsize_t
        self.frame_interval = frame_interval
        self.num_clips = num_clips

    def get_frame_indices(self, num_frames):

        tgrids = np.array(
            [num_frames // self.fragments_t * i for i in range(self.fragments_t)],
            dtype=np.int32,
        )
        tlength = num_frames // self.fragments_t

        if tlength > self.fsize_t * self.frame_interval:
            rnd_t = np.random.randint(
                0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids)
            )
        else:
            rnd_t = np.zeros(len(tgrids), dtype=np.int32)

        ranges_t = (
            np.arange(self.fsize_t)[None, :] * self.frame_interval
            + rnd_t[:, None]
            + tgrids[:, None]
        )
        return np.concatenate(ranges_t)

    def __call__(self, total_frames, train=False, start_index=0):
        frame_inds = []
        for i in range(self.num_clips):
            frame_inds += [self.get_frame_indices(total_frames)]
        frame_inds = np.concatenate(frame_inds)
        frame_inds = np.mod(frame_inds + start_index, total_frames)
        return frame_inds


class SampleFrames:
    def __init__(self, clip_len, frame_interval=1, num_clips=1):

        self.clip_len = clip_len
        self.frame_interval = frame_interval
        self.num_clips = num_clips

    def _get_train_clips(self, num_frames):
        """Get clip offsets in train mode.

        It will calculate the average interval for selected frames,
        and randomly shift them within offsets between [0, avg_interval].
        If the total number of frames is smaller than clips num or origin
        frames length, it will return all zero indices.

        Args:
            num_frames (int): Total number of frame in the video.

        Returns:
            np.ndarray: Sampled frame indices in train mode.
        """
        ori_clip_len = self.clip_len * self.frame_interval
        avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips

        if avg_interval > 0:
            base_offsets = np.arange(self.num_clips) * avg_interval
            clip_offsets = base_offsets + np.random.randint(
                avg_interval, size=self.num_clips
            )
        elif num_frames > max(self.num_clips, ori_clip_len):
            clip_offsets = np.sort(
                np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips)
            )
        elif avg_interval == 0:
            ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips
            clip_offsets = np.around(np.arange(self.num_clips) * ratio)
        else:
            clip_offsets = np.zeros((self.num_clips,), dtype=np.int)
        return clip_offsets

    def _get_test_clips(self, num_frames, start_index=0):
        """Get clip offsets in test mode.

        Calculate the average interval for selected frames, and shift them
        fixedly by avg_interval/2.

        Args:
            num_frames (int): Total number of frame in the video.

        Returns:
            np.ndarray: Sampled frame indices in test mode.
        """
        ori_clip_len = self.clip_len * self.frame_interval
        avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)
        if num_frames > ori_clip_len - 1:
            base_offsets = np.arange(self.num_clips) * avg_interval
            clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32)
        else:
            clip_offsets = np.zeros((self.num_clips,), dtype=np.int32)
        return clip_offsets

    def __call__(self, total_frames, train=False, start_index=0):
        """Perform the SampleFrames loading.

        Args:
            results (dict): The resulting dict to be modified and passed
                to the next transform in pipeline.
        """
        if train:
            clip_offsets = self._get_train_clips(total_frames)
        else:
            clip_offsets = self._get_test_clips(total_frames)
        frame_inds = (
            clip_offsets[:, None]
            + np.arange(self.clip_len)[None, :] * self.frame_interval
        )
        frame_inds = np.concatenate(frame_inds)

        frame_inds = frame_inds.reshape((-1, self.clip_len))
        frame_inds = np.mod(frame_inds, total_frames)
        frame_inds = np.concatenate(frame_inds) + start_index
        return frame_inds.astype(np.int32)


class FastVQAPlusPlusDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        frame_interval=2,
        aligned=32,
        fragments=(8, 8, 8),
        fsize=(4, 32, 32),
        num_clips=1,
        nfrags=1,
        cache_in_memory=False,
        phase="test",
        fallback_type="oversample",
    ):
        """
        Fragments.
        args:
            fragments: G_f as in the paper.
            fsize: S_f as in the paper.
            nfrags: number of samples (spatially) as in the paper.
            num_clips: number of samples (temporally) as in the paper.
        """
        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.frame_interval = frame_interval
        self.num_clips = num_clips
        self.fragments = fragments
        self.fsize = fsize
        self.nfrags = nfrags
        self.clip_len = fragments[0] * fsize[0]
        self.aligned = aligned
        self.fallback_type = fallback_type
        self.sampler = FragmentSampleFrames(
            fsize[0], fragments[0], frame_interval, num_clips
        )
        self.video_infos = []
        self.phase = phase
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        if isinstance(self.ann_file, list):
            self.video_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.video_infos.append(dict(filename=filename, label=label))
        if cache_in_memory:
            self.cache = {}
            for i in tqdm(range(len(self)), desc="Caching fragments"):
                self.cache[i] = self.__getitem__(i, tocache=True)
        else:
            self.cache = None

    def __getitem__(
        self, index, tocache=False, need_original_frames=False,
    ):
        if tocache or self.cache is None:
            fx, fy = self.fragments[1:]
            fsx, fsy = self.fsize[1:]
            video_info = self.video_infos[index]
            filename = video_info["filename"]
            label = video_info["label"]
            if filename.endswith(".yuv"):
                video = skvideo.io.vread(
                    filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
                )
                frame_inds = self.sampler(video.shape[0], self.phase == "train")
                imgs = [torch.from_numpy(video[idx]) for idx in frame_inds]
            else:
                vreader = VideoReader(filename)
                frame_inds = self.sampler(len(vreader), self.phase == "train")
                frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
                imgs = [frame_dict[idx] for idx in frame_inds]
            img_shape = imgs[0].shape
            video = torch.stack(imgs, 0)
            video = video.permute(3, 0, 1, 2)
            if self.nfrags == 1:
                vfrag = get_spatial_fragments(
                    video,
                    fx,
                    fy,
                    fsx,
                    fsy,
                    aligned=self.aligned,
                    fallback_type=self.fallback_type,
                )
            else:
                vfrag = get_spatial_fragments(
                    video,
                    fx,
                    fy,
                    fsx,
                    fsy,
                    aligned=self.aligned,
                    fallback_type=self.fallback_type,
                )
                for i in range(1, self.nfrags):
                    vfrag = torch.cat(
                        (
                            vfrag,
                            get_spatial_fragments(
                                video,
                                fragments,
                                fx,
                                fy,
                                fsx,
                                fsy,
                                aligned=self.aligned,
                                fallback_type=self.fallback_type,
                            ),
                        ),
                        1,
                    )
            if tocache:
                return (vfrag, frame_inds, label, img_shape)
        else:
            vfrag, frame_inds, label, img_shape = self.cache[index]
        vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
        data = {
            "video": vfrag.reshape(
                (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:]
            ).transpose(
                0, 1
            ),  # B, V, T, C, H, W
            "frame_inds": frame_inds,
            "gt_label": label,
            "original_shape": img_shape,
        }
        if need_original_frames:
            data["original_video"] = video.reshape(
                (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
            ).transpose(0, 1)
        return data

    def __len__(self):
        return len(self.video_infos)


class FragmentVideoDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        clip_len=32,
        frame_interval=2,
        num_clips=4,
        aligned=32,
        fragments=7,
        fsize=32,
        nfrags=1,
        cache_in_memory=False,
        phase="test",
    ):
        """
        Fragments.
        args:
            fragments: G_f as in the paper.
            fsize: S_f as in the paper.
            nfrags: number of samples as in the paper.
        """
        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.clip_len = clip_len
        self.frame_interval = frame_interval
        self.num_clips = num_clips
        self.fragments = fragments
        self.fsize = fsize
        self.nfrags = nfrags
        self.aligned = aligned
        self.sampler = SampleFrames(clip_len, frame_interval, num_clips)
        self.video_infos = []
        self.phase = phase
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        if isinstance(self.ann_file, list):
            self.video_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.video_infos.append(dict(filename=filename, label=label))
        if cache_in_memory:
            self.cache = {}
            for i in tqdm(range(len(self)), desc="Caching fragments"):
                self.cache[i] = self.__getitem__(i, tocache=True)
        else:
            self.cache = None

    def __getitem__(
        self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False,
    ):
        if tocache or self.cache is None:
            if fragments == -1:
                fragments = self.fragments
            if fsize == -1:
                fsize = self.fsize
            video_info = self.video_infos[index]
            filename = video_info["filename"]
            label = video_info["label"]
            if filename.endswith(".yuv"):
                video = skvideo.io.vread(
                    filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
                )
                frame_inds = self.sampler(video.shape[0], self.phase == "train")
                imgs = [torch.from_numpy(video[idx]) for idx in frame_inds]
            else:
                vreader = VideoReader(filename)
                frame_inds = self.sampler(len(vreader), self.phase == "train")
                frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
                imgs = [frame_dict[idx] for idx in frame_inds]
            img_shape = imgs[0].shape
            video = torch.stack(imgs, 0)
            video = video.permute(3, 0, 1, 2)
            if self.nfrags == 1:
                vfrag = get_spatial_fragments(
                    video, fragments, fragments, fsize, fsize, aligned=self.aligned
                )
            else:
                vfrag = get_spatial_fragments(
                    video, fragments, fragments, fsize, fsize, aligned=self.aligned
                )
                for i in range(1, self.nfrags):
                    vfrag = torch.cat(
                        (
                            vfrag,
                            get_spatial_fragments(
                                video,
                                fragments,
                                fragments,
                                fsize,
                                fsize,
                                aligned=self.aligned,
                            ),
                        ),
                        1,
                    )
            if tocache:
                return (vfrag, frame_inds, label, img_shape)
        else:
            vfrag, frame_inds, label, img_shape = self.cache[index]
        vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
        data = {
            "video": vfrag.reshape(
                (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:]
            ).transpose(
                0, 1
            ),  # B, V, T, C, H, W
            "frame_inds": frame_inds,
            "gt_label": label,
            "original_shape": img_shape,
        }
        if need_original_frames:
            data["original_video"] = video.reshape(
                (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
            ).transpose(0, 1)
        return data

    def __len__(self):
        return len(self.video_infos)


class ResizedVideoDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        clip_len=32,
        frame_interval=2,
        num_clips=4,
        aligned=32,
        size=224,
        cache_in_memory=False,
        phase="test",
    ):
        """
        Using resizing.
        """
        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.clip_len = clip_len
        self.frame_interval = frame_interval
        self.num_clips = num_clips
        self.size = size
        self.aligned = aligned
        self.sampler = SampleFrames(clip_len, frame_interval, num_clips)
        self.video_infos = []
        self.phase = phase
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        if isinstance(self.ann_file, list):
            self.video_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.video_infos.append(dict(filename=filename, label=label))
        if cache_in_memory:
            self.cache = {}
            for i in tqdm(range(len(self)), desc="Caching resized videos"):
                self.cache[i] = self.__getitem__(i, tocache=True)
        else:
            self.cache = None

    def __getitem__(self, index, tocache=False, need_original_frames=False):
        if tocache or self.cache is None:
            video_info = self.video_infos[index]
            filename = video_info["filename"]
            label = video_info["label"]
            vreader = VideoReader(filename)
            frame_inds = self.sampler(len(vreader), self.phase == "train")
            frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
            imgs = [frame_dict[idx] for idx in frame_inds]
            img_shape = imgs[0].shape
            video = torch.stack(imgs, 0)
            video = video.permute(3, 0, 1, 2)
            video = torch.nn.functional.interpolate(video, size=(self.size, self.size))
            if tocache:
                return (vfrag, frame_inds, label, img_shape)
        else:
            vfrag, frame_inds, label, img_shape = self.cache[index]
        vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
        data = {
            "video": vfrag.reshape(
                (-1, self.num_clips, self.clip_len) + vfrag.shape[2:]
            ).transpose(
                0, 1
            ),  # B, V, T, C, H, W
            "frame_inds": frame_inds,
            "gt_label": label,
            "original_shape": img_shape,
        }
        if need_original_frames:
            data["original_video"] = video.reshape(
                (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
            ).transpose(0, 1)
        return data

    def __len__(self):
        return len(self.video_infos)


class CroppedVideoDataset(FragmentVideoDataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        clip_len=32,
        frame_interval=2,
        num_clips=4,
        aligned=32,
        size=224,
        ncrops=1,
        cache_in_memory=False,
        phase="test",
    ):

        """
        Regard Cropping as a special case for Fragments in Grid 1*1.
        """
        super().__init__(
            ann_file,
            data_prefix,
            clip_len=clip_len,
            frame_interval=frame_interval,
            num_clips=num_clips,
            aligned=aligned,
            fragments=1,
            fsize=224,
            nfrags=ncrops,
            cache_in_memory=cache_in_memory,
            phase=phase,
        )


class FragmentImageDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        fragments=7,
        fsize=32,
        nfrags=1,
        cache_in_memory=False,
        phase="test",
    ):
        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.fragments = fragments
        self.fsize = fsize
        self.nfrags = nfrags
        self.image_infos = []
        self.phase = phase
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        if isinstance(self.ann_file, list):
            self.image_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.image_infos.append(dict(filename=filename, label=label))
        if cache_in_memory:
            self.cache = {}
            for i in tqdm(range(len(self)), desc="Caching fragments"):
                self.cache[i] = self.__getitem__(i, tocache=True)
        else:
            self.cache = None

    def __getitem__(
        self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False
    ):
        if tocache or self.cache is None:
            if fragments == -1:
                fragments = self.fragments
            if fsize == -1:
                fsize = self.fsize
            image_info = self.image_infos[index]
            filename = image_info["filename"]
            label = image_info["label"]
            try:
                img = torchvision.io.read_image(filename)
            except:
                img = cv2.imread(filename)
                img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1)
            img_shape = img.shape[1:]
            image = img.unsqueeze(1)
            if self.nfrags == 1:
                ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize)
            else:
                ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize)
                for i in range(1, self.nfrags):
                    ifrag = torch.cat(
                        (
                            ifrag,
                            get_spatial_fragments(
                                image, fragments, fragments, fsize, fsize
                            ),
                        ),
                        1,
                    )
            if tocache:
                return (ifrag, label, img_shape)
        else:
            ifrag, label, img_shape = self.cache[index]
        if self.nfrags == 1:
            ifrag = (
                ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
                .squeeze(0)
                .permute(2, 0, 1)
            )
        else:
            ### During testing, one image as a batch
            ifrag = (
                ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
                .squeeze(0)
                .permute(0, 3, 1, 2)
            )
        data = {
            "image": ifrag,
            "gt_label": label,
            "original_shape": img_shape,
            "name": filename,
        }
        if need_original_frames:
            data["original_image"] = image.squeeze(1)
        return data

    def __len__(self):
        return len(self.image_infos)


class ResizedImageDataset(torch.utils.data.Dataset):
    def __init__(
        self, ann_file, data_prefix, size=224, cache_in_memory=False, phase="test",
    ):
        self.ann_file = ann_file
        self.data_prefix = data_prefix
        self.size = size
        self.image_infos = []
        self.phase = phase
        self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
        self.std = torch.FloatTensor([58.395, 57.12, 57.375])
        if isinstance(self.ann_file, list):
            self.image_infos = self.ann_file
        else:
            with open(self.ann_file, "r") as fin:
                for line in fin:
                    line_split = line.strip().split(",")
                    filename, _, _, label = line_split
                    label = float(label)
                    filename = osp.join(self.data_prefix, filename)
                    self.image_infos.append(dict(filename=filename, label=label))
        if cache_in_memory:
            self.cache = {}
            for i in tqdm(range(len(self)), desc="Caching fragments"):
                self.cache[i] = self.__getitem__(i, tocache=True)
        else:
            self.cache = None

    def __getitem__(
        self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False
    ):
        if tocache or self.cache is None:
            if fragments == -1:
                fragments = self.fragments
            if fsize == -1:
                fsize = self.fsize
            image_info = self.image_infos[index]
            filename = image_info["filename"]
            label = image_info["label"]
            img = torchvision.io.read_image(filename)
            img_shape = img.shape[1:]
            image = img.unsqueeze(1)
            if self.nfrags == 1:
                ifrag = get_spatial_fragments(image, fragments, fsize)
            else:
                ifrag = get_spatial_fragments(image, fragments, fsize)
                for i in range(1, self.nfrags):
                    ifrag = torch.cat(
                        (ifrag, get_spatial_fragments(image, fragments, fsize)), 1
                    )
            if tocache:
                return (ifrag, label, img_shape)
        else:
            ifrag, label, img_shape = self.cache[index]
        ifrag = (
            ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
            .squeeze(0)
            .permute(2, 0, 1)
        )
        data = {
            "image": ifrag,
            "gt_label": label,
            "original_shape": img_shape,
        }
        if need_original_frames:
            data["original_image"] = image.squeeze(1)
        return data

    def __len__(self):
        return len(self.image_infos)


class CroppedImageDataset(FragmentImageDataset):
    def __init__(
        self,
        ann_file,
        data_prefix,
        size=224,
        ncrops=1,
        cache_in_memory=False,
        phase="test",
    ):

        """
        Regard Cropping as a special case for Fragments in Grid 1*1.
        """
        super().__init__(
            ann_file,
            data_prefix,
            fragments=1,
            fsize=224,
            nfrags=ncrops,
            cache_in_memory=cache_in_memory,
            phase=phase,
        )