# Copyright 2024 Bingxin Ke, ETH Zurich. All rights reserved.
# Last modified: 2024-11-28
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/RollingDepth#-citation
# More information about the method can be found at https://rollingdepth.github.io
# ---------------------------------------------------------------------------------

import argparse
import logging
import os
from pathlib import Path

import numpy as np
import torch
from tqdm.auto import tqdm
import einops
from omegaconf import OmegaConf

from rollingdepth import (
    RollingDepthOutput,
    RollingDepthPipeline,
    write_video_from_numpy,
    get_video_fps,
    concatenate_videos_horizontally_torch,
)
from src.util.colorize import colorize_depth_multi_thread
from src.util.config import str2bool

if "__main__" == __name__:
    logging.basicConfig(level=logging.INFO)

    # -------------------- Arguments --------------------
    parser = argparse.ArgumentParser(
        description="Run video depth estimation using RollingDepth."
    )
    parser.add_argument(
        "-i",
        "--input-video",
        type=str,
        required=True,
        help=(
            "Path to the input video(s) to be processed. Accepts: "
            "- Single video file path (e.g., 'video.mp4') "
            "- Text file containing a list of video paths (one per line) "
            "- Directory path containing video files "
            "Required argument."
        ),
        dest="input_video",
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        type=str,
        required=True,
        help=(
            "Directory path where processed outputs will be saved. "
            "Will be created if it doesn't exist. "
            "Required argument."
        ),
        dest="output_dir",
    )
    parser.add_argument(
        "-p",
        "--preset",
        type=str,
        choices=["fast", "fast1024", "full", "paper", "none"],
        help="Inference preset. TODO: write detailed explanation",
    )
    parser.add_argument(
        "--start-frame",
        "--from",
        type=int,
        default=0,
        help=(
            "Specifies the starting frame index for processing. "
            "Use 0 to start from the beginning of the video. "
            "Default: 0"
        ),
        dest="start_frame",
    )
    parser.add_argument(
        "--frame-count",
        "--frames",
        type=int,
        default=0,
        help=(
            "Number of frames to process after the starting frame. "
            "Set to 0 to process until the end of the video. "
            "Default: 0 (process all frames)"
        ),
        dest="frame_count",
    )

    parser.add_argument(
        "-c",
        "--checkpoint",
        type=str,
        default="prs-eth/rollingdepth-v1-0",
        help=(
            "Path to the model checkpoint to use for inference. Can be either: "
            "- A local path to checkpoint files "
            "- A Hugging Face model hub identifier (e.g., 'prs-eth/rollingdepth-v1-0') "
            "Default: 'prs-eth/rollingdepth-v1-0'"
        ),
        dest="checkpoint",
    )
    parser.add_argument(
        "--res",
        "--processing-resolution",
        type=int,
        default=None,
        help=(
            "Specifies the maximum resolution (in pixels) at which image processing will be performed. "
            "If set to None, uses the preset configuration value. "
            "If set to 0, processes at the original input image resolution. "
            "Default: None"
        ),
        dest="res",
    )
    parser.add_argument(
        "--max-vae-bs",
        type=int,
        default=4,
        help=(
            "Maximum batch size for the Variational Autoencoder (VAE) processing. "
            "Higher values increase memory usage but may improve processing speed. "
            "Reduce this value if encountering out-of-memory errors. "
            "Default: 4"
        ),
    )

    # Output settings
    parser.add_argument(
        "--fps",
        "--output-fps",
        type=int,
        default=0,
        help=(
            "Frame rate (FPS) for the output video. "
            "Set to 0 to match the input video's frame rate. "
            "Default: 0"
        ),
        dest="output_fps",
    )
    parser.add_argument(
        "--restore-resolution",
        "--restore-res",
        type=str2bool,
        nargs="?",
        default=False,
        help=(
            "Whether to restore the output to the original input resolution after processing. "
            "Only applies when input has been resized during processing. "
            "Default: False"
        ),
        dest="restore_res",
    )
    parser.add_argument(
        "--save-sbs" "--save-side-by-side",
        type=str2bool,
        nargs="?",
        default=True,
        help=(
            "Whether to save RGB and colored depth videos side-by-side. "
            "If True, the first color map will be used. "
            "Default: True"
        ),
        dest="save_sbs",
    )
    parser.add_argument(
        "--save-npy",
        type=str2bool,
        nargs="?",
        default=True,
        help=(
            "Whether to save depth maps as NumPy (.npy) files. "
            "Enables further processing and analysis of raw depth data. "
            "Default: True"
        ),
    )
    parser.add_argument(
        "--save-snippets",
        type=str2bool,
        nargs="?",
        default=False,
        help=(
            "Whether to save visualization snippets of the depth estimation process. "
            "Useful for debugging and quality assessment. "
            "Default: False"
        ),
    )
    parser.add_argument(
        "--cmap",
        "--color-maps",
        type=str,
        nargs="+",
        default=["Spectral_r", "Greys_r"],
        help=(
            "One or more matplotlib color maps for depth visualization. "
            "Multiple maps can be specified for different visualization styles. "
            "Common options: 'Spectral_r', 'Greys_r', 'viridis', 'magma'. "
            "Use '' (empty string) to skip colorization. "
            "Default: ['Spectral_r', 'Greys_r']"
        ),
        dest="color_maps",
    )

    # Inference setting
    parser.add_argument(
        "-d",
        "--dilations",
        type=int,
        nargs="+",
        default=None,
        help=(
            "Spacing between frames for temporal analysis. "
            "Set to None to use preset configurations based on video length. "
            "Custom configurations: "
            "- [1, 10, 25]: Best accuracy, slower processing "
            "- [1, 25]: Balanced speed and accuracy "
            "- [1, 10]: For short videos (<78 frames) "
            "Default: None (auto-select based on video length)"
        ),
        dest="dilations",
    )
    parser.add_argument(
        "--cap-dilation",
        type=str2bool,
        default=None,
        help=(
            "Whether to automatically reduce dilation spacing for short videos. "
            "Set to None to use preset configuration. "
            "Enabling this prevents temporal windows from extending beyond video length. "
            "Default: None (automatically determined based on video length)"
        ),
        dest="cap_dilation",
    )
    parser.add_argument(
        "--dtype",
        "--data-type",
        type=str,
        choices=["fp16", "fp32", None],
        default=None,
        help=(
            "Specifies the floating-point precision for inference operations. "
            "Options: 'fp16' (16-bit), 'fp32' (32-bit), or None. "
            "If None, uses the preset configuration value. "
            "Lower precision (fp16) reduces memory usage but may affect accuracy. "
            "Default: None"
        ),
        dest="dtype",
    )
    parser.add_argument(
        "--snip-len",
        "--snippet-lengths",
        type=int,
        nargs="+",
        choices=[2, 3, 4],
        default=None,
        help=(
            "Number of consecutive frames to analyze in each temporal window. "
            "Set to None to use preset value (3). "
            "Can specify multiple values corresponding to different dilation rates. "
            "Example: '--dilations 1 25 --snippet-length 2 3' uses "
            "2 frames for dilation 1 and 3 frames for dilation 25. "
            "Allowed values: 2, 3, or 4 frames. "
            "Default: None"
        ),
        dest="snippet_lengths",
    )
    parser.add_argument(
        "--refine-step",
        type=int,
        default=None,
        help=(
            "Number of refinement iterations to improve depth estimation accuracy. "
            "Set to None to use preset configuration. "
            "Set to 0 to disable refinement. "
            "Higher values may improve accuracy but increase processing time. "
            "Default: None (uses 0, no refinement)"
        ),
        dest="refine_step",
    )
    parser.add_argument(
        "--refine-snippet-len",
        type=int,
        default=None,
        help=(
            "Length of text snippets used during the refinement phase. "
            "Specifies the number of sentences or segments to process at once. "
            "If not specified (None), system-defined preset values will be used. "
            "Default: None"
        ),
    )
    parser.add_argument(
        "--refine-start-dilation",
        type=int,
        default=None,
        help=(
            "Initial dilation factor for the coarse-to-fine refinement process. "
            "Controls the starting granularity of the refinement steps. "
            "Higher values result in larger initial search windows. "
            "If not specified (None), uses system default. "
            "Default: None"
        ),
    )

    # Other settings
    parser.add_argument(
        "--resample-method",
        type=str,
        choices=["BILINEAR", "NEAREST_EXACT", "BICUBIC"],
        default="BILINEAR",
        help="Resampling method used to resize images.",
    )
    parser.add_argument(
        "--unload-snippet",
        type=str2bool,
        default=False,
        help=(
            "Controls memory optimization by moving processed data snippets to CPU. "
            "When enabled, reduces GPU memory usage at the cost of slower processing. "
            "Useful for systems with limited GPU memory or large datasets. "
            "Default: False"
        ),
    )
    parser.add_argument(
        "--verbose",
        action="store_true",
        help=("Enable detailed progress and information reporting during processing. "),
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help=(
            "Random number generator seed for reproducibility (up to computational randomness). "
            "Using the same seed value will produce identical results across runs. "
            "If not specified (None), a random seed will be used. "
            "Default: None"
        ),
    )

    # -------------------- Config preset arguments --------------------
    input_args = parser.parse_args()

    args = OmegaConf.create(
        {
            "res": 768,
            "snippet_lengths": [3],
            "cap_dilation": True,
            "dtype": "fp16",
            "refine_snippet_len": 3,
            "refine_start_dilation": 6,
        }
    )
    preset_args_dict = {
        "fast": OmegaConf.create(
            {
                "dilations": [1, 25],
                "refine_step": 0,
            }
        ),
        "fasthr": OmegaConf.create(
            {
                "res": 1024,
                "dilations": [1, 25],
                "refine_step": 0,
            }
        ),
        "full": OmegaConf.create(
            {
                "res": 1024,
                "dilations": [1, 10, 25],
                "refine_step": 10,
            }
        ),
        "paper": OmegaConf.create(
            {
                "dilations": [1, 10, 25],
                "cap_dilation": False,
                "dtype": "fp32",
                "refine_step": 10,
            }
        ),
    }
    if "none" != input_args.preset:
        logging.info(f"Using preset: {input_args.preset}")
        args.update(preset_args_dict[input_args.preset])

    # Merge or overwrite arguments
    for key, value in vars(input_args).items():
        if key in args.keys():
            # overwrite if value is set and different from preset
            if value is not None and value != args[key]:
                logging.warning(f"Overwritting argument: {key} = {value}")
                args[key] = value
        else:
            # add argument
            args[key] = value
            # sanity check
            assert value is not None or key in ["seed"], f"Undefined argument: {key}"

    msg = f"arguments: {args}"
    if args.verbose:
        logging.info(msg)
    else:
        logging.debug(msg)

    # Argument check
    if args.save_sbs:
        assert (
            len(args.color_maps) > 0
        ), "No color map is given, can not save side-by-side videos."

    input_video = Path(args.input_video)
    output_dir = Path(args.output_dir)
    os.makedirs(output_dir, exist_ok=True)

    # -------------------- Device --------------------
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
        logging.warning("CUDA is not available. Running on CPU will be slow.")
    logging.info(f"device = {device}")

    # -------------------- Data --------------------
    if input_video.is_dir():
        input_video_ls = os.listdir(input_video)
        input_video_ls = [input_video.joinpath(v_name) for v_name in input_video_ls]
    elif ".txt" == input_video.suffix:
        with open(input_video, "r") as f:
            input_video_ls = f.readlines()
        input_video_ls = [Path(s.strip()) for s in input_video_ls]
    else:
        input_video_ls = [Path(input_video)]
    input_video_ls = sorted(input_video_ls)

    logging.info(f"Found {len(input_video_ls)} videos.")

    # -------------------- Model --------------------
    if "fp16" == args.dtype:
        dtype = torch.float16
    elif "fp32" == args.dtype:
        dtype = torch.float32
    else:
        raise ValueError(f"Unsupported dtype: {args.dtype}")

    pipe: RollingDepthPipeline = RollingDepthPipeline.from_pretrained(
        args.checkpoint, torch_dtype=dtype
    )  # type: ignore

    try:
        pipe.enable_xformers_memory_efficient_attention()
        logging.info("xformers enabled")
    except ImportError:
        logging.warning("Run without xformers")

    pipe = pipe.to(device)

    # -------------------- Inference and saving --------------------
    with torch.no_grad():
        if args.verbose:
            video_iterable = tqdm(input_video_ls, desc="Processing videos", leave=True)
        else:
            video_iterable = input_video_ls
        for video_path in video_iterable:
            # Random number generator
            if args.seed is None:
                generator = None
            else:
                generator = torch.Generator(device=device)
                generator.manual_seed(args.seed)

            # Predict depth
            pipe_out: RollingDepthOutput = pipe(
                # input setting
                input_video_path=video_path,
                start_frame=args.start_frame,
                frame_count=args.frame_count,
                processing_res=args.res,
                resample_method=args.resample_method,
                # infer setting
                dilations=list(args.dilations),
                cap_dilation=args.cap_dilation,
                snippet_lengths=list(args.snippet_lengths),
                init_infer_steps=[1],
                strides=[1],
                coalign_kwargs=None,
                refine_step=args.refine_step,
                refine_snippet_len=args.refine_snippet_len,
                refine_start_dilation=args.refine_start_dilation,
                # other settings
                generator=generator,
                verbose=args.verbose,
                max_vae_bs=args.max_vae_bs,
                # output settings
                restore_res=args.restore_res,
                unload_snippet=args.unload_snippet,
            )

            depth_pred = pipe_out.depth_pred  # [N 1 H W]

            os.makedirs(output_dir, exist_ok=True)

            # Save prediction as npy
            if args.save_npy:
                save_to = output_dir.joinpath(f"{video_path.stem}_pred.npy")
                if args.verbose:
                    logging.info(f"Saving predictions to {save_to}")
                np.save(save_to, depth_pred.numpy().squeeze(1))  # [N H W]

            # Save intermediate snippets
            if args.save_snippets and pipe_out.snippet_ls is not None:
                save_to = output_dir.joinpath(f"{video_path.stem}_snippets.npz")
                if args.verbose:
                    logging.info(f"Saving snippets to {save_to}")
                snippet_dict = {}
                for i_dil, snippets in enumerate(pipe_out.snippet_ls):
                    dilation = args.dilations[i_dil]
                    snippet_dict[f"dilation{dilation}"] = snippets.numpy().squeeze(
                        2
                    )  # [n_snip, snippet_len, H W]
                np.savez_compressed(save_to, **snippet_dict)

            # Colorize results
            for i_cmap, cmap in enumerate(args.color_maps):
                if "" == cmap:
                    continue
                colored_np = colorize_depth_multi_thread(
                    depth=depth_pred.numpy(),
                    valid_mask=None,
                    chunk_size=4,
                    num_threads=4,
                    color_map=cmap,
                    verbose=args.verbose,
                )  # [n h w 3], in [0, 255]
                save_to = output_dir.joinpath(f"{video_path.stem}_{cmap}.mp4")
                if not args.output_fps > 0:
                    output_fps = int(get_video_fps(video_path))
                write_video_from_numpy(
                    frames=colored_np,
                    output_path=save_to,
                    fps=args.output_fps,
                    crf=23,
                    preset="medium",
                    verbose=args.verbose,
                )

                # Save side-by-side videos
                if args.save_sbs and 0 == i_cmap:
                    rgb = pipe_out.input_rgb * 255  # [N 3 H W]
                    colored_depth = einops.rearrange(
                        torch.from_numpy(colored_np), "n h w c -> n c h w"
                    )
                    concat_video = (
                        concatenate_videos_horizontally_torch(rgb, colored_depth, gap=10)
                        .int()
                        .numpy()
                        .astype(np.uint8)
                    )
                    concat_video = einops.rearrange(concat_video, "n c h w -> n h w c")
                    save_to = output_dir.joinpath(f"{video_path.stem}_rgbd.mp4")
                    write_video_from_numpy(
                        frames=concat_video,
                        output_path=save_to,
                        fps=args.output_fps,
                        crf=23,
                        preset="medium",
                        verbose=args.verbose,
                    )

        logging.info(
            f"Finished. {len(video_iterable)} predictions are saved to {output_dir}"
        )