diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..aa7b75fdd0009905fac2cb3f8cb377472fae7984 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 *.zip filter=lfs diff=lfs merge=lfs -text
 *.zst filter=lfs diff=lfs merge=lfs -text
 *tfevents* filter=lfs diff=lfs merge=lfs -text
+unimernet/processors/formula_processor_helper/frost/frost1.png filter=lfs diff=lfs merge=lfs -text
diff --git a/examples/0000004.png b/examples/0000004.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0006f68e8c1e18e258baaab3263228c461d7f0c
Binary files /dev/null and b/examples/0000004.png differ
diff --git a/examples/0000005.png b/examples/0000005.png
new file mode 100644
index 0000000000000000000000000000000000000000..0ec583b402563a2bf34cf66664c13f2c905f631e
Binary files /dev/null and b/examples/0000005.png differ
diff --git a/examples/0000006.png b/examples/0000006.png
new file mode 100644
index 0000000000000000000000000000000000000000..5acb9e4020fe7b1b5068dc21cc001b8d02f63096
Binary files /dev/null and b/examples/0000006.png differ
diff --git a/examples/0000007.png b/examples/0000007.png
new file mode 100644
index 0000000000000000000000000000000000000000..cc807adb3ada1010f32b5b3c678d178fbe2444b0
Binary files /dev/null and b/examples/0000007.png differ
diff --git a/examples/0000011.png b/examples/0000011.png
new file mode 100644
index 0000000000000000000000000000000000000000..2417a6f37f86b585ce8ffc0527a1268f4ea1d3d6
Binary files /dev/null and b/examples/0000011.png differ
diff --git a/unimernet/__init__.py b/unimernet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a759d516575604fdf99b56de5cdf149e579012
--- /dev/null
+++ b/unimernet/__init__.py
@@ -0,0 +1,31 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import os
+import sys
+
+from omegaconf import OmegaConf
+
+from unimernet.common.registry import registry
+
+from unimernet.datasets.builders import *
+from unimernet.models import *
+from unimernet.processors import *
+from unimernet.tasks import *
+
+
+root_dir = os.path.dirname(os.path.abspath(__file__))
+default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
+
+registry.register_path("library_root", root_dir)
+repo_root = os.path.join(root_dir, "..")
+registry.register_path("repo_root", repo_root)
+cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
+registry.register_path("cache_root", cache_root)
+
+registry.register("MAX_INT", sys.maxsize)
+registry.register("SPLIT_NAMES", ["train", "val", "test"])
diff --git a/unimernet/__pycache__/__init__.cpython-310.pyc b/unimernet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edea971fd906b1fc9895653f06b14d6b98d4f387
Binary files /dev/null and b/unimernet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/common/__init__.py b/unimernet/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/common/__pycache__/__init__.cpython-310.pyc b/unimernet/common/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14c726cbd5b8ce37830e6d27ec8d481518520fe2
Binary files /dev/null and b/unimernet/common/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/config.cpython-310.pyc b/unimernet/common/__pycache__/config.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b8085e1fb2360a3236efb0f8d4f080bd5fb8f8f
Binary files /dev/null and b/unimernet/common/__pycache__/config.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/dist_utils.cpython-310.pyc b/unimernet/common/__pycache__/dist_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b714750191d75d4b2575059e05940017d2c395e1
Binary files /dev/null and b/unimernet/common/__pycache__/dist_utils.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/logger.cpython-310.pyc b/unimernet/common/__pycache__/logger.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0964be7a43f1cd39e88bbdcf94aa45b88f831e27
Binary files /dev/null and b/unimernet/common/__pycache__/logger.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/registry.cpython-310.pyc b/unimernet/common/__pycache__/registry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9101b0789cd5ffc61eab6ec231cc0661b0dc6e41
Binary files /dev/null and b/unimernet/common/__pycache__/registry.cpython-310.pyc differ
diff --git a/unimernet/common/__pycache__/utils.cpython-310.pyc b/unimernet/common/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..94155df43180ba04fbf358536c93dc7bd69ebcab
Binary files /dev/null and b/unimernet/common/__pycache__/utils.cpython-310.pyc differ
diff --git a/unimernet/common/config.py b/unimernet/common/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bfcff2e4c07557832501594a6a4ed0871c166fa
--- /dev/null
+++ b/unimernet/common/config.py
@@ -0,0 +1,468 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import json
+from typing import Dict
+
+from omegaconf import OmegaConf
+from unimernet.common.registry import registry
+
+
+class Config:
+    def __init__(self, args):
+        self.config = {}
+
+        self.args = args
+
+        # Register the config and configuration for setup
+        registry.register("configuration", self)
+
+        user_config = self._build_opt_list(self.args.options)
+
+        config = OmegaConf.load(self.args.cfg_path)
+
+        runner_config = self.build_runner_config(config)
+        model_config = self.build_model_config(config, **user_config)
+        dataset_config = self.build_dataset_config(config)
+
+        # Validate the user-provided runner configuration
+        # model and dataset configuration are supposed to be validated by the respective classes
+        # [TODO] validate the model/dataset configuration
+        # self._validate_runner_config(runner_config)
+
+        # Override the default configuration with user options.
+        self.config = OmegaConf.merge(
+            runner_config, model_config, dataset_config, user_config
+        )
+
+    def _validate_runner_config(self, runner_config):
+        """
+        This method validates the configuration, such that
+            1) all the user specified options are valid;
+            2) no type mismatches between the user specified options and the config.
+        """
+        runner_config_validator = create_runner_config_validator()
+        runner_config_validator.validate(runner_config)
+
+    def _build_opt_list(self, opts):
+        opts_dot_list = self._convert_to_dot_list(opts)
+        return OmegaConf.from_dotlist(opts_dot_list)
+
+    @staticmethod
+    def build_model_config(config, **kwargs):
+        model = config.get("model", None)
+        assert model is not None, "Missing model configuration file."
+
+        model_cls = registry.get_model_class(model.arch)
+        assert model_cls is not None, f"Model '{model.arch}' has not been registered."
+
+        model_type = kwargs.get("model.model_type", None)
+        if not model_type:
+            model_type = model.get("model_type", None)
+        # else use the model type selected by user.
+
+        assert model_type is not None, "Missing model_type."
+
+        model_config_path = model_cls.default_config_path(model_type=model_type)
+
+        model_config = OmegaConf.create()
+        # hiararchy override, customized config > default config
+        model_config = OmegaConf.merge(
+            model_config,
+            OmegaConf.load(model_config_path),
+            {"model": config["model"]},
+        )
+
+        return model_config
+
+    @staticmethod
+    def build_runner_config(config):
+        return {"run": config.run}
+
+    @staticmethod
+    def build_dataset_config(config):
+        datasets = config.get("datasets", None)
+        if datasets is None:
+            raise KeyError(
+                "Expecting 'datasets' as the root key for dataset configuration."
+            )
+
+        dataset_config = OmegaConf.create()
+
+        for dataset_name in datasets:
+            builder_cls = registry.get_builder_class(dataset_name)
+
+            dataset_config_type = datasets[dataset_name].get("type", "default")
+            dataset_config_path = builder_cls.default_config_path(
+                type=dataset_config_type
+            )
+
+            # hiararchy override, customized config > default config
+            dataset_config = OmegaConf.merge(
+                dataset_config,
+                OmegaConf.load(dataset_config_path),
+                {"datasets": {dataset_name: config["datasets"][dataset_name]}},
+            )
+
+        return dataset_config
+
+    def _convert_to_dot_list(self, opts):
+        if opts is None:
+            opts = []
+
+        if len(opts) == 0:
+            return opts
+
+        has_equal = opts[0].find("=") != -1
+
+        if has_equal:
+            return opts
+
+        return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
+
+    def get_config(self):
+        return self.config
+
+    @property
+    def run_cfg(self):
+        return self.config.run
+
+    @property
+    def datasets_cfg(self):
+        return self.config.datasets
+
+    @property
+    def model_cfg(self):
+        return self.config.model
+
+    def pretty_print(self):
+        logging.info("\n=====  Running Parameters    =====")
+        logging.info(self._convert_node_to_json(self.config.run))
+
+        logging.info("\n======  Dataset Attributes  ======")
+        datasets = self.config.datasets
+
+        for dataset in datasets:
+            if dataset in self.config.datasets:
+                logging.info(f"\n======== {dataset} =======")
+                dataset_config = self.config.datasets[dataset]
+                logging.info(self._convert_node_to_json(dataset_config))
+            else:
+                logging.warning(f"No dataset named '{dataset}' in config. Skipping")
+
+        logging.info(f"\n======  Model Attributes  ======")
+        logging.info(self._convert_node_to_json(self.config.model))
+
+    def _convert_node_to_json(self, node):
+        container = OmegaConf.to_container(node, resolve=True)
+        return json.dumps(container, indent=4, sort_keys=True)
+
+    def to_dict(self):
+        return OmegaConf.to_container(self.config)
+
+
+def node_to_dict(node):
+    return OmegaConf.to_container(node)
+
+
+class ConfigValidator:
+    """
+    This is a preliminary implementation to centralize and validate the configuration.
+    May be altered in the future.
+
+    A helper class to validate configurations from yaml file.
+
+    This serves the following purposes:
+        1. Ensure all the options in the yaml are defined, raise error if not.
+        2. when type mismatches are found, the validator will raise an error.
+        3. a central place to store and display helpful messages for supported configurations.
+
+    """
+
+    class _Argument:
+        def __init__(self, name, choices=None, type=None, help=None):
+            self.name = name
+            self.val = None
+            self.choices = choices
+            self.type = type
+            self.help = help
+
+        def __str__(self):
+            s = f"{self.name}={self.val}"
+            if self.type is not None:
+                s += f", ({self.type})"
+            if self.choices is not None:
+                s += f", choices: {self.choices}"
+            if self.help is not None:
+                s += f", ({self.help})"
+            return s
+
+    def __init__(self, description):
+        self.description = description
+
+        self.arguments = dict()
+
+        self.parsed_args = None
+
+    def __getitem__(self, key):
+        assert self.parsed_args is not None, "No arguments parsed yet."
+
+        return self.parsed_args[key]
+
+    def __str__(self) -> str:
+        return self.format_help()
+
+    def add_argument(self, *args, **kwargs):
+        """
+        Assume the first argument is the name of the argument.
+        """
+        self.arguments[args[0]] = self._Argument(*args, **kwargs)
+
+    def validate(self, config=None):
+        """
+        Convert yaml config (dict-like) to list, required by argparse.
+        """
+        for k, v in config.items():
+            assert (
+                k in self.arguments
+            ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
+
+            if self.arguments[k].type is not None:
+                try:
+                    self.arguments[k].val = self.arguments[k].type(v)
+                except ValueError:
+                    raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
+
+            if self.arguments[k].choices is not None:
+                assert (
+                    v in self.arguments[k].choices
+                ), f"""{k} must be one of {self.arguments[k].choices}."""
+
+        return config
+
+    def format_arguments(self):
+        return str([f"{k}" for k in sorted(self.arguments.keys())])
+
+    def format_help(self):
+        # description + key-value pair string for each argument
+        help_msg = str(self.description)
+        return help_msg + ", available arguments: " + self.format_arguments()
+
+    def print_help(self):
+        # display help message
+        print(self.format_help())
+
+
+def create_runner_config_validator():
+    validator = ConfigValidator(description="Runner configurations")
+
+    validator.add_argument(
+        "runner",
+        type=str,
+        choices=["runner_base", "runner_iter"],
+        help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
+            runner runs based on iters. Default: runner_base""",
+    )
+    # add argumetns for training dataset ratios
+    validator.add_argument(
+        "train_dataset_ratios",
+        type=Dict[str, float],
+        help="""Ratios of training dataset. This is used in iteration-based runner.
+        Do not support for epoch-based runner because how to define an epoch becomes tricky.
+        Default: None""",
+    )
+    validator.add_argument(
+        "max_iters",
+        type=float,
+        help="Maximum number of iterations to run.",
+    )
+    validator.add_argument(
+        "max_epoch",
+        type=int,
+        help="Maximum number of epochs to run.",
+    )
+    # add arguments for iters_per_inner_epoch
+    validator.add_argument(
+        "iters_per_inner_epoch",
+        type=float,
+        help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
+    )
+    lr_scheds_choices = registry.list_lr_schedulers()
+    validator.add_argument(
+        "lr_sched",
+        type=str,
+        choices=lr_scheds_choices,
+        help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
+    )
+    task_choices = registry.list_tasks()
+    validator.add_argument(
+        "task",
+        type=str,
+        choices=task_choices,
+        help="Task to use, from {}".format(task_choices),
+    )
+    # add arguments for init_lr
+    validator.add_argument(
+        "init_lr",
+        type=float,
+        help="Initial learning rate. This will be the learning rate after warmup and before decay.",
+    )
+    # add arguments for min_lr
+    validator.add_argument(
+        "min_lr",
+        type=float,
+        help="Minimum learning rate (after decay).",
+    )
+    # add arguments for warmup_lr
+    validator.add_argument(
+        "warmup_lr",
+        type=float,
+        help="Starting learning rate for warmup.",
+    )
+    # add arguments for learning rate decay rate
+    validator.add_argument(
+        "lr_decay_rate",
+        type=float,
+        help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
+    )
+    # add arguments for weight decay
+    validator.add_argument(
+        "weight_decay",
+        type=float,
+        help="Weight decay rate.",
+    )
+    # add arguments for training batch size
+    validator.add_argument(
+        "batch_size_train",
+        type=int,
+        help="Training batch size.",
+    )
+    # add arguments for evaluation batch size
+    validator.add_argument(
+        "batch_size_eval",
+        type=int,
+        help="Evaluation batch size, including validation and testing.",
+    )
+    # add arguments for number of workers for data loading
+    validator.add_argument(
+        "num_workers",
+        help="Number of workers for data loading.",
+    )
+    # add arguments for warm up steps
+    validator.add_argument(
+        "warmup_steps",
+        type=int,
+        help="Number of warmup steps. Required if a warmup schedule is used.",
+    )
+    # add arguments for random seed
+    validator.add_argument(
+        "seed",
+        type=int,
+        help="Random seed.",
+    )
+    # add arguments for output directory
+    validator.add_argument(
+        "output_dir",
+        type=str,
+        help="Output directory to save checkpoints and logs.",
+    )
+    # add arguments for whether only use evaluation
+    validator.add_argument(
+        "evaluate",
+        help="Whether to only evaluate the model. If true, training will not be performed.",
+    )
+    # add arguments for splits used for training, e.g. ["train", "val"]
+    validator.add_argument(
+        "train_splits",
+        type=list,
+        help="Splits to use for training.",
+    )
+    # add arguments for splits used for validation, e.g. ["val"]
+    validator.add_argument(
+        "valid_splits",
+        type=list,
+        help="Splits to use for validation. If not provided, will skip the validation.",
+    )
+    # add arguments for splits used for testing, e.g. ["test"]
+    validator.add_argument(
+        "test_splits",
+        type=list,
+        help="Splits to use for testing. If not provided, will skip the testing.",
+    )
+    # add arguments for accumulating gradient for iterations
+    validator.add_argument(
+        "accum_grad_iters",
+        type=int,
+        help="Number of iterations to accumulate gradient for.",
+    )
+
+    # ====== distributed training ======
+    validator.add_argument(
+        "device",
+        type=str,
+        choices=["cpu", "cuda"],
+        help="Device to use. Support 'cuda' or 'cpu' as for now.",
+    )
+    validator.add_argument(
+        "world_size",
+        type=int,
+        help="Number of processes participating in the job.",
+    )
+    validator.add_argument("dist_url", type=str)
+    validator.add_argument("distributed", type=bool)
+    # add arguments to opt using distributed sampler during evaluation or not
+    validator.add_argument(
+        "use_dist_eval_sampler",
+        type=bool,
+        help="Whether to use distributed sampler during evaluation or not.",
+    )
+
+    # ====== task specific ======
+    # generation task specific arguments
+    # add arguments for maximal length of text output
+    validator.add_argument(
+        "max_len",
+        type=int,
+        help="Maximal length of text output.",
+    )
+    # add arguments for minimal length of text output
+    validator.add_argument(
+        "min_len",
+        type=int,
+        help="Minimal length of text output.",
+    )
+    # add arguments number of beams
+    validator.add_argument(
+        "num_beams",
+        type=int,
+        help="Number of beams used for beam search.",
+    )
+
+    # vqa task specific arguments
+    # add arguments for number of answer candidates
+    validator.add_argument(
+        "num_ans_candidates",
+        type=int,
+        help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
+    )
+    # add arguments for inference method
+    validator.add_argument(
+        "inference_method",
+        type=str,
+        choices=["genearte", "rank"],
+        help="""Inference method to use for question answering. If rank, requires a answer list.""",
+    )
+
+    # ====== model specific ======
+    validator.add_argument(
+        "k_test",
+        type=int,
+        help="Number of top k most similar samples from ITC/VTC selection to be tested.",
+    )
+
+    return validator
diff --git a/unimernet/common/dist_utils.py b/unimernet/common/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..296a3c86f29c6e82fa8f1108c7dd9fa7d3e9ce45
--- /dev/null
+++ b/unimernet/common/dist_utils.py
@@ -0,0 +1,137 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import timm.models.hub as timm_hub
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop("force", False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def init_distributed_mode(args):
+    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ["WORLD_SIZE"])
+        args.gpu = int(os.environ["LOCAL_RANK"])
+    elif "SLURM_PROCID" in os.environ:
+        args.rank = int(os.environ["SLURM_PROCID"])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print("Not using distributed mode")
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = "nccl"
+    print(
+        "| distributed init (rank {}, world {}): {}".format(
+            args.rank, args.world_size, args.dist_url
+        ),
+        flush=True,
+    )
+    torch.distributed.init_process_group(
+        backend=args.dist_backend,
+        init_method=args.dist_url,
+        world_size=args.world_size,
+        rank=args.rank,
+        timeout=datetime.timedelta(
+            days=365
+        ),  # allow auto-downloading and de-compressing
+    )
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+def get_dist_info():
+    if torch.__version__ < "1.0":
+        initialized = dist._initialized
+    else:
+        initialized = dist.is_initialized()
+    if initialized:
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:  # non-distributed training
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def main_process(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        rank, _ = get_dist_info()
+        if rank == 0:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+    """
+    Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
+    If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
+    """
+
+    def get_cached_file_path():
+        # a hack to sync the file path across processes
+        parts = torch.hub.urlparse(url)
+        filename = os.path.basename(parts.path)
+        cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
+
+        return cached_file
+
+    if is_main_process():
+        timm_hub.download_cached_file(url, check_hash, progress)
+
+    if is_dist_avail_and_initialized():
+        dist.barrier()
+
+    return get_cached_file_path()
diff --git a/unimernet/common/gradcam.py b/unimernet/common/gradcam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0
--- /dev/null
+++ b/unimernet/common/gradcam.py
@@ -0,0 +1,24 @@
+import numpy as np
+from matplotlib import pyplot as plt
+from scipy.ndimage import filters
+from skimage import transform as skimage_transform
+
+
+def getAttMap(img, attMap, blur=True, overlap=True):
+    attMap -= attMap.min()
+    if attMap.max() > 0:
+        attMap /= attMap.max()
+    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
+    if blur:
+        attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
+        attMap -= attMap.min()
+        attMap /= attMap.max()
+    cmap = plt.get_cmap("jet")
+    attMapV = cmap(attMap)
+    attMapV = np.delete(attMapV, 3, 2)
+    if overlap:
+        attMap = (
+            1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
+            + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
+        )
+    return attMap
diff --git a/unimernet/common/logger.py b/unimernet/common/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..55d46267ed367996f17dc5a3df80e8bdb20b76af
--- /dev/null
+++ b/unimernet/common/logger.py
@@ -0,0 +1,195 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import time
+from collections import defaultdict, deque
+
+import torch
+import torch.distributed as dist
+
+from unimernet.common import dist_utils
+
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not dist_utils.is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value,
+        )
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError(
+            "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+        )
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append("{}: {}".format(name, str(meter)))
+        return self.delimiter.join(loss_str)
+
+    def global_avg(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ""
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt="{avg:.4f}")
+        data_time = SmoothedValue(fmt="{avg:.4f}")
+        space_fmt = ":" + str(len(str(len(iterable)))) + "d"
+        log_msg = [
+            header,
+            "[{0" + space_fmt + "}/{1}]",
+            "eta: {eta}",
+            "{meters}",
+            "time: {time}",
+            "data: {data}",
+        ]
+        if torch.cuda.is_available():
+            log_msg.append("max mem: {memory:.0f}")
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                            memory=torch.cuda.max_memory_allocated() / MB,
+                        )
+                    )
+                else:
+                    print(
+                        log_msg.format(
+                            i,
+                            len(iterable),
+                            eta=eta_string,
+                            meters=str(self),
+                            time=str(iter_time),
+                            data=str(data_time),
+                        )
+                    )
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print(
+            "{} Total time: {} ({:.4f} s / it)".format(
+                header, total_time_str, total_time / len(iterable)
+            )
+        )
+
+
+class AttrDict(dict):
+    def __init__(self, *args, **kwargs):
+        super(AttrDict, self).__init__(*args, **kwargs)
+        self.__dict__ = self
+
+
+def setup_logger():
+    logging.basicConfig(
+        level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
+        format="%(asctime)s [%(levelname)s] %(message)s",
+        handlers=[logging.StreamHandler()],
+    )
diff --git a/unimernet/common/optims.py b/unimernet/common/optims.py
new file mode 100644
index 0000000000000000000000000000000000000000..148b5a2c30520ae3e0e033142300ba90703c6939
--- /dev/null
+++ b/unimernet/common/optims.py
@@ -0,0 +1,120 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import math
+
+from unimernet.common.registry import registry
+
+
+@registry.register_lr_scheduler("linear_warmup_step_lr")
+class LinearWarmupStepLRScheduler:
+    def __init__(
+        self,
+        optimizer,
+        max_epoch,
+        min_lr,
+        init_lr,
+        decay_rate=1,
+        warmup_start_lr=-1,
+        warmup_steps=0,
+        **kwargs
+    ):
+        self.optimizer = optimizer
+
+        self.max_epoch = max_epoch
+        self.min_lr = min_lr
+
+        self.decay_rate = decay_rate
+
+        self.init_lr = init_lr
+        self.warmup_steps = warmup_steps
+        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+    def step(self, cur_epoch, cur_step):
+        if cur_epoch == 0:
+            warmup_lr_schedule(
+                step=cur_step,
+                optimizer=self.optimizer,
+                max_step=self.warmup_steps,
+                init_lr=self.warmup_start_lr,
+                max_lr=self.init_lr,
+            )
+        else:
+            step_lr_schedule(
+                epoch=cur_epoch,
+                optimizer=self.optimizer,
+                init_lr=self.init_lr,
+                min_lr=self.min_lr,
+                decay_rate=self.decay_rate,
+            )
+
+
+@registry.register_lr_scheduler("linear_warmup_cosine_lr")
+class LinearWarmupCosineLRScheduler:
+    def __init__(
+        self,
+        optimizer,
+        max_epoch,
+        min_lr,
+        init_lr,
+        iters_per_epoch,
+        warmup_steps=0,
+        warmup_start_lr=-1,
+        **kwargs
+    ):
+        self.optimizer = optimizer
+
+        self.max_epoch = max_epoch
+        self.min_lr = min_lr
+
+        self.init_lr = init_lr
+        self.warmup_steps = warmup_steps
+        self.iters_per_epoch = iters_per_epoch
+        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
+
+    def step(self, cur_epoch, cur_step):
+        # assuming the warmup iters less than one epoch
+        total_steps = cur_epoch * self.iters_per_epoch + cur_step
+        if total_steps < self.warmup_steps:
+            warmup_lr_schedule(
+                step=cur_step,
+                optimizer=self.optimizer,
+                max_step=self.warmup_steps,
+                init_lr=self.warmup_start_lr,
+                max_lr=self.init_lr,
+            )
+        else:
+            cosine_lr_schedule(
+                epoch=total_steps,
+                optimizer=self.optimizer,
+                max_epoch=self.max_epoch * self.iters_per_epoch,
+                init_lr=self.init_lr,
+                min_lr=self.min_lr,
+            )
+
+
+def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
+    """Decay the learning rate"""
+    lr = (init_lr - min_lr) * 0.5 * (
+        1.0 + math.cos(math.pi * epoch / max_epoch)
+    ) + min_lr
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
+
+
+def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
+    """Warmup the learning rate"""
+    lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
+
+
+def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
+    """Decay the learning rate"""
+    lr = max(min_lr, init_lr * (decay_rate**epoch))
+    for param_group in optimizer.param_groups:
+        param_group["lr"] = lr
diff --git a/unimernet/common/registry.py b/unimernet/common/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..699c1bc137ea422e3ecde40d3fade83b0bac45f0
--- /dev/null
+++ b/unimernet/common/registry.py
@@ -0,0 +1,329 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+
+class Registry:
+    mapping = {
+        "builder_name_mapping": {},
+        "task_name_mapping": {},
+        "processor_name_mapping": {},
+        "model_name_mapping": {},
+        "lr_scheduler_name_mapping": {},
+        "runner_name_mapping": {},
+        "state": {},
+        "paths": {},
+    }
+
+    @classmethod
+    def register_builder(cls, name):
+        r"""Register a dataset builder to registry with key 'name'
+
+        Args:
+            name: Key with which the builder will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+            from unimernet.datasets.base_dataset_builder import BaseDatasetBuilder
+        """
+
+        def wrap(builder_cls):
+            from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+
+            assert issubclass(
+                builder_cls, BaseDatasetBuilder
+            ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
+                builder_cls
+            )
+            if name in cls.mapping["builder_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["builder_name_mapping"][name]
+                    )
+                )
+            cls.mapping["builder_name_mapping"][name] = builder_cls
+            return builder_cls
+
+        return wrap
+
+    @classmethod
+    def register_task(cls, name):
+        r"""Register a task to registry with key 'name'
+
+        Args:
+            name: Key with which the task will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+
+        def wrap(task_cls):
+            from unimernet.tasks.base_task import BaseTask
+
+            assert issubclass(
+                task_cls, BaseTask
+            ), "All tasks must inherit BaseTask class"
+            if name in cls.mapping["task_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["task_name_mapping"][name]
+                    )
+                )
+            cls.mapping["task_name_mapping"][name] = task_cls
+            return task_cls
+
+        return wrap
+
+    @classmethod
+    def register_model(cls, name):
+        r"""Register a task to registry with key 'name'
+
+        Args:
+            name: Key with which the task will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+
+        def wrap(model_cls):
+            from unimernet.models import BaseModel
+
+            assert issubclass(
+                model_cls, BaseModel
+            ), "All models must inherit BaseModel class"
+            if name in cls.mapping["model_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["model_name_mapping"][name]
+                    )
+                )
+            cls.mapping["model_name_mapping"][name] = model_cls
+            return model_cls
+
+        return wrap
+
+    @classmethod
+    def register_processor(cls, name):
+        r"""Register a processor to registry with key 'name'
+
+        Args:
+            name: Key with which the task will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+
+        def wrap(processor_cls):
+            from unimernet.processors import BaseProcessor
+
+            assert issubclass(
+                processor_cls, BaseProcessor
+            ), "All processors must inherit BaseProcessor class"
+            if name in cls.mapping["processor_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["processor_name_mapping"][name]
+                    )
+                )
+            cls.mapping["processor_name_mapping"][name] = processor_cls
+            return processor_cls
+
+        return wrap
+
+    @classmethod
+    def register_lr_scheduler(cls, name):
+        r"""Register a model to registry with key 'name'
+
+        Args:
+            name: Key with which the task will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+
+        def wrap(lr_sched_cls):
+            if name in cls.mapping["lr_scheduler_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["lr_scheduler_name_mapping"][name]
+                    )
+                )
+            cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
+            return lr_sched_cls
+
+        return wrap
+
+    @classmethod
+    def register_runner(cls, name):
+        r"""Register a model to registry with key 'name'
+
+        Args:
+            name: Key with which the task will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+
+        def wrap(runner_cls):
+            if name in cls.mapping["runner_name_mapping"]:
+                raise KeyError(
+                    "Name '{}' already registered for {}.".format(
+                        name, cls.mapping["runner_name_mapping"][name]
+                    )
+                )
+            cls.mapping["runner_name_mapping"][name] = runner_cls
+            return runner_cls
+
+        return wrap
+
+    @classmethod
+    def register_path(cls, name, path):
+        r"""Register a path to registry with key 'name'
+
+        Args:
+            name: Key with which the path will be registered.
+
+        Usage:
+
+            from unimernet.common.registry import registry
+        """
+        assert isinstance(path, str), "All path must be str."
+        if name in cls.mapping["paths"]:
+            raise KeyError("Name '{}' already registered.".format(name))
+        cls.mapping["paths"][name] = path
+
+    @classmethod
+    def register(cls, name, obj):
+        r"""Register an item to registry with key 'name'
+
+        Args:
+            name: Key with which the item will be registered.
+
+        Usage::
+
+            from unimernet.common.registry import registry
+
+            registry.register("config", {})
+        """
+        path = name.split(".")
+        current = cls.mapping["state"]
+
+        for part in path[:-1]:
+            if part not in current:
+                current[part] = {}
+            current = current[part]
+
+        current[path[-1]] = obj
+
+    # @classmethod
+    # def get_trainer_class(cls, name):
+    #     return cls.mapping["trainer_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_builder_class(cls, name):
+        return cls.mapping["builder_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_model_class(cls, name):
+        return cls.mapping["model_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_task_class(cls, name):
+        return cls.mapping["task_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_processor_class(cls, name):
+        return cls.mapping["processor_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_lr_scheduler_class(cls, name):
+        return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
+
+    @classmethod
+    def get_runner_class(cls, name):
+        return cls.mapping["runner_name_mapping"].get(name, None)
+
+    @classmethod
+    def list_runners(cls):
+        return sorted(cls.mapping["runner_name_mapping"].keys())
+
+    @classmethod
+    def list_models(cls):
+        return sorted(cls.mapping["model_name_mapping"].keys())
+
+    @classmethod
+    def list_tasks(cls):
+        return sorted(cls.mapping["task_name_mapping"].keys())
+
+    @classmethod
+    def list_processors(cls):
+        return sorted(cls.mapping["processor_name_mapping"].keys())
+
+    @classmethod
+    def list_lr_schedulers(cls):
+        return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
+
+    @classmethod
+    def list_datasets(cls):
+        return sorted(cls.mapping["builder_name_mapping"].keys())
+
+    @classmethod
+    def get_path(cls, name):
+        return cls.mapping["paths"].get(name, None)
+
+    @classmethod
+    def get(cls, name, default=None, no_warning=False):
+        r"""Get an item from registry with key 'name'
+
+        Args:
+            name (string): Key whose value needs to be retrieved.
+            default: If passed and key is not in registry, default value will
+                     be returned with a warning. Default: None
+            no_warning (bool): If passed as True, warning when key doesn't exist
+                               will not be generated. Useful for MMF's
+                               internal operations. Default: False
+        """
+        original_name = name
+        name = name.split(".")
+        value = cls.mapping["state"]
+        for subname in name:
+            value = value.get(subname, default)
+            if value is default:
+                break
+
+        if (
+            "writer" in cls.mapping["state"]
+            and value == default
+            and no_warning is False
+        ):
+            cls.mapping["state"]["writer"].warning(
+                "Key {} is not present in registry, returning default value "
+                "of {}".format(original_name, default)
+            )
+        return value
+
+    @classmethod
+    def unregister(cls, name):
+        r"""Remove an item from registry with key 'name'
+
+        Args:
+            name: Key which needs to be removed.
+        Usage::
+
+            from mmf.common.registry import registry
+
+            config = registry.unregister("config")
+        """
+        return cls.mapping["state"].pop(name, None)
+
+
+registry = Registry()
diff --git a/unimernet/common/utils.py b/unimernet/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6c3366b18db21db5330e85ab4e239a404312b86
--- /dev/null
+++ b/unimernet/common/utils.py
@@ -0,0 +1,424 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import io
+import json
+import logging
+import os
+import pickle
+import re
+import shutil
+import urllib
+import urllib.error
+import urllib.request
+from typing import Optional
+from urllib.parse import urlparse
+
+import numpy as np
+import pandas as pd
+import yaml
+from iopath.common.download import download
+from iopath.common.file_io import file_lock, g_pathmgr
+from unimernet.common.registry import registry
+from torch.utils.model_zoo import tqdm
+from torchvision.datasets.utils import (
+    check_integrity,
+    download_file_from_google_drive,
+    extract_archive,
+)
+
+
+def now():
+    from datetime import datetime
+
+    return datetime.now().strftime("%Y%m%d%H%M")[:-1]
+
+
+def is_url(url_or_filename):
+    parsed = urlparse(url_or_filename)
+    return parsed.scheme in ("http", "https")
+
+
+def get_cache_path(rel_path):
+    return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
+
+
+def get_abs_path(rel_path):
+    return os.path.join(registry.get_path("library_root"), rel_path)
+
+
+def load_json(filename):
+    with open(filename, "r") as f:
+        return json.load(f)
+
+
+# The following are adapted from torchvision and vissl
+# torchvision: https://github.com/pytorch/vision
+# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
+
+
+def makedir(dir_path):
+    """
+    Create the directory if it does not exist.
+    """
+    is_success = False
+    try:
+        if not g_pathmgr.exists(dir_path):
+            g_pathmgr.mkdirs(dir_path)
+        is_success = True
+    except BaseException:
+        print(f"Error creating directory: {dir_path}")
+    return is_success
+
+
+def get_redirected_url(url: str):
+    """
+    Given a URL, returns the URL it redirects to or the
+    original URL in case of no indirection
+    """
+    import requests
+
+    with requests.Session() as session:
+        with session.get(url, stream=True, allow_redirects=True) as response:
+            if response.history:
+                return response.url
+            else:
+                return url
+
+
+def to_google_drive_download_url(view_url: str) -> str:
+    """
+    Utility function to transform a view URL of google drive
+    to a download URL for google drive
+    Example input:
+        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
+    Example output:
+        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
+    """
+    splits = view_url.split("/")
+    assert splits[-1] == "view"
+    file_id = splits[-2]
+    return f"https://drive.google.com/uc?export=download&id={file_id}"
+
+
+def download_google_drive_url(url: str, output_path: str, output_file_name: str):
+    """
+    Download a file from google drive
+    Downloading an URL from google drive requires confirmation when
+    the file of the size is too big (google drive notifies that
+    anti-viral checks cannot be performed on such files)
+    """
+    import requests
+
+    with requests.Session() as session:
+
+        # First get the confirmation token and append it to the URL
+        with session.get(url, stream=True, allow_redirects=True) as response:
+            for k, v in response.cookies.items():
+                if k.startswith("download_warning"):
+                    url = url + "&confirm=" + v
+
+        # Then download the content of the file
+        with session.get(url, stream=True, verify=True) as response:
+            makedir(output_path)
+            path = os.path.join(output_path, output_file_name)
+            total_size = int(response.headers.get("Content-length", 0))
+            with open(path, "wb") as file:
+                from tqdm import tqdm
+
+                with tqdm(total=total_size) as progress_bar:
+                    for block in response.iter_content(
+                        chunk_size=io.DEFAULT_BUFFER_SIZE
+                    ):
+                        file.write(block)
+                        progress_bar.update(len(block))
+
+
+def _get_google_drive_file_id(url: str) -> Optional[str]:
+    parts = urlparse(url)
+
+    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
+        return None
+
+    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
+    if match is None:
+        return None
+
+    return match.group("id")
+
+
+def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
+    with open(filename, "wb") as fh:
+        with urllib.request.urlopen(
+            urllib.request.Request(url, headers={"User-Agent": "vissl"})
+        ) as response:
+            with tqdm(total=response.length) as pbar:
+                for chunk in iter(lambda: response.read(chunk_size), ""):
+                    if not chunk:
+                        break
+                    pbar.update(chunk_size)
+                    fh.write(chunk)
+
+
+def download_url(
+    url: str,
+    root: str,
+    filename: Optional[str] = None,
+    md5: Optional[str] = None,
+) -> None:
+    """Download a file from a url and place it in root.
+    Args:
+        url (str): URL to download file from
+        root (str): Directory to place downloaded file in
+        filename (str, optional): Name to save the file under.
+                                  If None, use the basename of the URL.
+        md5 (str, optional): MD5 checksum of the download. If None, do not check
+    """
+    root = os.path.expanduser(root)
+    if not filename:
+        filename = os.path.basename(url)
+    fpath = os.path.join(root, filename)
+
+    makedir(root)
+
+    # check if file is already present locally
+    if check_integrity(fpath, md5):
+        print("Using downloaded and verified file: " + fpath)
+        return
+
+    # expand redirect chain if needed
+    url = get_redirected_url(url)
+
+    # check if file is located on Google Drive
+    file_id = _get_google_drive_file_id(url)
+    if file_id is not None:
+        return download_file_from_google_drive(file_id, root, filename, md5)
+
+    # download the file
+    try:
+        print("Downloading " + url + " to " + fpath)
+        _urlretrieve(url, fpath)
+    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]
+        if url[:5] == "https":
+            url = url.replace("https:", "http:")
+            print(
+                "Failed download. Trying https -> http instead."
+                " Downloading " + url + " to " + fpath
+            )
+            _urlretrieve(url, fpath)
+        else:
+            raise e
+
+    # check integrity of downloaded file
+    if not check_integrity(fpath, md5):
+        raise RuntimeError("File not found or corrupted.")
+
+
+def download_and_extract_archive(
+    url: str,
+    download_root: str,
+    extract_root: Optional[str] = None,
+    filename: Optional[str] = None,
+    md5: Optional[str] = None,
+    remove_finished: bool = False,
+) -> None:
+    download_root = os.path.expanduser(download_root)
+    if extract_root is None:
+        extract_root = download_root
+    if not filename:
+        filename = os.path.basename(url)
+
+    download_url(url, download_root, filename, md5)
+
+    archive = os.path.join(download_root, filename)
+    print("Extracting {} to {}".format(archive, extract_root))
+    extract_archive(archive, extract_root, remove_finished)
+
+
+def cache_url(url: str, cache_dir: str) -> str:
+    """
+    This implementation downloads the remote resource and caches it locally.
+    The resource will only be downloaded if not previously requested.
+    """
+    parsed_url = urlparse(url)
+    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
+    makedir(dirname)
+    filename = url.split("/")[-1]
+    cached = os.path.join(dirname, filename)
+    with file_lock(cached):
+        if not os.path.isfile(cached):
+            logging.info(f"Downloading {url} to {cached} ...")
+            cached = download(url, dirname, filename=filename)
+    logging.info(f"URL {url} cached in {cached}")
+    return cached
+
+
+# TODO (prigoyal): convert this into RAII-style API
+def create_file_symlink(file1, file2):
+    """
+    Simply create the symlinks for a given file1 to file2.
+    Useful during model checkpointing to symlinks to the
+    latest successful checkpoint.
+    """
+    try:
+        if g_pathmgr.exists(file2):
+            g_pathmgr.rm(file2)
+        g_pathmgr.symlink(file1, file2)
+    except Exception as e:
+        logging.info(f"Could NOT create symlink. Error: {e}")
+
+
+def save_file(data, filename, append_to_json=True, verbose=True):
+    """
+    Common i/o utility to handle saving data to various file formats.
+    Supported:
+        .pkl, .pickle, .npy, .json
+    Specifically for .json, users have the option to either append (default)
+    or rewrite by passing in Boolean value to append_to_json.
+    """
+    if verbose:
+        logging.info(f"Saving data to file: {filename}")
+    file_ext = os.path.splitext(filename)[1]
+    if file_ext in [".pkl", ".pickle"]:
+        with g_pathmgr.open(filename, "wb") as fopen:
+            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
+    elif file_ext == ".npy":
+        with g_pathmgr.open(filename, "wb") as fopen:
+            np.save(fopen, data)
+    elif file_ext == ".json":
+        if append_to_json:
+            with g_pathmgr.open(filename, "a") as fopen:
+                fopen.write(json.dumps(data, sort_keys=True) + "\n")
+                fopen.flush()
+        else:
+            with g_pathmgr.open(filename, "w") as fopen:
+                fopen.write(json.dumps(data, sort_keys=True) + "\n")
+                fopen.flush()
+    elif file_ext == ".yaml":
+        with g_pathmgr.open(filename, "w") as fopen:
+            dump = yaml.dump(data)
+            fopen.write(dump)
+            fopen.flush()
+    else:
+        raise Exception(f"Saving {file_ext} is not supported yet")
+
+    if verbose:
+        logging.info(f"Saved data to file: {filename}")
+
+
+def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
+    """
+    Common i/o utility to handle loading data from various file formats.
+    Supported:
+        .pkl, .pickle, .npy, .json
+    For the npy files, we support reading the files in mmap_mode.
+    If the mmap_mode of reading is not successful, we load data without the
+    mmap_mode.
+    """
+    if verbose:
+        logging.info(f"Loading data from file: {filename}")
+
+    file_ext = os.path.splitext(filename)[1]
+    if file_ext == ".txt":
+        with g_pathmgr.open(filename, "r") as fopen:
+            data = fopen.readlines()
+    elif file_ext in [".pkl", ".pickle"]:
+        with g_pathmgr.open(filename, "rb") as fopen:
+            data = pickle.load(fopen, encoding="latin1")
+    elif file_ext == ".npy":
+        if mmap_mode:
+            try:
+                with g_pathmgr.open(filename, "rb") as fopen:
+                    data = np.load(
+                        fopen,
+                        allow_pickle=allow_pickle,
+                        encoding="latin1",
+                        mmap_mode=mmap_mode,
+                    )
+            except ValueError as e:
+                logging.info(
+                    f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
+                )
+                data = np.load(
+                    filename,
+                    allow_pickle=allow_pickle,
+                    encoding="latin1",
+                    mmap_mode=mmap_mode,
+                )
+                logging.info("Successfully loaded without g_pathmgr")
+            except Exception:
+                logging.info("Could not mmap without g_pathmgr. Trying without mmap")
+                with g_pathmgr.open(filename, "rb") as fopen:
+                    data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+        else:
+            with g_pathmgr.open(filename, "rb") as fopen:
+                data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
+    elif file_ext == ".json":
+        with g_pathmgr.open(filename, "r") as fopen:
+            data = json.load(fopen)
+    elif file_ext == ".yaml":
+        with g_pathmgr.open(filename, "r") as fopen:
+            data = yaml.load(fopen, Loader=yaml.FullLoader)
+    elif file_ext == ".csv":
+        with g_pathmgr.open(filename, "r") as fopen:
+            data = pd.read_csv(fopen)
+    else:
+        raise Exception(f"Reading from {file_ext} is not supported yet")
+    return data
+
+
+def abspath(resource_path: str):
+    """
+    Make a path absolute, but take into account prefixes like
+    "http://" or "manifold://"
+    """
+    regex = re.compile(r"^\w+://")
+    if regex.match(resource_path) is None:
+        return os.path.abspath(resource_path)
+    else:
+        return resource_path
+
+
+def makedir(dir_path):
+    """
+    Create the directory if it does not exist.
+    """
+    is_success = False
+    try:
+        if not g_pathmgr.exists(dir_path):
+            g_pathmgr.mkdirs(dir_path)
+        is_success = True
+    except BaseException:
+        logging.info(f"Error creating directory: {dir_path}")
+    return is_success
+
+
+def is_url(input_url):
+    """
+    Check if an input string is a url. look for http(s):// and ignoring the case
+    """
+    is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
+    return is_url
+
+
+def cleanup_dir(dir):
+    """
+    Utility for deleting a directory. Useful for cleaning the storage space
+    that contains various training artifacts like checkpoints, data etc.
+    """
+    if os.path.exists(dir):
+        logging.info(f"Deleting directory: {dir}")
+        shutil.rmtree(dir)
+    logging.info(f"Deleted contents of directory: {dir}")
+
+
+def get_file_size(filename):
+    """
+    Given a file, get the size of file in MB
+    """
+    size_in_mb = os.path.getsize(filename) / float(1024**2)
+    return size_in_mb
diff --git a/unimernet/configs/datasets/formula/formula_eval.yaml b/unimernet/configs/datasets/formula/formula_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6e7e1ed2acd650c5a2224c77cc64bf67d5d62e5
--- /dev/null
+++ b/unimernet/configs/datasets/formula/formula_eval.yaml
@@ -0,0 +1,6 @@
+datasets:
+  formula_rec_eval:
+    data_type: images
+    build_info:
+      images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/val
+      annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
\ No newline at end of file
diff --git a/unimernet/configs/datasets/formula/formula_train.yaml b/unimernet/configs/datasets/formula/formula_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aa4af4cf3464521f2ac30686e087a01885c03741
--- /dev/null
+++ b/unimernet/configs/datasets/formula/formula_train.yaml
@@ -0,0 +1,6 @@
+datasets:
+  formula_rec_train:
+    data_type: images
+    build_info:
+      images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/train
+      annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
\ No newline at end of file
diff --git a/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml b/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c6dc5058ba06593d16f67882e2600acdcd9116b
--- /dev/null
+++ b/unimernet/configs/datasets/formula/multi_scale_formula_train.yaml
@@ -0,0 +1,21 @@
+datasets:
+  multi_scale_formula_rec_train:
+    data_type: images
+    build_info:
+      images: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/train
+      annotation: /mnt/petrelfs/share_data/hanxiao/latex-ocr/pdf/pdfmath.txt
+
+    vis_processor:
+      train:
+        name: "formula_image_multi_scale_train"
+        all_scales:
+          - [ 96, 336 ]
+          - [ 128, 448 ]
+          - [ 192, 672 ]
+          - [ 288, 1008 ]
+          - [ 384, 1344 ]
+
+    text_processor:
+      train:
+        name: "blip_caption"
+        max_words: 256
\ No newline at end of file
diff --git a/unimernet/configs/default.yaml b/unimernet/configs/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c4800a0e0d4a0444db40558ba950e08a33d80d31
--- /dev/null
+++ b/unimernet/configs/default.yaml
@@ -0,0 +1,10 @@
+ # Copyright (c) 2022, salesforce.com, inc.
+ # All rights reserved.
+ # SPDX-License-Identifier: BSD-3-Clause
+ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+
+env:
+  # For default users
+  # cache_root: "cache"
+  # For internal use with persistent storage
+  cache_root: "/export/home/.cache/vigc"
diff --git a/unimernet/configs/models/unimernet_base.yaml b/unimernet/configs/models/unimernet_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24e31350f30b8e0b8d4fc2331cc1f0319c33f349
--- /dev/null
+++ b/unimernet/configs/models/unimernet_base.yaml
@@ -0,0 +1,31 @@
+model:
+  arch: unimernet
+  load_finetuned: False
+  load_pretrained: False
+  pretrained: "path/to/pretrained/weight"
+  finetuned: ""
+  tokenizer_name: nougat
+  tokenizer_config:
+    path: ./models/unimernet
+  model_name: unimernet
+  model_config:
+    max_seq_len: 384
+
+
+preprocess:
+  vis_processor:
+    train:
+      name: "formula_image_train"
+      image_size:
+        - 192
+        - 672
+    eval:
+      name: "formula_image_eval"
+      image_size:
+        - 192
+        - 672
+  text_processor:
+    train:
+      name: "blip_caption"
+    eval:
+      name: "blip_caption"
diff --git a/unimernet/datasets/__init__.py b/unimernet/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/datasets/__pycache__/__init__.cpython-310.pyc b/unimernet/datasets/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f42879cec833c2ad1436195309ddbe4cd93beeaf
Binary files /dev/null and b/unimernet/datasets/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc b/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..da626f1b4f156471999aca4475572beddd9b19ac
Binary files /dev/null and b/unimernet/datasets/__pycache__/data_utils.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__init__.py b/unimernet/datasets/builders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9ad64e5aa54e4396ee752b2f1a1aa980f254f8
--- /dev/null
+++ b/unimernet/datasets/builders/__init__.py
@@ -0,0 +1,69 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.datasets.builders.base_dataset_builder import load_dataset_config
+from unimernet.common.registry import registry
+from unimernet.datasets.builders.formula import FormulaRecTrainBuilder, FormulaRecEvalBuilder, \
+    MultiScaleFormulaRecTrainBuilder
+
+__all__ = [
+    "FormulaRecTrainBuilder",
+    "FormulaRecEvalBuilder",
+    "MultiScaleFormulaRecTrainBuilder",
+]
+
+
+def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
+    """
+    Example
+
+    >>> dataset = load_dataset("coco_caption", cfg=None)
+    >>> splits = dataset.keys()
+    >>> print([len(dataset[split]) for split in splits])
+
+    """
+    if cfg_path is None:
+        cfg = None
+    else:
+        cfg = load_dataset_config(cfg_path)
+
+    try:
+        builder = registry.get_builder_class(name)(cfg)
+    except TypeError:
+        print(
+            f"Dataset {name} not found. Available datasets:\n"
+            + ", ".join([str(k) for k in dataset_zoo.get_names()])
+        )
+        exit(1)
+
+    if vis_path is not None:
+        if data_type is None:
+            # use default data type in the config
+            data_type = builder.config.data_type
+
+        assert (
+                data_type in builder.config.build_info
+        ), f"Invalid data_type {data_type} for {name}."
+
+        builder.config.build_info.get(data_type).storage = vis_path
+
+    dataset = builder.build_datasets()
+    return dataset
+
+
+class DatasetZoo:
+    def __init__(self) -> None:
+        self.dataset_zoo = {
+            k: list(v.DATASET_CONFIG_DICT.keys())
+            for k, v in sorted(registry.mapping["builder_name_mapping"].items())
+        }
+
+    def get_names(self):
+        return list(self.dataset_zoo.keys())
+
+
+dataset_zoo = DatasetZoo()
diff --git a/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0d38a0ca49e1d07a0295d7a6eeb9a4f5b43faa29
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d02302e0d85d678b9f40452cf49f01dad6bbf614
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/base_dataset_builder.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc b/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef951c05619e26e0164a7d7d2a58b5b0dbdfd3c3
Binary files /dev/null and b/unimernet/datasets/builders/__pycache__/formula.cpython-310.pyc differ
diff --git a/unimernet/datasets/builders/base_dataset_builder.py b/unimernet/datasets/builders/base_dataset_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fdc086214ee9c9b92143f6233a4dc3c8a3ec06d
--- /dev/null
+++ b/unimernet/datasets/builders/base_dataset_builder.py
@@ -0,0 +1,233 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+import shutil
+import warnings
+
+import unimernet.common.utils as utils
+import torch.distributed as dist
+from unimernet.common.dist_utils import is_dist_avail_and_initialized, is_main_process
+from unimernet.common.registry import registry
+from unimernet.processors.base_processor import BaseProcessor
+from omegaconf import OmegaConf
+from torchvision.datasets.utils import download_url
+
+
+class BaseDatasetBuilder:
+    train_dataset_cls, eval_dataset_cls = None, None
+
+    def __init__(self, cfg=None):
+        super().__init__()
+
+        if cfg is None:
+            # help to create datasets from default config.
+            self.config = load_dataset_config(self.default_config_path())
+        elif isinstance(cfg, str):
+            self.config = load_dataset_config(cfg)
+        else:
+            # when called from task.build_dataset()
+            self.config = cfg
+
+        self.data_type = self.config.data_type
+
+        self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+        self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
+
+    def build_datasets(self):
+        # download, split, etc...
+        # only called on 1 GPU/TPU in distributed
+
+        if is_main_process():
+            self._download_data()
+
+        if is_dist_avail_and_initialized():
+            dist.barrier()
+
+        # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
+        logging.info("Building datasets...")
+        datasets = self.build()  # dataset['train'/'val'/'test']
+
+        return datasets
+
+    def build_processors(self):
+        vis_proc_cfg = self.config.get("vis_processor")
+        txt_proc_cfg = self.config.get("text_processor")
+
+        if vis_proc_cfg is not None:
+            vis_train_cfg = vis_proc_cfg.get("train")
+            vis_eval_cfg = vis_proc_cfg.get("eval")
+
+            self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
+            self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
+
+        if txt_proc_cfg is not None:
+            txt_train_cfg = txt_proc_cfg.get("train")
+            txt_eval_cfg = txt_proc_cfg.get("eval")
+
+            self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
+            self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
+
+    @staticmethod
+    def _build_proc_from_cfg(cfg):
+        return (
+            registry.get_processor_class(cfg.name).from_config(cfg)
+            if cfg is not None
+            else None
+        )
+
+    @classmethod
+    def default_config_path(cls, type="default"):
+        return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
+
+    def _download_data(self):
+        self._download_ann()
+        self._download_vis()
+
+    def _download_ann(self):
+        """
+        Download annotation files if necessary.
+        All the vision-language datasets should have annotations of unified format.
+
+        storage_path can be:
+          (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
+          (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
+
+        Local annotation paths should be relative.
+        """
+        anns = self.config.build_info.annotations
+
+        splits = anns.keys()
+
+        cache_root = registry.get_path("cache_root")
+
+        for split in splits:
+            info = anns[split]
+
+            urls, storage_paths = info.get("url", None), info.storage
+
+            if isinstance(urls, str):
+                urls = [urls]
+            if isinstance(storage_paths, str):
+                storage_paths = [storage_paths]
+
+            assert len(urls) == len(storage_paths)
+
+            for url_or_filename, storage_path in zip(urls, storage_paths):
+                # if storage_path is relative, make it full by prefixing with cache_root.
+                if not os.path.isabs(storage_path):
+                    storage_path = os.path.join(cache_root, storage_path)
+
+                dirname = os.path.dirname(storage_path)
+                if not os.path.exists(dirname):
+                    os.makedirs(dirname)
+
+                if os.path.isfile(url_or_filename):
+                    src, dst = url_or_filename, storage_path
+                    if not os.path.exists(dst):
+                        shutil.copyfile(src=src, dst=dst)
+                    else:
+                        logging.info("Using existing file {}.".format(dst))
+                else:
+                    if os.path.isdir(storage_path):
+                        # if only dirname is provided, suffix with basename of URL.
+                        raise ValueError(
+                            "Expecting storage_path to be a file path, got directory {}".format(
+                                storage_path
+                            )
+                        )
+                    else:
+                        filename = os.path.basename(storage_path)
+
+                    download_url(url=url_or_filename, root=dirname, filename=filename)
+
+    def _download_vis(self):
+
+        storage_path = self.config.build_info.get(self.data_type).storage
+        storage_path = utils.get_cache_path(storage_path)
+
+        if not os.path.exists(storage_path):
+            warnings.warn(
+                f"""
+                The specified path {storage_path} for visual inputs does not exist.
+                Please provide a correct path to the visual inputs or
+                refer to datasets/download_scripts/README.md for downloading instructions.
+                """
+            )
+
+    def build(self):
+        """
+        Create by split datasets inheriting torch.utils.data.Datasets.
+
+        # build() can be dataset-specific. Overwrite to customize.
+        """
+        self.build_processors()
+
+        build_info = self.config.build_info
+
+        ann_info = build_info.annotations
+        vis_info = build_info.get(self.data_type)
+
+        datasets = dict()
+        for split in ann_info.keys():
+            if split not in ["train", "val", "test"]:
+                continue
+
+            is_train = split == "train"
+
+            # processors
+            vis_processor = (
+                self.vis_processors["train"]
+                if is_train
+                else self.vis_processors["eval"]
+            )
+            text_processor = (
+                self.text_processors["train"]
+                if is_train
+                else self.text_processors["eval"]
+            )
+
+            # annotation path
+            ann_paths = ann_info.get(split).storage
+            if isinstance(ann_paths, str):
+                ann_paths = [ann_paths]
+
+            abs_ann_paths = []
+            for ann_path in ann_paths:
+                if not os.path.isabs(ann_path):
+                    ann_path = utils.get_cache_path(ann_path)
+                abs_ann_paths.append(ann_path)
+            ann_paths = abs_ann_paths
+
+            # visual data storage path
+            vis_path = vis_info.storage
+
+            if not os.path.isabs(vis_path):
+                # vis_path = os.path.join(utils.get_cache_path(), vis_path)
+                vis_path = utils.get_cache_path(vis_path)
+
+            if not os.path.exists(vis_path):
+                warnings.warn("storage path {} does not exist.".format(vis_path))
+
+            # create datasets
+            dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
+            datasets[split] = dataset_cls(
+                vis_processor=vis_processor,
+                text_processor=text_processor,
+                ann_paths=ann_paths,
+                vis_root=vis_path,
+            )
+
+        return datasets
+
+
+def load_dataset_config(cfg_path):
+    cfg = OmegaConf.load(cfg_path).datasets
+    cfg = cfg[list(cfg.keys())[0]]
+
+    return cfg
diff --git a/unimernet/datasets/builders/formula.py b/unimernet/datasets/builders/formula.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e3af7568eaee4ef3ec856a7899ae6243cb5499
--- /dev/null
+++ b/unimernet/datasets/builders/formula.py
@@ -0,0 +1,105 @@
+import logging
+from unimernet.common.registry import registry
+from unimernet.datasets.builders.base_dataset_builder import BaseDatasetBuilder
+from unimernet.datasets.datasets.formula import Im2LatexDataset
+from unimernet.datasets.datasets.formula_multi_scale import MultiScaleIm2LatexDataset
+
+
+@registry.register_builder("formula_rec_train")
+class FormulaRecTrainBuilder(BaseDatasetBuilder):
+    train_dataset_cls = Im2LatexDataset
+    DATASET_CONFIG_DICT = {
+        "default": "configs/datasets/formula/formula_train.yaml"
+    }
+    LOG_INFO = "Formula Recgnition Train"
+
+    def build_datasets(self):
+        logging.info(f"Building {self.LOG_INFO} datasets ...")
+        self.build_processors()
+
+        build_info = self.config.build_info
+        anno_path = build_info.annotation,
+        vis_root = build_info.images
+        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+        datasets = dict()
+
+        # create datasets
+        dataset_cls = self.train_dataset_cls
+        datasets['train'] = dataset_cls(
+            vis_processor=self.vis_processors["train"],
+            text_processor=self.text_processors["train"],
+            vis_root=vis_root,
+            anno_path=anno_path,
+        )
+        print(datasets['train'][0])
+
+        return datasets
+
+
+@registry.register_builder("multi_scale_formula_rec_train")
+class MultiScaleFormulaRecTrainBuilder(BaseDatasetBuilder):
+    train_dataset_cls = MultiScaleIm2LatexDataset
+    DATASET_CONFIG_DICT = {
+        "default": "configs/datasets/formula/multi_scale_formula_train.yaml"
+    }
+    LOG_INFO = "Multi Scale Formula Recgnition Train"
+
+    def build_datasets(self):
+        logging.info(f"Building {self.LOG_INFO} datasets ...")
+        self.build_processors()
+
+        build_info = self.config.build_info
+        anno_path = build_info.annotation,
+        vis_root = build_info.images
+
+        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+
+        datasets = dict()
+
+        # create datasets
+        dataset_cls = self.train_dataset_cls
+        datasets['train'] = dataset_cls(
+            vis_processor=self.vis_processors["train"],
+            text_processor=self.text_processors["train"],
+            vis_root=vis_root,
+            anno_path=anno_path,
+        )
+        print(datasets['train'][0])
+
+        return datasets
+
+
+@registry.register_builder("formula_rec_eval")
+class FormulaRecEvalBuilder(BaseDatasetBuilder):
+    eval_dataset_cls = Im2LatexDataset
+    DATASET_CONFIG_DICT = {
+        "default": "configs/datasets/formula/formula_eval.yaml"
+    }
+    LOG_INFO = "Formula Recgnition Eval"
+
+    def build_datasets(self):
+        logging.info(f"Building {self.LOG_INFO} datasets ...")
+        self.build_processors()
+
+        build_info = self.config.build_info
+        anno_path = build_info.annotation,
+        vis_root = build_info.images
+
+        anno_path = [anno_path] if isinstance(anno_path, str) else anno_path
+        vis_root = [vis_root] if isinstance(vis_root, str) else vis_root
+
+        datasets = dict()
+
+        # create datasets
+        dataset_cls = self.eval_dataset_cls
+        datasets['eval'] = dataset_cls(
+            vis_processor=self.vis_processors["eval"],
+            text_processor=self.text_processors["eval"],
+            vis_root=vis_root,
+            anno_path=anno_path,
+        )
+        print(datasets['eval'][0])
+
+        return datasets
diff --git a/unimernet/datasets/data_utils.py b/unimernet/datasets/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e82c90ebf4fa8a094b0434c114c566e2c8b7d61
--- /dev/null
+++ b/unimernet/datasets/data_utils.py
@@ -0,0 +1,284 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import gzip
+import logging
+import os
+import random as rnd
+import tarfile
+import zipfile
+
+import decord
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data.dataset import IterableDataset, ChainDataset
+from decord import VideoReader
+from unimernet.common.registry import registry
+from unimernet.datasets.datasets.base_dataset import ConcatDataset
+from tqdm import tqdm
+
+decord.bridge.set_bridge("torch")
+MAX_INT = registry.get("MAX_INT")
+
+
+def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"):
+    vr = VideoReader(uri=video_path, height=height, width=width)
+
+    vlen = len(vr)
+    start, end = 0, vlen
+
+    n_frms = min(n_frms, vlen)
+
+    if sampling == "uniform":
+        indices = np.arange(start, end, vlen / n_frms).astype(int)
+    elif sampling == "headtail":
+        indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
+        indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
+        indices = indices_h + indices_t
+    else:
+        raise NotImplementedError
+
+    # get_batch -> T, H, W, C
+    frms = vr.get_batch(indices).permute(3, 0, 1, 2).float()  # (C, T, H, W)
+
+    return frms
+
+
+def apply_to_sample(f, sample):
+    if len(sample) == 0:
+        return {}
+
+    def _apply(x):
+        if torch.is_tensor(x):
+            return f(x)
+        elif isinstance(x, dict):
+            return {key: _apply(value) for key, value in x.items()}
+        elif isinstance(x, list):
+            return [_apply(x) for x in x]
+        else:
+            return x
+
+    return _apply(sample)
+
+
+def move_to_cuda(sample):
+    def _move_to_cuda(tensor):
+        return tensor.cuda()
+
+    return apply_to_sample(_move_to_cuda, sample)
+
+
+def prepare_sample(samples, cuda_enabled=True):
+    if cuda_enabled:
+        samples = move_to_cuda(samples)
+
+    # TODO fp16 support
+
+    return samples
+
+
+def reorg_datasets_by_split(datasets):
+    """
+    Organizes datasets by split.
+
+    Args:
+        datasets: dict of torch.utils.data.Dataset objects by name.
+
+    Returns:
+        Dict of datasets by split {split_name: List[Datasets]}.
+    """
+    # if len(datasets) == 1:
+    #     return datasets[list(datasets.keys())[0]]
+    # else:
+    reorg_datasets = dict()
+
+    # reorganize by split
+    for _, dataset in datasets.items():
+        for split_name, dataset_split in dataset.items():
+            if split_name not in reorg_datasets:
+                reorg_datasets[split_name] = [dataset_split]
+            else:
+                reorg_datasets[split_name].append(dataset_split)
+
+    return reorg_datasets
+
+
+def concat_datasets(datasets):
+    """
+    Concatenates multiple datasets into a single dataset.
+
+    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
+    generic IterableDataset because it requires creating separate samplers.
+
+    Now only supports conctenating training datasets and assuming validation and testing
+    have only a single dataset. This is because metrics should not be computed on the concatenated
+    datasets.
+
+    Args:
+        datasets: dict of torch.utils.data.Dataset objects by split.
+
+    Returns:
+        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
+        "val" and "test" remain the same.
+
+        If the input training datasets contain both map-style and DataPipeline datasets, returns
+        a tuple, where the first element is a concatenated map-style dataset and the second
+        element is a chained DataPipeline dataset.
+
+    """
+    # concatenate datasets in the same split
+    for split_name in datasets:
+        if split_name != "train":
+            assert (
+                len(datasets[split_name]) == 1
+            ), "Do not support multiple {} datasets.".format(split_name)
+            datasets[split_name] = datasets[split_name][0]
+        else:
+            iterable_datasets, map_datasets = [], []
+            for dataset in datasets[split_name]:
+                if isinstance(dataset, wds.DataPipeline):
+                    logging.info(
+                        "Dataset {} is IterableDataset, can't be concatenated.".format(
+                            dataset
+                        )
+                    )
+                    iterable_datasets.append(dataset)
+                elif isinstance(dataset, IterableDataset):
+                    raise NotImplementedError(
+                        "Do not support concatenation of generic IterableDataset."
+                    )
+                else:
+                    map_datasets.append(dataset)
+
+            # if len(iterable_datasets) > 0:
+            # concatenate map-style datasets and iterable-style datasets separately
+            chained_datasets = (
+                ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None
+            )
+            concat_datasets = (
+                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
+            )
+
+            train_datasets = concat_datasets, chained_datasets
+            train_datasets = tuple([x for x in train_datasets if x is not None])
+            train_datasets = (
+                train_datasets[0] if len(train_datasets) == 1 else train_datasets
+            )
+
+            datasets[split_name] = train_datasets
+
+    return datasets
+
+
+def extract_archive(from_path, to_path=None, overwrite=False):
+    """Extract archive.
+
+    Args:
+        from_path: the path of the archive.
+        to_path: the root path of the extracted files (directory of from_path)
+        overwrite: overwrite existing files (False)
+
+    Returns:
+        List of paths to extracted files even if not overwritten.
+
+    Examples:
+        >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
+        >>> from_path = './validation.tar.gz'
+        >>> to_path = './'
+        >>> torchtext.utils.download_from_url(url, from_path)
+        >>> torchtext.utils.extract_archive(from_path, to_path)
+        >>> ['.data/val.de', '.data/val.en']
+        >>> torchtext.utils.download_from_url(url, from_path)
+        >>> torchtext.utils.extract_archive(from_path, to_path)
+        >>> ['.data/val.de', '.data/val.en']
+
+    """
+
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    if from_path.endswith((".tar.gz", ".tgz")):
+        logging.info("Opening tar file {} to {}.".format(from_path, to_path))
+        with tarfile.open(from_path, "r") as tar:
+            files = []
+            for file_ in tqdm(tar):
+                file_path = os.path.join(to_path, file_.name)
+                if file_.isfile():
+                    files.append(file_path)
+                    if os.path.exists(file_path):
+                        logging.info("{} already extracted.".format(file_path))
+                        if not overwrite:
+                            continue
+                tar.extract(file_, to_path)
+            logging.info("Finished extracting tar file {}.".format(from_path))
+            return files
+
+    elif from_path.endswith(".zip"):
+        assert zipfile.is_zipfile(from_path), from_path
+        logging.info("Opening zip file {} to {}.".format(from_path, to_path))
+        with zipfile.ZipFile(from_path, "r") as zfile:
+            files = []
+            for file_ in tqdm(zfile.namelist()):
+                file_path = os.path.join(to_path, file_)
+                files.append(file_path)
+                if os.path.exists(file_path):
+                    logging.info("{} already extracted.".format(file_path))
+                    if not overwrite:
+                        continue
+                zfile.extract(file_, to_path)
+        files = [f for f in files if os.path.isfile(f)]
+        logging.info("Finished extracting zip file {}.".format(from_path))
+        return files
+
+    elif from_path.endswith(".gz"):
+        logging.info("Opening gz file {} to {}.".format(from_path, to_path))
+        default_block_size = 65536
+        filename = from_path[:-3]
+        files = [filename]
+        with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
+            while True:
+                block = gzfile.read(default_block_size)
+                if not block:
+                    break
+                else:
+                    d_file.write(block)
+            d_file.write(block)
+        logging.info("Finished extracting gz file {}.".format(from_path))
+        return files
+
+    else:
+        raise NotImplementedError(
+            "We currently only support tar.gz, .tgz, .gz and zip achives."
+        )
+
+
+def save_frames_grid(img_array, out_path):
+    import torch
+    from PIL import Image
+    from torchvision.utils import make_grid
+
+    if len(img_array.shape) == 3:
+        img_array = img_array.unsqueeze(0)
+    elif len(img_array.shape) == 5:
+        b, t, c, h, w = img_array.shape
+        img_array = img_array.view(-1, c, h, w)
+    elif len(img_array.shape) == 4:
+        pass
+    else:
+        raise NotImplementedError(
+            "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored."
+        )
+
+    assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
+
+    grid = make_grid(img_array)
+    ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+
+    img = Image.fromarray(ndarr)
+
+    img.save(out_path)
diff --git a/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b66d52b14d2c07e62d7f928930209753687b9e1
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/base_dataset.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c75254a5a71a8f172cb61aea9c6faa75439dcd
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/formula.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc b/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0b9ae43b5ea5491a780bff67b6f4eb8fd0946d6
Binary files /dev/null and b/unimernet/datasets/datasets/__pycache__/formula_multi_scale.cpython-310.pyc differ
diff --git a/unimernet/datasets/datasets/base_dataset.py b/unimernet/datasets/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..591ab94db2a688dc3c74d06c72a1102333411a8f
--- /dev/null
+++ b/unimernet/datasets/datasets/base_dataset.py
@@ -0,0 +1,103 @@
+import json
+from PIL import Image, ImageFile
+import os.path as osp
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+from io import BytesIO
+from typing import Iterable
+from torch.utils.data import Dataset, ConcatDataset
+import torch
+
+
+class BaseDataset(Dataset):
+
+    def __init__(self, vis_processor, text_processor, vis_root, anno_path):
+
+        self.vis_root = vis_root
+        # if isinstance(anno_path, tuple) or isinstance(anno_path, list):
+        #     anno_path = anno_path[0]
+        self.anno_path = anno_path
+
+        self.vis_processor = vis_processor
+        self.text_processor = text_processor
+
+        self.samples = self.init_samples()
+        self.reader = self.init_reader()
+
+        print('total {} {} samples'.format(self.__len__(), self.__class__.__name__))
+
+        for idx in range(10):
+            self.__getitem__(idx)
+
+    def __len__(self):
+        return len(self.samples)
+
+    def __getitem__(self, index):
+        raise NotImplementedError
+
+    def init_samples(self):
+        # read annotation from ceph
+        if self.anno_path.startswith('cluster'):
+            from petrel_client.client import Client
+            client = Client("~/petreloss.conf")
+            samples = json.loads(client.get(self.anno_path))
+        else:
+            samples = json.load(open(self.anno_path, 'r'))
+        return samples
+
+    def init_reader(self):
+        if self.vis_root.startswith('cluster'):
+            from petrel_client.client import Client
+            client = Client("~/petreloss.conf")
+            reader = {'type': 'PetrelReader', 'body': client.get}
+        else:
+            reader = {'type': 'LocalReader', 'body': Image.open}
+        return reader
+
+    def _read_image(self, sample, image_key="image"):
+        img_file = sample[image_key]
+        image_path = osp.join(self.vis_root, img_file)
+        image = self.reader['body'](image_path)
+        if isinstance(image, bytes):
+            bytes_stream = BytesIO(image)
+            image = Image.open(bytes_stream)
+        image = image.convert("RGB")
+        return image
+
+    def collater(self, samples):
+        image_list, question_list, answer_list = [], [], []
+
+        for sample in samples:
+            image_list.append(sample["image"])
+            question_list.append(sample["text_input"])
+            answer_list.append(sample["text_output"])
+
+        return {
+            "image": torch.stack(image_list, dim=0),
+            "text_input": question_list,
+            "text_output": answer_list,
+            "data_type": "vqa",
+        }
+
+
+class ConcatDataset(ConcatDataset):
+    def __init__(self, datasets: Iterable[Dataset]) -> None:
+        super().__init__(datasets)
+
+    def collater(self, samples):
+        # TODO For now only supports datasets with same underlying collater implementations
+
+        all_keys = set()
+        for s in samples:
+            all_keys.update(s)
+
+        shared_keys = all_keys
+        for s in samples:
+            shared_keys = shared_keys & set(s.keys())
+
+        samples_shared_keys = []
+        for s in samples:
+            samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
+
+        return self.datasets[0].collater(samples_shared_keys)
diff --git a/unimernet/datasets/datasets/dataloader_utils.py b/unimernet/datasets/datasets/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab63bf9b6dce16447f816763210cf87ca3940097
--- /dev/null
+++ b/unimernet/datasets/datasets/dataloader_utils.py
@@ -0,0 +1,200 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import time
+import random
+import torch
+from unimernet.datasets.data_utils import move_to_cuda
+from torch.utils.data import DataLoader
+
+
+class MultiIterLoader:
+    """
+    A simple wrapper for iterating over multiple iterators.
+
+    Args:
+        loaders (List[Loader]): List of Iterator loaders.
+        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+    """
+
+    def __init__(self, loaders, ratios=None):
+        # assert all loaders has __next__ method
+        for loader in loaders:
+            assert hasattr(
+                loader, "__next__"
+            ), "Loader {} has no __next__ method.".format(loader)
+
+        if ratios is None:
+            ratios = [1.0] * len(loaders)
+        else:
+            assert len(ratios) == len(loaders)
+            ratios = [float(ratio) / sum(ratios) for ratio in ratios]
+
+        self.loaders = loaders
+        self.ratios = ratios
+
+    def __next__(self):
+        # random sample from each loader by ratio
+        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
+        return next(self.loaders[loader_idx])
+
+    def __len__(self):
+        return sum([len(_) for _ in self.loaders if hasattr(_, "__len__")])
+
+
+class ConcatLoader:
+    """
+    A simple wrapper for iterating over multiple iterators.
+
+    Args:
+        loaders (List[Loader]): List of Iterator loaders.
+        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
+    """
+
+    def __init__(self, loaders):
+        # assert all loaders has __next__ method
+        for loader in loaders:
+            assert hasattr(
+                loader, "__len__"
+            ), "Loader {} has no __len__ method.".format(loader)
+
+        self._epoch = 0
+        self._loader_lens = [len(_) for _ in loaders]
+        self._rest_lens = self._loader_lens.copy()
+
+        self.loaders = loaders
+
+    def __next__(self):
+        # random sample from each loader by ratio
+        loader_idx = random.choices(range(len(self.loaders)), self._rest_lens, k=1)[0]
+        self._rest_lens[loader_idx] -= 1
+        if sum(self._rest_lens) == 0:
+            self._epoch += 1
+            self._rest_lens = self._loader_lens.copy()
+        return next(self.loaders[loader_idx])
+
+    def __len__(self):
+        return sum([len(_) for _ in self.loaders if hasattr(_, "__len__")])
+
+
+class PrefetchLoader(object):
+    """
+    Modified from https://github.com/ChenRocks/UNITER.
+
+    overlap compute and cuda data transfer
+    (copied and then modified from nvidia apex)
+    """
+
+    def __init__(self, loader):
+        self.loader = loader
+        self.stream = torch.cuda.Stream()
+
+    def __iter__(self):
+        loader_it = iter(self.loader)
+        self.preload(loader_it)
+        batch = self.next(loader_it)
+        while batch is not None:
+            is_tuple = isinstance(batch, tuple)
+            if is_tuple:
+                task, batch = batch
+
+            if is_tuple:
+                yield task, batch
+            else:
+                yield batch
+            batch = self.next(loader_it)
+
+    def __len__(self):
+        return len(self.loader)
+
+    def preload(self, it):
+        try:
+            self.batch = next(it)
+        except StopIteration:
+            self.batch = None
+            return
+        # if record_stream() doesn't work, another option is to make sure
+        # device inputs are created on the main stream.
+        # self.next_input_gpu = torch.empty_like(self.next_input,
+        #                                        device='cuda')
+        # self.next_target_gpu = torch.empty_like(self.next_target,
+        #                                         device='cuda')
+        # Need to make sure the memory allocated for next_* is not still in use
+        # by the main stream at the time we start copying to next_*:
+        # self.stream.wait_stream(torch.cuda.current_stream())
+        with torch.cuda.stream(self.stream):
+            self.batch = move_to_cuda(self.batch)
+            # more code for the alternative if record_stream() doesn't work:
+            # copy_ will record the use of the pinned source tensor in this
+            # side stream.
+            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
+            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
+            # self.next_input = self.next_input_gpu
+            # self.next_target = self.next_target_gpu
+
+    def next(self, it):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        batch = self.batch
+        if batch is not None:
+            record_cuda_stream(batch)
+        self.preload(it)
+        return batch
+
+    def __getattr__(self, name):
+        method = self.loader.__getattribute__(name)
+        return method
+
+
+def record_cuda_stream(batch):
+    if isinstance(batch, torch.Tensor):
+        batch.record_stream(torch.cuda.current_stream())
+    elif isinstance(batch, list) or isinstance(batch, tuple):
+        for t in batch:
+            record_cuda_stream(t)
+    elif isinstance(batch, dict):
+        for t in batch.values():
+            record_cuda_stream(t)
+    else:
+        pass
+
+
+class IterLoader:
+    """
+    A wrapper to convert DataLoader as an infinite iterator.
+
+    Modified from:
+        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
+    """
+
+    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
+        self._dataloader = dataloader
+        self.iter_loader = iter(self._dataloader)
+        self._use_distributed = use_distributed
+        self._epoch = 0
+
+    @property
+    def epoch(self) -> int:
+        return self._epoch
+
+    def __next__(self):
+        try:
+            data = next(self.iter_loader)
+        except StopIteration:
+            self._epoch += 1
+            if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
+                self._dataloader.sampler.set_epoch(self._epoch)
+            time.sleep(2)  # Prevent possible deadlock during epoch transition
+            self.iter_loader = iter(self._dataloader)
+            data = next(self.iter_loader)
+
+        return data
+
+    def __iter__(self):
+        return self
+
+    def __len__(self):
+        return len(self._dataloader)
diff --git a/unimernet/datasets/datasets/formula.py b/unimernet/datasets/datasets/formula.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a628b149099f621502eb014cd8747545497f57a
--- /dev/null
+++ b/unimernet/datasets/datasets/formula.py
@@ -0,0 +1,71 @@
+import torch
+from .base_dataset import BaseDataset
+import os.path as osp
+import glob
+from io import BytesIO
+from PIL import Image
+
+
+class Im2LatexDataset(BaseDataset):
+
+    def init_samples(self):
+        samples = []
+        for vis_root, anno_path in zip(self.vis_root, self.anno_path):
+            images = [path.replace('\\', '/') for path in glob.glob(osp.join(vis_root, '*.png'))]
+            indices = [int(osp.basename(img).split('.')[0]) for img in images]
+
+            eqs = open(anno_path, 'r').read().split('\n')
+            eqs = [eqs[_] for _ in indices]
+
+            for i, e in zip(images, eqs):
+                samples.append({"image": i, "equation": e, "vis_root": vis_root})
+        return samples
+
+    def __getitem__(self, index):
+        ann = self.samples[index]
+        try:
+            image = self.vis_processor(self._read_image(ann))
+        except Exception:
+            return self[(index + 1) % len(self)]
+        if image is None:
+            return self[(index + 1) % len(self)]
+        equation = ann["equation"]
+        return {"image": image, "text_input": equation, "id": index}
+
+    def _read_image(self, sample, image_key="image"):
+        img_file = sample[image_key]
+        vis_root = sample["vis_root"]
+        image_path = osp.join(vis_root, img_file)
+        image = self.reader['body'](image_path)
+        if isinstance(image, bytes):
+            bytes_stream = BytesIO(image)
+            image = Image.open(bytes_stream)
+        image = image.convert("RGB")
+        return image
+
+    def init_reader(self):
+        if not isinstance(self.vis_root, str):
+            vis_root = self.vis_root[0]
+        else:
+            vis_root = self.vis_root
+        if vis_root.startswith('cluster'):
+            from petrel_client.client import Client
+            client = Client("~/petreloss.conf")
+            reader = {'type': 'PetrelReader', 'body': client.get}
+        else:
+            reader = {'type': 'LocalReader', 'body': Image.open}
+        return reader
+
+    def collater(self, samples):
+        image_list, question_list, id_list = [], [], []
+
+        for sample in samples:
+            image_list.append(sample["image"])
+            question_list.append(sample["text_input"])
+            id_list.append(sample["id"])
+
+        return {
+            "image": torch.stack(image_list, dim=0),
+            "text_input": question_list,
+            "id": id_list
+        }
diff --git a/unimernet/datasets/datasets/formula_multi_scale.py b/unimernet/datasets/datasets/formula_multi_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ab04648063ef7c3ef5b2586d80d7fe35e010e8
--- /dev/null
+++ b/unimernet/datasets/datasets/formula_multi_scale.py
@@ -0,0 +1,32 @@
+import torch
+from .formula import Im2LatexDataset
+
+
+class MultiScaleIm2LatexDataset(Im2LatexDataset):
+
+    def __getitem__(self, index):
+        ann = self.samples[index]
+        try:
+            pil_image = self._read_image(ann)
+            image = self.vis_processor(pil_image)
+        except Exception:
+            return self[(index + 1) % len(self)]
+        if image is None:
+            return self[(index + 1) % len(self)]
+        equation = ann["equation"]
+        return {"image": image, "text_input": equation, "id": index, "raw_image": pil_image}
+
+    def collater(self, samples):
+        self.vis_processor.reset_scale()
+        image_list, question_list, id_list = [], [], []
+
+        for sample in samples:
+            image_list.append(self.vis_processor(sample["raw_image"]))
+            question_list.append(sample["text_input"])
+            id_list.append(sample["id"])
+
+        return {
+            "image": torch.stack(image_list, dim=0),
+            "text_input": question_list,
+            "id": id_list
+        }
diff --git a/unimernet/models/__init__.py b/unimernet/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b86bd4f5376f3b0597b1f008ac893985bc0a06fc
--- /dev/null
+++ b/unimernet/models/__init__.py
@@ -0,0 +1,198 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import torch
+from omegaconf import OmegaConf
+from unimernet.common.registry import registry
+
+from unimernet.models.base_model import BaseModel
+
+from unimernet.processors.base_processor import BaseProcessor
+from unimernet.models.unimernet.unimernet import UniMERModel
+
+__all__ = [
+    "load_model",
+    "BaseModel",
+    "UniMERModel",
+]
+
+
+def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
+    """
+    Load supported models.
+
+    To list all available models and types in registry:
+    >>> from unimernet.models import model_zoo
+    >>> print(model_zoo)
+
+    Args:
+        name (str): name of the model.
+        model_type (str): type of the model.
+        is_eval (bool): whether the model is in eval mode. Default: False.
+        device (str): device to use. Default: "cpu".
+        checkpoint (str): path or to checkpoint. Default: None.
+            Note that expecting the checkpoint to have the same keys in state_dict as the model.
+
+    Returns:
+        model (torch.nn.Module): model.
+    """
+
+    model = registry.get_model_class(name).from_pretrained(model_type=model_type)
+
+    if checkpoint is not None:
+        model.load_checkpoint(checkpoint)
+
+    if is_eval:
+        model.eval()
+
+    if device == "cpu":
+        model = model.float()
+
+    return model.to(device)
+
+
+def load_preprocess(config):
+    """
+    Load preprocessor configs and construct preprocessors.
+
+    If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
+
+    Args:
+        config (dict): preprocessor configs.
+
+    Returns:
+        vis_processors (dict): preprocessors for visual inputs.
+        txt_processors (dict): preprocessors for text inputs.
+
+        Key is "train" or "eval" for processors used in training and evaluation respectively.
+    """
+
+    def _build_proc_from_cfg(cfg):
+        return (
+            registry.get_processor_class(cfg.name).from_config(cfg)
+            if cfg is not None
+            else BaseProcessor()
+        )
+
+    vis_processors = dict()
+    txt_processors = dict()
+
+    vis_proc_cfg = config.get("vis_processor")
+    txt_proc_cfg = config.get("text_processor")
+
+    if vis_proc_cfg is not None:
+        vis_train_cfg = vis_proc_cfg.get("train")
+        vis_eval_cfg = vis_proc_cfg.get("eval")
+    else:
+        vis_train_cfg = None
+        vis_eval_cfg = None
+
+    vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
+    vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
+
+    if txt_proc_cfg is not None:
+        txt_train_cfg = txt_proc_cfg.get("train")
+        txt_eval_cfg = txt_proc_cfg.get("eval")
+    else:
+        txt_train_cfg = None
+        txt_eval_cfg = None
+
+    txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
+    txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
+
+    return vis_processors, txt_processors
+
+
+def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
+    """
+    Load model and its related preprocessors.
+
+    List all available models and types in registry:
+    >>> from unimernet.models import model_zoo
+    >>> print(model_zoo)
+
+    Args:
+        name (str): name of the model.
+        model_type (str): type of the model.
+        is_eval (bool): whether the model is in eval mode. Default: False.
+        device (str): device to use. Default: "cpu".
+
+    Returns:
+        model (torch.nn.Module): model.
+        vis_processors (dict): preprocessors for visual inputs.
+        txt_processors (dict): preprocessors for text inputs.
+    """
+    model_cls = registry.get_model_class(name)
+
+    # load model
+    model = model_cls.from_pretrained(model_type=model_type)
+
+    if is_eval:
+        model.eval()
+
+    # load preprocess
+    cfg = OmegaConf.load(model_cls.default_config_path(model_type))
+    if cfg is not None:
+        preprocess_cfg = cfg.preprocess
+
+        vis_processors, txt_processors = load_preprocess(preprocess_cfg)
+    else:
+        vis_processors, txt_processors = None, None
+        logging.info(
+            f"""No default preprocess for model {name} ({model_type}).
+                This can happen if the model is not finetuned on downstream datasets,
+                or it is not intended for direct use without finetuning.
+            """
+        )
+
+    if device == "cpu" or device == torch.device("cpu"):
+        model = model.float()
+
+    return model.to(device), vis_processors, txt_processors
+
+
+class ModelZoo:
+    """
+    A utility class to create string representation of available model architectures and types.
+
+    >>> from unimernet.models import model_zoo
+    >>> # list all available models
+    >>> print(model_zoo)
+    >>> # show total number of models
+    >>> print(len(model_zoo))
+    """
+
+    def __init__(self) -> None:
+        self.model_zoo = {
+            k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
+            for k, v in registry.mapping["model_name_mapping"].items()
+        }
+
+    def __str__(self) -> str:
+        return (
+                "=" * 50
+                + "\n"
+                + f"{'Architectures':<30} {'Types'}\n"
+                + "=" * 50
+                + "\n"
+                + "\n".join(
+            [
+                f"{name:<30} {', '.join(types)}"
+                for name, types in self.model_zoo.items()
+            ]
+        )
+        )
+
+    def __iter__(self):
+        return iter(self.model_zoo.items())
+
+    def __len__(self):
+        return sum([len(v) for v in self.model_zoo.values()])
+
+
+model_zoo = ModelZoo()
diff --git a/unimernet/models/__pycache__/__init__.cpython-310.pyc b/unimernet/models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b70668ef663906775cebc29ec58d545e464b378
Binary files /dev/null and b/unimernet/models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/base_model.cpython-310.pyc b/unimernet/models/__pycache__/base_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83f3bab0ee55232ac666e7064ab83f8228d70984
Binary files /dev/null and b/unimernet/models/__pycache__/base_model.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/clip_vit.cpython-310.pyc b/unimernet/models/__pycache__/clip_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c0d8c99acd91ce10388b7224ac0bf1d30f10f7a
Binary files /dev/null and b/unimernet/models/__pycache__/clip_vit.cpython-310.pyc differ
diff --git a/unimernet/models/__pycache__/eva_vit.cpython-310.pyc b/unimernet/models/__pycache__/eva_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..021e9e282399735f4d525e5319132debb3433568
Binary files /dev/null and b/unimernet/models/__pycache__/eva_vit.cpython-310.pyc differ
diff --git a/unimernet/models/base_model.py b/unimernet/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..104fae583941d9192002cf7d2196fcf92d0f28e9
--- /dev/null
+++ b/unimernet/models/base_model.py
@@ -0,0 +1,251 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from unimernet.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
+from unimernet.common.utils import get_abs_path, is_url
+from omegaconf import OmegaConf
+
+
+class BaseModel(nn.Module):
+    """Base class for models."""
+
+    def __init__(self):
+        super().__init__()
+
+    @property
+    def device(self):
+        return list(self.parameters())[0].device
+
+    def load_checkpoint(self, url_or_filename):
+        """
+        Load from a finetuned checkpoint.
+
+        This should expect no mismatch in the model keys and the checkpoint keys.
+        """
+
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location="cpu")
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location="cpu")
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        if "model" in checkpoint.keys():
+            state_dict = checkpoint["model"]
+        else:
+            state_dict = checkpoint
+
+        msg = self.load_state_dict(state_dict, strict=False)
+
+        # logging.info("Missing keys {}".format(msg.missing_keys))
+        logging.info(f"Missing keys exist when loading '{url_or_filename}'.")
+        logging.info("load checkpoint from %s" % url_or_filename)
+
+        return msg
+
+    @classmethod
+    def from_pretrained(cls, model_type):
+        """
+        Build a pretrained model from default configuration file, specified by model_type.
+
+        Args:
+            - model_type (str): model type, specifying architecture and checkpoints.
+
+        Returns:
+            - model (nn.Module): pretrained or finetuned model, depending on the configuration.
+        """
+        model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
+        model = cls.from_config(model_cfg)
+
+        return model
+
+    @classmethod
+    def default_config_path(cls, model_type):
+        assert (
+                model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
+        ), "Unknown model type {}".format(model_type)
+        return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
+
+    def load_checkpoint_from_config(self, cfg, **kwargs):
+        """
+        Load checkpoint as specified in the config file.
+
+        If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
+        When loading the pretrained model, each task-specific architecture may define their
+        own load_from_pretrained() method.
+        """
+        load_pretrained = cfg.get("load_pretrained", True)
+        load_finetuned = cfg.get("load_finetuned", False)
+
+        if load_pretrained:
+            # load pre-trained weights
+            pretrain_path = cfg.get("pretrained", None)
+            assert pretrain_path, "Found load_finetuned is False, but pretrain_path is None."
+            self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
+            logging.info(f"Loaded pretrained model '{pretrain_path}'.")
+
+        if load_finetuned:
+            finetune_path = cfg.get("finetuned", None)
+            assert finetune_path is not None, "Found load_finetuned is True, but finetune_path is None."
+            self.load_checkpoint(url_or_filename=finetune_path)
+            logging.info(f"Loaded finetuned model '{finetune_path}'.")
+
+    def before_evaluation(self, **kwargs):
+        pass
+
+    def show_n_params(self, return_str=True):
+        tot = 0
+        for p in self.parameters():
+            w = 1
+            for x in p.shape:
+                w *= x
+            tot += w
+        if return_str:
+            if tot >= 1e6:
+                return "{:.1f}M".format(tot / 1e6)
+            else:
+                return "{:.1f}K".format(tot / 1e3)
+        else:
+            return tot
+
+
+class BaseEncoder(nn.Module):
+    """
+    Base class for primitive encoders, such as ViT, TimeSformer, etc.
+    """
+
+    def __init__(self):
+        super().__init__()
+
+    def forward_features(self, samples, **kwargs):
+        raise NotImplementedError
+
+    @property
+    def device(self):
+        return list(self.parameters())[0].device
+
+
+class SharedQueueMixin:
+    @torch.no_grad()
+    def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
+        # gather keys before updating queue
+        image_feats = concat_all_gather(image_feat)
+        text_feats = concat_all_gather(text_feat)
+
+        batch_size = image_feats.shape[0]
+
+        ptr = int(self.queue_ptr)
+        assert self.queue_size % batch_size == 0  # for simplicity
+
+        # replace the keys at ptr (dequeue and enqueue)
+        self.image_queue[:, ptr: ptr + batch_size] = image_feats.T
+        self.text_queue[:, ptr: ptr + batch_size] = text_feats.T
+
+        if idxs is not None:
+            idxs = concat_all_gather(idxs)
+            self.idx_queue[:, ptr: ptr + batch_size] = idxs.T
+
+        ptr = (ptr + batch_size) % self.queue_size  # move pointer
+        self.queue_ptr[0] = ptr
+
+
+class MomentumDistilationMixin:
+    @torch.no_grad()
+    def copy_params(self):
+        for model_pair in self.model_pairs:
+            for param, param_m in zip(
+                    model_pair[0].parameters(), model_pair[1].parameters()
+            ):
+                param_m.data.copy_(param.data)  # initialize
+                param_m.requires_grad = False  # not update by gradient
+
+    @torch.no_grad()
+    def _momentum_update(self):
+        for model_pair in self.model_pairs:
+            for param, param_m in zip(
+                    model_pair[0].parameters(), model_pair[1].parameters()
+            ):
+                param_m.data = param_m.data * self.momentum + param.data * (
+                        1.0 - self.momentum
+                )
+
+
+class GatherLayer(torch.autograd.Function):
+    """
+    Gather tensors from all workers with support for backward propagation:
+    This implementation does not cut the gradients as torch.distributed.all_gather does.
+    """
+
+    @staticmethod
+    def forward(ctx, x):
+        output = [
+            torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
+        ]
+        torch.distributed.all_gather(output, x)
+        return tuple(output)
+
+    @staticmethod
+    def backward(ctx, *grads):
+        all_gradients = torch.stack(grads)
+        torch.distributed.all_reduce(all_gradients)
+        return all_gradients[torch.distributed.get_rank()]
+
+
+def all_gather_with_grad(tensors):
+    """
+    Performs all_gather operation on the provided tensors.
+    Graph remains connected for backward grad computation.
+    """
+    # Queue the gathered tensors
+    world_size = torch.distributed.get_world_size()
+    # There is no need for reduction in the single-proc case
+    if world_size == 1:
+        return tensors
+
+    # tensor_all = GatherLayer.apply(tensors)
+    tensor_all = GatherLayer.apply(tensors)
+
+    return torch.cat(tensor_all, dim=0)
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    # if use distributed training
+    if not is_dist_avail_and_initialized():
+        return tensor
+
+    tensors_gather = [
+        torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
+    ]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def tile(x, dim, n_tile):
+    init_dim = x.size(dim)
+    repeat_idx = [1] * x.dim()
+    repeat_idx[dim] = n_tile
+    x = x.repeat(*(repeat_idx))
+    order_index = torch.LongTensor(
+        np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
+    )
+    return torch.index_select(x, dim, order_index.to(x.device))
diff --git a/unimernet/models/blip2_models/Qformer.py b/unimernet/models/blip2_models/Qformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603
--- /dev/null
+++ b/unimernet/models/blip2_models/Qformer.py
@@ -0,0 +1,1216 @@
+"""
+ * Copyright (c) 2023, salesforce.com, inc.
+ * All rights reserved.
+ * SPDX-License-Identifier: BSD-3-Clause
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ * By Junnan Li
+ * Based on huggingface code base
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
+"""
+
+import math
+import os
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Dict, Any
+
+import torch
+from torch import Tensor, device, dtype, nn
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+import torch.nn.functional as F
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import (
+    ModelOutput,
+)
+from transformers.modeling_outputs import (
+    BaseModelOutputWithPastAndCrossAttentions,
+    BaseModelOutputWithPoolingAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    MaskedLMOutput,
+    MultipleChoiceModelOutput,
+    NextSentencePredictorOutput,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutput,
+    TokenClassifierOutput,
+)
+from transformers.modeling_utils import (
+    PreTrainedModel,
+    apply_chunking_to_forward,
+    find_pruneable_heads_and_indices,
+    prune_linear_layer,
+)
+from transformers.utils import logging
+from transformers.models.bert.configuration_bert import BertConfig
+
+logger = logging.get_logger(__name__)
+
+
+class BertEmbeddings(nn.Module):
+    """Construct the embeddings from word and position embeddings."""
+
+    def __init__(self, config):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(
+            config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
+        )
+        self.position_embeddings = nn.Embedding(
+            config.max_position_embeddings, config.hidden_size
+        )
+
+        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+        # any TensorFlow checkpoint file
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+        self.register_buffer(
+            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
+        )
+        self.position_embedding_type = getattr(
+            config, "position_embedding_type", "absolute"
+        )
+
+        self.config = config
+
+    def forward(
+        self,
+        input_ids=None,
+        position_ids=None,
+        query_embeds=None,
+        past_key_values_length=0,
+    ):
+        if input_ids is not None:
+            seq_length = input_ids.size()[1]
+        else:
+            seq_length = 0
+
+        if position_ids is None:
+            position_ids = self.position_ids[
+                :, past_key_values_length : seq_length + past_key_values_length
+            ].clone()
+
+        if input_ids is not None:
+            embeddings = self.word_embeddings(input_ids)
+            if self.position_embedding_type == "absolute":
+                position_embeddings = self.position_embeddings(position_ids)
+                embeddings = embeddings + position_embeddings
+
+            if query_embeds is not None:
+                embeddings = torch.cat((query_embeds, embeddings), dim=1)
+        else:
+            embeddings = query_embeds
+
+        embeddings = self.LayerNorm(embeddings)
+        embeddings = self.dropout(embeddings)
+        return embeddings
+
+
+class BertSelfAttention(nn.Module):
+    def __init__(self, config, is_cross_attention):
+        super().__init__()
+        self.config = config
+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
+            config, "embedding_size"
+        ):
+            raise ValueError(
+                "The hidden size (%d) is not a multiple of the number of attention "
+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
+            )
+
+        self.num_attention_heads = config.num_attention_heads
+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+        self.query = nn.Linear(config.hidden_size, self.all_head_size)
+        if is_cross_attention:
+            self.key = nn.Linear(config.encoder_width, self.all_head_size)
+            self.value = nn.Linear(config.encoder_width, self.all_head_size)
+        else:
+            self.key = nn.Linear(config.hidden_size, self.all_head_size)
+            self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        self.position_embedding_type = getattr(
+            config, "position_embedding_type", "absolute"
+        )
+        if (
+            self.position_embedding_type == "relative_key"
+            or self.position_embedding_type == "relative_key_query"
+        ):
+            self.max_position_embeddings = config.max_position_embeddings
+            self.distance_embedding = nn.Embedding(
+                2 * config.max_position_embeddings - 1, self.attention_head_size
+            )
+        self.save_attention = False
+
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+
+    def get_attn_gradients(self):
+        return self.attn_gradients
+
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+
+    def get_attention_map(self):
+        return self.attention_map
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (
+            self.num_attention_heads,
+            self.attention_head_size,
+        )
+        x = x.view(*new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+
+        # If this is instantiated as a cross-attention module, the keys
+        # and values come from an encoder; the attention mask needs to be
+        # such that the encoder's padding tokens are not attended to.
+        is_cross_attention = encoder_hidden_states is not None
+
+        if is_cross_attention:
+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+            attention_mask = encoder_attention_mask
+        elif past_key_value is not None:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+        else:
+            key_layer = self.transpose_for_scores(self.key(hidden_states))
+            value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+        mixed_query_layer = self.query(hidden_states)
+
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        past_key_value = (key_layer, value_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        if (
+            self.position_embedding_type == "relative_key"
+            or self.position_embedding_type == "relative_key_query"
+        ):
+            seq_length = hidden_states.size()[1]
+            position_ids_l = torch.arange(
+                seq_length, dtype=torch.long, device=hidden_states.device
+            ).view(-1, 1)
+            position_ids_r = torch.arange(
+                seq_length, dtype=torch.long, device=hidden_states.device
+            ).view(1, -1)
+            distance = position_ids_l - position_ids_r
+            positional_embedding = self.distance_embedding(
+                distance + self.max_position_embeddings - 1
+            )
+            positional_embedding = positional_embedding.to(
+                dtype=query_layer.dtype
+            )  # fp16 compatibility
+
+            if self.position_embedding_type == "relative_key":
+                relative_position_scores = torch.einsum(
+                    "bhld,lrd->bhlr", query_layer, positional_embedding
+                )
+                attention_scores = attention_scores + relative_position_scores
+            elif self.position_embedding_type == "relative_key_query":
+                relative_position_scores_query = torch.einsum(
+                    "bhld,lrd->bhlr", query_layer, positional_embedding
+                )
+                relative_position_scores_key = torch.einsum(
+                    "bhrd,lrd->bhlr", key_layer, positional_embedding
+                )
+                attention_scores = (
+                    attention_scores
+                    + relative_position_scores_query
+                    + relative_position_scores_key
+                )
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+            attention_scores = attention_scores + attention_mask
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+        if is_cross_attention and self.save_attention:
+            self.save_attention_map(attention_probs)
+            attention_probs.register_hook(self.save_attn_gradients)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs_dropped = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs_dropped = attention_probs_dropped * head_mask
+
+        context_layer = torch.matmul(attention_probs_dropped, value_layer)
+
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(*new_context_layer_shape)
+
+        outputs = (
+            (context_layer, attention_probs) if output_attentions else (context_layer,)
+        )
+
+        outputs = outputs + (past_key_value,)
+        return outputs
+
+
+class BertSelfOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertAttention(nn.Module):
+    def __init__(self, config, is_cross_attention=False):
+        super().__init__()
+        self.self = BertSelfAttention(config, is_cross_attention)
+        self.output = BertSelfOutput(config)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads,
+            self.self.num_attention_heads,
+            self.self.attention_head_size,
+            self.pruned_heads,
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = (
+            self.self.attention_head_size * self.self.num_attention_heads
+        )
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+    ):
+        self_outputs = self.self(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            encoder_hidden_states,
+            encoder_attention_mask,
+            past_key_value,
+            output_attentions,
+        )
+        attention_output = self.output(self_outputs[0], hidden_states)
+
+        outputs = (attention_output,) + self_outputs[
+            1:
+        ]  # add attentions if we output them
+        return outputs
+
+
+class BertIntermediate(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states, input_tensor):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+    def __init__(self, config, layer_num):
+        super().__init__()
+        self.config = config
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.seq_len_dim = 1
+        self.attention = BertAttention(config)
+        self.layer_num = layer_num
+        if (
+            self.config.add_cross_attention
+            and layer_num % self.config.cross_attention_freq == 0
+        ):
+            self.crossattention = BertAttention(
+                config, is_cross_attention=self.config.add_cross_attention
+            )
+            self.has_cross_attention = True
+        else:
+            self.has_cross_attention = False
+        self.intermediate = BertIntermediate(config)
+        self.output = BertOutput(config)
+
+        self.intermediate_query = BertIntermediate(config)
+        self.output_query = BertOutput(config)
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_value=None,
+        output_attentions=False,
+        query_length=0,
+    ):
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = (
+            past_key_value[:2] if past_key_value is not None else None
+        )
+        self_attention_outputs = self.attention(
+            hidden_states,
+            attention_mask,
+            head_mask,
+            output_attentions=output_attentions,
+            past_key_value=self_attn_past_key_value,
+        )
+        attention_output = self_attention_outputs[0]
+        outputs = self_attention_outputs[1:-1]
+
+        present_key_value = self_attention_outputs[-1]
+
+        if query_length > 0:
+            query_attention_output = attention_output[:, :query_length, :]
+
+            if self.has_cross_attention:
+                assert (
+                    encoder_hidden_states is not None
+                ), "encoder_hidden_states must be given for cross-attention layers"
+                cross_attention_outputs = self.crossattention(
+                    query_attention_output,
+                    attention_mask,
+                    head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    output_attentions=output_attentions,
+                )
+                query_attention_output = cross_attention_outputs[0]
+                outputs = (
+                    outputs + cross_attention_outputs[1:-1]
+                )  # add cross attentions if we output attention weights
+
+            layer_output = apply_chunking_to_forward(
+                self.feed_forward_chunk_query,
+                self.chunk_size_feed_forward,
+                self.seq_len_dim,
+                query_attention_output,
+            )
+            if attention_output.shape[1] > query_length:
+                layer_output_text = apply_chunking_to_forward(
+                    self.feed_forward_chunk,
+                    self.chunk_size_feed_forward,
+                    self.seq_len_dim,
+                    attention_output[:, query_length:, :],
+                )
+                layer_output = torch.cat([layer_output, layer_output_text], dim=1)
+        else:
+            layer_output = apply_chunking_to_forward(
+                self.feed_forward_chunk,
+                self.chunk_size_feed_forward,
+                self.seq_len_dim,
+                attention_output,
+            )
+        outputs = (layer_output,) + outputs
+
+        outputs = outputs + (present_key_value,)
+
+        return outputs
+
+    def feed_forward_chunk(self, attention_output):
+        intermediate_output = self.intermediate(attention_output)
+        layer_output = self.output(intermediate_output, attention_output)
+        return layer_output
+
+    def feed_forward_chunk_query(self, attention_output):
+        intermediate_output = self.intermediate_query(attention_output)
+        layer_output = self.output_query(intermediate_output, attention_output)
+        return layer_output
+
+
+class BertEncoder(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.config = config
+        self.layer = nn.ModuleList(
+            [BertLayer(config, i) for i in range(config.num_hidden_layers)]
+        )
+
+    def forward(
+        self,
+        hidden_states,
+        attention_mask=None,
+        head_mask=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=False,
+        output_hidden_states=False,
+        return_dict=True,
+        query_length=0,
+    ):
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+        all_cross_attentions = (
+            () if output_attentions and self.config.add_cross_attention else None
+        )
+
+        next_decoder_cache = () if use_cache else None
+
+        for i in range(self.config.num_hidden_layers):
+            layer_module = self.layer[i]
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+            past_key_value = past_key_values[i] if past_key_values is not None else None
+
+            if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+                if use_cache:
+                    logger.warn(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                    )
+                    use_cache = False
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(
+                            *inputs, past_key_value, output_attentions, query_length
+                        )
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(layer_module),
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states,
+                    attention_mask,
+                    layer_head_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    past_key_value,
+                    output_attentions,
+                    query_length,
+                )
+
+            hidden_states = layer_outputs[0]
+            if use_cache:
+                next_decoder_cache += (layer_outputs[-1],)
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (layer_outputs[1],)
+                all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [
+                    hidden_states,
+                    next_decoder_cache,
+                    all_hidden_states,
+                    all_self_attentions,
+                    all_cross_attentions,
+                ]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_decoder_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class BertPooler(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+class BertPredictionHeadTransform(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+        if isinstance(config.hidden_act, str):
+            self.transform_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.transform_act_fn = config.hidden_act
+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+    def forward(self, hidden_states):
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.transform_act_fn(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states)
+        return hidden_states
+
+
+class BertLMPredictionHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.transform = BertPredictionHeadTransform(config)
+
+        # The output weights are the same as the input embeddings, but there is
+        # an output-only bias for each token.
+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+        self.decoder.bias = self.bias
+
+    def forward(self, hidden_states):
+        hidden_states = self.transform(hidden_states)
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+class BertOnlyMLMHead(nn.Module):
+    def __init__(self, config):
+        super().__init__()
+        self.predictions = BertLMPredictionHead(config)
+
+    def forward(self, sequence_output):
+        prediction_scores = self.predictions(sequence_output)
+        return prediction_scores
+
+
+class BertPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = BertConfig
+    base_model_prefix = "bert"
+    _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Embedding)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+        if isinstance(module, nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+
+class BertModel(BertPreTrainedModel):
+    """
+    The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+    cross-attention is added between the self-attention layers, following the architecture described in `Attention is
+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
+    argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
+    input to the forward pass.
+    """
+
+    def __init__(self, config, add_pooling_layer=False):
+        super().__init__(config)
+        self.config = config
+
+        self.embeddings = BertEmbeddings(config)
+
+        self.encoder = BertEncoder(config)
+
+        self.pooler = BertPooler(config) if add_pooling_layer else None
+
+        self.init_weights()
+
+    def get_input_embeddings(self):
+        return self.embeddings.word_embeddings
+
+    def set_input_embeddings(self, value):
+        self.embeddings.word_embeddings = value
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    def get_extended_attention_mask(
+        self,
+        attention_mask: Tensor,
+        input_shape: Tuple[int],
+        device: device,
+        is_decoder: bool,
+        has_query: bool = False,
+    ) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (:obj:`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (:obj:`Tuple[int]`):
+                The shape of the input to the model.
+            device: (:obj:`torch.device`):
+                The device of the input to the model.
+
+        Returns:
+            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
+        """
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if is_decoder:
+                batch_size, seq_length = input_shape
+
+                seq_ids = torch.arange(seq_length, device=device)
+                causal_mask = (
+                    seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
+                    <= seq_ids[None, :, None]
+                )
+
+                # add a prefix ones mask to the causal mask
+                # causal and attention masks must have same type with pytorch version < 1.3
+                causal_mask = causal_mask.to(attention_mask.dtype)
+
+                if causal_mask.shape[1] < attention_mask.shape[1]:
+                    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+                    if has_query:  # UniLM style attention mask
+                        causal_mask = torch.cat(
+                            [
+                                torch.zeros(
+                                    (batch_size, prefix_seq_len, seq_length),
+                                    device=device,
+                                    dtype=causal_mask.dtype,
+                                ),
+                                causal_mask,
+                            ],
+                            axis=1,
+                        )
+                    causal_mask = torch.cat(
+                        [
+                            torch.ones(
+                                (batch_size, causal_mask.shape[1], prefix_seq_len),
+                                device=device,
+                                dtype=causal_mask.dtype,
+                            ),
+                            causal_mask,
+                        ],
+                        axis=-1,
+                    )
+                extended_attention_mask = (
+                    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+                )
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
+                    input_shape, attention_mask.shape
+                )
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and -10000.0 for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(
+            dtype=self.dtype
+        )  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+        return extended_attention_mask
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        past_key_values=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        is_decoder=False,
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        """
+        output_attentions = (
+            output_attentions
+            if output_attentions is not None
+            else self.config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states
+            if output_hidden_states is not None
+            else self.config.output_hidden_states
+        )
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        # use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        if input_ids is None:
+            assert (
+                query_embeds is not None
+            ), "You have to specify query_embeds when input_ids is None"
+
+        # past_key_values_length
+        past_key_values_length = (
+            past_key_values[0][0].shape[2] - self.config.query_length
+            if past_key_values is not None
+            else 0
+        )
+
+        query_length = query_embeds.shape[1] if query_embeds is not None else 0
+
+        embedding_output = self.embeddings(
+            input_ids=input_ids,
+            position_ids=position_ids,
+            query_embeds=query_embeds,
+            past_key_values_length=past_key_values_length,
+        )
+
+        input_shape = embedding_output.size()[:-1]
+        batch_size, seq_length = input_shape
+        device = embedding_output.device
+
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                ((batch_size, seq_length + past_key_values_length)), device=device
+            )
+
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if is_decoder:
+            extended_attention_mask = self.get_extended_attention_mask(
+                attention_mask,
+                input_ids.shape,
+                device,
+                is_decoder,
+                has_query=(query_embeds is not None),
+            )
+        else:
+            extended_attention_mask = self.get_extended_attention_mask(
+                attention_mask, input_shape, device, is_decoder
+            )
+
+        # If a 2D or 3D attention mask is provided for the cross-attention
+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+        if encoder_hidden_states is not None:
+            if type(encoder_hidden_states) == list:
+                encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
+                    0
+                ].size()
+            else:
+                (
+                    encoder_batch_size,
+                    encoder_sequence_length,
+                    _,
+                ) = encoder_hidden_states.size()
+            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+
+            if type(encoder_attention_mask) == list:
+                encoder_extended_attention_mask = [
+                    self.invert_attention_mask(mask) for mask in encoder_attention_mask
+                ]
+            elif encoder_attention_mask is None:
+                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+                encoder_extended_attention_mask = self.invert_attention_mask(
+                    encoder_attention_mask
+                )
+            else:
+                encoder_extended_attention_mask = self.invert_attention_mask(
+                    encoder_attention_mask
+                )
+        else:
+            encoder_extended_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            attention_mask=extended_attention_mask,
+            head_mask=head_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_extended_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            query_length=query_length,
+        )
+        sequence_output = encoder_outputs[0]
+        pooled_output = (
+            self.pooler(sequence_output) if self.pooler is not None else None
+        )
+
+        if not return_dict:
+            return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+        return BaseModelOutputWithPoolingAndCrossAttentions(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            past_key_values=encoder_outputs.past_key_values,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            cross_attentions=encoder_outputs.cross_attentions,
+        )
+
+
+class BertLMHeadModel(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        past_key_values=None,
+        use_cache=True,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        return_logits=False,
+        is_decoder=True,
+        reduction="mean",
+    ):
+        r"""
+        encoder_hidden_states  (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+            the model is configured as a decoder.
+        encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+            the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+            ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
+            ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
+        past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+            If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
+            (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
+            instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
+        use_cache (:obj:`bool`, `optional`):
+            If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
+            decoding (see :obj:`past_key_values`).
+        Returns:
+        Example::
+            >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
+            >>> import torch
+            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+            >>> config = BertConfig.from_pretrained("bert-base-cased")
+            >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
+            >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+            >>> outputs = model(**inputs)
+            >>> prediction_logits = outputs.logits
+        """
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+        if labels is not None:
+            use_cache = False
+        if past_key_values is not None:
+            query_embeds = None
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            query_embeds=query_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+        )
+
+        sequence_output = outputs[0]
+        if query_embeds is not None:
+            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+
+        prediction_scores = self.cls(sequence_output)
+
+        if return_logits:
+            return prediction_scores[:, :-1, :].contiguous()
+
+        lm_loss = None
+        if labels is not None:
+            # we are doing next-token prediction; shift prediction scores and input ids by one
+            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+            labels = labels[:, 1:].contiguous()
+            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+            lm_loss = loss_fct(
+                shifted_prediction_scores.view(-1, self.config.vocab_size),
+                labels.view(-1),
+            )
+            if reduction == "none":
+                lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return ((lm_loss,) + output) if lm_loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=lm_loss,
+            logits=prediction_scores,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
+    ):
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+        query_mask = input_ids.new_ones(query_embeds.shape[:-1])
+        attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
+
+        # cut decoder_input_ids if past is used
+        if past is not None:
+            input_ids = input_ids[:, -1:]
+
+        return {
+            "input_ids": input_ids,
+            "query_embeds": query_embeds,
+            "attention_mask": attention_mask,
+            "past_key_values": past,
+            "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
+            "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
+            "is_decoder": True,
+        }
+
+    def _reorder_cache(self, past, beam_idx):
+        reordered_past = ()
+        for layer_past in past:
+            reordered_past += (
+                tuple(
+                    past_state.index_select(0, beam_idx) for past_state in layer_past
+                ),
+            )
+        return reordered_past
+
+
+class BertForMaskedLM(BertPreTrainedModel):
+
+    _keys_to_ignore_on_load_unexpected = [r"pooler"]
+    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.bert = BertModel(config, add_pooling_layer=False)
+        self.cls = BertOnlyMLMHead(config)
+
+        self.init_weights()
+
+    def get_output_embeddings(self):
+        return self.cls.predictions.decoder
+
+    def set_output_embeddings(self, new_embeddings):
+        self.cls.predictions.decoder = new_embeddings
+
+    def forward(
+        self,
+        input_ids=None,
+        attention_mask=None,
+        position_ids=None,
+        head_mask=None,
+        query_embeds=None,
+        encoder_hidden_states=None,
+        encoder_attention_mask=None,
+        labels=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+        return_logits=False,
+        is_decoder=False,
+    ):
+        r"""
+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+        """
+
+        return_dict = (
+            return_dict if return_dict is not None else self.config.use_return_dict
+        )
+
+        outputs = self.bert(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            head_mask=head_mask,
+            query_embeds=query_embeds,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+            is_decoder=is_decoder,
+        )
+
+        if query_embeds is not None:
+            sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
+        prediction_scores = self.cls(sequence_output)
+
+        if return_logits:
+            return prediction_scores
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()  # -100 index = padding token
+            masked_lm_loss = loss_fct(
+                prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
+            )
+
+        if not return_dict:
+            output = (prediction_scores,) + outputs[2:]
+            return (
+                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+            )
+
+        return MaskedLMOutput(
+            loss=masked_lm_loss,
+            logits=prediction_scores,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
diff --git a/unimernet/models/blip2_models/__init__.py b/unimernet/models/blip2_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a040440a336c86164ff02f23bc90c0af3046ccb
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/Qformer.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..66b2031f3f5efe47f280420ea496156ea14f4fe1
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc b/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b46fcb658d78c631ebcd9a06309f3883c35f6786
Binary files /dev/null and b/unimernet/models/blip2_models/__pycache__/blip2.cpython-310.pyc differ
diff --git a/unimernet/models/blip2_models/blip2.py b/unimernet/models/blip2_models/blip2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3829d58c1a97d49893566488aafc95f1e4c8d458
--- /dev/null
+++ b/unimernet/models/blip2_models/blip2.py
@@ -0,0 +1,322 @@
+"""
+ Copyright (c) 2023, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import contextlib
+import logging
+import os
+import time
+import datetime
+
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+import torch.nn.functional as F
+
+import unimernet.common.dist_utils as dist_utils
+from unimernet.common.dist_utils import download_cached_file
+from unimernet.common.utils import is_url
+from unimernet.common.logger import MetricLogger
+from unimernet.models.base_model import BaseModel
+from unimernet.models.blip2_models.Qformer import BertConfig, BertLMHeadModel
+from unimernet.models.eva_vit import create_eva_vit_g
+from unimernet.models.clip_vit import create_clip_vit_L
+from transformers import BertTokenizer
+from transformers.utils import logging as tf_logging
+
+tf_logging.set_verbosity_error()
+
+
+class Blip2Base(BaseModel):
+    @classmethod
+    def init_tokenizer(cls, truncation_side="right"):
+        tokenizer = BertTokenizer.from_pretrained("/mnt/lustre/hanxiao/work/bert-base-uncased", truncation_side=truncation_side)
+        tokenizer.add_special_tokens({"bos_token": "[DEC]"})
+        return tokenizer
+
+    def maybe_autocast(self, dtype=torch.float16):
+        # if on cpu, don't use autocast
+        # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
+        enable_autocast = self.device != torch.device("cpu")
+
+        if enable_autocast:
+            return torch.cuda.amp.autocast(dtype=dtype)
+        else:
+            return contextlib.nullcontext()
+
+    @classmethod
+    def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
+        encoder_config = BertConfig.from_pretrained("/mnt/lustre/hanxiao/work/bert-base-uncased")
+        encoder_config.encoder_width = vision_width
+        # insert cross-attention layer every other block
+        encoder_config.add_cross_attention = True
+        encoder_config.cross_attention_freq = cross_attention_freq
+        encoder_config.query_length = num_query_token
+        Qformer = BertLMHeadModel.from_pretrained(
+            "/mnt/lustre/hanxiao/work/bert-base-uncased", config=encoder_config
+        )
+        query_tokens = nn.Parameter(
+            torch.zeros(1, num_query_token, encoder_config.hidden_size)
+        )
+        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
+        return Qformer, query_tokens
+
+    def init_vision_encoder(
+            self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
+    ):
+        assert model_name in [
+            "eva_clip_g",
+            "eva2_clip_L",
+            "clip_L",
+        ], "vit model must be eva_clip_g, eva2_clip_L or clip_L"
+        if model_name == "eva_clip_g":
+            visual_encoder = create_eva_vit_g(
+                img_size, drop_path_rate, use_grad_checkpoint, precision
+            )
+        #         elif model_name == "eva2_clip_L":
+        #             visual_encoder = create_eva2_vit_L(
+        #                 img_size, drop_path_rate, use_grad_checkpoint, precision
+        #             )
+        elif model_name == "clip_L":
+            visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision)
+        ln_vision = LayerNorm(visual_encoder.num_features)
+        self.vit_name = model_name
+        return visual_encoder, ln_vision
+
+    def load_from_pretrained(self, url_or_filename):
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location="cpu")
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location="cpu")
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        state_dict = checkpoint["model"]
+
+        msg = self.load_state_dict(state_dict, strict=False)
+
+        # logging.info("Missing keys {}".format(msg.missing_keys))
+        logging.info("load checkpoint from %s" % url_or_filename)
+
+        return msg
+
+    def get_optimizer_params(self, weight_decay, lr_scale=1):
+        if self.vit_name == "eva_clip_g":
+            vit_num_layers = self.visual_encoder.get_num_layer()
+            lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2))
+
+            parameter_group_names = {}
+            parameter_group_vars = {}
+
+            for name, param in self.named_parameters():
+                if not param.requires_grad:
+                    continue  # frozen weights
+                if len(param.shape) == 1 or name.endswith(".bias"):
+                    group_name = "no_decay"
+                    this_weight_decay = 0.
+                else:
+                    group_name = "decay"
+                    this_weight_decay = weight_decay
+                if 'visual_encoder' in name:
+                    layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.', ''))
+                    group_name = "vit_layer_%d_%s" % (layer_id, group_name)
+                else:
+                    layer_id = None
+
+                if group_name not in parameter_group_names:
+                    if layer_id is not None:
+                        scale = lr_scales[layer_id]
+                    else:
+                        scale = 1
+                    parameter_group_names[group_name] = {
+                        "weight_decay": this_weight_decay,
+                        "params": [],
+                        "lr_scale": scale
+                    }
+                    parameter_group_vars[group_name] = {
+                        "weight_decay": this_weight_decay,
+                        "params": [],
+                        "lr_scale": scale
+                    }
+                parameter_group_vars[group_name]["params"].append(param)
+                parameter_group_names[group_name]["params"].append(name)
+            # import json
+            # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+            optim_params = list(parameter_group_vars.values())
+            return optim_params
+        else:
+            return super().get_optimizer_params(weight_decay, lr_scale)
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+def compute_sim_matrix(model, data_loader, **kwargs):
+    k_test = kwargs.pop("k_test")
+
+    metric_logger = MetricLogger(delimiter="  ")
+    header = "Evaluation:"
+
+    logging.info("Computing features for evaluation...")
+    start_time = time.time()
+
+    texts = data_loader.dataset.text
+    num_text = len(texts)
+    text_bs = 256
+    text_ids = []
+    text_embeds = []
+    text_atts = []
+    for i in range(0, num_text, text_bs):
+        text = texts[i: min(num_text, i + text_bs)]
+        text_input = model.tokenizer(
+            text,
+            padding="max_length",
+            truncation=True,
+            max_length=35,
+            return_tensors="pt",
+        ).to(model.device)
+        text_feat = model.forward_text(text_input)
+        text_embed = F.normalize(model.text_proj(text_feat))
+        text_embeds.append(text_embed)
+        text_ids.append(text_input.input_ids)
+        text_atts.append(text_input.attention_mask)
+
+    text_embeds = torch.cat(text_embeds, dim=0)
+    text_ids = torch.cat(text_ids, dim=0)
+    text_atts = torch.cat(text_atts, dim=0)
+
+    vit_feats = []
+    image_embeds = []
+    for samples in data_loader:
+        image = samples["image"]
+
+        image = image.to(model.device)
+        image_feat, vit_feat = model.forward_image(image)
+        image_embed = model.vision_proj(image_feat)
+        image_embed = F.normalize(image_embed, dim=-1)
+
+        vit_feats.append(vit_feat.cpu())
+        image_embeds.append(image_embed)
+
+    vit_feats = torch.cat(vit_feats, dim=0)
+    image_embeds = torch.cat(image_embeds, dim=0)
+
+    sims_matrix = []
+    for image_embed in image_embeds:
+        sim_q2t = image_embed @ text_embeds.t()
+        sim_i2t, _ = sim_q2t.max(0)
+        sims_matrix.append(sim_i2t)
+    sims_matrix = torch.stack(sims_matrix, dim=0)
+
+    score_matrix_i2t = torch.full(
+        (len(data_loader.dataset.image), len(texts)), -100.0
+    ).to(model.device)
+
+    num_tasks = dist_utils.get_world_size()
+    rank = dist_utils.get_rank()
+    step = sims_matrix.size(0) // num_tasks + 1
+    start = rank * step
+    end = min(sims_matrix.size(0), start + step)
+
+    for i, sims in enumerate(
+            metric_logger.log_every(sims_matrix[start:end], 50, header)
+    ):
+        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+        image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
+        score = model.compute_itm(
+            image_inputs=image_inputs,
+            text_ids=text_ids[topk_idx],
+            text_atts=text_atts[topk_idx],
+        ).float()
+        score_matrix_i2t[start + i, topk_idx] = score + topk_sim
+
+    sims_matrix = sims_matrix.t()
+    score_matrix_t2i = torch.full(
+        (len(texts), len(data_loader.dataset.image)), -100.0
+    ).to(model.device)
+
+    step = sims_matrix.size(0) // num_tasks + 1
+    start = rank * step
+    end = min(sims_matrix.size(0), start + step)
+
+    for i, sims in enumerate(
+            metric_logger.log_every(sims_matrix[start:end], 50, header)
+    ):
+        topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
+        image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
+        score = model.compute_itm(
+            image_inputs=image_inputs,
+            text_ids=text_ids[start + i].repeat(k_test, 1),
+            text_atts=text_atts[start + i].repeat(k_test, 1),
+        ).float()
+        score_matrix_t2i[start + i, topk_idx] = score + topk_sim
+
+    if dist_utils.is_dist_avail_and_initialized():
+        dist.barrier()
+        torch.distributed.all_reduce(
+            score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
+        )
+        torch.distributed.all_reduce(
+            score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
+        )
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logging.info("Evaluation time {}".format(total_time_str))
+
+    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
diff --git a/unimernet/models/blip2_models/blip2_vicuna_instruct.py b/unimernet/models/blip2_models/blip2_vicuna_instruct.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5e257b74c8d17db444e4aa78c06956a2f83027
--- /dev/null
+++ b/unimernet/models/blip2_models/blip2_vicuna_instruct.py
@@ -0,0 +1,666 @@
+"""
+Requires Transformer 4.28 and above, implementation may change according the Llama implementation
+"""
+import logging
+from packaging import version
+
+import torch
+import torch.nn as nn
+
+import transformers
+
+from unimernet.common.registry import registry
+from unimernet.models.blip2_models.blip2 import Blip2Base, disabled_train
+
+
+@registry.register_model("blip2_vicuna_instruct")
+class Blip2VicunaInstruct(Blip2Base):
+    """
+    BLIP2 Vicuna model.
+    Supported model types:
+        - vicuna7b
+        - vicuna13b
+    Usage:
+        >>> from unimernet.models import load_model
+        >>> model = load_model("blip2_vicuna_instruct", "vicuna7b")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "vicuna7b": "configs/models/blip2_instruct_vicuna7b.yaml",
+        "vicuna13b": "configs/models/blip2_instruct_vicuna13b.yaml",
+        "minigpt4_vicuna7b": "configs/models/mini_gpt4_vicuna7b.yaml",
+        "minigpt4_vicuna13b": "configs/models/mini_gpt4_vicuna13b.yaml",
+    }
+
+    def __init__(
+            self,
+            vit_model="eva_clip_g",
+            img_size=224,
+            drop_path_rate=0,
+            use_grad_checkpoint=False,
+            vit_precision="fp16",
+            freeze_vit=True,
+            freeze_vit_ln=False,
+            num_query_token=32,
+            llm_model="",
+            prompt="",
+            max_txt_len=128,
+            max_output_txt_len=256,
+            apply_lemmatizer=False,
+            qformer_text_input=True,
+            truncate_q_former_output=True
+    ):
+        super().__init__()
+        transformers_version = version.parse(transformers.__version__)
+        assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28"
+        from transformers import LlamaTokenizer
+        from unimernet.models.blip2_models.modeling_llama import LlamaForCausalLM
+
+        self.tokenizer = self.init_tokenizer(truncation_side="left")
+
+        self.visual_encoder, self.ln_vision = self.init_vision_encoder(
+            vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
+        )
+        if freeze_vit:
+            for name, param in self.visual_encoder.named_parameters():
+                param.requires_grad = False
+            self.visual_encoder = self.visual_encoder.eval()
+            self.visual_encoder.train = disabled_train
+            logging.info("freeze vision encoder")
+
+        if freeze_vit_ln:
+            for name, param in self.ln_vision.named_parameters():
+                param.requires_grad = False
+            self.ln_vision = self.ln_vision.eval()
+            self.ln_vision.train = disabled_train
+            logging.info("freeze vit layner norm")
+
+        self.Qformer, self.query_tokens = self.init_Qformer(
+            num_query_token, self.visual_encoder.num_features
+        )
+
+        if not qformer_text_input:
+            self.Qformer.bert.embeddings.word_embeddings = None
+            self.Qformer.bert.embeddings.position_embeddings = None
+            for layer in self.Qformer.bert.encoder.layer:
+                layer.output = None
+                layer.intermediate = None
+        else:
+            self.Qformer.resize_token_embeddings(len(self.tokenizer))
+        self.Qformer.cls = None
+
+        self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_model, use_fast=False, truncation_side="left")
+        self.llm_tokenizer_for_generate = LlamaTokenizer.from_pretrained(llm_model, use_fast=False,
+                                                                         truncation_side="left")
+        self.llm_model = LlamaForCausalLM.from_pretrained(
+            llm_model, torch_dtype=torch.float16
+        )
+        self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+        self.llm_tokenizer.add_special_tokens({'bos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'eos_token': '</s>'})
+        self.llm_tokenizer.add_special_tokens({'unk_token': '</s>'})
+        # self.llm_tokenizer.pad_token = self.llm_tokenizer.unk_token
+
+        self.llm_tokenizer_for_generate.add_special_tokens({'pad_token': '[PAD]'})
+        self.llm_tokenizer_for_generate.add_special_tokens({'bos_token': '</s>'})
+        self.llm_tokenizer_for_generate.add_special_tokens({'eos_token': '</s>'})
+        self.llm_tokenizer_for_generate.add_special_tokens({'unk_token': '</s>'})
+        self.llm_model.resize_token_embeddings(len(self.llm_tokenizer))
+
+        # self.eos_token_id = self.llm_tokenizer(
+        #     self.llm_tokenizer.eos_token, add_special_tokens=False
+        # ).input_ids[0]
+
+        for name, param in self.llm_model.named_parameters():
+            param.requires_grad = False
+
+        self.llm_proj = nn.Linear(
+            self.Qformer.config.hidden_size, self.llm_model.config.hidden_size
+        )
+
+        self.max_txt_len = max_txt_len
+        self.max_output_txt_len = max_output_txt_len
+        self.prompt = prompt
+        prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt")
+        self.prompt_length = prompt_tokens.attention_mask.sum(1)
+
+        self._lemmatizer = None
+
+        self.qformer_text_input = qformer_text_input
+        self.truncate_q_former_output = truncate_q_former_output
+
+    def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts):
+        input_part_targets_len = []
+        llm_tokens = {"input_ids": [], "attention_mask": []}
+        for i in range(input_ids.size(0)):
+            this_input_ones = input_atts[i].sum()
+            input_part_targets_len.append(this_input_ones)
+            llm_tokens['input_ids'].append(
+                torch.cat([
+                    input_ids[i][:this_input_ones],
+                    output_ids[i][1:],
+                    input_ids[i][this_input_ones:]
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    input_atts[i][:this_input_ones],
+                    output_atts[i][1:],
+                    input_atts[i][this_input_ones:]
+                ])
+            )
+        llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids'])
+        llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+        return llm_tokens, input_part_targets_len
+
+    def forward(self, samples):
+        # print('-----------------')
+        # print(samples["text_input"])
+        # print(samples["text_output"])
+        # print('-----------------')
+
+        image = samples["image"]
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        bs = image.size(0)
+
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                samples["text_input"],
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+            query_output = self.Qformer.bert(
+                text_Qformer.input_ids,
+                attention_mask=Qformer_atts,
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+        else:
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+        if self.truncate_q_former_output:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+        else:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state)
+        atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            samples['text_input'],
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(image.device)
+
+        self.llm_tokenizer.truncation_side = 'right'
+        text_output_tokens = self.llm_tokenizer(
+            [t + self.llm_tokenizer.eos_token for t in samples['text_output']],
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_output_txt_len,
+        ).to(image.device)
+
+        llm_tokens, input_part_targets_len = self.concat_text_input_output(
+            text_input_tokens.input_ids,
+            text_input_tokens.attention_mask,
+            text_output_tokens.input_ids,
+            text_output_tokens.attention_mask,
+        )
+
+        # do not apply loss to the padding
+        targets = llm_tokens['input_ids'].masked_fill(
+            llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+        )
+
+        # do not apply loss to the text input (i.e., instruction)
+        for i, l in enumerate(input_part_targets_len):
+            targets[i][:l] = -100
+
+        # do not apply loss to the query tokens
+        empty_targets = (
+            torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+        )
+        targets = torch.cat([empty_targets, targets], dim=1)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
+
+        with self.maybe_autocast():
+            outputs = self.llm_model(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                return_dict=True,
+                labels=targets,
+                use_cache=False,
+            )
+
+        loss = outputs.loss
+
+        return {"loss": loss}
+
+    def get_vision_feats(self, image, prompt):
+        bs = image.size(0)
+
+        if isinstance(prompt, str):
+            prompt = [prompt] * bs
+        else:
+            assert len(prompt) == bs, "The number of prompts must be equal to the batch size."
+
+        query_tokens = self.query_tokens.expand(bs, -1, -1)
+
+        text_Qformer = self.tokenizer(
+            prompt,
+            padding='longest',
+            truncation=True,
+            max_length=self.max_txt_len,
+            return_tensors="pt",
+        ).to(image.device)
+        query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+        Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        query_output = self.Qformer.bert(
+            text_Qformer.input_ids,
+            attention_mask=Qformer_atts,
+            query_embeds=query_tokens,
+            encoder_hidden_states=image_embeds,
+            encoder_attention_mask=image_atts,
+            return_dict=True,
+        )
+        if self.truncate_q_former_output:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+        else:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state)
+        atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+        return inputs_llm, atts_llm
+
+    def shift_padding_to_left(self, inputs_embeds, attention_mask):
+        llm_tokens = {"input_embeds": [], "attention_mask": []}
+        for i in range(inputs_embeds.size(0)):
+            this_input_ones = attention_mask[i].sum()
+            llm_tokens['input_embeds'].append(
+                torch.cat([
+                    inputs_embeds[i][this_input_ones:],
+                    inputs_embeds[i][:this_input_ones],
+                ])
+            )
+            llm_tokens['attention_mask'].append(
+                torch.cat([
+                    attention_mask[i][this_input_ones:],
+                    attention_mask[i][:this_input_ones],
+                ])
+            )
+        llm_tokens['input_embeds'] = torch.stack(llm_tokens['input_embeds'])
+        llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask'])
+        return llm_tokens['input_embeds'], llm_tokens['attention_mask']
+
+    @torch.no_grad()
+    def generate(
+            self,
+            samples,
+            use_nucleus_sampling=False,
+            num_beams=5,
+            max_length=256,
+            min_length=1,
+            top_p=0.9,
+            repetition_penalty=1.5,
+            length_penalty=1,
+            num_captions=1,
+            temperature=1,
+    ):
+
+        if "prompt" in samples.keys():
+            prompt = samples["prompt"]
+        else:
+            prompt = self.prompt
+
+        image = samples["image"]
+
+        inputs_llm, atts_llm = self.get_vision_feats(image, prompt)
+
+        self.llm_tokenizer_for_generate.padding_side = "right"
+
+        self.llm_tokenizer_for_generate.pad_token = self.llm_tokenizer_for_generate.eos_token  # debug
+        ori_pad_token_id = self.llm_model.config.pad_token_id
+        self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id  # debug
+
+        if "prefix" in samples:
+            prompt = [f"{prompt_} {prefix_}".strip() for prompt_, prefix_ in zip(prompt, samples["prefix"])]
+
+        llm_tokens = self.llm_tokenizer_for_generate(
+            prompt,
+            padding="longest",
+            return_tensors="pt",
+        ).to(image.device)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        inputs_embeds = inputs_embeds.to(next(self.llm_model.parameters()).dtype)
+        attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
+        inputs_embeds, attention_mask = self.shift_padding_to_left(inputs_embeds, attention_mask)
+
+        with self.maybe_autocast():
+            outputs = self.llm_model.generate(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                do_sample=use_nucleus_sampling,
+                top_p=top_p,
+                temperature=temperature,
+                num_beams=num_beams,
+                max_length=max_length,
+                min_length=min_length,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                num_return_sequences=num_captions,
+                use_cache=True
+            )
+
+        outputs[outputs == 0] = 2  # convert output id 0 to 2 (eos_token_id)
+        outputs[outputs == -1] = 1  # debug
+        output_text = self.llm_tokenizer_for_generate.batch_decode(outputs, skip_special_tokens=True)
+        output_text = [text.strip() for text in output_text]
+
+        self.llm_model.config.pad_token_id = ori_pad_token_id
+
+        return output_text
+
+    @torch.no_grad()
+    def generate_multi(
+            self,
+            samples,
+            use_nucleus_sampling=False,
+            num_beams=5,
+            max_length=256,
+            min_length=1,
+            top_p=0.9,
+            repetition_penalty=1.5,
+            length_penalty=1,
+            temperature=1,
+    ):
+
+        if "prompt" in samples.keys():
+            prompt = samples["prompt"]
+        else:
+            prompt = self.prompt
+
+        image = samples["image"]
+
+        inputs_llm, atts_llm = self.get_vision_feats(image, prompt)
+
+        self.llm_tokenizer_for_generate.padding_side = "right"
+
+        self.llm_tokenizer_for_generate.pad_token = self.llm_tokenizer_for_generate.eos_token  # debug
+        ori_pad_token_id = self.llm_model.config.pad_token_id
+        self.llm_model.config.pad_token_id = self.llm_model.config.eos_token_id  # debug
+
+        if "prefix" in samples:
+            prompt = [f"{prompt_} {prefix_}".strip() for prompt_, prefix_ in zip(prompt, samples["prefix"])]
+
+        llm_tokens = self.llm_tokenizer_for_generate(
+            prompt,
+            padding="longest",
+            return_tensors="pt",
+        ).to(image.device)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids)
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        inputs_embeds = inputs_embeds.to(next(self.llm_model.parameters()).dtype)
+        attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1)
+        inputs_embeds, attention_mask = self.shift_padding_to_left(inputs_embeds, attention_mask)
+
+        with self.maybe_autocast():
+            raw_output = self.llm_model.generate(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                do_sample=use_nucleus_sampling,
+                top_p=top_p,
+                temperature=temperature,
+                num_beams=num_beams,
+                max_length=max_length,
+                min_length=min_length,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                num_return_sequences=num_beams,
+                output_scores=True,
+                return_dict_in_generate=True,
+                use_cache=True
+            )
+        outputs = raw_output.sequences
+        outputs[outputs == 0] = 2  # convert output id 0 to 2 (eos_token_id)
+        outputs[outputs == -1] = 1  # debug
+        output_text = self.llm_tokenizer_for_generate.batch_decode(outputs, skip_special_tokens=True)
+
+        output_text = [text.strip() for text in output_text]
+        scores = torch.exp(raw_output.sequences_scores).cpu().numpy() ** 3 * 100  # TODO
+
+        all_texts = []
+        all_scores = []
+        for i in range(0, len(output_text), num_beams):
+            this_text = output_text[i:i + num_beams]
+            all_texts.append(this_text)
+            this_score = scores[i: i + num_beams]
+            all_scores.append(this_score)
+
+        self.llm_model.config.pad_token_id = ori_pad_token_id
+
+        return all_texts, all_scores
+
+    def predict_by_rank(
+            self,
+            samples,
+            **kwargs
+    ):
+        image = samples["image"]
+        prompt = samples["prompt"]
+        candidates = samples["candidates"][0]
+        if isinstance(prompt, str):
+            prompt = [prompt]
+        assert image.size(0) == len(prompt) == 1, "When doing predict by rank, the batch size must be 1."
+
+        with self.maybe_autocast():
+            image_embeds = self.ln_vision(self.visual_encoder(image))
+        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)
+
+        batch_size = len(candidates)
+
+        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
+        if self.qformer_text_input:
+            text_Qformer = self.tokenizer(
+                prompt,
+                padding='longest',
+                truncation=True,
+                max_length=self.max_txt_len,
+                return_tensors="pt",
+            ).to(image.device)
+            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)
+            Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1)
+
+            query_output = self.Qformer.bert(
+                text_Qformer.input_ids,
+                attention_mask=Qformer_atts,
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+        else:
+            query_output = self.Qformer.bert(
+                query_embeds=query_tokens,
+                encoder_hidden_states=image_embeds,
+                encoder_attention_mask=image_atts,
+                return_dict=True,
+            )
+
+        if self.truncate_q_former_output:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state[:, :query_tokens.size(1), :])
+        else:
+            inputs_llm = self.llm_proj(query_output.last_hidden_state)
+        atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device)
+
+        self.llm_tokenizer.padding_side = "right"
+        self.llm_tokenizer.truncation_side = 'left'
+        text_input_tokens = self.llm_tokenizer(
+            prompt,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_txt_len,
+        ).to(image.device)
+
+        inputs_llm = inputs_llm.repeat(batch_size, 1, 1)
+        atts_llm = atts_llm.repeat(batch_size, 1)
+        text_input_ids = text_input_tokens.input_ids.repeat(batch_size, 1)
+        text_input_mask = text_input_tokens.attention_mask.repeat(batch_size, 1)
+
+        self.llm_tokenizer.truncation_side = 'right'
+        text_output_tokens = self.llm_tokenizer(
+            [t + self.llm_tokenizer.eos_token for t in candidates],
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=self.max_output_txt_len,
+        ).to(image.device)
+
+        llm_tokens, input_part_targets_len = self.concat_text_input_output(
+            text_input_ids,
+            text_input_mask,
+            text_output_tokens.input_ids,
+            text_output_tokens.attention_mask,
+        )
+
+        # do not apply loss to the padding
+        targets = llm_tokens['input_ids'].masked_fill(
+            llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100
+        )
+
+        # do not apply loss to the text input (i.e., instruction)
+        for i, l in enumerate(input_part_targets_len):
+            targets[i][:l] = -100
+
+        # do not apply loss to the query tokens
+        empty_targets = (
+            torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100)
+        )
+        targets = torch.cat([empty_targets, targets], dim=1)
+
+        inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids'])
+        inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1)
+        attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1)
+
+        with self.maybe_autocast():
+            outputs = self.llm_model(
+                inputs_embeds=inputs_embeds,
+                attention_mask=attention_mask,
+                return_dict=True,
+                labels=targets,
+                reduction="none",
+                use_cache=False
+            )
+
+        loss = outputs.loss.view(batch_size)
+        top1 = int(torch.argsort(loss, dim=-1)[0])
+
+        return [candidates[top1]]
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+
+    @classmethod
+    def from_config(cls, cfg):
+        vit_model = cfg.get("vit_model", "eva_clip_g")
+        img_size = cfg.get("image_size")
+        num_query_token = cfg.get("num_query_token")
+        llm_model = cfg.get("llm_model")
+
+        drop_path_rate = cfg.get("drop_path_rate", 0)
+        use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
+        vit_precision = cfg.get("vit_precision", "fp16")
+        freeze_vit = cfg.get("freeze_vit", True)
+        freeze_vit_ln = cfg.get("freeze_vit_ln", False)
+        prompt = cfg.get("prompt", "")
+        max_txt_len = cfg.get("max_txt_len", 128)
+        max_output_txt_len = cfg.get("max_output_txt_len", 256)
+
+        apply_lemmatizer = cfg.get("apply_lemmatizer", False)
+
+        qformer_text_input = cfg.get("qformer_text_input", True)
+        truncate_q_former_output = cfg.get("truncate_q_former_output", True)
+
+        model = cls(
+            vit_model=vit_model,
+            img_size=img_size,
+            drop_path_rate=drop_path_rate,
+            use_grad_checkpoint=use_grad_checkpoint,
+            vit_precision=vit_precision,
+            freeze_vit=freeze_vit,
+            freeze_vit_ln=freeze_vit_ln,
+            num_query_token=num_query_token,
+            llm_model=llm_model,
+            prompt=prompt,
+            max_txt_len=max_txt_len,
+            max_output_txt_len=max_output_txt_len,
+            apply_lemmatizer=apply_lemmatizer,
+            qformer_text_input=qformer_text_input,
+            truncate_q_former_output=truncate_q_former_output
+        )
+
+        model.load_checkpoint_from_config(cfg)
+
+        return model
diff --git a/unimernet/models/blip2_models/modeling_llama.py b/unimernet/models/blip2_models/modeling_llama.py
new file mode 100644
index 0000000000000000000000000000000000000000..08d16a4abfb0a83dc416888755e31ea55c5be02b
--- /dev/null
+++ b/unimernet/models/blip2_models/modeling_llama.py
@@ -0,0 +1,994 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, \
+    SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, \
+    replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+from einops import rearrange
+
+FLASH_ATTN_FLAG = True
+try:
+    from flash_attn.flash_attn_interface import (  # pip3 install "flash-attn>=2.0"
+        flash_attn_varlen_qkvpacked_func,
+    )
+    from flash_attn.bert_padding import unpad_input, pad_input
+
+    cuda_major, cuda_minor = torch.cuda.get_device_capability()
+    if cuda_major < 8:
+        logging.warning(
+            "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+            "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+        )
+        FLASH_ATTN_FLAG = False
+except ImportError:
+    FLASH_ATTN_FLAG = False
+    logging.warning("You haven't installed flash attention")
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+        input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+    """
+    Make causal mask used for bi-directional self-attention.
+    """
+    bsz, tgt_len = input_ids_shape
+    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
+    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+    mask = mask.to(dtype)
+
+    if past_key_values_length > 0:
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+    """
+    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+    """
+    bsz, src_len = mask.size()
+    tgt_len = tgt_len if tgt_len is not None else src_len
+
+    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+    inverted_mask = 1.0 - expanded_mask
+
+    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        LlamaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return (self.weight * hidden_states).to(input_dtype)
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+        super().__init__()
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        # Build here to make `torch.jit.trace` work.
+        self.max_seq_len_cached = max_position_embeddings
+        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+        if seq_len > self.max_seq_len_cached:
+            self.max_seq_len_cached = seq_len
+            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            # Different from paper, but it uses a different permutation in order to obtain the same calculation
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+        return (
+            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+        )
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2:]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
+    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+    def __init__(
+            self,
+            hidden_size: int,
+            intermediate_size: int,
+            hidden_act: str,
+    ):
+        super().__init__()
+        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.act_fn = ACT2FN[hidden_act]
+
+    def forward(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def flash_attn_forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.Tensor] = None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: bool = False,
+            use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel
+
+        attention_mask: [bsz, q_len]
+        """
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = (
+            self.q_proj(hidden_states)
+                .view(bsz, q_len, self.num_heads, self.head_dim)
+                .transpose(1, 2)
+        )
+        key_states = (
+            self.k_proj(hidden_states)
+                .view(bsz, q_len, self.num_heads, self.head_dim)
+                .transpose(1, 2)
+        )
+        value_states = (
+            self.v_proj(hidden_states)
+                .view(bsz, q_len, self.num_heads, self.head_dim)
+                .transpose(1, 2)
+        )
+        # [bsz, q_len, nh, hd]
+        # [bsz, nh, q_len, hd]
+
+        kv_seq_len = key_states.shape[-2]
+        assert past_key_value is None, "past_key_value is not supported"
+
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(
+            query_states, key_states, cos, sin, position_ids
+        )
+        # [bsz, nh, t, hd]
+        assert not output_attentions, "output_attentions is not supported"
+        assert not use_cache, "use_cache is not supported"
+
+        # Flash attention codes from
+        # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+        # transform the data into the format required by flash attention
+        qkv = torch.stack(
+            [query_states, key_states, value_states], dim=2
+        )  # [bsz, nh, 3, q_len, hd]
+        qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]
+        # We have disabled _prepare_decoder_attention_mask in LlamaModel
+        # the attention_mask should be the same as the key_padding_mask
+        key_padding_mask = attention_mask
+
+        if key_padding_mask is None:
+            qkv = rearrange(qkv, "b s ... -> (b s) ...")
+            max_s = q_len
+            cu_q_lens = torch.arange(
+                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+            )
+            output = flash_attn_varlen_qkvpacked_func(
+                qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+            )
+            output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
+        else:
+            nheads = qkv.shape[-2]
+            x = rearrange(qkv, "b s three h d -> b s (three h d)")
+            x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+            x_unpad = rearrange(
+                x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
+            )
+            output_unpad = flash_attn_varlen_qkvpacked_func(
+                x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+            )
+            output = rearrange(
+                pad_input(
+                    rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
+                ),
+                "b s (h d) -> b s h d",
+                h=nheads,
+            )
+        return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: bool = False,
+            use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        if FLASH_ATTN_FLAG and not use_cache:
+            return self.flash_attn_forward(hidden_states, attention_mask, position_ids, past_key_value,
+                                           output_attentions, use_cache)
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        # [bsz, nh, t, hd]
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = LlamaAttention(config=config)
+        self.mlp = LlamaMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+        )
+        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: Optional[bool] = False,
+            use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`LlamaConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+    config_class = LlamaConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["LlamaDecoderLayer"]
+    _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if isinstance(module, LlamaModel):
+            module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+    Args:
+        config: LlamaConfig
+    """
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+        # create causal mask
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        combined_attention_mask = None
+        if input_shape[-1] > 1:
+            combined_attention_mask = _make_causal_mask(
+                input_shape,
+                inputs_embeds.dtype,
+                device=inputs_embeds.device,
+                past_key_values_length=past_key_values_length,
+            )
+
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+                inputs_embeds.device
+            )
+            combined_attention_mask = (
+                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+            )
+
+        return combined_attention_mask
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    def forward(
+            self,
+            input_ids: torch.LongTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        if past_key_values is not None:
+            past_key_values_length = past_key_values[0][0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+            )
+        if not (FLASH_ATTN_FLAG and (use_cache is False)):
+            attention_mask = self._prepare_decoder_attention_mask(
+                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+            )
+
+        hidden_states = inputs_embeds
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, output_attentions, None)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(decoder_layer),
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    None,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = LlamaModel(config)
+
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def forward(
+            self,
+            input_ids: torch.LongTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            labels: Optional[torch.LongTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+            reduction: Optional[str] = "mean",
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+        >>> prompt = "Hey, are you consciours? Can you talk to me?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss(reduction=reduction)
+            shift_logits = shift_logits.view(-1, self.config.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+            if reduction == "none":
+                # loss = loss.view(logits.size(0), -1).sum(1)
+                loss = loss.view(logits.size(0), -1).mean(1)
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+            self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+    ):
+        if past_key_values:
+            input_ids = input_ids[:, -1:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+        return reordered_past
+
+
+@add_start_docstrings(
+    """
+    The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+    (e.g. GPT-2) do.
+
+    Since it does classification on the last token, it requires to know the position of the last token. If a
+    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+    each row of the batch).
+    """,
+    LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.model = LlamaModel(config)
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    def forward(
+            self,
+            input_ids: torch.LongTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            labels: Optional[torch.LongTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = transformer_outputs[0]
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+        if self.config.pad_token_id is None:
+            sequence_lengths = -1
+        else:
+            if input_ids is not None:
+                sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+            else:
+                sequence_lengths = -1
+
+        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(pooled_logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(pooled_logits, labels)
+        if not return_dict:
+            output = (pooled_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
diff --git a/unimernet/models/blip2_models/modeling_llama_.py b/unimernet/models/blip2_models/modeling_llama_.py
new file mode 100644
index 0000000000000000000000000000000000000000..372889e0f7495ae7db0cdcd1bebd748833f66e93
--- /dev/null
+++ b/unimernet/models/blip2_models/modeling_llama_.py
@@ -0,0 +1,885 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch LLaMA model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+    """
+    Make causal mask used for bi-directional self-attention.
+    """
+    bsz, tgt_len = input_ids_shape
+    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
+    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+    mask = mask.to(dtype)
+
+    if past_key_values_length > 0:
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+    """
+    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+    """
+    bsz, src_len = mask.size()
+    tgt_len = tgt_len if tgt_len is not None else src_len
+
+    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+    inverted_mask = 1.0 - expanded_mask
+
+    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class LlamaRMSNorm(nn.Module):
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        LlamaRMSNorm is equivalent to T5LayerNorm
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states):
+        input_dtype = hidden_states.dtype
+        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+        return (self.weight * hidden_states).to(input_dtype)
+
+
+class LlamaRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+        super().__init__()
+        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+        # Build here to make `torch.jit.trace` work.
+        self.max_seq_len_cached = max_position_embeddings
+        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+        self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
+        if seq_len > self.max_seq_len_cached:
+            self.max_seq_len_cached = seq_len
+            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
+            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+            # Different from paper, but it uses a different permutation in order to obtain the same calculation
+            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
+            self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
+            self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
+        return (
+            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
+        )
+
+
+def rotate_half(x):
+    """Rotates half the hidden dims of the input."""
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2 :]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+    gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
+    gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
+    cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+    sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class LlamaMLP(nn.Module):
+    def __init__(
+        self,
+        hidden_size: int,
+        intermediate_size: int,
+        hidden_act: str,
+    ):
+        super().__init__()
+        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
+        self.act_fn = ACT2FN[hidden_act]
+
+    def forward(self, x):
+        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+
+class LlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+        self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        # [bsz, nh, t, hd]
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class LlamaDecoderLayer(nn.Module):
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = LlamaAttention(config=config)
+        self.mlp = LlamaMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+        )
+        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`LlamaConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+class LlamaPreTrainedModel(PreTrainedModel):
+    config_class = LlamaConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["LlamaDecoderLayer"]
+    _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    def _set_gradient_checkpointing(self, module, value=False):
+        if isinstance(module, LlamaModel):
+            module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+            information on the default strategy.
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+            config.n_positions - 1]`.
+
+            [What are position IDs?](../glossary#position-ids)
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+            model's internal embedding lookup matrix.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+    LLAMA_START_DOCSTRING,
+)
+class LlamaModel(LlamaPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+    Args:
+        config: LlamaConfig
+    """
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__(config)
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
+        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+        # create causal mask
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        combined_attention_mask = None
+        if input_shape[-1] > 1:
+            combined_attention_mask = _make_causal_mask(
+                input_shape,
+                inputs_embeds.dtype,
+                device=inputs_embeds.device,
+                past_key_values_length=past_key_values_length,
+            )
+
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+                inputs_embeds.device
+            )
+            combined_attention_mask = (
+                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+            )
+
+        return combined_attention_mask
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPast]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        if past_key_values is not None:
+            past_key_values_length = past_key_values[0][0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+            )
+        attention_mask = self._prepare_decoder_attention_mask(
+            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+        )
+
+        hidden_states = inputs_embeds
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = () if use_cache else None
+
+        for idx, decoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        # None for past_key_value
+                        return module(*inputs, output_attentions, None)
+
+                    return custom_forward
+
+                layer_outputs = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(decoder_layer),
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    None,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+        hidden_states = self.norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class LlamaForCausalLM(LlamaPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.model = LlamaModel(config)
+
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        reduction: Optional[str] = "mean",
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        r"""
+        Args:
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+        >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+        >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+        >>> prompt = "Hey, are you consciours? Can you talk to me?"
+        >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+        >>> # Generate
+        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+        "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0]
+        logits = self.lm_head(hidden_states)
+
+        loss = None
+        if labels is not None:
+            # Shift so that tokens < n predict n
+            shift_logits = logits[..., :-1, :].contiguous()
+            shift_labels = labels[..., 1:].contiguous()
+            # Flatten the tokens
+            loss_fct = CrossEntropyLoss(reduction=reduction)
+            shift_logits = shift_logits.view(-1, self.config.vocab_size)
+            shift_labels = shift_labels.view(-1)
+            # Enable model parallelism
+            shift_labels = shift_labels.to(shift_logits.device)
+            loss = loss_fct(shift_logits, shift_labels)
+            if reduction == "none":
+                # loss = loss.view(logits.size(0), -1).sum(1)
+                loss = loss.view(logits.size(0), -1).mean(1)
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+    ):
+        if past_key_values:
+            input_ids = input_ids[:, -1:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -1].unsqueeze(-1)
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+        return reordered_past
+
+
+@add_start_docstrings(
+    """
+    The LLaMa Model transformer with a sequence classification head on top (linear layer).
+
+    [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+    (e.g. GPT-2) do.
+
+    Since it does classification on the last token, it requires to know the position of the last token. If a
+    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+    each row of the batch).
+    """,
+    LLAMA_START_DOCSTRING,
+)
+class LlamaForSequenceClassification(LlamaPreTrainedModel):
+    _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        self.model = LlamaModel(config)
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        transformer_outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = transformer_outputs[0]
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+        if self.config.pad_token_id is None:
+            sequence_lengths = -1
+        else:
+            if input_ids is not None:
+                sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+            else:
+                sequence_lengths = -1
+
+        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            if self.config.problem_type is None:
+                if self.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.num_labels == 1:
+                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(pooled_logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(pooled_logits, labels)
+        if not return_dict:
+            output = (pooled_logits,) + transformer_outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
\ No newline at end of file
diff --git a/unimernet/models/clip_vit.py b/unimernet/models/clip_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5547f2756531e8db014a99ef7c70ee3a4ce1533
--- /dev/null
+++ b/unimernet/models/clip_vit.py
@@ -0,0 +1,254 @@
+from collections import OrderedDict
+from itertools import repeat
+import collections.abc
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+
+from unimernet.models.eva_vit import convert_weights_to_fp16
+from unimernet.common.dist_utils import download_cached_file
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu1 = nn.ReLU(inplace=True)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.relu2 = nn.ReLU(inplace=True)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu3 = nn.ReLU(inplace=True)
+
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(OrderedDict([
+                ("-1", nn.AvgPool2d(stride)),
+                ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+                ("1", nn.BatchNorm2d(planes * self.expansion))
+            ]))
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu1(self.bn1(self.conv1(x)))
+        out = self.relu2(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu3(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+
+    def forward(self, x):
+        x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1)  # NCHW -> (HW)NC
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x, key=x, value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False
+        )
+
+        return x[0]
+    
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(OrderedDict([
+            ("c_fc", nn.Linear(d_model, d_model * 4)),
+            ("gelu", QuickGELU()),
+            ("c_proj", nn.Linear(d_model * 4, d_model))
+        ]))
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+        if use_grad_checkpointing:
+            self.attn = checkpoint_wrapper(self.attn)
+            self.mlp = checkpoint_wrapper(self.mlp)
+            
+    def attention(self, x: torch.Tensor):
+        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+    def forward(self, x: torch.Tensor):
+        x = x + self.attention(self.ln_1(x))
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)])
+
+    def forward(self, x: torch.Tensor):
+        return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+    def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.num_features = width
+        self.num_heads = heads
+        self.num_patches = (input_resolution // patch_size) ** 2
+        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width))
+        self.ln_pre = LayerNorm(width)
+        
+        self.transformer = Transformer(width, layers, heads, use_grad_checkpointing=use_grad_checkpointing)
+           
+#         self.ln_final = LayerNorm(width)
+
+    def forward(self, x: torch.Tensor):
+
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        
+#         x = self.ln_final(x)
+        return x
+    
+    
+            
+# From PyTorch internals
+def _ntuple(n):
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+    return parse
+to_2tuple = _ntuple(2)        
+def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1):
+    # Rescale the grid of position embeddings when loading from state_dict
+    old_pos_embed = state_dict.get('positional_embedding', None)
+    
+    grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5)
+    if old_pos_embed is None:
+        return
+    grid_size = to_2tuple(grid_size)
+    extra_tokens = 1  # FIXME detect different token configs (ie no class token, or more)
+    new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
+    if new_seq_len == old_pos_embed.shape[0]:
+        return
+
+    if extra_tokens:
+        pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
+    else:
+        pos_emb_tok, pos_emb_img = None, old_pos_embed
+        
+    old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
+
+    print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
+    pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
+    pos_emb_img = F.interpolate(
+        pos_emb_img,
+        size=grid_size,
+        mode=interpolation,
+        align_corners=True,
+    )
+    pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
+    if pos_emb_tok is not None:
+        new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
+    else:
+        new_pos_embed = pos_emb_img
+    state_dict['positional_embedding'] = new_pos_embed
+    
+    
+def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"):
+    model = VisionTransformer(
+            input_resolution=img_size,
+            patch_size=14,
+            width=1024,
+            layers=23,
+            heads=16,
+            use_grad_checkpointing=use_checkpoint,
+        )         
+    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth"
+    cached_file = download_cached_file(
+        url, check_hash=False, progress=True
+    )
+    state_dict = torch.load(cached_file, map_location="cpu")    
+    interpolate_pos_embed(model,state_dict)
+    
+    incompatible_keys = model.load_state_dict(state_dict, strict=False)
+    # print(incompatible_keys)
+    
+    if precision == "fp16":
+        convert_weights_to_fp16(model)
+    return model
diff --git a/unimernet/models/eva_vit.py b/unimernet/models/eva_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..c495afeb15a6f8be0159cbc4da95e43a3513a33d
--- /dev/null
+++ b/unimernet/models/eva_vit.py
@@ -0,0 +1,448 @@
+# Based on EVA, BEIT, timm and DeiT code bases
+# https://github.com/baaivision/EVA
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/microsoft/unilm/tree/master/beit
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import drop_path, to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+
+from unimernet.common.dist_utils import download_cached_file
+
+
+def _cfg(url='', **kwargs):
+    return {
+        'url': url,
+        'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+        'crop_pct': .9, 'interpolation': 'bicubic',
+        'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+        **kwargs
+    }
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
+    """
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return 'p={}'.format(self.drop_prob)
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        # x = self.drop(x)
+        # commit this for the orignal BERT implement 
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+            self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
+            proj_drop=0., window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        if attn_head_dim is not None:
+            head_dim = attn_head_dim
+        all_head_dim = head_dim * self.num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+        if qkv_bias:
+            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+        else:
+            self.q_bias = None
+            self.v_bias = None
+
+        if window_size:
+            self.window_size = window_size
+            self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+            self.relative_position_bias_table = nn.Parameter(
+                torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+            # cls to token & token 2 cls & cls to cls
+
+            # get pair-wise relative position index for each token inside the window
+            coords_h = torch.arange(window_size[0])
+            coords_w = torch.arange(window_size[1])
+            coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+            coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+            relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+            relative_coords[:, :, 1] += window_size[1] - 1
+            relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+            relative_position_index = \
+                torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+            relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+            relative_position_index[0, 0:] = self.num_relative_distance - 3
+            relative_position_index[0:, 0] = self.num_relative_distance - 2
+            relative_position_index[0, 0] = self.num_relative_distance - 1
+
+            self.register_buffer("relative_position_index", relative_position_index)
+        else:
+            self.window_size = None
+            self.relative_position_bias_table = None
+            self.relative_position_index = None
+
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(all_head_dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x, rel_pos_bias=None):
+        B, N, C = x.shape
+        qkv_bias = None
+        if self.q_bias is not None:
+            qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
+        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+        qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        if self.relative_position_bias_table is not None:
+            relative_position_bias = \
+                self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                    self.window_size[0] * self.window_size[1] + 1,
+                    self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+            attn = attn + relative_position_bias.unsqueeze(0)
+
+        if rel_pos_bias is not None:
+            attn = attn + rel_pos_bias
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 window_size=None, attn_head_dim=None):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if init_values is not None and init_values > 0:
+            self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+            self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+        else:
+            self.gamma_1, self.gamma_2 = None, None
+
+    def forward(self, x, rel_pos_bias=None):
+        if self.gamma_1 is None:
+            x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+            x = x + self.drop_path(self.mlp(self.norm2(x)))
+        else:
+            x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+        return x
+
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x, **kwargs):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)
+        return x
+
+
+class RelativePositionBias(nn.Module):
+
+    def __init__(self, window_size, num_heads):
+        super().__init__()
+        self.window_size = window_size
+        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros(self.num_relative_distance, num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+        # cls to token & token 2 cls & cls to cls
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(window_size[0])
+        coords_w = torch.arange(window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+        relative_position_index = \
+            torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        relative_position_index[0, 0:] = self.num_relative_distance - 3
+        relative_position_index[0:, 0] = self.num_relative_distance - 2
+        relative_position_index[0, 0] = self.num_relative_distance - 1
+
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+    def forward(self):
+        relative_position_bias = \
+            self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+                self.window_size[0] * self.window_size[1] + 1,
+                self.window_size[0] * self.window_size[1] + 1, -1)  # Wh*Ww,Wh*Ww,nH
+        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+
+
+class VisionTransformer(nn.Module):
+    """ Vision Transformer with support for patch or hybrid CNN input stage
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+                 num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
+                 drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
+                 use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+                 use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
+        super().__init__()
+        self.image_size = img_size
+        self.num_classes = num_classes
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        if use_abs_pos_emb:
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        else:
+            self.pos_embed = None
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        if use_shared_rel_pos_bias:
+            self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
+        else:
+            self.rel_pos_bias = None
+        self.use_checkpoint = use_checkpoint
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+        self.use_rel_pos_bias = use_rel_pos_bias
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+                init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
+            for i in range(depth)])
+        #         self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+        #         self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+        #         self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=.02)
+        trunc_normal_(self.cls_token, std=.02)
+        # trunc_normal_(self.mask_token, std=.02)
+        #         if isinstance(self.head, nn.Linear):
+        #             trunc_normal_(self.head.weight, std=.02)
+        self.apply(self._init_weights)
+        self.fix_init_weight()
+
+    #         if isinstance(self.head, nn.Linear):
+    #             self.head.weight.data.mul_(init_scale)
+    #             self.head.bias.data.mul_(init_scale)
+
+    def fix_init_weight(self):
+        def rescale(param, layer_id):
+            param.div_(math.sqrt(2.0 * layer_id))
+
+        for layer_id, layer in enumerate(self.blocks):
+            rescale(layer.attn.proj.weight.data, layer_id + 1)
+            rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def get_classifier(self):
+        return self.head
+
+    def reset_classifier(self, num_classes, global_pool=''):
+        self.num_classes = num_classes
+        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+    def forward_features(self, x):
+        x = self.patch_embed(x)
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, rel_pos_bias)
+            else:
+                x = blk(x, rel_pos_bias)
+        return x
+
+    #         x = self.norm(x)
+
+    #         if self.fc_norm is not None:
+    #             t = x[:, 1:, :]
+    #             return self.fc_norm(t.mean(1))
+    #         else:
+    #             return x[:, 0]
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        #         x = self.head(x)
+        return x
+
+    def get_intermediate_layers(self, x):
+        x = self.patch_embed(x)
+        batch_size, seq_len, _ = x.size()
+
+        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        x = self.pos_drop(x)
+
+        features = []
+        rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+        for blk in self.blocks:
+            x = blk(x, rel_pos_bias)
+            features.append(x)
+
+        return features
+
+
+def interpolate_pos_embed(model, checkpoint_model):
+    if 'pos_embed' in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches ** 0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def convert_weights_to_fp16(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+    #         if isinstance(l, (nn.MultiheadAttention, Attention)):
+    #             for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+    #                 tensor = getattr(l, attr)
+    #                 if tensor is not None:
+    #                     tensor.data = tensor.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"):
+    model = VisionTransformer(
+        img_size=img_size,
+        patch_size=14,
+        use_mean_pooling=False,
+        embed_dim=1408,
+        depth=39,
+        num_heads=1408 // 88,
+        mlp_ratio=4.3637,
+        qkv_bias=True,
+        drop_path_rate=drop_path_rate,
+        norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        use_checkpoint=use_checkpoint,
+    )
+    url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
+    cached_file = download_cached_file(
+        url, check_hash=False, progress=True
+    )
+    state_dict = torch.load(cached_file, map_location="cpu")
+    interpolate_pos_embed(model, state_dict)
+
+    incompatible_keys = model.load_state_dict(state_dict, strict=False)
+    #     print(incompatible_keys)
+
+    if precision == "fp16":
+        #         model.to("cuda")
+        convert_weights_to_fp16(model)
+    return model
diff --git a/unimernet/models/unimernet/__init__.py b/unimernet/models/unimernet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..474a8a8878e82ee7ab895af876e07d469038356d
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aef80e0ababefc0707c88033d942ed541903fddd
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/configuration_unimernet_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c69d47a15353280939c43d456713882fa257d69e
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/configuration_unimernet_encoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..386cda1a50aff3a9cc6dc7a0007f44f06148a219
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/encoder_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a342676a6daacc8fa1142a0cebce81ebffc567a2
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/modeling_unimernet_decoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1326b01ebe093a348d72b6c6bd79b9b9915f3f18
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/modeling_unimernet_encoder.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5451cf820f71b4ada76773203ba3db52c27e6800
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/processor.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc b/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e4fdac683ba4855eaab5174ed1e3798f8f56a59
Binary files /dev/null and b/unimernet/models/unimernet/__pycache__/unimernet.cpython-310.pyc differ
diff --git a/unimernet/models/unimernet/configuration_unimernet_decoder.py b/unimernet/models/unimernet/configuration_unimernet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa794fd714f1cb018ec5b5e5d2173c9cd7855c3
--- /dev/null
+++ b/unimernet/models/unimernet/configuration_unimernet_decoder.py
@@ -0,0 +1,387 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""MBART model configuration"""
+
+from collections import OrderedDict
+from typing import Any, Mapping, Optional
+
+from transformers import PreTrainedTokenizer
+from transformers.configuration_utils import PretrainedConfig
+from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
+from transformers.onnx.utils import compute_effective_axis_dimension
+from transformers.utils import TensorType, is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class MBartConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
+    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+    defaults will yield a similar configuration to that of the MBART
+    [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+
+    Args:
+        vocab_size (`int`, *optional*, defaults to 50265):
+            Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
+            `inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
+        d_model (`int`, *optional*, defaults to 1024):
+            Dimensionality of the layers and the pooler layer.
+        encoder_layers (`int`, *optional*, defaults to 12):
+            Number of encoder layers.
+        decoder_layers (`int`, *optional*, defaults to 12):
+            Number of decoder layers.
+        encoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer encoder.
+        decoder_attention_heads (`int`, *optional*, defaults to 16):
+            Number of attention heads for each attention layer in the Transformer decoder.
+        decoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        encoder_ffn_dim (`int`, *optional*, defaults to 4096):
+            Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
+        activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+            `"relu"`, `"silu"` and `"gelu_new"` are supported.
+        dropout (`float`, *optional*, defaults to 0.1):
+            The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+        attention_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        activation_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for activations inside the fully connected layer.
+        classifier_dropout (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for classifier.
+        max_position_embeddings (`int`, *optional*, defaults to 1024):
+            The maximum sequence length that this model might ever be used with. Typically set this to something large
+            just in case (e.g., 512 or 1024 or 2048).
+        init_std (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        encoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        decoder_layerdrop (`float`, *optional*, defaults to 0.0):
+            The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
+            for more details.
+        scale_embedding (`bool`, *optional*, defaults to `False`):
+            Scale embeddings by diving by sqrt(d_model).
+        use_cache (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should return the last key/values attentions (not used by all models)
+        forced_eos_token_id (`int`, *optional*, defaults to 2):
+            The id of the token to force as the last generated token when `max_length` is reached. Usually set to
+            `eos_token_id`.
+
+    Example:
+
+    ```python
+    >>> from transformers import MBartConfig, MBartModel
+
+    >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
+    >>> configuration = MBartConfig()
+
+    >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
+    >>> model = MBartModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "mbart"
+    keys_to_ignore_at_inference = ["past_key_values"]
+    attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+    def __init__(
+        self,
+        vocab_size=50265,
+        max_position_embeddings=1024,
+        encoder_layers=12,
+        encoder_ffn_dim=4096,
+        encoder_attention_heads=16,
+        decoder_layers=12,
+        decoder_ffn_dim=4096,
+        decoder_attention_heads=16,
+        encoder_layerdrop=0.0,
+        decoder_layerdrop=0.0,
+        use_cache=True,
+        is_encoder_decoder=True,
+        activation_function="gelu",
+        d_model=1024,
+        dropout=0.1,
+        attention_dropout=0.0,
+        activation_dropout=0.0,
+        init_std=0.02,
+        classifier_dropout=0.0,
+        scale_embedding=False,
+        pad_token_id=1,
+        bos_token_id=0,
+        eos_token_id=2,
+        forced_eos_token_id=2,
+        **kwargs,
+    ):
+        self.vocab_size = vocab_size
+        self.max_position_embeddings = max_position_embeddings
+        self.d_model = d_model
+        self.encoder_ffn_dim = encoder_ffn_dim
+        self.encoder_layers = encoder_layers
+        self.encoder_attention_heads = encoder_attention_heads
+        self.decoder_ffn_dim = decoder_ffn_dim
+        self.decoder_layers = decoder_layers
+        self.decoder_attention_heads = decoder_attention_heads
+        self.dropout = dropout
+        self.attention_dropout = attention_dropout
+        self.activation_dropout = activation_dropout
+        self.activation_function = activation_function
+        self.init_std = init_std
+        self.encoder_layerdrop = encoder_layerdrop
+        self.decoder_layerdrop = decoder_layerdrop
+        self.classifier_dropout = classifier_dropout
+        self.use_cache = use_cache
+        self.num_hidden_layers = encoder_layers
+        self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
+        super().__init__(
+            pad_token_id=pad_token_id,
+            bos_token_id=bos_token_id,
+            eos_token_id=eos_token_id,
+            is_encoder_decoder=is_encoder_decoder,
+            forced_eos_token_id=forced_eos_token_id,
+            **kwargs,
+        )
+
+
+# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart
+class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
+    @property
+    def inputs(self) -> Mapping[str, Mapping[int, str]]:
+        if self.task in ["default", "seq2seq-lm"]:
+            common_inputs = OrderedDict(
+                [
+                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+                ]
+            )
+
+            if self.use_past:
+                common_inputs["decoder_input_ids"] = {0: "batch"}
+                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
+            else:
+                common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
+                common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
+
+            if self.use_past:
+                self.fill_with_past_key_values_(common_inputs, direction="inputs")
+        elif self.task == "causal-lm":
+            # TODO: figure this case out.
+            common_inputs = OrderedDict(
+                [
+                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+                ]
+            )
+            if self.use_past:
+                num_encoder_layers, _ = self.num_layers
+                for i in range(num_encoder_layers):
+                    common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
+                    common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
+        else:
+            common_inputs = OrderedDict(
+                [
+                    ("input_ids", {0: "batch", 1: "encoder_sequence"}),
+                    ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
+                    ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
+                    ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
+                ]
+            )
+
+        return common_inputs
+
+    @property
+    def outputs(self) -> Mapping[str, Mapping[int, str]]:
+        if self.task in ["default", "seq2seq-lm"]:
+            common_outputs = super().outputs
+        else:
+            common_outputs = super(OnnxConfigWithPast, self).outputs
+            if self.use_past:
+                num_encoder_layers, _ = self.num_layers
+                for i in range(num_encoder_layers):
+                    common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
+                    common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
+        return common_outputs
+
+    def _generate_dummy_inputs_for_default_and_seq2seq_lm(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        batch_size: int = -1,
+        seq_length: int = -1,
+        is_pair: bool = False,
+        framework: Optional[TensorType] = None,
+    ) -> Mapping[str, Any]:
+        encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+            tokenizer, batch_size, seq_length, is_pair, framework
+        )
+
+        # Generate decoder inputs
+        decoder_seq_length = seq_length if not self.use_past else 1
+        decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+            tokenizer, batch_size, decoder_seq_length, is_pair, framework
+        )
+        decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
+        common_inputs = dict(**encoder_inputs, **decoder_inputs)
+
+        if self.use_past:
+            if not is_torch_available():
+                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+            else:
+                import torch
+            batch, encoder_seq_length = common_inputs["input_ids"].shape
+            decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
+            num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
+            encoder_shape = (
+                batch,
+                num_encoder_attention_heads,
+                encoder_seq_length,
+                self._config.hidden_size // num_encoder_attention_heads,
+            )
+            decoder_past_length = decoder_seq_length + 3
+            decoder_shape = (
+                batch,
+                num_decoder_attention_heads,
+                decoder_past_length,
+                self._config.hidden_size // num_decoder_attention_heads,
+            )
+
+            common_inputs["decoder_attention_mask"] = torch.cat(
+                [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
+            )
+
+            common_inputs["past_key_values"] = []
+            # If the number of encoder and decoder layers are present in the model configuration, both are considered
+            num_encoder_layers, num_decoder_layers = self.num_layers
+            min_num_layers = min(num_encoder_layers, num_decoder_layers)
+            max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
+            remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
+
+            for _ in range(min_num_layers):
+                common_inputs["past_key_values"].append(
+                    (
+                        torch.zeros(decoder_shape),
+                        torch.zeros(decoder_shape),
+                        torch.zeros(encoder_shape),
+                        torch.zeros(encoder_shape),
+                    )
+                )
+            # TODO: test this.
+            shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
+            for _ in range(min_num_layers, max_num_layers):
+                common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
+        return common_inputs
+
+    def _generate_dummy_inputs_for_causal_lm(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        batch_size: int = -1,
+        seq_length: int = -1,
+        is_pair: bool = False,
+        framework: Optional[TensorType] = None,
+    ) -> Mapping[str, Any]:
+        common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+            tokenizer, batch_size, seq_length, is_pair, framework
+        )
+
+        if self.use_past:
+            if not is_torch_available():
+                raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
+            else:
+                import torch
+            batch, seqlen = common_inputs["input_ids"].shape
+            # Not using the same length for past_key_values
+            past_key_values_length = seqlen + 2
+            num_encoder_layers, _ = self.num_layers
+            num_encoder_attention_heads, _ = self.num_attention_heads
+            past_shape = (
+                batch,
+                num_encoder_attention_heads,
+                past_key_values_length,
+                self._config.hidden_size // num_encoder_attention_heads,
+            )
+
+            mask_dtype = common_inputs["attention_mask"].dtype
+            common_inputs["attention_mask"] = torch.cat(
+                [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
+            )
+            common_inputs["past_key_values"] = [
+                (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
+            ]
+        return common_inputs
+
+    def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        batch_size: int = -1,
+        seq_length: int = -1,
+        is_pair: bool = False,
+        framework: Optional[TensorType] = None,
+    ) -> Mapping[str, Any]:
+        # Copied from OnnxConfig.generate_dummy_inputs
+        # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
+        # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
+        batch_size = compute_effective_axis_dimension(
+            batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
+        )
+
+        # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
+        token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
+        seq_length = compute_effective_axis_dimension(
+            seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
+        )
+
+        # Generate dummy inputs according to compute batch and sequence
+        dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
+        common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
+        return common_inputs
+
+    def generate_dummy_inputs(
+        self,
+        tokenizer: PreTrainedTokenizer,
+        batch_size: int = -1,
+        seq_length: int = -1,
+        is_pair: bool = False,
+        framework: Optional[TensorType] = None,
+    ) -> Mapping[str, Any]:
+        if self.task in ["default", "seq2seq-lm"]:
+            common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
+                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+            )
+
+        elif self.task == "causal-lm":
+            common_inputs = self._generate_dummy_inputs_for_causal_lm(
+                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+            )
+        else:
+            common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
+                tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
+            )
+
+        return common_inputs
+
+    def _flatten_past_key_values_(self, flattened_output, name, idx, t):
+        if self.task in ["default", "seq2seq-lm"]:
+            flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
+        else:
+            flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
+                flattened_output, name, idx, t
+            )
diff --git a/unimernet/models/unimernet/configuration_unimernet_encoder.py b/unimernet/models/unimernet/configuration_unimernet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7bd531d9d4a9b4fd323dee58ec5462d48c17168
--- /dev/null
+++ b/unimernet/models/unimernet/configuration_unimernet_encoder.py
@@ -0,0 +1,132 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Donut Swin Transformer model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class UnimerNetConfig(PretrainedConfig):
+    r"""
+    This is the configuration class to store the configuration of a [`UnimerNetModel`]. It is used to instantiate a
+    Donut model according to the specified arguments, defining the model architecture. Instantiating a configuration
+    with the defaults will yield a similar configuration to that of the Donut
+    [naver-clova-ix/donut-base](https://huggingface.co/naver-clova-ix/donut-base) architecture.
+
+    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+    documentation from [`PretrainedConfig`] for more information.
+
+    Args:
+        image_size (`int`, *optional*, defaults to 224):
+            The size (resolution) of each image.
+        patch_size (`int`, *optional*, defaults to 4):
+            The size (resolution) of each patch.
+        num_channels (`int`, *optional*, defaults to 3):
+            The number of input channels.
+        embed_dim (`int`, *optional*, defaults to 96):
+            Dimensionality of patch embedding.
+        depths (`list(int)`, *optional*, defaults to `[2, 2, 6, 2]`):
+            Depth of each layer in the Transformer encoder.
+        num_heads (`list(int)`, *optional*, defaults to `[3, 6, 12, 24]`):
+            Number of attention heads in each layer of the Transformer encoder.
+        window_size (`int`, *optional*, defaults to 7):
+            Size of windows.
+        mlp_ratio (`float`, *optional*, defaults to 4.0):
+            Ratio of MLP hidden dimensionality to embedding dimensionality.
+        qkv_bias (`bool`, *optional*, defaults to `True`):
+            Whether or not a learnable bias should be added to the queries, keys and values.
+        hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout probability for all fully connected layers in the embeddings and encoder.
+        attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
+            The dropout ratio for the attention probabilities.
+        drop_path_rate (`float`, *optional*, defaults to 0.1):
+            Stochastic depth rate.
+        hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+            The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
+            `"selu"` and `"gelu_new"` are supported.
+        use_absolute_embeddings (`bool`, *optional*, defaults to `False`):
+            Whether or not to add absolute position embeddings to the patch embeddings.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+        layer_norm_eps (`float`, *optional*, defaults to 1e-05):
+            The epsilon used by the layer normalization layers.
+
+    Example:
+
+    ```python
+    >>> from transformers import UnimerNetConfig, UnimerNetModel
+
+    >>> # Initializing a Donut naver-clova-ix/donut-base style configuration
+    >>> configuration = UnimerNetConfig()
+
+    >>> # Randomly initializing a model from the naver-clova-ix/donut-base style configuration
+    >>> model = UnimerNetModel(configuration)
+
+    >>> # Accessing the model configuration
+    >>> configuration = model.config
+    ```"""
+
+    model_type = "donut-swin"
+
+    attribute_map = {
+        "num_attention_heads": "num_heads",
+        "num_hidden_layers": "num_layers",
+    }
+
+    def __init__(
+        self,
+        image_size=224,
+        patch_size=4,
+        num_channels=3,
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        hidden_dropout_prob=0.0,
+        attention_probs_dropout_prob=0.0,
+        drop_path_rate=0.1,
+        hidden_act="gelu",
+        use_absolute_embeddings=False,
+        initializer_range=0.02,
+        layer_norm_eps=1e-5,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.embed_dim = embed_dim
+        self.depths = depths
+        self.num_layers = len(depths)
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.mlp_ratio = mlp_ratio
+        self.qkv_bias = qkv_bias
+        self.hidden_dropout_prob = hidden_dropout_prob
+        self.attention_probs_dropout_prob = attention_probs_dropout_prob
+        self.drop_path_rate = drop_path_rate
+        self.hidden_act = hidden_act
+        self.use_absolute_embeddings = use_absolute_embeddings
+        self.layer_norm_eps = layer_norm_eps
+        self.initializer_range = initializer_range
+        # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
+        # this indicates the channel dimension after the last stage of the model
+        self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
diff --git a/unimernet/models/unimernet/encoder_decoder.py b/unimernet/models/unimernet/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cf350afdea954ceeb513474b4d3bc861487682d
--- /dev/null
+++ b/unimernet/models/unimernet/encoder_decoder.py
@@ -0,0 +1,843 @@
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ftfy import fix_text
+from torch.nn import CrossEntropyLoss
+from typing import Optional, Tuple, Union, List
+from dataclasses import dataclass
+import math
+
+from transformers import PreTrainedTokenizerFast
+from transformers import VisionEncoderDecoderConfig
+from transformers import AutoModel, VisionEncoderDecoderModel, AutoImageProcessor, MBartForCausalLM
+from unimernet.models.unimernet.processor import VariableDonutProcessor, VariableDonutImageProcessor
+# from transformers.models.mbart.modeling_mbart import MBartDecoder
+from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput, CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
+# from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, DonutSwinEncoder
+from transformers.utils import logging, ModelOutput
+
+from functools import partial
+from .configuration_unimernet_encoder import UnimerNetConfig
+
+from .modeling_unimernet_encoder import UnimerNetPatchEmbeddings, UnimerNetEmbeddings, UnimerNetModel, UnimerNetEncoder
+from .modeling_unimernet_decoder import MBartDecoder
+
+logger = logging.get_logger(__name__)
+
+
+class VariableUnimerNetConfig(UnimerNetConfig):
+    pass
+
+
+def build_norm_layer(dim,
+                     norm_layer,):
+    layers = []
+    if norm_layer == 'BN':
+        layers.append(nn.BatchNorm2d(dim))
+    else:
+        raise NotImplementedError(
+            f'build_norm_layer does not support {norm_layer}')
+    return nn.Sequential(*layers)
+
+class StemLayer(nn.Module):
+    r""" Stem layer of InternImage
+    Args:
+        in_chans (int): number of input channels
+        out_chans (int): number of output channels
+        act_layer (str): activation layer
+        norm_layer (str): normalization layer
+    """
+
+    def __init__(self,
+                 in_chans=3,
+                 out_chans=96,
+                 act_layer=nn.GELU,
+                 norm_layer='BN'):
+        super().__init__()
+        self.conv1 = nn.Conv2d(in_chans,
+                               out_chans // 2,
+                               kernel_size=3,
+                               stride=2,
+                               padding=1)
+        self.norm1 = build_norm_layer(out_chans // 2, norm_layer)
+        
+        self.act = act_layer()
+        self.conv2 = nn.Conv2d(out_chans // 2,
+                               out_chans,
+                               kernel_size=3,
+                               stride=2,
+                               padding=1)
+
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.norm1(x)
+        x = self.act(x)
+        x = self.conv2(x)
+        return x
+    
+
+class VariableUnimerNetPatchEmbeddings(UnimerNetPatchEmbeddings):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        print("VariableUnimerNetPatchEmbeddings init")
+        super().__init__(config)
+        num_channels, hidden_size = config.num_channels, config.embed_dim
+        self.projection = StemLayer(in_chans=num_channels, out_chans=hidden_size)
+
+
+
+
+class VariableUnimerNetEmbeddings(UnimerNetEmbeddings):
+    """
+    Construct the patch and position embeddings. Optionally, also the mask token.
+    """
+
+    def __init__(self, config, use_mask_token=False):
+        super().__init__(config, use_mask_token)
+
+        self.patch_embeddings = VariableUnimerNetPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.patch_grid = self.patch_embeddings.grid_size
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+        self.position_embeddings = None
+
+        if config.use_absolute_embeddings:
+            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+
+        self.row_embeddings = None
+        self.column_embeddings = None
+        if config.use_2d_embeddings:
+            self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
+            self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
+
+        self.norm = nn.LayerNorm(config.embed_dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(
+            self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, interpolate_pos_encoding: bool = False,
+        ) -> Tuple[torch.Tensor]:
+        # print('before pixel_values.shape',pixel_values.shape)
+
+        embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+
+        # print('after embeddings.shape',embeddings.shape)
+
+        # Layernorm across the last dimension (each patch is a single row)
+        embeddings = self.norm(embeddings)
+        batch_size, seq_len, embed_dim = embeddings.size()
+
+        if bool_masked_pos is not None:
+            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        if self.position_embeddings is not None:
+            embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
+
+        if self.row_embeddings is not None and self.column_embeddings is not None:
+            # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
+            row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1],
+                                                                                                dim=1)
+            column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
+
+            embeddings = embeddings + row_embeddings + column_embeddings
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings, output_dimensions
+
+class VariableUnimerNetModel(UnimerNetModel):
+    config_class = VariableUnimerNetConfig
+
+    def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+        print("VariableUnimerNetModel init")
+        super().__init__(config)
+        
+        self.config = config
+        self.num_layers = len(config.depths)
+        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+        self.embeddings = VariableUnimerNetEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = UnimerNetEncoder(config, self.embeddings.patch_grid)
+
+        self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentionsAndCounting(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+    """
+    loss: Optional[torch.FloatTensor] = None
+    logits: torch.FloatTensor = None
+    counting: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+    attentions: Optional[Tuple[torch.FloatTensor]] = None
+    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class CustomMBartDecoder(MBartDecoder):
+    def __init__(self, config):
+        print("CustomMBartDecoder init")
+        super().__init__(config)
+        hidden_size = config.d_model
+        self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        count_pred: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                of the decoder.
+            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+                selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+                cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
+                shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
+                `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
+                control over how to convert `input_ids` indices into associated vectors than the model's internal
+                embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+
+        # TODO: add counting context weight to hidden_states
+        if count_pred is not None:
+            count_context_weight = self.counting_context_weight(count_pred)
+            hidden_states = hidden_states + 0.5 * count_context_weight.unsqueeze(1)
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+class CustomMBartForCausalLM(MBartForCausalLM):
+    def __init__(self, config):
+        print("CustomMBartForCausalLM init")
+        super().__init__(config)
+        # Modify the decoder within MBartDecoderWrapper
+        self.model.decoder = CustomMBartDecoder(config)
+
+    
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        count_gt: Optional[torch.LongTensor] = None,
+    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                if the model is configured as a decoder.
+            encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+                in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, MBartForCausalLM
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+        >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
+        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+        >>> outputs = model(**inputs)
+
+        >>> logits = outputs.logits
+        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+        >>> list(logits.shape) == expected_shape
+        True
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        
+
+        count_pred = None
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            count_pred=count_pred,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        logits = self.lm_head(outputs[0])
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (logits, count_pred) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentionsAndCounting(
+            loss=loss,
+            logits=logits,
+            counting=count_pred,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+
+class CustomVisionEncoderDecoderModel(VisionEncoderDecoderModel):
+    def __init__(self, config):
+        print("CustomVisionEncoderDecoderModel init")
+        super().__init__(config)
+        # Replace the MBartForCausalLM with your CustomMBartForCausalLM
+        self.encoder = VariableUnimerNetModel(config.encoder)
+        self.decoder = CustomMBartForCausalLM(self.config.decoder)
+
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.BoolTensor] = None,
+        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        **kwargs,
+    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
+        r"""
+        Returns:
+
+        Examples:
+
+        ```python
+        >>> from transformers import AutoProcessor, VisionEncoderDecoderModel
+        >>> import requests
+        >>> from PIL import Image
+        >>> import torch
+
+        >>> processor = AutoProcessor.from_pretrained("microsoft/trocr-base-handwritten")
+        >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
+
+        >>> # load image from the IAM dataset
+        >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
+        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
+
+        >>> # training
+        >>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
+        >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
+        >>> model.config.vocab_size = model.config.decoder.vocab_size
+
+        >>> pixel_values = processor(image, return_tensors="pt").pixel_values
+        >>> text = "hello world"
+        >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
+        >>> outputs = model(pixel_values=pixel_values, labels=labels)
+        >>> loss = outputs.loss
+
+        >>> # inference (generation)
+        >>> generated_ids = model.generate(pixel_values)
+        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
+        ```"""
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
+
+        kwargs_decoder = {
+            argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
+        }
+
+        if encoder_outputs is None:
+            if pixel_values is None:
+                raise ValueError("You have to specify pixel_values")
+
+            encoder_outputs = self.encoder(
+                pixel_values,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+                **kwargs_encoder,
+            )
+        elif isinstance(encoder_outputs, tuple):
+            encoder_outputs = BaseModelOutput(*encoder_outputs)
+
+        encoder_hidden_states = encoder_outputs[0]
+
+        # optionally project encoder_hidden_states
+        if (
+            self.encoder.config.hidden_size != self.decoder.config.hidden_size
+            and self.decoder.config.cross_attention_hidden_size is None
+        ):
+            encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
+
+        # else:
+        encoder_attention_mask = None
+
+        if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
+            decoder_input_ids = shift_tokens_right(
+                labels, self.config.pad_token_id, self.config.decoder_start_token_id
+            )
+
+        # Decode
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            inputs_embeds=decoder_inputs_embeds,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            past_key_values=past_key_values,
+            return_dict=return_dict,
+            **kwargs_decoder,
+        )
+
+        # Compute loss independent from decoder (as some shift the logits inside them)
+        loss = None
+        if labels is not None:
+            logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
+            count_gt = kwargs_decoder.get("count_gt", None)
+
+
+
+        if not return_dict:
+            if loss is not None:
+                return (loss,) + decoder_outputs + encoder_outputs
+            else:
+                return decoder_outputs + encoder_outputs
+
+        return Seq2SeqLMOutput(
+            loss=loss,
+            logits=decoder_outputs.logits,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+    
+
+class SelfAttentionBlock(nn.Module):
+    def __init__(self, embed_size, num_heads):
+        super(SelfAttentionBlock, self).__init__()
+        self.self_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
+        self.norm = nn.LayerNorm(embed_size)
+
+    def forward(self, x):
+        attn_output, _ = self.self_attention(x, x, x)
+        x = self.norm(attn_output + x)
+        return x
+
+
+class DonutEncoderDecoder(nn.Module):
+
+    def __init__(self, model_name, num_tokens, pad_token_id, bos_token_id, eos_token_id):
+        super().__init__()
+        config = VisionEncoderDecoderConfig.from_pretrained(model_name)
+        encoder_config = vars(config.encoder)
+        encoder = VariableUnimerNetConfig(**encoder_config)
+        config.encoder = encoder
+        self.config = config
+
+        AutoModel.register(VariableUnimerNetConfig, VariableUnimerNetModel)
+
+        self.model = CustomVisionEncoderDecoderModel(config=self.config)
+
+        self.model.config.decoder_start_token_id = bos_token_id
+        self.model.config.pad_token_id = pad_token_id
+        self.model.config.eos_token_id = eos_token_id
+        self.model.decoder.resize_token_embeddings(num_tokens)
+        self.pad_token_id = pad_token_id
+
+    def forward(self, pixel_values, decoder_input_ids, decoder_attention_mask, **kwargs):
+        num_channels = pixel_values.shape[1]
+        if num_channels == 1:
+            pixel_values = pixel_values.repeat(1, 3, 1, 1)
+
+        labels = decoder_input_ids * 1
+        labels = labels.masked_fill(labels == self.pad_token_id, -100)
+
+        loss = self.model(
+            pixel_values=pixel_values,
+            decoder_input_ids=decoder_input_ids[:, :-1],
+            decoder_attention_mask=decoder_attention_mask[:, :-1],
+            labels=labels[:, 1:],
+            **kwargs
+        ).loss
+        return loss
+
+    @torch.no_grad()
+    def generate(self, pixel_values, temperature, max_new_tokens, decoder_start_token_id, do_sample, top_p,
+                 **kwargs):
+
+        num_channels = pixel_values.shape[1]
+        if num_channels == 1:
+            pixel_values = pixel_values.repeat(1, 3, 1, 1)
+        outputs = self.model.generate(
+            pixel_values=pixel_values,
+            max_new_tokens=max_new_tokens,
+            decoder_start_token_id=decoder_start_token_id,
+            temperature=temperature,
+            do_sample=do_sample,
+            top_p=top_p,
+        )
+        return outputs[:, 1:]
+
+
+
+class DonutTokenizer:
+    def __init__(self, path):
+        AutoImageProcessor.register(VariableUnimerNetConfig, VariableDonutImageProcessor)
+        processor = VariableDonutProcessor.from_pretrained(path)
+        processor.train = False
+        self.tokenizer = processor.tokenizer
+        self.max_seq_len = 2048
+        self.pad_token_id = self.tokenizer.pad_token_id
+        self.bos_token_id = self.tokenizer.bos_token_id
+        self.eos_token_id = self.tokenizer.eos_token_id
+
+    def __len__(self):
+        return len(self.tokenizer)
+
+    def tokenize(self, texts, max_length=None):
+        if not max_length:
+            max_length = self.max_seq_len
+        text_inputs = self.tokenizer(
+            texts,
+            return_token_type_ids=False,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=max_length,
+        )
+        return text_inputs
+
+    @staticmethod
+    def post_process(text):
+        text = fix_text(text)
+        return text
+
+    def token2str(self, tokens) -> list:
+        generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
+        generated_text = [self.post_process(text) for text in generated_text]
+        return generated_text
+
+    def detokenize(self, tokens):
+        toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
+        for b in range(len(toks)):
+            for i in reversed(range(len(toks[b]))):
+                if toks[b][i] is None:
+                    toks[b][i] = ''
+                toks[b][i] = toks[b][i].replace('Ä ', ' ').strip()
+                if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
+                    del toks[b][i]
+        return toks
diff --git a/unimernet/models/unimernet/modeling_unimernet_decoder.py b/unimernet/models/unimernet/modeling_unimernet_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f42d46175cb8a228656cc3421a55029f9bd19575
--- /dev/null
+++ b/unimernet/models/unimernet/modeling_unimernet_decoder.py
@@ -0,0 +1,2158 @@
+# coding=utf-8
+# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch MBART model."""
+
+import copy
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (
+    BaseModelOutput,
+    BaseModelOutputWithPastAndCrossAttentions,
+    CausalLMOutputWithCrossAttentions,
+    Seq2SeqLMOutput,
+    Seq2SeqModelOutput,
+    Seq2SeqQuestionAnsweringModelOutput,
+    Seq2SeqSequenceClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (
+    add_code_sample_docstrings,
+    add_end_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    is_flash_attn_2_available,
+    is_flash_attn_greater_or_equal_2_10,
+    logging,
+    replace_return_docstrings,
+)
+from .configuration_unimernet_decoder import MBartConfig
+
+
+if is_flash_attn_2_available():
+    from flash_attn import flash_attn_func, flash_attn_varlen_func
+    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "facebook/mbart-large-cc25"
+_CONFIG_FOR_DOC = "MBartConfig"
+
+# Base model docstring
+_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+    return (
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+    )
+
+
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
+    """
+    Shift input ids one token to the right, and wrap the last non pad token (the <LID> token) Note that MBart does not
+    have a single `decoder_start_token_id` in contrast to other Bart-like models.
+    """
+    prev_output_tokens = input_ids.clone()
+
+    if pad_token_id is None:
+        raise ValueError("self.model.config.pad_token_id has to be defined.")
+    # replace possible -100 values in labels by `pad_token_id`
+    prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
+
+    index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
+    decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
+    prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
+    prev_output_tokens[:, 0] = decoder_start_tokens
+
+    return prev_output_tokens
+
+
+# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MBart
+class MBartLearnedPositionalEmbedding(nn.Embedding):
+    """
+    This module learns positional embeddings up to a fixed maximum size.
+    """
+
+    def __init__(self, num_embeddings: int, embedding_dim: int):
+        # MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
+        # and adjust num_embeddings appropriately. Other models don't have this hack
+        self.offset = 2
+        super().__init__(num_embeddings + self.offset, embedding_dim)
+
+    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
+        """`input_ids' shape is expected to be [bsz x seqlen]."""
+
+        bsz, seq_len = input_ids.shape[:2]
+        positions = torch.arange(
+            past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
+        ).expand(bsz, -1)
+
+        return super().forward(positions + self.offset)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
+class MBartScaledWordEmbedding(nn.Embedding):
+    """
+    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
+    """
+
+    def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+        super().__init__(num_embeddings, embedding_dim, padding_idx)
+        self.embed_scale = embed_scale
+
+    def forward(self, input_ids: torch.Tensor):
+        return super().forward(input_ids) * self.embed_scale
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
+class MBartSqueezeAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper, with qk_squeeze"""
+
+    def __init__(
+        self,
+        embed_dim: int,
+        num_heads: int,
+        qk_squeeze: int = 2,
+        dropout: float = 0.0,
+        is_decoder: bool = False,
+        bias: bool = True,
+        is_causal: bool = False,
+        config: Optional[MBartConfig] = None,
+    ):
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.head_dim = embed_dim // num_heads
+        self.config = config
+
+        if (self.head_dim * num_heads) != self.embed_dim:
+            raise ValueError(
+                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+                f" and `num_heads`: {num_heads})."
+            )
+        
+        self.squeeze_dim = embed_dim // qk_squeeze
+        self.squeeze_head_dim = self.squeeze_dim // num_heads
+        self.scaling = self.squeeze_head_dim**-0.5
+        self.is_decoder = is_decoder
+        self.is_causal = is_causal
+
+        self.k_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+        self.q_proj = nn.Linear(embed_dim, self.squeeze_dim, bias=bias)
+        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+    def _shape_qk(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.squeeze_head_dim).transpose(1, 2).contiguous()
+
+    def _shape_v(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        """Input shape: Batch x Time x Channel"""
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, tgt_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self.q_proj(hidden_states) * self.scaling
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0]
+            value_states = past_key_value[1]
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._shape_qk(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+        else:
+            # self_attention
+            key_states = self._shape_qk(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._shape_v(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states, value_states)
+
+        proj_shape = (bsz * self.num_heads, -1, self.squeeze_head_dim)
+        value_shape = (bsz * self.num_heads, -1, self.head_dim)
+        query_states = self._shape_qk(query_states, tgt_len, bsz).view(*proj_shape)
+        key_states = key_states.reshape(*proj_shape)
+        value_states = value_states.reshape(*value_shape)
+
+        src_len = key_states.size(1)
+        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+        if layer_head_mask is not None:
+            if layer_head_mask.size() != (self.num_heads,):
+                raise ValueError(
+                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+                    f" {layer_head_mask.size()}"
+                )
+            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+        if output_attentions:
+            # this operation is a bit awkward, but it's required to
+            # make sure that attn_weights keeps its gradient.
+            # In order to do so, attn_weights have to be reshaped
+            # twice and have to be reused in the following
+            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+        else:
+            attn_weights_reshaped = None
+
+        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+        attn_output = torch.bmm(attn_probs, value_states)
+
+        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+        attn_output = attn_output.transpose(1, 2)
+
+        # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+        # partitioned across GPUs when using tensor-parallelism.
+        attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+        attn_output = self.out_proj(attn_output)
+
+        return attn_output, attn_weights_reshaped, past_key_value
+
+
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
+class MBartFlashAttention2(MBartSqueezeAttention):
+    """
+    MBart flash attention module. This module inherits from `MBartSqueezeAttention` as the weights of the module stays
+    untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+    flash attention and deal with padding tokens in case the input contains any of them.
+    """
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+        # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+        # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+    def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        key_value_states: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        output_attentions: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        # MBartFlashAttention2 attention does not support output_attentions
+        if output_attentions:
+            raise ValueError("MBartFlashAttention2 attention does not support output_attentions")
+
+        # if key_value_states are provided this layer is used as a cross-attention layer
+        # for the decoder
+        is_cross_attention = key_value_states is not None
+
+        bsz, q_len, _ = hidden_states.size()
+
+        # get query proj
+        query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+        # get key, value proj
+        # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+        # is checking that the `sequence_length` of the `past_key_value` is the same as
+        # the provided `key_value_states` to support prefix tuning
+        if (
+            is_cross_attention
+            and past_key_value is not None
+            and past_key_value[0].shape[2] == key_value_states.shape[1]
+        ):
+            # reuse k,v, cross_attentions
+            key_states = past_key_value[0].transpose(1, 2)
+            value_states = past_key_value[1].transpose(1, 2)
+        elif is_cross_attention:
+            # cross_attentions
+            key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+        elif past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+            key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+            value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+        else:
+            # self_attention
+            key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+            value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+        if self.is_decoder:
+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+            # Further calls to cross_attention layer can then reuse all cross-attention
+            # key/value_states (first "if" case)
+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+            # all previous decoder key/value_states. Further calls to uni-directional self-attention
+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+            # if encoder bi-directional self-attention `past_key_value` is always `None`
+            past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+
+        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+        # therefore the input hidden states gets silently casted in float32. Hence, we need
+        # cast them back in the correct dtype just to be sure everything works as expected.
+        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+        # in fp32. (LlamaRMSNorm handles it correctly)
+
+        input_dtype = query_states.dtype
+        if input_dtype == torch.float32:
+            if torch.is_autocast_enabled():
+                target_dtype = torch.get_autocast_gpu_dtype()
+            # Handle the case where the model is quantized
+            elif hasattr(self.config, "_pre_quantization_dtype"):
+                target_dtype = self.config._pre_quantization_dtype
+            else:
+                target_dtype = self.q_proj.weight.dtype
+
+            logger.warning_once(
+                f"The input hidden states seems to be silently casted in float32, this might be related to"
+                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+                f" {target_dtype}."
+            )
+
+            query_states = query_states.to(target_dtype)
+            key_states = key_states.to(target_dtype)
+            value_states = value_states.to(target_dtype)
+
+        attn_output = self._flash_attention_forward(
+            query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
+        )
+
+        attn_output = attn_output.reshape(bsz, q_len, -1)
+        attn_output = self.out_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
+    def _flash_attention_forward(
+        self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+    ):
+        """
+        Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+        first unpad the input, then computes the attention scores and pad the final attention scores.
+
+        Args:
+            query_states (`torch.Tensor`):
+                Input query states to be passed to Flash Attention API
+            key_states (`torch.Tensor`):
+                Input key states to be passed to Flash Attention API
+            value_states (`torch.Tensor`):
+                Input value states to be passed to Flash Attention API
+            attention_mask (`torch.Tensor`):
+                The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+                position of padding tokens and 1 for the position of non-padding tokens.
+            dropout (`float`):
+                Attention dropout
+            softmax_scale (`float`, *optional*):
+                The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+        """
+        if not self._flash_attn_uses_top_left_mask:
+            causal = self.is_causal
+        else:
+            # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+            causal = self.is_causal and query_length != 1
+
+        # Contains at least one padding token in the sequence
+        if attention_mask is not None:
+            batch_size = query_states.shape[0]
+            query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+                query_states, key_states, value_states, attention_mask, query_length
+            )
+
+            cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+            max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+            attn_output_unpad = flash_attn_varlen_func(
+                query_states,
+                key_states,
+                value_states,
+                cu_seqlens_q=cu_seqlens_q,
+                cu_seqlens_k=cu_seqlens_k,
+                max_seqlen_q=max_seqlen_in_batch_q,
+                max_seqlen_k=max_seqlen_in_batch_k,
+                dropout_p=dropout,
+                softmax_scale=softmax_scale,
+                causal=causal,
+            )
+
+            attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+        else:
+            attn_output = flash_attn_func(
+                query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+            )
+
+        return attn_output
+
+    # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
+    def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+        indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+        batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+        key_layer = index_first_axis(
+            key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+        )
+        value_layer = index_first_axis(
+            value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+        )
+        if query_length == kv_seq_len:
+            query_layer = index_first_axis(
+                query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+            )
+            cu_seqlens_q = cu_seqlens_k
+            max_seqlen_in_batch_q = max_seqlen_in_batch_k
+            indices_q = indices_k
+        elif query_length == 1:
+            max_seqlen_in_batch_q = 1
+            cu_seqlens_q = torch.arange(
+                batch_size + 1, dtype=torch.int32, device=query_layer.device
+            )  # There is a memcpy here, that is very bad.
+            indices_q = cu_seqlens_q[:-1]
+            query_layer = query_layer.squeeze(1)
+        else:
+            # The -q_len: slice assumes left padding.
+            attention_mask = attention_mask[:, -query_length:]
+            query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+        return (
+            query_layer,
+            key_layer,
+            value_layer,
+            indices_q,
+            (cu_seqlens_q, cu_seqlens_k),
+            (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+        )
+
+
+MBART_ATTENTION_CLASSES = {
+    "eager": MBartSqueezeAttention,
+    "flash_attention_2": MBartFlashAttention2,
+}
+
+
+class MBartEncoderLayer(nn.Module):
+    def __init__(self, config: MBartConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.encoder_attention_heads,
+            dropout=config.attention_dropout,
+            config=config,
+        )
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: torch.Tensor,
+        layer_head_mask: torch.Tensor,
+        output_attentions: bool = False,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+        hidden_states, attn_weights, _ = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        if hidden_states.dtype == torch.float16 and (
+            torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+        ):
+            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (attn_weights,)
+
+        return outputs
+
+
+class MBartDecoderLayer(nn.Module):
+    def __init__(self, config: MBartConfig):
+        super().__init__()
+        self.embed_dim = config.d_model
+
+        self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+            embed_dim=self.embed_dim,
+            num_heads=config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            is_causal=True,
+            config=config,
+        )
+        self.dropout = config.dropout
+        self.activation_fn = ACT2FN[config.activation_function]
+        self.activation_dropout = config.activation_dropout
+
+        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+            self.embed_dim,
+            config.decoder_attention_heads,
+            dropout=config.attention_dropout,
+            is_decoder=True,
+            config=config,
+        )
+        self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+        self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        encoder_attention_mask: Optional[torch.Tensor] = None,
+        layer_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = True,
+    ) -> torch.Tensor:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            encoder_hidden_states (`torch.FloatTensor`):
+                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+                `(encoder_attention_heads,)`.
+            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+                size `(decoder_attention_heads,)`.
+            past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+        """
+        residual = hidden_states
+        hidden_states = self.self_attn_layer_norm(hidden_states)
+
+        # Self Attention
+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+        # add present self-attn cache to positions 1,2 of present_key_value tuple
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            past_key_value=self_attn_past_key_value,
+            attention_mask=attention_mask,
+            layer_head_mask=layer_head_mask,
+            output_attentions=output_attentions,
+        )
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        # Cross-Attention Block
+        cross_attn_present_key_value = None
+        cross_attn_weights = None
+        if encoder_hidden_states is not None:
+            residual = hidden_states
+            hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+            # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+            hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+                hidden_states=hidden_states,
+                key_value_states=encoder_hidden_states,
+                attention_mask=encoder_attention_mask,
+                layer_head_mask=cross_attn_layer_head_mask,
+                past_key_value=cross_attn_past_key_value,
+                output_attentions=output_attentions,
+            )
+            hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+            hidden_states = residual + hidden_states
+
+            # add cross-attn to positions 3,4 of present_key_value tuple
+            present_key_value = present_key_value + cross_attn_present_key_value
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.final_layer_norm(hidden_states)
+        hidden_states = self.activation_fn(self.fc1(hidden_states))
+        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+        hidden_states = self.fc2(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights, cross_attn_weights)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->MBart
+class MBartClassificationHead(nn.Module):
+    """Head for sentence-level classification tasks."""
+
+    def __init__(
+        self,
+        input_dim: int,
+        inner_dim: int,
+        num_classes: int,
+        pooler_dropout: float,
+    ):
+        super().__init__()
+        self.dense = nn.Linear(input_dim, inner_dim)
+        self.dropout = nn.Dropout(p=pooler_dropout)
+        self.out_proj = nn.Linear(inner_dim, num_classes)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.dense(hidden_states)
+        hidden_states = torch.tanh(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.out_proj(hidden_states)
+        return hidden_states
+
+
+class MBartPreTrainedModel(PreTrainedModel):
+    config_class = MBartConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["MBartDecoderLayer", "MBartSqueezeAttention"]
+    _supports_flash_attn_2 = True
+
+    def _init_weights(self, module):
+        std = self.config.init_std
+        if isinstance(module, nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+    @property
+    def dummy_inputs(self):
+        pad_token = self.config.pad_token_id
+        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+        dummy_inputs = {
+            "attention_mask": input_ids.ne(pad_token),
+            "input_ids": input_ids,
+        }
+        return dummy_inputs
+
+
+MBART_START_DOCSTRING = r"""
+    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+    etc.)
+
+    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+    and behavior.
+
+    Parameters:
+        config ([`MBartConfig`]):
+            Model configuration class with all the parameters of the model. Initializing with a config file does not
+            load the weights associated with the model, only the configuration. Check out the
+            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+MBART_GENERATION_EXAMPLE = r"""
+    Translation example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-en-ro")
+    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro")
+
+    >>> example_english_phrase = "42 is the answer"
+    >>> inputs = tokenizer(example_english_phrase, return_tensors="pt")
+
+    >>> # Translate
+    >>> generated_ids = model.generate(**inputs, num_beams=4, max_length=5)
+    >>> tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+    '42 este răspuns'
+    ```
+
+    Mask filling example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, MBartForConditionalGeneration
+
+    >>> model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
+    >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+
+    >>> # de_DE is the language symbol id <LID> for German
+    >>> TXT = "</s> Meine Freunde sind <mask> nett aber sie essen zu viel Kuchen. </s> de_DE"
+
+    >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt")["input_ids"]
+    >>> logits = model(input_ids).logits
+
+    >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+    >>> probs = logits[0, masked_index].softmax(dim=0)
+    >>> values, predictions = probs.topk(5)
+
+    >>> tokenizer.decode(predictions).split()
+    ['nett', 'sehr', 'ganz', 'nicht', 'so']
+    ```
+"""
+
+MBART_INPUTS_DOCSTRING = r"""
+    Args:
+        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+            it.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are input IDs?](../glossary#input-ids)
+        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+            - 1 for tokens that are **not masked**,
+            - 0 for tokens that are **masked**.
+
+            [What are attention masks?](../glossary#attention-mask)
+        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Indices of decoder input sequence tokens in the vocabulary.
+
+            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+            [`PreTrainedTokenizer.__call__`] for details.
+
+            [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+            MBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+            varies according to source and target language, *e.g.* 25004 for *en_XX*, and 25003 for *de_DE*. If
+            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+            `past_key_values`).
+
+            For translation and summarization training, `decoder_input_ids` should be provided. If no
+            `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+            for denoising pre-training following the paper.
+        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
+            be used by default.
+        head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
+            1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+            Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
+            `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
+            hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+            This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+            than the model's internal embedding lookup matrix.
+        decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
+            Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
+            representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
+            input (see `past_key_values`). This is useful if you want more control over how to convert
+            `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
+
+            If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
+            of `inputs_embeds`.
+        use_cache (`bool`, *optional*):
+            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+            `past_key_values`).
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+class MBartEncoder(MBartPreTrainedModel):
+    """
+    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+    [`MBartEncoderLayer`].
+
+    Args:
+        config: MBartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+
+        self.dropout = config.dropout
+        self.layerdrop = config.encoder_layerdrop
+
+        embed_dim = config.d_model
+        self.padding_idx = config.pad_token_id
+        self.max_source_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
+
+        self.embed_tokens = MBartScaledWordEmbedding(
+            config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
+        )
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = MBartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            embed_dim,
+        )
+        self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)])
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self.layernorm_embedding = nn.LayerNorm(embed_dim)
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def _backward_compatibility_gradient_checkpointing(self):
+        # Override to not delete the attribute from the config
+        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
+            self.gradient_checkpointing_enable()
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutput]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.shape
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        embed_pos = self.embed_positions(input)
+
+        hidden_states = inputs_embeds + embed_pos.to(inputs_embeds.device)
+        hidden_states = self.layernorm_embedding(hidden_states)
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        # expand attention_mask
+        if attention_mask is not None:
+            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+            if self._use_flash_attention_2:
+                attention_mask = attention_mask if 0 in attention_mask else None
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+        encoder_states = () if output_hidden_states else None
+        all_attentions = () if output_attentions else None
+
+        # check if head_mask has a correct number of layers specified if desired
+        if head_mask is not None:
+            if head_mask.size()[0] != len(self.layers):
+                raise ValueError(
+                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+                    f" {head_mask.size()[0]}."
+                )
+        for idx, encoder_layer in enumerate(self.layers):
+            if output_hidden_states:
+                encoder_states = encoder_states + (hidden_states,)
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            to_drop = False
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:  # skip the layer
+                    to_drop = True
+
+            if to_drop:
+                layer_outputs = (None, None)
+            else:
+                if self.gradient_checkpointing and self.training:
+                    layer_outputs = self._gradient_checkpointing_func(
+                        encoder_layer.__call__,
+                        hidden_states,
+                        attention_mask,
+                        (head_mask[idx] if head_mask is not None else None),
+                        output_attentions,
+                    )
+                else:
+                    layer_outputs = encoder_layer(
+                        hidden_states,
+                        attention_mask,
+                        layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                        output_attentions=output_attentions,
+                    )
+
+                hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_attentions = all_attentions + (layer_outputs[1],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        if output_hidden_states:
+            encoder_states = encoder_states + (hidden_states,)
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+        return BaseModelOutput(
+            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+        )
+
+
+class MBartDecoder(MBartPreTrainedModel):
+    """
+    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`MBartDecoderLayer`]
+
+    Args:
+        config: MBartConfig
+        embed_tokens (nn.Embedding): output embedding
+    """
+
+    def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
+        super().__init__(config)
+        self.dropout = config.dropout
+        self.layerdrop = config.decoder_layerdrop
+        self.padding_idx = config.pad_token_id
+        self.max_target_positions = config.max_position_embeddings
+        embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+
+        self.embed_tokens = MBartScaledWordEmbedding(
+            config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
+        )
+
+        if embed_tokens is not None:
+            self.embed_tokens.weight = embed_tokens.weight
+
+        self.embed_positions = MBartLearnedPositionalEmbedding(
+            config.max_position_embeddings,
+            config.d_model,
+        )
+        self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)])
+        self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+        self.layernorm_embedding = nn.LayerNorm(config.d_model)
+        self.layer_norm = nn.LayerNorm(config.d_model)
+
+        self.gradient_checkpointing = False
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                of the decoder.
+            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
+                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
+                selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
+                cross-attention on hidden heads. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+                than the model's internal embedding lookup matrix.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            input = input_ids
+            input_shape = input.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+            input = inputs_embeds[:, :, -1]
+        else:
+            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+        # past_key_values_length
+        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self._use_flash_attention_2:
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask, input_shape, inputs_embeds, past_key_values_length
+            )
+
+        # expand encoder attention mask
+        if encoder_hidden_states is not None and encoder_attention_mask is not None:
+            if self._use_flash_attention_2:
+                encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+            else:
+                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+                encoder_attention_mask = _prepare_4d_attention_mask(
+                    encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+                )
+
+        # embed positions
+        positions = self.embed_positions(input, past_key_values_length)
+
+        hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
+        hidden_states = self.layernorm_embedding(hidden_states)
+
+        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                logger.warning_once(
+                    "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+                )
+                use_cache = False
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
+        next_decoder_cache = () if use_cache else None
+
+        # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
+        for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
+            if attn_mask is not None:
+                if attn_mask.size()[0] != len(self.layers):
+                    raise ValueError(
+                        f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
+                        f" {attn_mask.size()[0]}."
+                    )
+        for idx, decoder_layer in enumerate(self.layers):
+            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+            if self.training:
+                dropout_probability = torch.rand([])
+                if dropout_probability < self.layerdrop:
+                    continue
+
+            past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    encoder_hidden_states,
+                    encoder_attention_mask,
+                    head_mask[idx] if head_mask is not None else None,
+                    cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
+                    None,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    encoder_hidden_states=encoder_hidden_states,
+                    encoder_attention_mask=encoder_attention_mask,
+                    layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+                    cross_attn_layer_head_mask=(
+                        cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
+                    ),
+                    past_key_value=past_key_value,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+            hidden_states = layer_outputs[0]
+
+            if use_cache:
+                next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+                if encoder_hidden_states is not None:
+                    all_cross_attentions += (layer_outputs[2],)
+
+        hidden_states = self.layer_norm(hidden_states)
+
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = next_decoder_cache if use_cache else None
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+                if v is not None
+            )
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+            cross_attentions=all_cross_attentions,
+        )
+
+
+@add_start_docstrings(
+    "The bare MBART Model outputting raw hidden-states without any specific head on top.",
+    MBART_START_DOCSTRING,
+)
+class MBartModel(MBartPreTrainedModel):
+    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+    def __init__(self, config: MBartConfig):
+        super().__init__(config)
+
+        padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+        self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
+
+        self.encoder = MBartEncoder(config, self.shared)
+        self.decoder = MBartDecoder(config, self.shared)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.shared
+
+    def set_input_embeddings(self, value):
+        self.shared = value
+        self.encoder.embed_tokens = self.shared
+        self.decoder.embed_tokens = self.shared
+
+    def get_encoder(self):
+        return self.encoder
+
+    def get_decoder(self):
+        return self.decoder
+
+    def _tie_weights(self):
+        if self.config.tie_word_embeddings:
+            self._tie_or_clone_weights(self.encoder.embed_tokens, self.get_input_embeddings())
+            self._tie_or_clone_weights(self.decoder.embed_tokens, self.get_input_embeddings())
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]:
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # different to other models, MBart automatically creates decoder_input_ids from
+        # input_ids if no decoder_input_ids are provided
+        if decoder_input_ids is None and decoder_inputs_embeds is None:
+            decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
+
+        if encoder_outputs is None:
+            encoder_outputs = self.encoder(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                head_mask=head_mask,
+                inputs_embeds=inputs_embeds,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                return_dict=return_dict,
+            )
+        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+            encoder_outputs = BaseModelOutput(
+                last_hidden_state=encoder_outputs[0],
+                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+            )
+
+        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            input_ids=decoder_input_ids,
+            attention_mask=decoder_attention_mask,
+            encoder_hidden_states=encoder_outputs[0],
+            encoder_attention_mask=attention_mask,
+            head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        if not return_dict:
+            return decoder_outputs + encoder_outputs
+
+        return Seq2SeqModelOutput(
+            last_hidden_state=decoder_outputs.last_hidden_state,
+            past_key_values=decoder_outputs.past_key_values,
+            decoder_hidden_states=decoder_outputs.hidden_states,
+            decoder_attentions=decoder_outputs.attentions,
+            cross_attentions=decoder_outputs.cross_attentions,
+            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+            encoder_hidden_states=encoder_outputs.hidden_states,
+            encoder_attentions=encoder_outputs.attentions,
+        )
+
+
+@add_start_docstrings(
+    "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.",
+    MBART_START_DOCSTRING,
+)
+class MBartForConditionalGeneration(MBartPreTrainedModel):
+    base_model_prefix = "model"
+    _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"]
+
+    def __init__(self, config: MBartConfig):
+        super().__init__(config)
+        self.model = MBartModel(config)
+        self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+        self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_encoder(self):
+        return self.model.get_encoder()
+
+    def get_decoder(self):
+        return self.model.get_decoder()
+
+    def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
+        new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+        self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+        return new_embeddings
+
+    def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+        old_num_tokens = self.final_logits_bias.shape[-1]
+        if new_num_tokens <= old_num_tokens:
+            new_bias = self.final_logits_bias[:, :new_num_tokens]
+        else:
+            extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+            new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+        self.register_buffer("final_logits_bias", new_bias)
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
+    @add_end_docstrings(MBART_GENERATION_EXAMPLE)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+        Returns:
+
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if labels is not None:
+            if use_cache:
+                logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
+            use_cache = False
+            if decoder_input_ids is None and decoder_inputs_embeds is None:
+                decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            encoder_outputs=encoder_outputs,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
+
+        masked_lm_loss = None
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (lm_logits,) + outputs[1:]
+            return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+        return Seq2SeqLMOutput(
+            loss=masked_lm_loss,
+            logits=lm_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self,
+        decoder_input_ids,
+        past_key_values=None,
+        attention_mask=None,
+        head_mask=None,
+        decoder_head_mask=None,
+        cross_attn_head_mask=None,
+        use_cache=None,
+        encoder_outputs=None,
+        **kwargs,
+    ):
+        # cut decoder_input_ids if past is used
+        if past_key_values is not None:
+            past_length = past_key_values[0][0].shape[2]
+
+            # Some generation methods already pass only the last input ID
+            if decoder_input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                # Default to old behavior: keep only final ID
+                remove_prefix_length = decoder_input_ids.shape[1] - 1
+
+            decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
+
+        return {
+            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
+            "encoder_outputs": encoder_outputs,
+            "past_key_values": past_key_values,
+            "decoder_input_ids": decoder_input_ids,
+            "attention_mask": attention_mask,
+            "head_mask": head_mask,
+            "decoder_head_mask": decoder_head_mask,
+            "cross_attn_head_mask": cross_attn_head_mask,
+            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
+        }
+
+    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+        return shift_tokens_right(labels, self.config.pad_token_id)
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            # cached cross_attention states don't have to be reordered -> they are always the same
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+                + layer_past[2:],
+            )
+        return reordered_past
+
+
+@add_start_docstrings(
+    """
+    MBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
+    tasks.
+    """,
+    MBART_START_DOCSTRING,
+)
+class MBartForSequenceClassification(MBartPreTrainedModel):
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+    def __init__(self, config: MBartConfig, **kwargs):
+        super().__init__(config, **kwargs)
+        self.model = MBartModel(config)
+        self.classification_head = MBartClassificationHead(
+            config.d_model,
+            config.d_model,
+            config.num_labels,
+            config.classifier_dropout,
+        )
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqSequenceClassifierOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    # Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification.forward
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
+        r"""
+        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+            config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if labels is not None:
+            use_cache = False
+
+        if input_ids is None and inputs_embeds is not None:
+            raise NotImplementedError(
+                f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
+            )
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            encoder_outputs=encoder_outputs,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+        hidden_states = outputs[0]  # last hidden state
+
+        eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device)
+
+        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
+            raise ValueError("All examples must have the same number of <eos> tokens.")
+        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
+            :, -1, :
+        ]
+        logits = self.classification_head(sentence_representation)
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            if self.config.problem_type is None:
+                if self.config.num_labels == 1:
+                    self.config.problem_type = "regression"
+                elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+                    self.config.problem_type = "single_label_classification"
+                else:
+                    self.config.problem_type = "multi_label_classification"
+
+            if self.config.problem_type == "regression":
+                loss_fct = MSELoss()
+                if self.config.num_labels == 1:
+                    loss = loss_fct(logits.squeeze(), labels.squeeze())
+                else:
+                    loss = loss_fct(logits, labels)
+            elif self.config.problem_type == "single_label_classification":
+                loss_fct = CrossEntropyLoss()
+                loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
+            elif self.config.problem_type == "multi_label_classification":
+                loss_fct = BCEWithLogitsLoss()
+                loss = loss_fct(logits, labels)
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return ((loss,) + output) if loss is not None else output
+
+        return Seq2SeqSequenceClassifierOutput(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+@add_start_docstrings(
+    """
+    MBART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+    layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
+    """,
+    MBART_START_DOCSTRING,
+)
+class MBartForQuestionAnswering(MBartPreTrainedModel):
+    _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        config.num_labels = 2
+        self.num_labels = config.num_labels
+
+        self.model = MBartModel(config)
+        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=Seq2SeqQuestionAnsweringModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+    )
+    # Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering.forward
+    def forward(
+        self,
+        input_ids: torch.Tensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        decoder_input_ids: Optional[torch.LongTensor] = None,
+        decoder_attention_mask: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        decoder_head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        end_positions: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
+        r"""
+        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for position (index) of the start of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+            are not taken into account for computing the loss.
+        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+            Labels for position (index) of the end of the labelled span for computing the token classification loss.
+            Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
+            are not taken into account for computing the loss.
+        """
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+        if start_positions is not None and end_positions is not None:
+            use_cache = False
+
+        outputs = self.model(
+            input_ids,
+            attention_mask=attention_mask,
+            decoder_input_ids=decoder_input_ids,
+            decoder_attention_mask=decoder_attention_mask,
+            head_mask=head_mask,
+            decoder_head_mask=decoder_head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            encoder_outputs=encoder_outputs,
+            inputs_embeds=inputs_embeds,
+            decoder_inputs_embeds=decoder_inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = outputs[0]
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1).contiguous()
+        end_logits = end_logits.squeeze(-1).contiguous()
+
+        total_loss = None
+        if start_positions is not None and end_positions is not None:
+            # If we are on multi-GPU, split add a dimension
+            if len(start_positions.size()) > 1:
+                start_positions = start_positions.squeeze(-1)
+            if len(end_positions.size()) > 1:
+                end_positions = end_positions.squeeze(-1)
+            # sometimes the start/end positions are outside our model inputs, we ignore these terms
+            ignored_index = start_logits.size(1)
+            start_positions = start_positions.clamp(0, ignored_index)
+            end_positions = end_positions.clamp(0, ignored_index)
+
+            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+            start_loss = loss_fct(start_logits, start_positions)
+            end_loss = loss_fct(end_logits, end_positions)
+            total_loss = (start_loss + end_loss) / 2
+
+        if not return_dict:
+            output = (
+                start_logits,
+                end_logits,
+            ) + outputs[1:]
+            return ((total_loss,) + output) if total_loss is not None else output
+
+        return Seq2SeqQuestionAnsweringModelOutput(
+            loss=total_loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            past_key_values=outputs.past_key_values,
+            decoder_hidden_states=outputs.decoder_hidden_states,
+            decoder_attentions=outputs.decoder_attentions,
+            cross_attentions=outputs.cross_attentions,
+            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+            encoder_hidden_states=outputs.encoder_hidden_states,
+            encoder_attentions=outputs.encoder_attentions,
+        )
+
+
+# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
+class MBartDecoderWrapper(MBartPreTrainedModel):
+    """
+    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
+    used in combination with the [`EncoderDecoderModel`] framework.
+    """
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.decoder = MBartDecoder(config)
+
+    def forward(self, *args, **kwargs):
+        return self.decoder(*args, **kwargs)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
+class MBartForCausalLM(MBartPreTrainedModel):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(self, config):
+        config = copy.deepcopy(config)
+        config.is_decoder = True
+        config.is_encoder_decoder = False
+        super().__init__(config)
+        self.model = MBartDecoderWrapper(config)
+
+        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.model.decoder.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.decoder.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        self.lm_head = new_embeddings
+
+    def set_decoder(self, decoder):
+        self.model.decoder = decoder
+
+    def get_decoder(self):
+        return self.model.decoder
+
+    @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        cross_attn_head_mask: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+        r"""
+        Args:
+            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+                provide it.
+
+                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+                [`PreTrainedTokenizer.__call__`] for details.
+
+                [What are input IDs?](../glossary#input-ids)
+            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+
+                [What are attention masks?](../glossary#attention-mask)
+            encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
+                if the model is configured as a decoder.
+            encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
+                in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+                Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+                - 1 indicates the head is **not masked**,
+                - 0 indicates the head is **masked**.
+
+            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+                shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+                shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+                tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+
+                - 1 for tokens that are **not masked**,
+                - 0 for tokens that are **masked**.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            output_hidden_states (`bool`, *optional*):
+                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+                for more detail.
+            return_dict (`bool`, *optional*):
+                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+        Returns:
+
+        Example:
+
+        ```python
+        >>> from transformers import AutoTokenizer, MBartForCausalLM
+
+        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
+        >>> model = MBartForCausalLM.from_pretrained("facebook/mbart-large-cc25", add_cross_attention=False)
+        >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+        >>> outputs = model(**inputs)
+
+        >>> logits = outputs.logits
+        >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+        >>> list(logits.shape) == expected_shape
+        True
+        ```"""
+
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+        outputs = self.model.decoder(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            encoder_hidden_states=encoder_hidden_states,
+            encoder_attention_mask=encoder_attention_mask,
+            head_mask=head_mask,
+            cross_attn_head_mask=cross_attn_head_mask,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        logits = self.lm_head(outputs[0])
+
+        loss = None
+        if labels is not None:
+            labels = labels.to(logits.device)
+            loss_fct = CrossEntropyLoss()
+            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithCrossAttentions(
+            loss=loss,
+            logits=logits,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+            cross_attentions=outputs.cross_attentions,
+        )
+
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
+    ):
+        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+        if attention_mask is None:
+            attention_mask = input_ids.new_ones(input_ids.shape)
+
+        if past_key_values:
+            past_length = past_key_values[0][0].shape[2]
+
+            # Some generation methods already pass only the last input ID
+            if input_ids.shape[1] > past_length:
+                remove_prefix_length = past_length
+            else:
+                # Default to old behavior: keep only final ID
+                remove_prefix_length = input_ids.shape[1] - 1
+
+            input_ids = input_ids[:, remove_prefix_length:]
+        # first step, decoder_cached_states are empty
+        return {
+            "input_ids": input_ids,  # encoder_outputs is defined. input_ids not needed
+            "attention_mask": attention_mask,
+            "past_key_values": past_key_values,
+            "use_cache": use_cache,
+        }
+
+    @staticmethod
+    def _reorder_cache(past_key_values, beam_idx):
+        reordered_past = ()
+        for layer_past in past_key_values:
+            reordered_past += (
+                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+            )
+        return reordered_past
diff --git a/unimernet/models/unimernet/modeling_unimernet_encoder.py b/unimernet/models/unimernet/modeling_unimernet_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8653071bdcb7e7460c1ec27a268b28409b2fa7
--- /dev/null
+++ b/unimernet/models/unimernet/modeling_unimernet_encoder.py
@@ -0,0 +1,1035 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UnimerNet Transformer model.
+
+This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
+states."""
+
+import collections.abc
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.activations import ACT2FN
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
+from transformers.utils import (
+    ModelOutput,
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+    logging,
+    torch_int,
+)
+from .configuration_unimernet_encoder import UnimerNetConfig
+
+
+logger = logging.get_logger(__name__)
+
+# General docstring
+_CONFIG_FOR_DOC = "UnimerNetConfig"
+
+# Base docstring
+_CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
+_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->UnimerNet
+class UnimerNetEncoderOutput(ModelOutput):
+    """
+    UnimerNet encoder's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, hidden_size, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+            include the spatial dimensions.
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+# Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->UnimerNet
+class UnimerNetModelOutput(ModelOutput):
+    """
+    UnimerNet model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
+            Average pooling of the last layer hidden-state.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
+            shape `(batch_size, hidden_size, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
+            include the spatial dimensions.
+    """
+
+    last_hidden_state: torch.FloatTensor = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
+    reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
+
+
+# Copied from transformers.models.swin.modeling_swin.window_partition
+def window_partition(input_feature, window_size):
+    """
+    Partitions the given input into windows.
+    """
+    batch_size, height, width, num_channels = input_feature.shape
+    input_feature = input_feature.view(
+        batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
+    )
+    windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
+    return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.window_reverse
+def window_reverse(windows, window_size, height, width):
+    """
+    Merges windows to produce higher resolution features.
+    """
+    num_channels = windows.shape[-1]
+    windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
+    windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
+    return windows
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->UnimerNet
+class UnimerNetEmbeddings(nn.Module):
+    """
+    Construct the patch and position embeddings. Optionally, also the mask token.
+    """
+
+    def __init__(self, config, use_mask_token=False):
+        super().__init__()
+
+        self.patch_embeddings = UnimerNetPatchEmbeddings(config)
+        num_patches = self.patch_embeddings.num_patches
+        self.patch_grid = self.patch_embeddings.grid_size
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
+
+        if config.use_absolute_embeddings:
+            self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
+        else:
+            self.position_embeddings = None
+
+        self.norm = nn.LayerNorm(config.embed_dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
+        """
+        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
+        resolution images.
+
+        Source:
+        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
+        """
+
+        num_patches = embeddings.shape[1] - 1
+        num_positions = self.position_embeddings.shape[1] - 1
+        if num_patches == num_positions and height == width:
+            return self.position_embeddings
+        class_pos_embed = self.position_embeddings[:, 0]
+        patch_pos_embed = self.position_embeddings[:, 1:]
+        dim = embeddings.shape[-1]
+        h0 = height // self.config.patch_size
+        w0 = width // self.config.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        h0, w0 = h0 + 0.1, w0 + 0.1
+        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
+        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed,
+            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
+            mode="bicubic",
+            align_corners=False,
+        )
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor],
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        interpolate_pos_encoding: bool = False,
+    ) -> Tuple[torch.Tensor]:
+        _, num_channels, height, width = pixel_values.shape
+        embeddings, output_dimensions = self.patch_embeddings(pixel_values)
+        embeddings = self.norm(embeddings)
+        batch_size, seq_len, _ = embeddings.size()
+
+        if bool_masked_pos is not None:
+            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
+            # replace the masked visual tokens by mask_tokens
+            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
+            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
+
+        if self.position_embeddings is not None:
+            if interpolate_pos_encoding:
+                embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
+            else:
+                embeddings = embeddings + self.position_embeddings
+
+        embeddings = self.dropout(embeddings)
+
+        return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->UnimerNet
+class UnimerNetPatchEmbeddings(nn.Module):
+    """
+    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
+    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
+    Transformer.
+    """
+
+    def __init__(self, config):
+        super().__init__()
+        image_size, patch_size = config.image_size, config.patch_size
+        num_channels, hidden_size = config.num_channels, config.embed_dim
+        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
+        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
+        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.num_channels = num_channels
+        self.num_patches = num_patches
+        self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
+
+        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
+
+    def maybe_pad(self, pixel_values, height, width):
+        if width % self.patch_size[1] != 0:
+            pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
+            pixel_values = nn.functional.pad(pixel_values, pad_values)
+        if height % self.patch_size[0] != 0:
+            pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
+            pixel_values = nn.functional.pad(pixel_values, pad_values)
+        return pixel_values
+
+    def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
+        _, num_channels, height, width = pixel_values.shape
+        # pad the input to be divisible by self.patch_size, if needed
+        pixel_values = self.maybe_pad(pixel_values, height, width)
+        embeddings = self.projection(pixel_values)
+        _, _, height, width = embeddings.shape
+        output_dimensions = (height, width)
+        embeddings = embeddings.flatten(2).transpose(1, 2)
+
+        return embeddings, output_dimensions
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
+class UnimerNetPatchMerging(nn.Module):
+    """
+    Patch Merging Layer.
+
+    Args:
+        input_resolution (`Tuple[int]`):
+            Resolution of input feature.
+        dim (`int`):
+            Number of input channels.
+        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
+            Normalization layer class.
+    """
+
+    def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def maybe_pad(self, input_feature, height, width):
+        should_pad = (height % 2 == 1) or (width % 2 == 1)
+        if should_pad:
+            pad_values = (0, 0, 0, width % 2, 0, height % 2)
+            input_feature = nn.functional.pad(input_feature, pad_values)
+
+        return input_feature
+
+    def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor:
+        height, width = input_dimensions
+        # `dim` is height * width
+        batch_size, dim, num_channels = input_feature.shape
+
+        input_feature = input_feature.view(batch_size, height, width, num_channels)
+        # pad input to be disible by width and height, if needed
+        input_feature = self.maybe_pad(input_feature, height, width)
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_0 = input_feature[:, 0::2, 0::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_1 = input_feature[:, 1::2, 0::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_2 = input_feature[:, 0::2, 1::2, :]
+        # [batch_size, height/2, width/2, num_channels]
+        input_feature_3 = input_feature[:, 1::2, 1::2, :]
+        # batch_size height/2 width/2 4*num_channels
+        input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
+        input_feature = input_feature.view(batch_size, -1, 4 * num_channels)  # batch_size height/2*width/2 4*C
+
+        input_feature = self.norm(input_feature)
+        input_feature = self.reduction(input_feature)
+
+        return input_feature
+
+
+# Copied from transformers.models.beit.modeling_beit.drop_path
+def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
+    """
+    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
+    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
+    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
+    argument.
+    """
+    if drop_prob == 0.0 or not training:
+        return input
+    keep_prob = 1 - drop_prob
+    shape = (input.shape[0],) + (1,) * (input.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
+    random_tensor.floor_()  # binarize
+    output = input.div(keep_prob) * random_tensor
+    return output
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinDropPath
+class UnimerNetDropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob: Optional[float] = None) -> None:
+        super().__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        return drop_path(hidden_states, self.drop_prob, self.training)
+
+    def extra_repr(self) -> str:
+        return "p={}".format(self.drop_prob)
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->UnimerNet
+class UnimerNetSelfAttention(nn.Module):
+    def __init__(self, config, dim, num_heads, window_size):
+        super().__init__()
+        if dim % num_heads != 0:
+            raise ValueError(
+                f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
+            )
+
+        self.num_attention_heads = num_heads
+        self.attention_head_size = int(dim / num_heads)
+        self.all_head_size = self.num_attention_heads * self.attention_head_size
+        self.window_size = (
+            window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
+        )
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
+        )
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
+        coords_flatten = torch.flatten(coords, 1)
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+        relative_coords[:, :, 0] += self.window_size[0] - 1
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+        self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
+
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def transpose_for_scores(self, x):
+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+        x = x.view(new_x_shape)
+        return x.permute(0, 2, 1, 3)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        batch_size, dim, num_channels = hidden_states.shape
+        mixed_query_layer = self.query(hidden_states)
+
+        key_layer = self.transpose_for_scores(self.key(hidden_states))
+        value_layer = self.transpose_for_scores(self.value(hidden_states))
+        query_layer = self.transpose_for_scores(mixed_query_layer)
+
+        # Take the dot product between "query" and "key" to get the raw attention scores.
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
+        relative_position_bias = relative_position_bias.view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+        )
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
+        attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
+
+        if attention_mask is not None:
+            # Apply the attention mask is (precomputed for all layers in UnimerNetModel forward() function)
+            mask_shape = attention_mask.shape[0]
+            attention_scores = attention_scores.view(
+                batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
+            )
+            attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
+            attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
+
+        # Normalize the attention scores to probabilities.
+        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+        attention_probs = self.dropout(attention_probs)
+
+        # Mask heads if we want to
+        if head_mask is not None:
+            attention_probs = attention_probs * head_mask
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+        context_layer = context_layer.view(new_context_layer_shape)
+
+        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+        return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
+class UnimerNetSelfOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, dim)
+        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+
+        return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->UnimerNet
+class UnimerNetAttention(nn.Module):
+    def __init__(self, config, dim, num_heads, window_size):
+        super().__init__()
+        self.self = UnimerNetSelfAttention(config, dim, num_heads, window_size)
+        self.output = UnimerNetSelfOutput(config, dim)
+        self.pruned_heads = set()
+
+    def prune_heads(self, heads):
+        if len(heads) == 0:
+            return
+        heads, index = find_pruneable_heads_and_indices(
+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+        )
+
+        # Prune linear layers
+        self.self.query = prune_linear_layer(self.self.query, index)
+        self.self.key = prune_linear_layer(self.self.key, index)
+        self.self.value = prune_linear_layer(self.self.value, index)
+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+        # Update hyper params and store pruned heads
+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+        self.pruned_heads = self.pruned_heads.union(heads)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
+        attention_output = self.output(self_outputs[0], hidden_states)
+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
+        return outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinIntermediate
+class UnimerNetIntermediate(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
+        if isinstance(config.hidden_act, str):
+            self.intermediate_act_fn = ACT2FN[config.hidden_act]
+        else:
+            self.intermediate_act_fn = config.hidden_act
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinOutput
+class UnimerNetOutput(nn.Module):
+    def __init__(self, config, dim):
+        super().__init__()
+        self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
+        self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states = self.dense(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        return hidden_states
+
+
+class ConvEnhance(nn.Module):
+    """Depth-wise convolution to get the positional information.
+    """
+    def __init__(self, config, dim, k=3):
+        super(ConvEnhance, self).__init__()
+        self.proj = nn.Conv2d(dim,
+                              dim,
+                              (k,k),
+                              (1,1),
+                              (k // 2,k // 2),
+                              groups=dim)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x, size: Tuple[int, int]):
+        B, N, C = x.shape
+        H, W = size
+        assert N == H * W
+
+        feat = x.transpose(1, 2).view(B, C, H, W)
+        feat = self.proj(feat)
+        feat = self.act_fn(feat)
+        feat = feat.flatten(2).transpose(1, 2)
+
+        x = x + feat
+        return x
+
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->UnimerNet
+class UnimerNetLayer(nn.Module):
+    def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
+        super().__init__()
+        self.chunk_size_feed_forward = config.chunk_size_feed_forward
+        self.shift_size = shift_size
+        self.window_size = config.window_size
+        self.input_resolution = input_resolution
+        self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+
+        self.ce = nn.ModuleList([ConvEnhance(config, dim=dim, k=3),
+                                  ConvEnhance(config, dim=dim, k=3)])
+
+        self.attention = UnimerNetAttention(config, dim, num_heads, window_size=self.window_size)
+        self.drop_path = UnimerNetDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
+        self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
+        self.intermediate = UnimerNetIntermediate(config, dim)
+        self.output = UnimerNetOutput(config, dim)
+
+    def set_shift_and_window_size(self, input_resolution):
+        if min(input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = torch_int(0)
+            self.window_size = (
+                torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
+            )
+
+    def get_attn_mask(self, height, width, dtype, device):
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
+            height_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            width_slices = (
+                slice(0, -self.window_size),
+                slice(-self.window_size, -self.shift_size),
+                slice(-self.shift_size, None),
+            )
+            count = 0
+            for height_slice in height_slices:
+                for width_slice in width_slices:
+                    img_mask[:, height_slice, width_slice, :] = count
+                    count += 1
+
+            mask_windows = window_partition(img_mask, self.window_size)
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+        else:
+            attn_mask = None
+        return attn_mask
+
+    def maybe_pad(self, hidden_states, height, width):
+        pad_right = (self.window_size - width % self.window_size) % self.window_size
+        pad_bottom = (self.window_size - height % self.window_size) % self.window_size
+        pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
+        hidden_states = nn.functional.pad(hidden_states, pad_values)
+        return hidden_states, pad_values
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if not always_partition:
+            self.set_shift_and_window_size(input_dimensions)
+        else:
+            pass
+        height, width = input_dimensions
+        batch_size, _, channels = hidden_states.size()
+        
+
+
+        hidden_states = self.ce[0](hidden_states, input_dimensions)
+        shortcut = hidden_states
+
+
+        hidden_states = self.layernorm_before(hidden_states)
+        hidden_states = hidden_states.view(batch_size, height, width, channels)
+
+        # pad hidden_states to multiples of window size
+        hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
+
+        _, height_pad, width_pad, _ = hidden_states.shape
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_hidden_states = hidden_states
+
+        # partition windows
+        hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
+        hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
+        attn_mask = self.get_attn_mask(
+            height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
+        )
+
+        attention_outputs = self.attention(
+            hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
+        )
+
+        attention_output = attention_outputs[0]
+
+        attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
+        shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            attention_windows = shifted_windows
+
+        was_padded = pad_values[3] > 0 or pad_values[5] > 0
+        if was_padded:
+            attention_windows = attention_windows[:, :height, :width, :].contiguous()
+
+        attention_windows = attention_windows.view(batch_size, height * width, channels)
+
+        hidden_states = shortcut + self.drop_path(attention_windows)
+
+
+
+        hidden_states = self.ce[1](hidden_states, input_dimensions)
+        layer_output = self.layernorm_after(hidden_states)
+        layer_output = self.intermediate(layer_output)
+        layer_output = hidden_states + self.output(layer_output)
+
+        layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
+        return layer_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->UnimerNet
+class UnimerNetStage(nn.Module):
+    def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
+        super().__init__()
+        self.config = config
+        self.dim = dim
+        self.blocks = nn.ModuleList(
+            [
+                UnimerNetLayer(
+                    config=config,
+                    dim=dim,
+                    input_resolution=input_resolution,
+                    num_heads=num_heads,
+                    shift_size=0,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
+        else:
+            self.downsample = None
+
+        self.pointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+    ) -> Tuple[torch.Tensor]:
+        height, width = input_dimensions
+        for i, layer_module in enumerate(self.blocks):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            layer_outputs = layer_module(
+                hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+            )
+
+            hidden_states = layer_outputs[0]
+
+        hidden_states_before_downsampling = hidden_states
+        if self.downsample is not None:
+            height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
+            output_dimensions = (height, width, height_downsampled, width_downsampled)
+            hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
+        else:
+            output_dimensions = (height, width, height, width)
+
+        stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
+
+        if output_attentions:
+            stage_outputs += layer_outputs[1:]
+        return stage_outputs
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->UnimerNet
+class UnimerNetEncoder(nn.Module):
+    def __init__(self, config, grid_size):
+        super().__init__()
+        self.num_layers = len(config.depths)
+        self.config = config
+        dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
+        self.layers = nn.ModuleList(
+            [
+                UnimerNetStage(
+                    config=config,
+                    dim=int(config.embed_dim * 2**i_layer),
+                    input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
+                    depth=config.depths[i_layer],
+                    num_heads=config.num_heads[i_layer],
+                    drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
+                    downsample=UnimerNetPatchMerging if (i_layer < self.num_layers - 1) else None,
+                )
+                for i_layer in range(self.num_layers)
+            ]
+        )
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        input_dimensions: Tuple[int, int],
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = False,
+        output_hidden_states: Optional[bool] = False,
+        output_hidden_states_before_downsampling: Optional[bool] = False,
+        always_partition: Optional[bool] = False,
+        return_dict: Optional[bool] = True,
+    ) -> Union[Tuple, UnimerNetEncoderOutput]:
+        all_hidden_states = () if output_hidden_states else None
+        all_reshaped_hidden_states = () if output_hidden_states else None
+        all_self_attentions = () if output_attentions else None
+
+        if output_hidden_states:
+            batch_size, _, hidden_size = hidden_states.shape
+            # rearrange b (h w) c -> b c h w
+            reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+            reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+            all_hidden_states += (hidden_states,)
+            all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+        for i, layer_module in enumerate(self.layers):
+            layer_head_mask = head_mask[i] if head_mask is not None else None
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    layer_module.__call__,
+                    hidden_states,
+                    input_dimensions,
+                    layer_head_mask,
+                    output_attentions,
+                    always_partition,
+                )
+            else:
+                layer_outputs = layer_module(
+                    hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
+                )
+
+            hidden_states = layer_outputs[0]
+            hidden_states_before_downsampling = layer_outputs[1]
+            output_dimensions = layer_outputs[2]
+
+            input_dimensions = (output_dimensions[-2], output_dimensions[-1])
+
+            if output_hidden_states and output_hidden_states_before_downsampling:
+                batch_size, _, hidden_size = hidden_states_before_downsampling.shape
+                # rearrange b (h w) c -> b c h w
+                # here we use the original (not downsampled) height and width
+                reshaped_hidden_state = hidden_states_before_downsampling.view(
+                    batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
+                )
+                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states_before_downsampling,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+            elif output_hidden_states and not output_hidden_states_before_downsampling:
+                batch_size, _, hidden_size = hidden_states.shape
+                # rearrange b (h w) c -> b c h w
+                reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
+                reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
+                all_hidden_states += (hidden_states,)
+                all_reshaped_hidden_states += (reshaped_hidden_state,)
+
+            if output_attentions:
+                all_self_attentions += layer_outputs[3:]
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+
+        return UnimerNetEncoderOutput(
+            last_hidden_state=hidden_states,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+            reshaped_hidden_states=all_reshaped_hidden_states,
+        )
+
+
+# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->UnimerNet
+class UnimerNetPreTrainedModel(PreTrainedModel):
+    """
+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+    models.
+    """
+
+    config_class = UnimerNetConfig
+    base_model_prefix = "swin"
+    main_input_name = "pixel_values"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["UnimerNetStage"]
+
+    def _init_weights(self, module):
+        """Initialize the weights"""
+        if isinstance(module, (nn.Linear, nn.Conv2d)):
+            # Slightly different from the TF version which uses truncated_normal for initialization
+            # cf https://github.com/pytorch/pytorch/pull/5617
+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.LayerNorm):
+            module.bias.data.zero_()
+            module.weight.data.fill_(1.0)
+
+
+SWIN_START_DOCSTRING = r"""
+    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
+    it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
+    behavior.
+
+    Parameters:
+        config ([`UnimerNetConfig`]): Model configuration class with all the parameters of the model.
+            Initializing with a config file does not load the weights associated with the model, only the
+            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+SWIN_INPUTS_DOCSTRING = r"""
+    Args:
+        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
+            [`DonutImageProcessor.__call__`] for details.
+        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+            - 1 indicates the head is **not masked**,
+            - 0 indicates the head is **masked**.
+
+        output_attentions (`bool`, *optional*):
+            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+            tensors for more detail.
+        output_hidden_states (`bool`, *optional*):
+            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+            more detail.
+        interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
+            Whether to interpolate the pre-trained position encodings.
+        return_dict (`bool`, *optional*):
+            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+    "The bare UnimerNet Model transformer outputting raw hidden-states without any specific head on top.",
+    SWIN_START_DOCSTRING,
+)
+class UnimerNetModel(UnimerNetPreTrainedModel):
+    def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
+        super().__init__(config)
+        self.config = config
+        self.num_layers = len(config.depths)
+        self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
+
+        self.embeddings = UnimerNetEmbeddings(config, use_mask_token=use_mask_token)
+        self.encoder = UnimerNetEncoder(config, self.embeddings.patch_grid)
+
+        self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return self.embeddings.patch_embeddings
+
+    def _prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+        class PreTrainedModel
+        """
+        for layer, heads in heads_to_prune.items():
+            self.encoder.layer[layer].attention.prune_heads(heads)
+
+    @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
+    @add_code_sample_docstrings(
+        checkpoint=_CHECKPOINT_FOR_DOC,
+        output_type=UnimerNetModelOutput,
+        config_class=_CONFIG_FOR_DOC,
+        modality="vision",
+        expected_output=_EXPECTED_OUTPUT_SHAPE,
+    )
+    def forward(
+        self,
+        pixel_values: Optional[torch.FloatTensor] = None,
+        bool_masked_pos: Optional[torch.BoolTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        interpolate_pos_encoding: bool = False,
+        return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, UnimerNetModelOutput]:
+        r"""
+        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
+            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
+        """
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if pixel_values is None:
+            raise ValueError("You have to specify pixel_values")
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+        head_mask = self.get_head_mask(head_mask, len(self.config.depths))
+
+        embedding_output, input_dimensions = self.embeddings(
+            pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
+        )
+
+        encoder_outputs = self.encoder(
+            embedding_output,
+            input_dimensions,
+            head_mask=head_mask,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        sequence_output = encoder_outputs[0]
+
+        pooled_output = None
+        if self.pooler is not None:
+            pooled_output = self.pooler(sequence_output.transpose(1, 2))
+            pooled_output = torch.flatten(pooled_output, 1)
+
+        if not return_dict:
+            output = (sequence_output, pooled_output) + encoder_outputs[1:]
+
+            return output
+
+        return UnimerNetModelOutput(
+            last_hidden_state=sequence_output,
+            pooler_output=pooled_output,
+            hidden_states=encoder_outputs.hidden_states,
+            attentions=encoder_outputs.attentions,
+            reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
+        )
diff --git a/unimernet/models/unimernet/processor.py b/unimernet/models/unimernet/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dce532e08cb6c671c500c33464a283e2a178afe
--- /dev/null
+++ b/unimernet/models/unimernet/processor.py
@@ -0,0 +1,192 @@
+from typing import Dict, Union, Optional, List
+
+from torch import TensorType
+from transformers import DonutImageProcessor, DonutProcessor
+from transformers.image_processing_utils import BatchFeature
+from transformers.image_transforms import pad
+from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
+    valid_images, to_numpy_array, is_scaled_image, get_image_size
+import numpy as np
+import PIL
+import logging
+
+logger = logging.getLogger()
+
+IMAGE_STD = [0.229, 0.224, 0.225]
+IMAGE_MEAN = [0.485, 0.456, 0.406]
+
+
+class VariableDonutImageProcessor(DonutImageProcessor):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def numpy_resize(self, image: np.ndarray, size, resample):
+        image = PIL.Image.fromarray(image)
+        resized = self.pil_resize(image, size, resample)
+        resized = np.array(resized, dtype=np.uint8)
+        resized_image = resized.transpose(2, 0, 1)
+
+        return resized_image
+
+    def pil_resize(self, image: PIL.Image.Image, size, resample):
+        width, height = image.size
+        max_width, max_height = size["width"], size["height"]
+        if width != max_width or height != max_height:
+            # Shrink to fit within dimensions
+            width_scale = max_width / width
+            height_scale = max_height / height
+            scale = min(width_scale, height_scale)
+
+            new_width = min(int(width * scale), max_width)
+            new_height = min(int(height * scale), max_height)
+
+            image = image.resize((new_width, new_height), resample)
+
+        image.thumbnail((max_width, max_height), resample)
+
+        assert image.width <= max_width and image.height <= max_height
+
+        return image
+
+    def process_inner(self, images: List[List], train=False):
+        # This will be in list of lists format, with height x width x channel
+        assert isinstance(images[0], (list, np.ndarray))
+
+        # convert list of lists format to array
+        if isinstance(images[0], list):
+            # numpy unit8 needed for augmentation
+            np_images = [np.array(img, dtype=np.uint8) for img in images]
+        else:
+            np_images = [img.astype(np.uint8) for img in images]
+
+        assert np_images[0].shape[2] == 3  # RGB input images, channel dim last
+
+        # This also applies the right channel dim format, to channel x height x width
+        np_images = [self.numpy_resize(img, self.max_size, self.resample) for img in np_images]
+        assert np_images[0].shape[0] == 3  # RGB input images, channel dim first
+
+        # Convert to float32 for rescale/normalize
+        np_images = [img.astype(np.float32) for img in np_images]
+
+        # Pads with 255 (whitespace)
+        # Pad to max size to improve performance
+        max_size = self.max_size
+        np_images = [
+            self.pad_image(
+                image=image,
+                size=max_size,
+                random_padding=train,  # Change amount of padding randomly during training
+                input_data_format=ChannelDimension.FIRST,
+                pad_value=255.0
+            )
+            for image in np_images
+        ]
+
+        # Rescale and normalize
+        np_images = [
+            self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
+            for img in np_images
+        ]
+        np_images = [
+            self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
+            for img in np_images
+        ]
+
+        return np_images
+
+    def preprocess(
+            self,
+            images: ImageInput,
+            do_resize: bool = None,
+            size: Dict[str, int] = None,
+            resample: PILImageResampling = None,
+            do_thumbnail: bool = None,
+            do_align_long_axis: bool = None,
+            do_pad: bool = None,
+            random_padding: bool = False,
+            do_rescale: bool = None,
+            rescale_factor: float = None,
+            do_normalize: bool = None,
+            image_mean: Optional[Union[float, List[float]]] = None,
+            image_std: Optional[Union[float, List[float]]] = None,
+            return_tensors: Optional[Union[str, TensorType]] = None,
+            data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
+            input_data_format: Optional[Union[str, ChannelDimension]] = None,
+            **kwargs,
+    ) -> PIL.Image.Image:
+        images = make_list_of_images(images)
+
+        if not valid_images(images):
+            raise ValueError(
+                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
+                "torch.Tensor, tf.Tensor or jax.ndarray."
+            )
+
+        # Convert to numpy for later processing steps
+        images = [to_numpy_array(image) for image in images]
+
+        images = self.process_inner(images, train=False)
+
+        data = {"pixel_values": images}
+        return BatchFeature(data=data, tensor_type=return_tensors)
+
+    def pad_image(
+            self,
+            image: np.ndarray,
+            size: Dict[str, int],
+            random_padding: bool = False,
+            data_format: Optional[Union[str, ChannelDimension]] = None,
+            input_data_format: Optional[Union[str, ChannelDimension]] = None,
+            pad_value: float = 0.0,
+    ) -> np.ndarray:
+        output_height, output_width = size["height"], size["width"]
+        input_height, input_width = get_image_size(image, channel_dim=input_data_format)
+
+        delta_width = output_width - input_width
+        delta_height = output_height - input_height
+
+        assert delta_width >= 0 and delta_height >= 0
+
+        if random_padding:
+            pad_top = np.random.randint(low=0, high=delta_height + 1)
+            pad_left = np.random.randint(low=0, high=delta_width + 1)
+        else:
+            pad_top = delta_height // 2
+            pad_left = delta_width // 2
+
+        pad_bottom = delta_height - pad_top
+        pad_right = delta_width - pad_left
+
+        padding = ((pad_top, pad_bottom), (pad_left, pad_right))
+        return pad(image, padding, data_format=data_format, input_data_format=input_data_format,
+                   constant_values=pad_value)
+
+
+class VariableDonutProcessor(DonutProcessor):
+    def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
+        if image_processor is None:
+            raise ValueError("You need to specify an `image_processor`.")
+        if tokenizer is None:
+            raise ValueError("You need to specify a `tokenizer`.")
+
+        super().__init__(image_processor, tokenizer)
+        self.current_processor = self.image_processor
+        self._in_target_context_manager = False
+        self.train = train
+
+    def __call__(self, *args, **kwargs):
+        # For backward compatibility
+        if self._in_target_context_manager:
+            return self.current_processor(*args, **kwargs)
+
+        images = kwargs.pop("images", None)
+        text = kwargs.pop("text", None)
+        if len(args) > 0:
+            images = args[0]
+            args = args[1:]
+
+        if images is None:
+            raise ValueError("You need to specify images to process.")
+
+        inputs = self.image_processor(images, *args, **kwargs)
+        return inputs
diff --git a/unimernet/models/unimernet/unimernet.py b/unimernet/models/unimernet/unimernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f224ae449aaf99b250b0507401c39a40ab233ddd
--- /dev/null
+++ b/unimernet/models/unimernet/unimernet.py
@@ -0,0 +1,110 @@
+import torch
+import torch.nn.functional as F
+from unimernet.common.registry import registry
+from unimernet.models.blip2_models.blip2 import Blip2Base
+from unimernet.models.unimernet.encoder_decoder import DonutEncoderDecoder, DonutTokenizer
+
+
+@registry.register_model("unimernet")
+class UniMERModel(Blip2Base):
+    """
+    Nougat model for formula recognition.
+    Supported model types:
+        - default
+    Usage:
+        >>> from unimernet.models import load_model
+        >>> model = load_model("unimernet", "default")
+    """
+
+    PRETRAINED_MODEL_CONFIG_DICT = {
+        "default": "configs/models/unimernet_base.yaml",
+        "unimernet": "configs/models/unimernet_base.yaml",
+    }
+
+    def __init__(
+            self,
+            *,
+            model_name,
+            model_config,
+            tokenizer_name,
+            tokenizer_config,
+    ):
+        super().__init__()
+
+        self.tokenizer = DonutTokenizer(tokenizer_config.path)
+        self.model = DonutEncoderDecoder(
+            model_config.model_name,
+            num_tokens=len(self.tokenizer),
+            bos_token_id=self.tokenizer.bos_token_id,
+            pad_token_id=self.tokenizer.pad_token_id,
+            eos_token_id=self.tokenizer.eos_token_id,
+        )
+        self.max_seq_len = model_config.max_seq_len
+        self.tokenizer.max_seq_len = self.max_seq_len
+
+    def forward(self, samples):
+        image, text = samples["image"], samples["text_input"]
+
+        text_inputs = self.tokenizer.tokenize(text).to(image.device)
+        count_gt = self._get_count_gt(text, image.device)
+        tgt_seq, tgt_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
+        with self.maybe_autocast():
+            loss = self.model(
+                pixel_values=image,
+                decoder_input_ids=tgt_seq,
+                decoder_attention_mask=tgt_mask,
+                decoder_count_gt=count_gt,
+            )
+        return {"loss": loss}
+
+    def _get_count_gt(self, text, device):
+        labels = self.tokenizer.tokenize(text, max_length=1536)["input_ids"].to(device)
+        mask = labels != self.tokenizer.pad_token_id
+        one_hot_labels = F.one_hot(labels, num_classes=self.tokenizer.tokenizer.vocab_size) * mask.unsqueeze(-1)
+        count_gt = torch.sum(one_hot_labels, dim=1)
+        return count_gt # (bs, vocab_size)
+
+    @torch.no_grad()
+    def generate(
+            self,
+            samples,
+            temperature: float = 0.2,
+            do_sample: bool = False,
+            top_p: float = 0.95,
+            **kwargs
+    ):
+
+        image = samples["image"]
+        with self.maybe_autocast():
+            outputs = self.model.generate(
+                pixel_values=image,
+                temperature=temperature,
+                max_new_tokens=self.max_seq_len,
+                decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
+                # decoder_end_token_id=self.tokenizer.tokenizer.eos_token_id,
+                do_sample=do_sample,
+                top_p=top_p,
+                **kwargs
+            )
+        pred_tokens = self.tokenizer.detokenize(outputs)
+        pred_str = self.tokenizer.token2str(outputs)
+        return {"pred_tokens": pred_tokens, "pred_str": pred_str, "pred_ids": outputs}
+
+    @classmethod
+    def from_config(cls, cfg):
+
+        model_name = cfg.get("model_name")
+        model_config = cfg.get("model_config")
+        tokenizer_name = cfg.get("tokenizer_name")
+        tokenizer_config = cfg.get("tokenizer_config")
+
+        model = cls(
+            model_name=model_name,
+            model_config=model_config,
+            tokenizer_name=tokenizer_name,
+            tokenizer_config=tokenizer_config
+        )
+
+        model.load_checkpoint_from_config(cfg)
+
+        return model
diff --git a/unimernet/models/unimernet/utils.py b/unimernet/models/unimernet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce00f4656c00777f868c7740a480987b2d4d8c40
--- /dev/null
+++ b/unimernet/models/unimernet/utils.py
@@ -0,0 +1,55 @@
+import torch
+import torch.nn as nn
+
+from . import hybrid
+from . import vit
+from . import transformer
+
+
+class Model(nn.Module):
+    def __init__(self, encoder, decoder, args):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.args = args
+
+    def data_parallel(self, x: torch.Tensor, device_ids, output_device=None, **kwargs):
+        if not device_ids or len(device_ids) == 1:
+            return self(x, **kwargs)
+        if output_device is None:
+            output_device = device_ids[0]
+        replicas = nn.parallel.replicate(self, device_ids)
+        inputs = nn.parallel.scatter(x, device_ids)  # Slices tensors into approximately equal chunks and distributes them across given GPUs.
+        kwargs = nn.parallel.scatter(kwargs, device_ids)  # Duplicates references to objects that are not tensors.
+        replicas = replicas[:len(inputs)]
+        kwargs = kwargs[:len(inputs)]
+        outputs = nn.parallel.parallel_apply(replicas, inputs, kwargs)
+        return nn.parallel.gather(outputs, output_device).mean()
+
+    def forward(self, x: torch.Tensor, tgt_seq: torch.Tensor,  **kwargs):
+        encoded = self.encoder(x)
+        out = self.decoder(tgt_seq, context=encoded, **kwargs)
+        return out
+
+    @torch.no_grad()
+    def generate(self, x: torch.Tensor, temperature: float = 0.25):
+        return self.decoder.generate((torch.LongTensor([self.args.bos_token]*len(x))[:, None]).to(x.device), self.args.max_seq_len,
+                                     eos_token=self.args.eos_token, context=self.encoder(x), temperature=temperature)
+
+
+def get_model(args):
+    if args.encoder_structure.lower() == 'vit':
+        encoder = vit.get_encoder(args)
+    elif args.encoder_structure.lower() == 'hybrid':
+        encoder = hybrid.get_encoder(args)
+    else:
+        raise NotImplementedError('Encoder structure "%s" not supported.' % args.encoder_structure)
+    decoder = transformer.get_decoder(args)
+    encoder.to(args.device)
+    decoder.to(args.device)
+    model = Model(encoder, decoder, args)
+    if args.wandb:
+        import wandb
+        wandb.watch(model)
+
+    return model
\ No newline at end of file
diff --git a/unimernet/models/vit.py b/unimernet/models/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b2c4de691f98c1d0ee2a921fa0aaf8ccc9cdfb
--- /dev/null
+++ b/unimernet/models/vit.py
@@ -0,0 +1,527 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+ 
+ Based on timm code base
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
+"""
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.models.vision_transformer import _cfg, PatchEmbed
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.helpers import named_apply, adapt_input_conv
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+from unimernet.models.base_model import BaseEncoder
+
+
+class Mlp(nn.Module):
+    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        act_layer=nn.GELU,
+        drop=0.0,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads=8,
+        qkv_bias=False,
+        qk_scale=None,
+        attn_drop=0.0,
+        proj_drop=0.0,
+    ):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim**-0.5
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.attn_gradients = None
+        self.attention_map = None
+
+    def save_attn_gradients(self, attn_gradients):
+        self.attn_gradients = attn_gradients
+
+    def get_attn_gradients(self):
+        return self.attn_gradients
+
+    def save_attention_map(self, attention_map):
+        self.attention_map = attention_map
+
+    def get_attention_map(self):
+        return self.attention_map
+
+    def forward(self, x, register_hook=False):
+        B, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )  # make torchscript happy (cannot use tensor as tuple)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        if register_hook:
+            self.save_attention_map(attn)
+            attn.register_hook(self.save_attn_gradients)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        mlp_ratio=4.0,
+        qkv_bias=False,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+        use_grad_checkpointing=False,
+    ):
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+        )
+
+        if use_grad_checkpointing:
+            self.attn = checkpoint_wrapper(self.attn)
+            self.mlp = checkpoint_wrapper(self.mlp)
+
+    def forward(self, x, register_hook=False):
+        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class VisionTransformer(nn.Module):
+    """Vision Transformer
+    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
+        https://arxiv.org/abs/2010.11929
+    """
+
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=16,
+        in_chans=3,
+        num_classes=1000,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        representation_size=None,
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        drop_path_rate=0.0,
+        norm_layer=None,
+        use_grad_checkpointing=False,
+        ckpt_layer=0,
+    ):
+        """
+        Args:
+            img_size (int, tuple): input image size
+            patch_size (int, tuple): patch size
+            in_chans (int): number of input channels
+            num_classes (int): number of classes for classification head
+            embed_dim (int): embedding dimension
+            depth (int): depth of transformer
+            num_heads (int): number of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+            drop_rate (float): dropout rate
+            attn_drop_rate (float): attention dropout rate
+            drop_path_rate (float): stochastic depth rate
+            norm_layer: (nn.Module): normalization layer
+        """
+        super().__init__()
+        self.num_features = (
+            self.embed_dim
+        ) = embed_dim  # num_features for consistency with other models
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, depth)
+        ]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList(
+            [
+                Block(
+                    dim=embed_dim,
+                    num_heads=num_heads,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop_rate,
+                    attn_drop=attn_drop_rate,
+                    drop_path=dpr[i],
+                    norm_layer=norm_layer,
+                    use_grad_checkpointing=(
+                        use_grad_checkpointing and i >= depth - ckpt_layer
+                    ),
+                )
+                for i in range(depth)
+            ]
+        )
+        self.norm = norm_layer(embed_dim)
+
+        trunc_normal_(self.pos_embed, std=0.02)
+        trunc_normal_(self.cls_token, std=0.02)
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {"pos_embed", "cls_token"}
+
+    def forward(self, x, register_blk=-1):
+        B = x.shape[0]
+        x = self.patch_embed(x)
+
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+
+        x = x + self.pos_embed[:, : x.size(1), :]
+        x = self.pos_drop(x)
+
+        for i, blk in enumerate(self.blocks):
+            x = blk(x, register_blk == i)
+        x = self.norm(x)
+
+        return x
+
+    @torch.jit.ignore()
+    def load_pretrained(self, checkpoint_path, prefix=""):
+        _load_weights(self, checkpoint_path, prefix)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""):
+    """Load weights from .npz checkpoints for official Google Brain Flax implementation"""
+    import numpy as np
+
+    def _n2p(w, t=True):
+        if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+            w = w.flatten()
+        if t:
+            if w.ndim == 4:
+                w = w.transpose([3, 2, 0, 1])
+            elif w.ndim == 3:
+                w = w.transpose([2, 0, 1])
+            elif w.ndim == 2:
+                w = w.transpose([1, 0])
+        return torch.from_numpy(w)
+
+    w = np.load(checkpoint_path)
+    if not prefix and "opt/target/embedding/kernel" in w:
+        prefix = "opt/target/"
+
+    if hasattr(model.patch_embed, "backbone"):
+        # hybrid
+        backbone = model.patch_embed.backbone
+        stem_only = not hasattr(backbone, "stem")
+        stem = backbone if stem_only else backbone.stem
+        stem.conv.weight.copy_(
+            adapt_input_conv(
+                stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"])
+            )
+        )
+        stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"]))
+        stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"]))
+        if not stem_only:
+            for i, stage in enumerate(backbone.stages):
+                for j, block in enumerate(stage.blocks):
+                    bp = f"{prefix}block{i + 1}/unit{j + 1}/"
+                    for r in range(3):
+                        getattr(block, f"conv{r + 1}").weight.copy_(
+                            _n2p(w[f"{bp}conv{r + 1}/kernel"])
+                        )
+                        getattr(block, f"norm{r + 1}").weight.copy_(
+                            _n2p(w[f"{bp}gn{r + 1}/scale"])
+                        )
+                        getattr(block, f"norm{r + 1}").bias.copy_(
+                            _n2p(w[f"{bp}gn{r + 1}/bias"])
+                        )
+                    if block.downsample is not None:
+                        block.downsample.conv.weight.copy_(
+                            _n2p(w[f"{bp}conv_proj/kernel"])
+                        )
+                        block.downsample.norm.weight.copy_(
+                            _n2p(w[f"{bp}gn_proj/scale"])
+                        )
+                        block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"]))
+        embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"])
+    else:
+        embed_conv_w = adapt_input_conv(
+            model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])
+        )
+    model.patch_embed.proj.weight.copy_(embed_conv_w)
+    model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"]))
+    model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False))
+    pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False)
+    if pos_embed_w.shape != model.pos_embed.shape:
+        pos_embed_w = resize_pos_embed(  # resize pos embedding when different size from pretrained weights
+            pos_embed_w,
+            model.pos_embed,
+            getattr(model, "num_tokens", 1),
+            model.patch_embed.grid_size,
+        )
+    model.pos_embed.copy_(pos_embed_w)
+    model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"]))
+    model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"]))
+    #     if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+    #         model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+    #         model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+    #     if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+    #         model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+    #         model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+    for i, block in enumerate(model.blocks.children()):
+        block_prefix = f"{prefix}Transformer/encoderblock_{i}/"
+        mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/"
+        block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"]))
+        block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"]))
+        block.attn.qkv.weight.copy_(
+            torch.cat(
+                [
+                    _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T
+                    for n in ("query", "key", "value")
+                ]
+            )
+        )
+        block.attn.qkv.bias.copy_(
+            torch.cat(
+                [
+                    _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1)
+                    for n in ("query", "key", "value")
+                ]
+            )
+        )
+        block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1))
+        block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"]))
+        for r in range(2):
+            getattr(block.mlp, f"fc{r + 1}").weight.copy_(
+                _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])
+            )
+            getattr(block.mlp, f"fc{r + 1}").bias.copy_(
+                _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])
+            )
+        block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"]))
+        block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"]))
+
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+    # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+    # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+    print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape)
+    ntok_new = posemb_new.shape[1]
+    if num_tokens:
+        posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+        ntok_new -= num_tokens
+    else:
+        posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+    gs_old = int(math.sqrt(len(posemb_grid)))
+    if not len(gs_new):  # backwards compatibility
+        gs_new = [int(math.sqrt(ntok_new))] * 2
+    assert len(gs_new) >= 2
+    print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new)
+    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = F.interpolate(
+        posemb_grid, size=gs_new, mode="bicubic", align_corners=False
+    )
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+    return
+
+
+def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
+    # interpolate position embedding
+    embedding_size = pos_embed_checkpoint.shape[-1]
+    num_patches = visual_encoder.patch_embed.num_patches
+    num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
+    # height (== width) for the checkpoint position embedding
+    orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+    # height (== width) for the new position embedding
+    new_size = int(num_patches**0.5)
+
+    if orig_size != new_size:
+        # class_token and dist_token are kept unchanged
+        extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+        # only the position tokens are interpolated
+        pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+        pos_tokens = pos_tokens.reshape(
+            -1, orig_size, orig_size, embedding_size
+        ).permute(0, 3, 1, 2)
+        pos_tokens = torch.nn.functional.interpolate(
+            pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
+        )
+        pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+        new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+        print(
+            "reshape position embedding from %d to %d" % (orig_size**2, new_size**2)
+        )
+
+        return new_pos_embed
+    else:
+        return pos_embed_checkpoint
+
+
+class VisionTransformerEncoder(VisionTransformer, BaseEncoder):
+    @classmethod
+    def from_config(cls, cfg, from_pretrained=False):
+
+        vit_type = cfg.get("vit_type", "base")
+        image_size = cfg.get("image_size", 384)
+        ckpt_layer = cfg.get("vit_ckpt_layer", 0)
+        drop_path_rate = cfg.get("vit_drop_path_rate", 0)
+        norm_layer_eps = cfg.get("vit_layer_norm_epsilon", -1)
+        use_grad_checkpointing = cfg.get("vit_grad_ckpt", False)
+
+        if norm_layer_eps == -1:
+            norm_layer = None
+        else:
+            norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps)
+
+        #     norm_layer=partial(nn.LayerNorm, eps=1e-6),
+        assert vit_type in ["base", "large"], "vit parameter must be base or large"
+        if vit_type == "base":
+            vision_width = 768
+            visual_encoder = cls(
+                img_size=image_size,
+                patch_size=16,
+                embed_dim=vision_width,
+                depth=12,
+                num_heads=12,
+                use_grad_checkpointing=use_grad_checkpointing,
+                ckpt_layer=ckpt_layer,
+                drop_path_rate=0 or drop_path_rate,
+                norm_layer=norm_layer,
+            )
+
+            if from_pretrained:
+                checkpoint = torch.hub.load_state_dict_from_url(
+                    url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+                    map_location="cpu",
+                    check_hash=True,
+                )
+                state_dict = checkpoint["model"]
+                state_dict["pos_embed"] = interpolate_pos_embed(
+                    state_dict["pos_embed"], visual_encoder
+                )
+                msg = visual_encoder.load_state_dict(state_dict, strict=False)
+
+        elif vit_type == "large":
+            vision_width = 1024
+            visual_encoder = cls(
+                img_size=image_size,
+                patch_size=16,
+                embed_dim=vision_width,
+                depth=24,
+                num_heads=16,
+                use_grad_checkpointing=use_grad_checkpointing,
+                ckpt_layer=ckpt_layer,
+                drop_path_rate=0.1 or drop_path_rate,
+                norm_layer=norm_layer,
+            )
+            if from_pretrained:
+                from timm.models.helpers import load_custom_pretrained
+                from timm.models.vision_transformer import default_cfgs
+
+                load_custom_pretrained(
+                    visual_encoder, default_cfgs["vit_large_patch16_224_in21k"]
+                )
+
+        visual_encoder.vision_width = vision_width
+        return visual_encoder
+
+    def forward_features(self, x, register_blk=-1):
+        return super().forward(x, register_blk)
diff --git a/unimernet/processors/__init__.py b/unimernet/processors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7548d5004f05b9717da6d87b66649ba0dfacda52
--- /dev/null
+++ b/unimernet/processors/__init__.py
@@ -0,0 +1,40 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.processors.base_processor import BaseProcessor
+
+
+from unimernet.processors.blip_processors import (
+    BlipImageTrainProcessor,
+    Blip2ImageTrainProcessor,
+    BlipImageEvalProcessor,
+    BlipCaptionProcessor,
+)
+
+from unimernet.processors.formula_processor import (
+    FormulaImageTrainProcessor,
+    FormulaImageEvalProcessor,
+    FormulaImageMultiScaleTrainProcessor,
+)
+
+from unimernet.common.registry import registry
+
+__all__ = [
+    "BaseProcessor",
+    "BlipCaptionProcessor",
+]
+
+
+def load_processor(name, cfg=None):
+    """
+    Example
+
+    >>> processor = load_processor("alpro_video_train", cfg=None)
+    """
+    processor = registry.get_processor_class(name).from_config(cfg)
+
+    return processor
diff --git a/unimernet/processors/__pycache__/__init__.cpython-310.pyc b/unimernet/processors/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b62669811f31a126af7d41f9bb43c41e34b56b87
Binary files /dev/null and b/unimernet/processors/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/base_processor.cpython-310.pyc b/unimernet/processors/__pycache__/base_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4bcb97b26620df45215edeab07fab15192c725e8
Binary files /dev/null and b/unimernet/processors/__pycache__/base_processor.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc b/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b7b83f42465f3a3642eafc3afd563e05ee10e40
Binary files /dev/null and b/unimernet/processors/__pycache__/blip_processors.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc b/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..38eb3428ea69b5ce0192162de3e4ef55c42a1872
Binary files /dev/null and b/unimernet/processors/__pycache__/formula_processor.cpython-310.pyc differ
diff --git a/unimernet/processors/__pycache__/randaugment.cpython-310.pyc b/unimernet/processors/__pycache__/randaugment.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..390161bd14da673b7aa1188926b5d44e3752607b
Binary files /dev/null and b/unimernet/processors/__pycache__/randaugment.cpython-310.pyc differ
diff --git a/unimernet/processors/base_processor.py b/unimernet/processors/base_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4c9d86859270a046623661a632587f2b3136b46
--- /dev/null
+++ b/unimernet/processors/base_processor.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from omegaconf import OmegaConf
+
+
+class BaseProcessor:
+    def __init__(self):
+        self.transform = lambda x: x
+        return
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        return cls()
+
+    def build(self, **kwargs):
+        cfg = OmegaConf.create(kwargs)
+
+        return self.from_config(cfg)
diff --git a/unimernet/processors/blip_processors.py b/unimernet/processors/blip_processors.py
new file mode 100644
index 0000000000000000000000000000000000000000..28d6c4f920a126667bfebf35ef5d6a64b4294fcb
--- /dev/null
+++ b/unimernet/processors/blip_processors.py
@@ -0,0 +1,281 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import re
+
+from unimernet.common.registry import registry
+from unimernet.processors.base_processor import BaseProcessor
+from unimernet.processors.randaugment import RandomAugment
+from omegaconf import OmegaConf
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+
+
+class BlipImageBaseProcessor(BaseProcessor):
+    def __init__(self, mean=None, std=None):
+        if mean is None:
+            mean = (0.48145466, 0.4578275, 0.40821073)
+        if std is None:
+            std = (0.26862954, 0.26130258, 0.27577711)
+
+        self.normalize = transforms.Normalize(mean, std)
+
+
+@registry.register_processor("blip_caption")
+class BlipCaptionProcessor(BaseProcessor):
+    def __init__(self, prompt="", max_words=50):
+        self.prompt = prompt
+        self.max_words = max_words
+
+    def __call__(self, caption):
+        caption = self.prompt + self.pre_caption(caption)
+
+        return caption
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        prompt = cfg.get("prompt", "")
+        max_words = cfg.get("max_words", 50)
+
+        return cls(prompt=prompt, max_words=max_words)
+
+    def pre_caption(self, caption):
+        caption = re.sub(
+            r"([.!\"()*#:;~])",
+            " ",
+            caption.lower(),
+        )
+        caption = re.sub(
+            r"\s{2,}",
+            " ",
+            caption,
+        )
+        caption = caption.rstrip("\n")
+        caption = caption.strip(" ")
+
+        # truncate caption
+        caption_words = caption.split(" ")
+        if len(caption_words) > self.max_words:
+            caption = " ".join(caption_words[: self.max_words])
+
+        return caption
+
+@registry.register_processor("blip_caption_instruct")
+class BlipCaptionInstructProcessor(BaseProcessor):
+    def __init__(self, prompt="", max_words=256):
+        self.prompt = prompt
+        self.max_words = max_words
+
+    def __call__(self, caption):
+        caption = self.prompt + self.pre_caption(caption)
+
+        return caption
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        prompt = cfg.get("prompt", "")
+        max_words = cfg.get("max_words", 256)
+
+        return cls(prompt=prompt, max_words=max_words)
+
+    def pre_caption(self, caption):
+        # caption = re.sub(
+        #     r"([.!\"()*#:;~])",
+        #     " ",
+        #     caption.lower(),
+        # )
+        # caption = re.sub(
+        #     r"\s{2,}",
+        #     " ",
+        #     caption,
+        # )
+        caption = caption.rstrip("\n")
+        caption = caption.strip(" ")
+
+        # # truncate caption
+        # caption_words = caption.split(" ")
+        # if len(caption_words) > self.max_words:
+        #     caption = " ".join(caption_words[: self.max_words])
+
+        return caption
+
+
+@registry.register_processor("blip_question")
+class BlipQuestionProcessor(BaseProcessor):
+    def __init__(self, max_words=50):
+        self.max_words = max_words
+
+    def __call__(self, question):
+        return self.pre_question(question)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        max_words = cfg.get("max_words", 50)
+
+        return cls(max_words=max_words)
+
+    def pre_question(self, question):
+        question = re.sub(
+            r"([.!\"()*#:;~])",
+            "",
+            question.lower(),
+        )
+        question = question.rstrip(" ")
+
+        # truncate question
+        question_words = question.split(" ")
+        if len(question_words) > self.max_words:
+            question = " ".join(question_words[: self.max_words])
+
+        return question
+
+
+@registry.register_processor("blip_image_train")
+class BlipImageTrainProcessor(BlipImageBaseProcessor):
+    def __init__(
+        self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
+    ):
+        super().__init__(mean=mean, std=std)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.RandomResizedCrop(
+                    image_size,
+                    scale=(min_scale, max_scale),
+                    interpolation=InterpolationMode.BICUBIC,
+                ),
+                transforms.RandomHorizontalFlip(),
+                RandomAugment(
+                    2,
+                    5,
+                    isPIL=True,
+                    augs=[
+                        "Identity",
+                        "AutoContrast",
+                        "Brightness",
+                        "Sharpness",
+                        "Equalize",
+                        "ShearX",
+                        "ShearY",
+                        "TranslateX",
+                        "TranslateY",
+                        "Rotate",
+                    ],
+                ),
+                transforms.ToTensor(),
+                self.normalize,
+            ]
+        )
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", 384)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        min_scale = cfg.get("min_scale", 0.5)
+        max_scale = cfg.get("max_scale", 1.0)
+
+        return cls(
+            image_size=image_size,
+            mean=mean,
+            std=std,
+            min_scale=min_scale,
+            max_scale=max_scale,
+        )
+
+
+@registry.register_processor("blip_image_eval")
+class BlipImageEvalProcessor(BlipImageBaseProcessor):
+    def __init__(self, image_size=384, mean=None, std=None):
+        super().__init__(mean=mean, std=std)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.Resize(
+                    (image_size, image_size), interpolation=InterpolationMode.BICUBIC
+                ),
+                transforms.ToTensor(),
+                self.normalize,
+            ]
+        )
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", 384)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        return cls(image_size=image_size, mean=mean, std=std)
+
+
+@registry.register_processor("blip2_image_train")
+class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
+    def __init__(
+        self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0
+    ):
+        super().__init__(mean=mean, std=std)
+
+        self.transform = transforms.Compose(
+            [
+                transforms.RandomResizedCrop(
+                    image_size,
+                    scale=(min_scale, max_scale),
+                    interpolation=InterpolationMode.BICUBIC,
+                ),
+                transforms.RandomHorizontalFlip(),
+                transforms.ToTensor(),
+                self.normalize,
+            ]
+        )
+
+    def __call__(self, item):
+        return self.transform(item)
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", 364)
+
+        mean = cfg.get("mean", None)
+        std = cfg.get("std", None)
+
+        min_scale = cfg.get("min_scale", 0.5)
+        max_scale = cfg.get("max_scale", 1.0)
+
+        return cls(
+            image_size=image_size,
+            mean=mean,
+            std=std,
+            min_scale=min_scale,
+            max_scale=max_scale,
+        )
\ No newline at end of file
diff --git a/unimernet/processors/formula_processor.py b/unimernet/processors/formula_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..efb12b2f9e5c1cc1f68f82ab008c75f8ccc16fb8
--- /dev/null
+++ b/unimernet/processors/formula_processor.py
@@ -0,0 +1,171 @@
+from unimernet.common.registry import registry
+from omegaconf import OmegaConf
+import albumentations as alb
+from albumentations.pytorch import ToTensorV2
+from unimernet.processors.base_processor import BaseProcessor
+import numpy as np
+import cv2
+from PIL import Image, ImageOps
+from torchvision.transforms.functional import resize
+import random
+from unimernet.processors.formula_processor_helper.nougat import Bitmap, Dilation, Erosion
+from unimernet.processors.formula_processor_helper.weather import Fog, Frost, Snow, Rain, Shadow
+
+
+class FormulaImageBaseProcessor(BaseProcessor):
+
+    def __init__(self, image_size):
+        super(FormulaImageBaseProcessor, self).__init__()
+        self.input_size = [int(_) for _ in image_size]
+        assert len(self.input_size) == 2
+
+    @staticmethod
+    def crop_margin(img: Image.Image) -> Image.Image:
+        data = np.array(img.convert("L"))
+        data = data.astype(np.uint8)
+        max_val = data.max()
+        min_val = data.min()
+        if max_val == min_val:
+            return img
+        data = (data - min_val) / (max_val - min_val) * 255
+        gray = 255 * (data < 200).astype(np.uint8)
+
+        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
+        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
+        return img.crop((a, b, w + a, h + b))
+
+    def prepare_input(self, img: Image.Image, random_padding: bool = False):
+        """
+        Convert PIL Image to tensor according to specified input_size after following steps below:
+            - resize
+            - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
+            - pad
+        """
+        if img is None:
+            return
+        # crop margins
+        try:
+            img = self.crop_margin(img.convert("RGB"))
+        except OSError:
+            # might throw an error for broken files
+            return
+
+        if img.height == 0 or img.width == 0:
+            return
+
+        img = resize(img, min(self.input_size))
+        img.thumbnail((self.input_size[1], self.input_size[0]))
+        delta_width = self.input_size[1] - img.width
+        delta_height = self.input_size[0] - img.height
+        if random_padding:
+            pad_width = np.random.randint(low=0, high=delta_width + 1)
+            pad_height = np.random.randint(low=0, high=delta_height + 1)
+        else:
+            pad_width = delta_width // 2
+            pad_height = delta_height // 2
+        padding = (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+        return ImageOps.expand(img, padding)
+
+
+@registry.register_processor("formula_image_train")
+class FormulaImageTrainProcessor(FormulaImageBaseProcessor):
+    def __init__(self, image_size=384):
+        super().__init__(image_size)
+
+        self.transform = alb.Compose(
+            [
+                alb.Compose(
+                    [
+                        Bitmap(p=0.05),
+                        alb.OneOf([Fog(), Frost(), Snow(), Rain(), Shadow()], p=0.2),
+                        alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.2),
+                        alb.ShiftScaleRotate(shift_limit=0, scale_limit=(-.15, 0), rotate_limit=1, border_mode=0,
+                                             interpolation=3,
+                                             value=[255, 255, 255],
+                                             p=1),
+                        alb.GridDistortion(distort_limit=0.1, border_mode=0, interpolation=3, value=[255, 255, 255],
+                                           p=.5)],
+                    p=.15),
+                # alb.InvertImg(p=.15),
+                alb.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.3),
+                alb.GaussNoise(10, p=.2),
+                alb.RandomBrightnessContrast(.05, (-.2, 0), True, p=0.2),
+                alb.ImageCompression(95, p=.3),
+                alb.ToGray(always_apply=True),
+                alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
+                # alb.Sharpen()
+                ToTensorV2(),
+            ]
+        )
+
+    def __call__(self, item):
+        img = self.prepare_input(item, random_padding=True)
+        if img is None:
+            return img
+        return self.transform(image=np.array(img))['image'][:1]
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", [384, 384])
+
+        return cls(
+            image_size=image_size,
+        )
+
+
+@registry.register_processor("formula_image_multi_scale_train")
+class FormulaImageMultiScaleTrainProcessor(FormulaImageTrainProcessor):
+    def __init__(self, all_scales):
+        for i, scales in enumerate(all_scales):
+            all_scales[i] = [int(_) for _ in scales]
+        super(FormulaImageMultiScaleTrainProcessor, self).__init__(all_scales[0])
+        self.all_scales = all_scales
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        all_scales = cfg.get("all_scales", [[384, 384]])
+        return cls(
+            all_scales=all_scales
+        )
+
+    def reset_scale(self):
+        self.input_size = random.choice(self.all_scales)
+
+
+@registry.register_processor("formula_image_eval")
+class FormulaImageEvalProcessor(FormulaImageBaseProcessor):
+    def __init__(self, image_size):
+        super().__init__(image_size)
+
+        self.transform = alb.Compose(
+            [
+                alb.ToGray(always_apply=True),
+                alb.Normalize((0.7931, 0.7931, 0.7931), (0.1738, 0.1738, 0.1738)),
+                # alb.Sharpen()
+                ToTensorV2(),
+            ]
+        )
+
+    def __call__(self, item):
+        image = self.prepare_input(item)
+        return self.transform(image=np.array(image))['image'][:1]
+
+    @classmethod
+    def from_config(cls, cfg=None):
+        if cfg is None:
+            cfg = OmegaConf.create()
+
+        image_size = cfg.get("image_size", [384, 384])
+
+        return cls(image_size=image_size)
diff --git a/unimernet/processors/formula_processor_helper/__init__.py b/unimernet/processors/formula_processor_helper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c43204a142ca2d57e1d4b5a2f933ac4194b55dde
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..837d17975ab8a3170a2ad42e3ec9df97e7c5b016
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/nougat.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bcf65d49a90b75e9663617a33efa0963bbe42748
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/ops.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc b/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae350972ba7f2d656db36bdcf258b53bf6e5e226
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/__pycache__/weather.cpython-310.pyc differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost1.png b/unimernet/processors/formula_processor_helper/frost/frost1.png
new file mode 100644
index 0000000000000000000000000000000000000000..c9edf9b6e1a2744d15af615af641f2aa48aa89c2
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/frost/frost1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ff9f907860bd7a835d459e32f9d588062b7f61ee267343cc7222b56753a14755
+size 1199930
diff --git a/unimernet/processors/formula_processor_helper/frost/frost2.png b/unimernet/processors/formula_processor_helper/frost/frost2.png
new file mode 100644
index 0000000000000000000000000000000000000000..48f7a861ffa41b6d7496b701fef96d5edf739282
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost2.png differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost3.png b/unimernet/processors/formula_processor_helper/frost/frost3.png
new file mode 100644
index 0000000000000000000000000000000000000000..d47f9d25f41251ee9a66b294c0bbfad6053017c9
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost3.png differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost4.jpg b/unimernet/processors/formula_processor_helper/frost/frost4.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f8b0c413176d70150b593e029d84b4a88c21dd4b
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost4.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost5.jpg b/unimernet/processors/formula_processor_helper/frost/frost5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..95dc9056926d8201df760535f9bb9112f012e862
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost5.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/frost/frost6.jpg b/unimernet/processors/formula_processor_helper/frost/frost6.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..14e5d58e762a5d0808df9fa6494fd6d78ee4409b
Binary files /dev/null and b/unimernet/processors/formula_processor_helper/frost/frost6.jpg differ
diff --git a/unimernet/processors/formula_processor_helper/nougat.py b/unimernet/processors/formula_processor_helper/nougat.py
new file mode 100644
index 0000000000000000000000000000000000000000..c51b99f3233b1b5254e423035b4a2a07816be5a3
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/nougat.py
@@ -0,0 +1,98 @@
+import albumentations as alb
+import numpy as np
+import cv2
+
+
+class Erosion(alb.ImageOnlyTransform):
+    """
+    Apply erosion operation to an image.
+
+    Erosion is a morphological operation that shrinks the white regions in a binary image.
+
+    Args:
+        scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
+            If an integer is provided, a square kernel of that size will be used.
+            If a tuple or list is provided, it should contain two integers representing the minimum
+            and maximum sizes for the erosion kernel.
+        always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+        p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+    Returns:
+        numpy.ndarray: The transformed image.
+    """
+
+    def __init__(self, scale, always_apply=False, p=0.5):
+        super().__init__(always_apply=always_apply, p=p)
+        if type(scale) is tuple or type(scale) is list:
+            assert len(scale) == 2
+            self.scale = scale
+        else:
+            self.scale = (scale, scale)
+
+    def apply(self, img, **params):
+        kernel = cv2.getStructuringElement(
+            cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
+        )
+        img = cv2.erode(img, kernel, iterations=1)
+        return img
+
+
+class Dilation(alb.ImageOnlyTransform):
+    """
+    Apply dilation operation to an image.
+
+    Dilation is a morphological operation that expands the white regions in a binary image.
+
+    Args:
+        scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
+            If an integer is provided, a square kernel of that size will be used.
+            If a tuple or list is provided, it should contain two integers representing the minimum
+            and maximum sizes for the dilation kernel.
+        always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+        p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+    Returns:
+        numpy.ndarray: The transformed image.
+    """
+
+    def __init__(self, scale, always_apply=False, p=0.5):
+        super().__init__(always_apply=always_apply, p=p)
+        if type(scale) is tuple or type(scale) is list:
+            assert len(scale) == 2
+            self.scale = scale
+        else:
+            self.scale = (scale, scale)
+
+    def apply(self, img, **params):
+        kernel = cv2.getStructuringElement(
+            cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
+        )
+        img = cv2.dilate(img, kernel, iterations=1)
+        return img
+
+
+class Bitmap(alb.ImageOnlyTransform):
+    """
+    Apply a bitmap-style transformation to an image.
+
+    This transformation replaces all pixel values below a certain threshold with a specified value.
+
+    Args:
+        value (int, optional): The value to replace pixels below the threshold with. Default is 0.
+        lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
+        always_apply (bool, optional): Whether to always apply this transformation. Default is False.
+        p (float, optional): The probability of applying this transformation. Default is 0.5.
+
+    Returns:
+        numpy.ndarray: The transformed image.
+    """
+
+    def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
+        super().__init__(always_apply=always_apply, p=p)
+        self.lower = lower
+        self.value = value
+
+    def apply(self, img, **params):
+        img = img.copy()
+        img[img < self.lower] = self.value
+        return img
diff --git a/unimernet/processors/formula_processor_helper/ops.py b/unimernet/processors/formula_processor_helper/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..b84016300e9237b87d2ce1d82e3f777f1599fe23
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/ops.py
@@ -0,0 +1,88 @@
+"""
+Common image operations
+
+Reference: https://github.com/hendrycks/robustness
+Hacked together for STR by: Rowel Atienza
+"""
+
+import cv2
+import numpy as np
+from scipy.ndimage import zoom as scizoom
+
+
+def clipped_zoom(img, zoom_factor):
+    h = img.shape[1]
+    # ceil crop height(= crop width)
+    ch = int(np.ceil(h / float(zoom_factor)))
+
+    top = (h - ch) // 2
+    img = scizoom(img[top:top + ch, top:top + ch], (zoom_factor, zoom_factor, 1), order=1)
+    # trim off any extra pixels
+    trim_top = (img.shape[0] - h) // 2
+
+    return img[trim_top:trim_top + h, trim_top:trim_top + h]
+
+
+def disk(radius, alias_blur=0.1, dtype=np.float32):
+    if radius <= 8:
+        coords = np.arange(-8, 8 + 1)
+        ksize = (3, 3)
+    else:
+        coords = np.arange(-radius, radius + 1)
+        ksize = (5, 5)
+    x, y = np.meshgrid(coords, coords)
+    aliased_disk = np.asarray((x ** 2 + y ** 2) <= radius ** 2, dtype=dtype)
+    aliased_disk /= np.sum(aliased_disk)
+
+    # supersample disk to antialias
+    return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
+
+
+# modification of https://github.com/FLHerne/mapgen/blob/master/diamondsquare.py
+def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
+    """
+    Generate a heightmap using diamond-square algorithm.
+    Return square 2d array, side length 'mapsize', of floats in range 0-255.
+    'mapsize' must be a power of two.
+    """
+    assert (mapsize & (mapsize - 1) == 0)
+    maparray = np.empty((mapsize, mapsize), dtype=np.float_)
+    maparray[0, 0] = 0
+    stepsize = mapsize
+    wibble = 100
+    if rng is None:
+        rng = np.random.default_rng()
+
+    def wibbledmean(array):
+        return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
+
+    def fillsquares():
+        """For each square of points stepsize apart,
+           calculate middle value as mean of points + wibble"""
+        cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
+        squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
+        squareaccum += np.roll(squareaccum, shift=-1, axis=1)
+        maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
+
+    def filldiamonds():
+        """For each diamond of points stepsize apart,
+           calculate middle value as mean of points + wibble"""
+        drgrid = maparray[stepsize // 2:mapsize:stepsize, stepsize // 2:mapsize:stepsize]
+        ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
+        ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
+        lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
+        ltsum = ldrsum + lulsum
+        maparray[0:mapsize:stepsize, stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
+        tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
+        tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
+        ttsum = tdrsum + tulsum
+        maparray[stepsize // 2:mapsize:stepsize, 0:mapsize:stepsize] = wibbledmean(ttsum)
+
+    while stepsize >= 2:
+        fillsquares()
+        filldiamonds()
+        stepsize //= 2
+        wibble /= wibbledecay
+
+    maparray -= maparray.min()
+    return maparray / maparray.max()
\ No newline at end of file
diff --git a/unimernet/processors/formula_processor_helper/weather.py b/unimernet/processors/formula_processor_helper/weather.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa5233958ad3f076df2d07b2c65b6e1f5cd29d7a
--- /dev/null
+++ b/unimernet/processors/formula_processor_helper/weather.py
@@ -0,0 +1,245 @@
+import math
+from io import BytesIO
+
+import cv2
+import numpy as np
+from PIL import Image, ImageOps, ImageDraw
+from pkg_resources import resource_filename
+from wand.image import Image as WandImage
+import albumentations as alb
+
+from .ops import plasma_fractal
+
+
+class Fog(alb.ImageOnlyTransform):
+    def __init__(self, mag=-1, always_apply=False, p=1.):
+        super().__init__(always_apply=always_apply, p=p)
+        self.rng = np.random.default_rng()
+        self.mag = mag
+
+    def apply(self, img, **params):
+        img = Image.fromarray(img.astype(np.uint8))
+        w, h = img.size
+        c = [(1.5, 2), (2., 2), (2.5, 1.7)]
+        if self.mag < 0 or self.mag >= len(c):
+            index = self.rng.integers(0, len(c))
+        else:
+            index = self.mag
+        c = c[index]
+
+        n_channels = len(img.getbands())
+        isgray = n_channels == 1
+
+        img = np.asarray(img) / 255.
+        max_val = img.max()
+        # Make sure fog image is at least twice the size of the input image
+        max_size = 2 ** math.ceil(math.log2(max(w, h)) + 1)
+        fog = c[0] * plasma_fractal(mapsize=max_size, wibbledecay=c[1], rng=self.rng)[:h, :w][..., np.newaxis]
+        # x += c[0] * plasma_fractal(wibbledecay=c[1])[:224, :224][..., np.newaxis]
+        # return np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
+        if isgray:
+            fog = np.squeeze(fog)
+        else:
+            fog = np.repeat(fog, 3, axis=2)
+
+        img += fog
+        img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
+        return img.astype(np.uint8)
+
+
+class Frost(alb.ImageOnlyTransform):
+    def __init__(self, mag=-1, always_apply=False, p=1.):
+        super().__init__(always_apply=always_apply, p=p)
+        self.rng = np.random.default_rng()
+        self.mag = mag
+
+    def apply(self, img, **params):
+        img = Image.fromarray(img.astype(np.uint8))
+        w, h = img.size
+        c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
+        if self.mag < 0 or self.mag >= len(c):
+            index = self.rng.integers(0, len(c))
+        else:
+            index = self.mag
+        c = c[index]
+
+        filename = [resource_filename(__name__, 'frost/frost1.png'),
+                    resource_filename(__name__, 'frost/frost2.png'),
+                    resource_filename(__name__, 'frost/frost3.png'),
+                    resource_filename(__name__, 'frost/frost4.jpg'),
+                    resource_filename(__name__, 'frost/frost5.jpg'),
+                    resource_filename(__name__, 'frost/frost6.jpg')]
+        index = self.rng.integers(0, len(filename))
+        filename = filename[index]
+        # Some images have transparency. Remove alpha channel.
+        frost = Image.open(filename).convert('RGB')
+
+        # Resize the frost image to match the input image's dimensions
+        f_w, f_h = frost.size
+        if w / h > f_w / f_h:
+            f_h = round(f_h * w / f_w)
+            f_w = w
+        else:
+            f_w = round(f_w * h / f_h)
+            f_h = h
+        frost = np.asarray(frost.resize((f_w, f_h)))
+
+        # randomly crop
+        y_start, x_start = self.rng.integers(0, f_h - h + 1), self.rng.integers(0, f_w - w + 1)
+        frost = frost[y_start:y_start + h, x_start:x_start + w]
+
+        n_channels = len(img.getbands())
+        isgray = n_channels == 1
+
+        img = np.asarray(img)
+
+        if isgray:
+            img = np.expand_dims(img, axis=2)
+            img = np.repeat(img, 3, axis=2)
+
+        img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
+        img = img.astype(np.uint8)
+        if isgray:
+            img = np.squeeze(img)
+        return img
+
+
+class Snow(alb.ImageOnlyTransform):
+    def __init__(self, mag=-1, always_apply=False, p=1.):
+        super().__init__(always_apply=always_apply, p=p)
+        self.rng = np.random.default_rng()
+        self.mag = mag
+
+    def apply(self, img, **params):
+        img = Image.fromarray(img.astype(np.uint8))
+        w, h = img.size
+        c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
+             (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
+             (0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
+        if self.mag < 0 or self.mag >= len(c):
+            index = self.rng.integers(0, len(c))
+        else:
+            index = self.mag
+        c = c[index]
+
+        n_channels = len(img.getbands())
+        isgray = n_channels == 1
+
+        img = np.asarray(img, dtype=np.float32) / 255.
+        if isgray:
+            img = np.expand_dims(img, axis=2)
+            img = np.repeat(img, 3, axis=2)
+
+        snow_layer = self.rng.normal(size=img.shape[:2], loc=c[0], scale=c[1])  # [:2] for monochrome
+
+        # snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
+        snow_layer[snow_layer < c[3]] = 0
+
+        snow_layer = Image.fromarray((np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode='L')
+        output = BytesIO()
+        snow_layer.save(output, format='PNG')
+        snow_layer = WandImage(blob=output.getvalue())
+
+        snow_layer.motion_blur(radius=c[4], sigma=c[5], angle=self.rng.uniform(-135, -45))
+
+        snow_layer = cv2.imdecode(np.frombuffer(snow_layer.make_blob(), np.uint8),
+                                  cv2.IMREAD_UNCHANGED) / 255.
+
+        # snow_layer = cv2.cvtColor(snow_layer, cv2.COLOR_BGR2RGB)
+
+        snow_layer = snow_layer[..., np.newaxis]
+
+        img = c[6] * img
+        gray_img = (1 - c[6]) * np.maximum(img, cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) * 1.5 + 0.5)
+        img += gray_img
+        img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
+        img = img.astype(np.uint8)
+        if isgray:
+            img = np.squeeze(img)
+        return img
+
+
+class Rain(alb.ImageOnlyTransform):
+    def __init__(self, mag=-1, always_apply=False, p=1.):
+        super().__init__(always_apply=always_apply, p=p)
+        self.rng = np.random.default_rng()
+        self.mag = mag
+
+    def apply(self, img, **params):
+        img = Image.fromarray(img.astype(np.uint8))
+        img = img.copy()
+        w, h = img.size
+        n_channels = len(img.getbands())
+        isgray = n_channels == 1
+        line_width = self.rng.integers(1, 2)
+
+        c = [50, 70, 90]
+        if self.mag < 0 or self.mag >= len(c):
+            index = 0
+        else:
+            index = self.mag
+        c = c[index]
+
+        n_rains = self.rng.integers(c, c + 20)
+        slant = self.rng.integers(-60, 60)
+        fillcolor = 200 if isgray else (200, 200, 200)
+
+        draw = ImageDraw.Draw(img)
+        max_length = min(w, h, 10)
+        for i in range(1, n_rains):
+            length = self.rng.integers(5, max_length)
+            x1 = self.rng.integers(0, w - length)
+            y1 = self.rng.integers(0, h - length)
+            x2 = x1 + length * math.sin(slant * math.pi / 180.)
+            y2 = y1 + length * math.cos(slant * math.pi / 180.)
+            x2 = int(x2)
+            y2 = int(y2)
+            draw.line([(x1, y1), (x2, y2)], width=line_width, fill=fillcolor)
+        img = np.asarray(img).astype(np.uint8)
+        return img
+
+
+class Shadow(alb.ImageOnlyTransform):
+    def __init__(self, mag=-1, always_apply=False, p=1.):
+        super().__init__(always_apply=always_apply, p=p)
+        self.rng = np.random.default_rng()
+        self.mag = mag
+
+    def apply(self, img, **params):
+        img = Image.fromarray(img.astype(np.uint8))
+        # img = img.copy()
+        w, h = img.size
+        n_channels = len(img.getbands())
+        isgray = n_channels == 1
+
+        c = [64, 96, 128]
+        if self.mag < 0 or self.mag >= len(c):
+            index = 0
+        else:
+            index = self.mag
+        c = c[index]
+
+        img = img.convert('RGBA')
+        overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
+        draw = ImageDraw.Draw(overlay)
+        transparency = self.rng.integers(c, c + 32)
+        x1 = self.rng.integers(0, w // 2)
+        y1 = 0
+
+        x2 = self.rng.integers(w // 2, w)
+        y2 = 0
+
+        x3 = self.rng.integers(w // 2, w)
+        y3 = h - 1
+
+        x4 = self.rng.integers(0, w // 2)
+        y4 = h - 1
+
+        draw.polygon([(x1, y1), (x2, y2), (x3, y3), (x4, y4)], fill=(0, 0, 0, transparency))
+
+        img = Image.alpha_composite(img, overlay)
+        img = img.convert("RGB")
+        if isgray:
+            img = ImageOps.grayscale(img)
+        img = np.asarray(img).astype(np.uint8)
+        return img
diff --git a/unimernet/processors/randaugment.py b/unimernet/processors/randaugment.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c6a9e6d62f74358f490d19546c9829b3ac6aaef
--- /dev/null
+++ b/unimernet/processors/randaugment.py
@@ -0,0 +1,398 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import cv2
+import numpy as np
+
+import torch
+
+
+## aug functions
+def identity_func(img):
+    return img
+
+
+def autocontrast_func(img, cutoff=0):
+    """
+    same output as PIL.ImageOps.autocontrast
+    """
+    n_bins = 256
+
+    def tune_channel(ch):
+        n = ch.size
+        cut = cutoff * n // 100
+        if cut == 0:
+            high, low = ch.max(), ch.min()
+        else:
+            hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+            low = np.argwhere(np.cumsum(hist) > cut)
+            low = 0 if low.shape[0] == 0 else low[0]
+            high = np.argwhere(np.cumsum(hist[::-1]) > cut)
+            high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
+        if high <= low:
+            table = np.arange(n_bins)
+        else:
+            scale = (n_bins - 1) / (high - low)
+            offset = -low * scale
+            table = np.arange(n_bins) * scale + offset
+            table[table < 0] = 0
+            table[table > n_bins - 1] = n_bins - 1
+        table = table.clip(0, 255).astype(np.uint8)
+        return table[ch]
+
+    channels = [tune_channel(ch) for ch in cv2.split(img)]
+    out = cv2.merge(channels)
+    return out
+
+
+def equalize_func(img):
+    """
+    same output as PIL.ImageOps.equalize
+    PIL's implementation is different from cv2.equalize
+    """
+    n_bins = 256
+
+    def tune_channel(ch):
+        hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
+        non_zero_hist = hist[hist != 0].reshape(-1)
+        step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
+        if step == 0:
+            return ch
+        n = np.empty_like(hist)
+        n[0] = step // 2
+        n[1:] = hist[:-1]
+        table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
+        return table[ch]
+
+    channels = [tune_channel(ch) for ch in cv2.split(img)]
+    out = cv2.merge(channels)
+    return out
+
+
+def rotate_func(img, degree, fill=(0, 0, 0)):
+    """
+    like PIL, rotate by degree, not radians
+    """
+    H, W = img.shape[0], img.shape[1]
+    center = W / 2, H / 2
+    M = cv2.getRotationMatrix2D(center, degree, 1)
+    out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
+    return out
+
+
+def solarize_func(img, thresh=128):
+    """
+    same output as PIL.ImageOps.posterize
+    """
+    table = np.array([el if el < thresh else 255 - el for el in range(256)])
+    table = table.clip(0, 255).astype(np.uint8)
+    out = table[img]
+    return out
+
+
+def color_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Color
+    """
+    ## implementation according to PIL definition, quite slow
+    #  degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
+    #  out = blend(degenerate, img, factor)
+    #  M = (
+    #      np.eye(3) * factor
+    #      + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
+    #  )[np.newaxis, np.newaxis, :]
+    M = np.float32(
+        [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
+    ) * factor + np.float32([[0.114], [0.587], [0.299]])
+    out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
+    return out
+
+
+def contrast_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Contrast
+    """
+    mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
+    table = (
+        np.array([(el - mean) * factor + mean for el in range(256)])
+        .clip(0, 255)
+        .astype(np.uint8)
+    )
+    out = table[img]
+    return out
+
+
+def brightness_func(img, factor):
+    """
+    same output as PIL.ImageEnhance.Contrast
+    """
+    table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
+    out = table[img]
+    return out
+
+
+def sharpness_func(img, factor):
+    """
+    The differences the this result and PIL are all on the 4 boundaries, the center
+    areas are same
+    """
+    kernel = np.ones((3, 3), dtype=np.float32)
+    kernel[1][1] = 5
+    kernel /= 13
+    degenerate = cv2.filter2D(img, -1, kernel)
+    if factor == 0.0:
+        out = degenerate
+    elif factor == 1.0:
+        out = img
+    else:
+        out = img.astype(np.float32)
+        degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
+        out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
+        out = out.astype(np.uint8)
+    return out
+
+
+def shear_x_func(img, factor, fill=(0, 0, 0)):
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, factor, 0], [0, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def translate_x_func(img, offset, fill=(0, 0, 0)):
+    """
+    same output as PIL.Image.transform
+    """
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, -offset], [0, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def translate_y_func(img, offset, fill=(0, 0, 0)):
+    """
+    same output as PIL.Image.transform
+    """
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, 0], [0, 1, -offset]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def posterize_func(img, bits):
+    """
+    same output as PIL.ImageOps.posterize
+    """
+    out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
+    return out
+
+
+def shear_y_func(img, factor, fill=(0, 0, 0)):
+    H, W = img.shape[0], img.shape[1]
+    M = np.float32([[1, 0, 0], [factor, 1, 0]])
+    out = cv2.warpAffine(
+        img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
+    ).astype(np.uint8)
+    return out
+
+
+def cutout_func(img, pad_size, replace=(0, 0, 0)):
+    replace = np.array(replace, dtype=np.uint8)
+    H, W = img.shape[0], img.shape[1]
+    rh, rw = np.random.random(2)
+    pad_size = pad_size // 2
+    ch, cw = int(rh * H), int(rw * W)
+    x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
+    y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
+    out = img.copy()
+    out[x1:x2, y1:y2, :] = replace
+    return out
+
+
+### level to args
+def enhance_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        return ((level / MAX_LEVEL) * 1.8 + 0.1,)
+
+    return level_to_args
+
+
+def shear_level_to_args(MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * 0.3
+        if np.random.random() > 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * float(translate_const)
+        if np.random.random() > 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * cutout_const)
+        return (level, replace_value)
+
+    return level_to_args
+
+
+def solarize_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * 256)
+        return (level,)
+
+    return level_to_args
+
+
+def none_level_to_args(level):
+    return ()
+
+
+def posterize_level_to_args(MAX_LEVEL):
+    def level_to_args(level):
+        level = int((level / MAX_LEVEL) * 4)
+        return (level,)
+
+    return level_to_args
+
+
+def rotate_level_to_args(MAX_LEVEL, replace_value):
+    def level_to_args(level):
+        level = (level / MAX_LEVEL) * 30
+        if np.random.random() < 0.5:
+            level = -level
+        return (level, replace_value)
+
+    return level_to_args
+
+
+func_dict = {
+    "Identity": identity_func,
+    "AutoContrast": autocontrast_func,
+    "Equalize": equalize_func,
+    "Rotate": rotate_func,
+    "Solarize": solarize_func,
+    "Color": color_func,
+    "Contrast": contrast_func,
+    "Brightness": brightness_func,
+    "Sharpness": sharpness_func,
+    "ShearX": shear_x_func,
+    "TranslateX": translate_x_func,
+    "TranslateY": translate_y_func,
+    "Posterize": posterize_func,
+    "ShearY": shear_y_func,
+}
+
+translate_const = 10
+MAX_LEVEL = 10
+replace_value = (128, 128, 128)
+arg_dict = {
+    "Identity": none_level_to_args,
+    "AutoContrast": none_level_to_args,
+    "Equalize": none_level_to_args,
+    "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
+    "Solarize": solarize_level_to_args(MAX_LEVEL),
+    "Color": enhance_level_to_args(MAX_LEVEL),
+    "Contrast": enhance_level_to_args(MAX_LEVEL),
+    "Brightness": enhance_level_to_args(MAX_LEVEL),
+    "Sharpness": enhance_level_to_args(MAX_LEVEL),
+    "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
+    "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+    "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
+    "Posterize": posterize_level_to_args(MAX_LEVEL),
+    "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
+}
+
+
+class RandomAugment(object):
+    def __init__(self, N=2, M=10, isPIL=False, augs=[]):
+        self.N = N
+        self.M = M
+        self.isPIL = isPIL
+        if augs:
+            self.augs = augs
+        else:
+            self.augs = list(arg_dict.keys())
+
+    def get_random_ops(self):
+        sampled_ops = np.random.choice(self.augs, self.N)
+        return [(op, 0.5, self.M) for op in sampled_ops]
+
+    def __call__(self, img):
+        if self.isPIL:
+            img = np.array(img)
+        ops = self.get_random_ops()
+        for name, prob, level in ops:
+            if np.random.random() > prob:
+                continue
+            args = arg_dict[name](level)
+            img = func_dict[name](img, *args)
+        return img
+
+
+class VideoRandomAugment(object):
+    def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
+        self.N = N
+        self.M = M
+        self.p = p
+        self.tensor_in_tensor_out = tensor_in_tensor_out
+        if augs:
+            self.augs = augs
+        else:
+            self.augs = list(arg_dict.keys())
+
+    def get_random_ops(self):
+        sampled_ops = np.random.choice(self.augs, self.N, replace=False)
+        return [(op, self.M) for op in sampled_ops]
+
+    def __call__(self, frames):
+        assert (
+            frames.shape[-1] == 3
+        ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
+
+        if self.tensor_in_tensor_out:
+            frames = frames.numpy().astype(np.uint8)
+
+        num_frames = frames.shape[0]
+
+        ops = num_frames * [self.get_random_ops()]
+        apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
+
+        frames = torch.stack(
+            list(map(self._aug, frames, ops, apply_or_not)), dim=0
+        ).float()
+
+        return frames
+
+    def _aug(self, img, ops, apply_or_not):
+        for i, (name, level) in enumerate(ops):
+            if not apply_or_not[i]:
+                continue
+            args = arg_dict[name](level)
+            img = func_dict[name](img, *args)
+        return torch.from_numpy(img)
+
+
+if __name__ == "__main__":
+    a = RandomAugment()
+    img = np.random.randn(32, 32, 3)
+    a(img)
diff --git a/unimernet/runners/__init__.py b/unimernet/runners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2960f292b80ba543877d92c96d94eb3cddaed22
--- /dev/null
+++ b/unimernet/runners/__init__.py
@@ -0,0 +1,11 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.runners.runner_base import RunnerBase
+from unimernet.runners.runner_iter import RunnerIter
+
+__all__ = ["RunnerBase", "RunnerIter"]
diff --git a/unimernet/runners/runner_base.py b/unimernet/runners/runner_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c5eccf38efa6e745dc06ee98bb531daa17affd3
--- /dev/null
+++ b/unimernet/runners/runner_base.py
@@ -0,0 +1,670 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import json
+import logging
+import os
+import time
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from unimernet.common.dist_utils import (
+    download_cached_file,
+    get_rank,
+    get_world_size,
+    is_main_process,
+    main_process,
+)
+from unimernet.common.registry import registry
+from unimernet.common.utils import is_url
+from unimernet.datasets.data_utils import reorg_datasets_by_split, concat_datasets
+from unimernet.datasets.datasets.dataloader_utils import (
+    IterLoader,
+    MultiIterLoader,
+    ConcatLoader,
+    PrefetchLoader,
+)
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data.dataset import ChainDataset
+
+
+@registry.register_runner("runner_base")
+class RunnerBase:
+    """
+    A runner class to train and evaluate a model given a task and datasets.
+
+    The runner uses pytorch distributed data parallel by default. Future release
+    will support other distributed frameworks.
+    """
+
+    def __init__(self, cfg, task, model, datasets, job_id):
+        self.config = cfg
+        self.job_id = job_id
+
+        self.task = task
+        self.datasets = datasets
+
+        self._model = model
+
+        self._wrapped_model = None
+        self._device = None
+        self._optimizer = None
+        self._scaler = None
+        self._dataloaders = None
+        self._lr_sched = None
+
+        self.start_epoch = 0
+
+        # self.setup_seeds()
+        self.setup_output_dir()
+
+    @property
+    def device(self):
+        if self._device is None:
+            self._device = torch.device(self.config.run_cfg.device)
+
+        return self._device
+
+    @property
+    def milestone(self):
+        return self.config.run_cfg.get("milestone", None)
+
+    @property
+    def use_distributed(self):
+        return self.config.run_cfg.distributed
+
+    @property
+    def model(self):
+        """
+        A property to get the DDP-wrapped model on the device.
+        """
+        # move model to device
+        if self._model.device != self.device:
+            self._model = self._model.to(self.device)
+
+            # distributed training wrapper
+            if self.use_distributed:
+                if self._wrapped_model is None:
+                    self._wrapped_model = DDP(
+                        self._model, device_ids=[self.config.run_cfg.gpu], find_unused_parameters=False
+                    )
+            else:
+                self._wrapped_model = self._model
+
+        return self._wrapped_model
+
+    @property
+    def optimizer(self):
+        # TODO make optimizer class and configurations
+        if self._optimizer is None:
+            num_parameters = 0
+            p_wd, p_non_wd = [], []
+            for n, p in self.model.named_parameters():
+                if not p.requires_grad:
+                    continue  # frozen weights
+                if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
+                    p_non_wd.append(p)
+                else:
+                    p_wd.append(p)
+                num_parameters += p.data.nelement()
+            logging.info("number of trainable parameters: %d" % num_parameters)
+            optim_params = [
+                {
+                    "params": p_wd,
+                    "weight_decay": float(self.config.run_cfg.weight_decay),
+                },
+                {"params": p_non_wd, "weight_decay": 0},
+            ]
+            beta2 = self.config.run_cfg.get("beta2", 0.999)
+            self._optimizer = torch.optim.AdamW(
+                optim_params,
+                lr=float(self.config.run_cfg.init_lr),
+                weight_decay=float(self.config.run_cfg.weight_decay),
+                betas=(0.9, beta2),
+            )
+
+        return self._optimizer
+
+    @property
+    def scaler(self):
+        amp = self.config.run_cfg.get("amp", False)
+
+        if amp:
+            if self._scaler is None:
+                self._scaler = torch.cuda.amp.GradScaler()
+
+        return self._scaler
+
+    @property
+    def lr_scheduler(self):
+        """
+        A property to get and create learning rate scheduler by split just in need.
+        """
+        if self._lr_sched is None:
+            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
+
+            # max_epoch = self.config.run_cfg.max_epoch
+            max_epoch = self.max_epoch
+            # min_lr = self.config.run_cfg.min_lr
+            min_lr = self.min_lr
+            # init_lr = self.config.run_cfg.init_lr
+            init_lr = self.init_lr
+
+            # optional parameters
+            decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
+            warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
+            warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
+            iters_per_epoch = self.config.run_cfg.get("iters_per_inner_epoch", len(self.train_loader))
+
+            self._lr_sched = lr_sched_cls(
+                optimizer=self.optimizer,
+                max_epoch=max_epoch,
+                min_lr=min_lr,
+                init_lr=init_lr,
+                decay_rate=decay_rate,
+                warmup_start_lr=warmup_start_lr,
+                warmup_steps=warmup_steps,
+                iters_per_epoch=iters_per_epoch,
+            )
+
+        return self._lr_sched
+
+    @property
+    def dataloaders(self) -> dict:
+        """
+        A property to get and create dataloaders by split just in need.
+
+        If no train_dataset_ratio is provided, concatenate map-style datasets and
+        chain wds.DataPipe datasets separately. Training set becomes a tuple
+        (ConcatDataset, ChainDataset), both are optional but at least one of them is
+        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+        If train_dataset_ratio is provided, create a MultiIterLoader to sample
+        each dataset by ratios during training.
+
+        Currently do not support multiple datasets for validation and test.
+
+        Returns:
+            dict: {split_name: (tuples of) dataloader}
+        """
+        if self._dataloaders is None:
+            # reoganize datasets by split and concatenate/chain if necessary
+
+            datasets = reorg_datasets_by_split(self.datasets)
+            self.datasets = concat_datasets(datasets)
+
+            self.datasets = {
+                k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
+            }
+
+            # print dataset statistics after concatenation/chaining
+            for split_name in self.datasets:
+                if isinstance(self.datasets[split_name], tuple) or isinstance(
+                        self.datasets[split_name], list
+                ):
+                    # mixed wds.DataPipeline and torch.utils.data.Dataset
+                    num_records = sum(
+                        [
+                            len(d)
+                            if not type(d) in [wds.DataPipeline, ChainDataset]
+                            else 0
+                            for d in self.datasets[split_name]
+                        ]
+                    )
+
+                else:
+                    if hasattr(self.datasets[split_name], "__len__"):
+                        # a single map-style dataset
+                        num_records = len(self.datasets[split_name])
+                    else:
+                        # a single wds.DataPipeline
+                        num_records = -1
+                        logging.info(
+                            "Only a single wds.DataPipeline dataset, no __len__ attribute."
+                        )
+
+                if num_records >= 0:
+                    logging.info(
+                        "Loaded {} records for {} split from the dataset.".format(
+                            num_records, split_name
+                        )
+                    )
+
+            # create dataloaders
+            split_names = sorted(self.datasets.keys())
+
+            datasets = [self.datasets[split] for split in split_names]
+            is_trains = [split in self.train_splits for split in split_names]
+
+            batch_sizes = [
+                self.config.run_cfg.batch_size_train
+                if split == "train"
+                else self.config.run_cfg.batch_size_eval
+                for split in split_names
+            ]
+
+            collate_fns = []
+            for dataset in datasets:
+                if isinstance(dataset, tuple) or isinstance(dataset, list):
+                    collate_fns.append([getattr(d, "collater", None) for d in dataset])
+                else:
+                    collate_fns.append(getattr(dataset, "collater", None))
+
+            dataloaders = self.create_loaders(
+                datasets=datasets,
+                num_workers=self.config.run_cfg.num_workers,
+                batch_sizes=batch_sizes,
+                is_trains=is_trains,
+                collate_fns=collate_fns,
+                # concat=True
+            )
+
+            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+        return self._dataloaders
+
+    @property
+    def cuda_enabled(self):
+        return self.device.type == "cuda"
+
+    @property
+    def max_epoch(self):
+        return int(self.config.run_cfg.max_epoch)
+
+    @property
+    def log_freq(self):
+        log_freq = self.config.run_cfg.get("log_freq", 50)
+        return int(log_freq)
+
+    @property
+    def init_lr(self):
+        return float(self.config.run_cfg.init_lr)
+
+    @property
+    def min_lr(self):
+        return float(self.config.run_cfg.min_lr)
+
+    @property
+    def accum_grad_iters(self):
+        return int(self.config.run_cfg.get("accum_grad_iters", 1))
+
+    @property
+    def valid_splits(self):
+        valid_splits = self.config.run_cfg.get("valid_splits", [])
+
+        if len(valid_splits) == 0:
+            logging.info("No validation splits found.")
+
+        return valid_splits
+
+    @property
+    def test_splits(self):
+        test_splits = self.config.run_cfg.get("test_splits", [])
+
+        return test_splits
+
+    @property
+    def train_splits(self):
+        train_splits = self.config.run_cfg.get("train_splits", [])
+
+        if len(train_splits) == 0:
+            logging.info("Empty train splits.")
+
+        return train_splits
+
+    @property
+    def evaluate_only(self):
+        """
+        Set to True to skip training.
+        """
+        return self.config.run_cfg.evaluate
+
+    @property
+    def use_dist_eval_sampler(self):
+        return self.config.run_cfg.get("use_dist_eval_sampler", True)
+
+    @property
+    def resume_ckpt_path(self):
+        return self.config.run_cfg.get("resume_ckpt_path", None)
+
+    @property
+    def train_loader(self):
+        train_dataloader = self.dataloaders["train"]
+
+        return train_dataloader
+
+    def setup_output_dir(self):
+        lib_root = Path(registry.get_path("library_root"))
+
+        output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
+        result_dir = output_dir / "result"
+
+        output_dir.mkdir(parents=True, exist_ok=True)
+        result_dir.mkdir(parents=True, exist_ok=True)
+
+        registry.register_path("result_dir", str(result_dir))
+        registry.register_path("output_dir", str(output_dir))
+
+        self.result_dir = result_dir
+        self.output_dir = output_dir
+
+    def train(self):
+        start_time = time.time()
+        best_agg_metric = 0
+        best_epoch = 0
+
+        self.log_config()
+
+        # resume from checkpoint if specified
+        if not self.evaluate_only and self.resume_ckpt_path is not None:
+            self._load_checkpoint(self.resume_ckpt_path)
+
+        for cur_epoch in range(self.start_epoch, self.max_epoch):
+            # training phase
+            if not self.evaluate_only:
+                logging.info("Start training")
+                train_stats = self.train_epoch(cur_epoch)
+                self.log_stats(split_name="train", stats=train_stats)
+
+            # evaluation phase
+            if len(self.valid_splits) > 0:
+                for split_name in self.valid_splits:
+                    logging.info("Evaluating on {}.".format(split_name))
+
+                    val_log = self.eval_epoch(
+                        split_name=split_name, cur_epoch=cur_epoch
+                    )
+                    if val_log is not None:
+                        if is_main_process():
+                            assert (
+                                    "agg_metrics" in val_log
+                            ), "No agg_metrics found in validation log."
+
+                            agg_metrics = val_log["agg_metrics"]
+                            if agg_metrics > best_agg_metric and split_name == "eval":
+                                best_epoch, best_agg_metric = cur_epoch, agg_metrics
+
+                                self._save_checkpoint(cur_epoch, is_best=True)
+
+                            val_log.update({"best_epoch": best_epoch})
+                            self.log_stats(val_log, split_name)
+
+            if self.evaluate_only:
+                break
+            if self.milestone and cur_epoch + 1 in self.milestone:
+                self._save_checkpoint(cur_epoch)
+            self._save_checkpoint(cur_epoch, latest=True)
+            dist.barrier()
+
+        # testing phase
+        test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
+        self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
+
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        logging.info("Training time {}".format(total_time_str))
+
+    def evaluate(self, cur_epoch="best", skip_reload=False):
+        test_logs = dict()
+
+        if len(self.test_splits) > 0:
+            for split_name in self.test_splits:
+                test_logs[split_name] = self.eval_epoch(
+                    split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
+                )
+
+            return test_logs
+
+    def train_epoch(self, epoch):
+        # train
+        self.model.train()
+
+        return self.task.train_epoch(
+            epoch=epoch,
+            model=self.model,
+            data_loader=self.train_loader,
+            optimizer=self.optimizer,
+            scaler=self.scaler,
+            lr_scheduler=self.lr_scheduler,
+            cuda_enabled=self.cuda_enabled,
+            log_freq=self.log_freq,
+            accum_grad_iters=self.accum_grad_iters,
+        )
+
+    @torch.no_grad()
+    def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
+        """
+        Evaluate the model on a given split.
+
+        Args:
+            split_name (str): name of the split to evaluate on.
+            cur_epoch (int): current epoch.
+            skip_reload_best (bool): whether to skip reloading the best checkpoint.
+                During training, we will reload the best checkpoint for validation.
+                During testing, we will use provided weights and skip reloading the best checkpoint .
+        """
+        data_loader = self.dataloaders.get(split_name, None)
+        assert data_loader, "data_loader for split {} is None.".format(split_name)
+
+        # TODO In validation, you need to compute loss as well as metrics
+        # TODO consider moving to model.before_evaluation()
+        model = self.unwrap_dist_model(self.model)
+        if not skip_reload and cur_epoch == "best":
+            model = self._reload_best_model(model)
+        model.eval()
+
+        self.task.before_evaluation(
+            model=model,
+            dataset=self.datasets[split_name],
+        )
+        results = self.task.evaluation(model, data_loader)
+
+        if results is not None:
+            return self.task.after_evaluation(
+                val_result=results,
+                split_name=split_name,
+                epoch=cur_epoch,
+            )
+
+    def unwrap_dist_model(self, model):
+        if self.use_distributed:
+            return model.module
+        else:
+            return model
+
+    def create_loaders(
+            self,
+            datasets,
+            num_workers,
+            batch_sizes,
+            is_trains,
+            collate_fns,
+            concat=False
+    ):
+        """
+        Create dataloaders for training and validation.
+        """
+
+        def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
+            # create a single dataloader for each split
+            if isinstance(dataset, ChainDataset) or isinstance(
+                    dataset, wds.DataPipeline
+            ):
+                # wds.WebdDataset instance are chained together
+                # webdataset.DataPipeline has its own sampler and collate_fn
+                loader = iter(
+                    DataLoader(
+                        dataset,
+                        batch_size=bsz,
+                        num_workers=num_workers,
+                        pin_memory=True,
+                    )
+                )
+            else:
+                # map-style dataset are concatenated together
+                # setup distributed sampler
+                if self.use_distributed:
+                    sampler = DistributedSampler(
+                        dataset,
+                        shuffle=is_train,
+                        num_replicas=get_world_size(),
+                        rank=get_rank(),
+                    )
+                    if not self.use_dist_eval_sampler:
+                        # e.g. retrieval evaluation
+                        sampler = sampler if is_train else None
+                else:
+                    sampler = None
+
+                loader = DataLoader(
+                    dataset,
+                    batch_size=bsz,
+                    num_workers=num_workers,
+                    pin_memory=True,
+                    sampler=sampler,
+                    shuffle=sampler is None and is_train,
+                    collate_fn=collate_fn,
+                    drop_last=True if is_train else False,
+                )
+                loader = PrefetchLoader(loader)
+
+                if is_train:
+                    loader = IterLoader(loader, use_distributed=self.use_distributed)
+
+            return loader
+
+        loaders = []
+
+        for dataset, bsz, is_train, collate_fn in zip(
+                datasets, batch_sizes, is_trains, collate_fns
+        ):
+            if isinstance(dataset, list) or isinstance(dataset, tuple):
+                if not concat:
+                    sample_ratios = [d.sample_ratio for d in dataset]
+                    loader = MultiIterLoader(
+                        loaders=[
+                            _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
+                            for i, d in enumerate(dataset)
+                        ],
+                        ratios=sample_ratios
+                    )
+                else:
+                    loader = ConcatLoader(
+                        loaders=[
+                            _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
+                            for i, d in enumerate(dataset)
+                        ]
+                    )
+
+            else:
+                loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
+
+            loaders.append(loader)
+
+        return loaders
+
+    @main_process
+    def _save_checkpoint(self, cur_epoch, is_best=False, latest=False):
+        """
+        Save the checkpoint at the current epoch.
+        """
+        assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
+        model_no_ddp = self.unwrap_dist_model(self.model)
+        param_grad_dic = {
+            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
+        }
+        state_dict = model_no_ddp.state_dict()
+        for k in list(state_dict.keys()):
+            if k in param_grad_dic.keys() and not param_grad_dic[k]:
+                # delete parameters that do not require gradient
+                del state_dict[k]
+        save_obj = {
+            "model": state_dict,
+            "optimizer": self.optimizer.state_dict(),
+            "config": self.config.to_dict(),
+            "scaler": self.scaler.state_dict() if self.scaler else None,
+            "epoch": cur_epoch,
+        }
+        if is_best:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format("best"),
+            )
+        elif latest:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format("latest"),
+            )
+        else:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format(cur_epoch+1),
+            )
+        logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch+1, save_to))
+        torch.save(save_obj, save_to)
+
+    def _reload_best_model(self, model):
+        """
+        Load the best checkpoint for evaluation.
+        """
+        checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
+
+        logging.info("Loading checkpoint from {}.".format(checkpoint_path))
+        checkpoint = torch.load(checkpoint_path, map_location="cpu")
+        try:
+            model.load_state_dict(checkpoint["model"])
+        except RuntimeError as e:
+            logging.warning(
+                """
+                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
+                Trying to load the model with strict=False.
+                """
+            )
+            model.load_state_dict(checkpoint["model"], strict=False)
+        return model
+
+    def _load_checkpoint(self, url_or_filename):
+        """
+        Resume from a checkpoint.
+        """
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location=self.device)
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location=self.device)
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        state_dict = checkpoint["model"]
+        self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+        self.optimizer.load_state_dict(checkpoint["optimizer"])
+        if self.scaler and "scaler" in checkpoint:
+            self.scaler.load_state_dict(checkpoint["scaler"])
+
+        self.start_epoch = checkpoint["epoch"]
+        logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+    @main_process
+    def log_stats(self, stats, split_name):
+        if isinstance(stats, dict):
+            log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
+            with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+                f.write(json.dumps(log_stats) + "\n")
+        elif isinstance(stats, list):
+            pass
+
+    @main_process
+    def log_config(self):
+        with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
+            f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")
diff --git a/unimernet/runners/runner_iter.py b/unimernet/runners/runner_iter.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b0a87ccb30c9c9b7993187df302525f866cb48
--- /dev/null
+++ b/unimernet/runners/runner_iter.py
@@ -0,0 +1,309 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import datetime
+import logging
+import os
+import time
+
+import torch
+import torch.distributed as dist
+import webdataset as wds
+from unimernet.common.dist_utils import download_cached_file, is_main_process, main_process
+from unimernet.common.registry import registry
+from unimernet.common.utils import is_url
+from unimernet.datasets.data_utils import reorg_datasets_by_split
+from unimernet.runners.runner_base import RunnerBase
+from torch.utils.data.dataset import ChainDataset
+
+
+@registry.register_runner("runner_iter")
+class RunnerIter(RunnerBase):
+    """
+    Run training based on the number of iterations. This is common when
+    the training dataset size is large. Underhood logic is similar to
+    epoch-based training by considering every #iters_per_inner_epoch as an
+    inner epoch.
+
+    In iter-based runner, after every #iters_per_inner_epoch steps, we
+
+        1) do a validation epoch;
+        2) schedule the learning rate;
+        3) save the checkpoint.
+
+    We refer every #iters_per_inner_epoch steps as an inner epoch.
+    """
+
+    def __init__(self, cfg, task, model, datasets, job_id):
+        super().__init__(cfg, task, model, datasets, job_id)
+
+        self.start_iters = 0
+
+        self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
+        assert self.max_iters > 0, "max_iters must be greater than 0."
+
+        self.iters_per_inner_epoch = int(
+            self.config.run_cfg.get("iters_per_inner_epoch", -1)
+        )
+        assert (
+                self.iters_per_inner_epoch > 0
+        ), "iters_per_inner_epoch must be greater than 0."
+
+    @property
+    def max_epoch(self):
+        return int(self.max_iters / self.iters_per_inner_epoch)
+
+    @property
+    def cur_epoch(self):
+        try:
+            return self.train_loader.epoch
+        except AttributeError:
+            # pipeline data (e.g. LAION) is streaming, have no concept of epoch
+            return 0
+
+    def _progress(self, cur_iters):
+        return "{}_iters={}".format(self.cur_epoch, cur_iters)
+
+    def train(self):
+        start_time = time.time()
+        best_agg_metric = 0
+        best_iters = 0
+
+        self.log_config()
+
+        # resume from checkpoint if specified
+        if not self.evaluate_only and self.resume_ckpt_path is not None:
+            self._load_checkpoint(self.resume_ckpt_path)
+        cur_epoch = 0
+        for start_iters in range(
+                self.start_iters, self.max_iters, self.iters_per_inner_epoch
+        ):
+            end_iters = start_iters + self.iters_per_inner_epoch
+
+            # training phase
+            if not self.evaluate_only:
+                logging.info(
+                    "Start training, max_iters={}, in total {} inner epochs.".format(
+                        self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
+                    )
+                )
+
+                train_stats = self.train_iters(self.cur_epoch, start_iters)
+                self.log_stats(split_name="train", stats=train_stats)
+
+            # evaluation phase
+            if len(self.valid_splits) > 0:
+                for split_name in self.valid_splits:
+                    logging.info("Evaluating on {}.".format(split_name))
+
+                    val_log = self.eval_epoch(
+                        split_name=split_name, cur_epoch=self._progress(end_iters)
+                    )
+                    if val_log is not None:
+                        if is_main_process():
+                            assert (
+                                    "agg_metrics" in val_log
+                            ), "No agg_metrics found in validation log."
+
+                            agg_metrics = val_log["agg_metrics"]
+                            if agg_metrics > best_agg_metric and split_name == "eval":
+                                best_iters, best_agg_metric = end_iters, agg_metrics
+
+                                self._save_checkpoint(end_iters, is_best=True)
+                            val_log.update({"best_iters": best_iters})
+                            self.log_stats(val_log, split_name)
+                            # print evaluation metric
+                            print(f"bleu:{val_log['bleu']:.6f}, edit_distance:{val_log['edit_distance']:.6f}, token_accuracy:{val_log['token_accuracy']:.6f} ")
+                            print("="*80)
+
+            if self.evaluate_only:
+                break
+            if self.milestone and cur_epoch + 1 in self.milestone:
+                self._save_checkpoint(cur_epoch)
+            self._save_checkpoint(end_iters, latest=True)
+            dist.barrier()
+            cur_epoch += 1
+
+        # testing phase
+        self.evaluate(cur_epoch=self.cur_epoch)
+
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        logging.info("Training time {}".format(total_time_str))
+
+    def train_iters(self, epoch, start_iters):
+        # train by iterations
+        self.model.train()
+
+        return self.task.train_iters(
+            epoch=epoch,
+            start_iters=start_iters,
+            iters_per_inner_epoch=self.iters_per_inner_epoch,
+            model=self.model,
+            data_loader=self.train_loader,
+            optimizer=self.optimizer,
+            scaler=self.scaler,
+            lr_scheduler=self.lr_scheduler,
+            cuda_enabled=self.cuda_enabled,
+            log_freq=self.log_freq,
+            accum_grad_iters=self.accum_grad_iters,
+        )
+
+    @main_process
+    def _save_checkpoint(self, cur_iters, is_best=False, latest=False):
+        # only save the params requires gradient
+        assert not (is_best and latest), "You can't set 'is_best' and 'latest' the same time."
+        unwrapped_model = self.unwrap_dist_model(self.model)
+        param_grad_dic = {
+            k: v.requires_grad for (k, v) in unwrapped_model.named_parameters()
+        }
+
+        state_dict = unwrapped_model.state_dict()
+        for k in list(state_dict.keys()):
+            if k in param_grad_dic.keys() and not param_grad_dic[k]:
+                del state_dict[k]
+
+        save_obj = {
+            "model": state_dict,
+            "optimizer": self.optimizer.state_dict(),
+            "config": self.config.to_dict(),
+            "scaler": self.scaler.state_dict() if self.scaler else None,
+            "iters": cur_iters,
+        }
+        if is_best:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format("best"),
+            )
+        elif latest:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format("latest"),
+            )
+        else:
+            save_to = os.path.join(
+                self.output_dir,
+                "checkpoint_{}.pth".format(cur_iters),
+            )
+        logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
+        torch.save(save_obj, save_to)
+
+    def _load_checkpoint(self, url_or_filename):
+        """
+        Resume from a checkpoint.
+        """
+        if is_url(url_or_filename):
+            cached_file = download_cached_file(
+                url_or_filename, check_hash=False, progress=True
+            )
+            checkpoint = torch.load(cached_file, map_location=self.device)
+        elif os.path.isfile(url_or_filename):
+            checkpoint = torch.load(url_or_filename, map_location=self.device)
+        else:
+            raise RuntimeError("checkpoint url or path is invalid")
+
+        state_dict = checkpoint["model"]
+        self.unwrap_dist_model(self.model).load_state_dict(state_dict)
+
+        self.optimizer.load_state_dict(checkpoint["optimizer"])
+        if self.scaler and "scaler" in checkpoint:
+            self.scaler.load_state_dict(checkpoint["scaler"])
+
+        self.start_iters = checkpoint["iters"] + 1
+        logging.info("Resume checkpoint from {}".format(url_or_filename))
+
+    @property
+    def dataloaders(self) -> dict:
+        """
+        A property to get and create dataloaders by split just in need.
+
+        If no train_dataset_ratio is provided, concatenate map-style datasets and
+        chain wds.DataPipe datasets separately. Training set becomes a tuple
+        (ConcatDataset, ChainDataset), both are optional but at least one of them is
+        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
+
+        If train_dataset_ratio is provided, create a MultiIterLoader to sample
+        each dataset by ratios during training.
+
+        Currently do not support multiple datasets for validation and test.
+
+        Returns:
+            dict: {split_name: (tuples of) dataloader}
+        """
+        if self._dataloaders is None:
+            # reoganize datasets by split and concatenate/chain if necessary
+
+            self.datasets = reorg_datasets_by_split(self.datasets)
+            # to keep the same structure as return value of concat_datasets
+            self.datasets = {
+                k: v[0] if len(v) == 1 else v for k, v in self.datasets.items()
+            }
+
+            # print dataset statistics after concatenation/chaining
+            for split_name in self.datasets:
+                if isinstance(self.datasets[split_name], tuple) or isinstance(
+                        self.datasets[split_name], list
+                ):
+                    # mixed wds.DataPipeline and torch.utils.data.Dataset
+                    num_records = sum(
+                        [
+                            len(d)
+                            if not type(d) in [wds.DataPipeline, ChainDataset]
+                            else 0
+                            for d in self.datasets[split_name]
+                        ]
+                    )
+
+                else:
+                    try:
+                        # a single map-style dataset
+                        num_records = len(self.datasets[split_name])
+                    except TypeError:
+                        # a single wds.DataPipeline or ChainDataset
+                        num_records = -1
+                        logging.info(
+                            "Only a single wds.DataPipeline dataset, no __len__ attribute."
+                        )
+
+                if num_records >= 0:
+                    logging.info(
+                        "Loaded {} records for {} split from the dataset.".format(
+                            num_records, split_name
+                        )
+                    )
+
+            # create dataloaders
+            split_names = sorted(self.datasets.keys())
+
+            datasets = [self.datasets[split] for split in split_names]
+            is_trains = [split in self.train_splits for split in split_names]
+
+            batch_sizes = [
+                self.config.run_cfg.batch_size_train
+                if split == "train"
+                else self.config.run_cfg.batch_size_eval
+                for split in split_names
+            ]
+
+            collate_fns = []
+            for dataset in datasets:
+                if isinstance(dataset, tuple) or isinstance(dataset, list):
+                    collate_fns.append([getattr(d, "collater", None) for d in dataset])
+                else:
+                    collate_fns.append(getattr(dataset, "collater", None))
+
+            dataloaders = self.create_loaders(
+                datasets=datasets,
+                num_workers=self.config.run_cfg.num_workers,
+                batch_sizes=batch_sizes,
+                is_trains=is_trains,
+                collate_fns=collate_fns,
+            )
+
+            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
+
+        return self._dataloaders
diff --git a/unimernet/tasks/__init__.py b/unimernet/tasks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aff8198a152b559b5f59c0c9154cfcf7a9df4871
--- /dev/null
+++ b/unimernet/tasks/__init__.py
@@ -0,0 +1,26 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+from unimernet.common.registry import registry
+from unimernet.tasks.base_task import BaseTask
+from unimernet.tasks.unimernet_train import UniMERNet_Train
+
+
+def setup_task(cfg):
+    assert "task" in cfg.run_cfg, "Task name must be provided."
+
+    task_name = cfg.run_cfg.task
+    task = registry.get_task_class(task_name).setup_task(cfg=cfg)
+    assert task is not None, "Task {} not properly registered.".format(task_name)
+
+    return task
+
+
+__all__ = [
+    "BaseTask",
+    "UniMERNet_Train",
+]
diff --git a/unimernet/tasks/__pycache__/__init__.cpython-310.pyc b/unimernet/tasks/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f78e665e826e074c75be2fb52c87bb78f363bdfd
Binary files /dev/null and b/unimernet/tasks/__pycache__/__init__.cpython-310.pyc differ
diff --git a/unimernet/tasks/__pycache__/base_task.cpython-310.pyc b/unimernet/tasks/__pycache__/base_task.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9de391d4bef30304f5b6fb95245fcc57e7ff500a
Binary files /dev/null and b/unimernet/tasks/__pycache__/base_task.cpython-310.pyc differ
diff --git a/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc b/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cea4e5569eba74538f4f7da00133c3b74d3e3438
Binary files /dev/null and b/unimernet/tasks/__pycache__/unimernet_train.cpython-310.pyc differ
diff --git a/unimernet/tasks/base_task.py b/unimernet/tasks/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ea4e27298ac0fcd0640853f726e1748ad2ee87
--- /dev/null
+++ b/unimernet/tasks/base_task.py
@@ -0,0 +1,288 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from unimernet.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
+from unimernet.common.logger import MetricLogger, SmoothedValue
+from unimernet.common.registry import registry
+from unimernet.datasets.data_utils import prepare_sample
+
+
+class BaseTask:
+    def __init__(self, **kwargs):
+        super().__init__()
+
+        self.inst_id_key = "instance_id"
+
+    @classmethod
+    def setup_task(cls, **kwargs):
+        return cls()
+
+    def build_model(self, cfg):
+        model_config = cfg.model_cfg
+
+        model_cls = registry.get_model_class(model_config.arch)
+        return model_cls.from_config(model_config)
+
+    def build_datasets(self, cfg):
+        """
+        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
+        Download dataset and annotations automatically if not exist.
+
+        Args:
+            cfg (common.config.Config): _description_
+
+        Returns:
+            dict: Dictionary of torch.utils.data.Dataset objects by split.
+        """
+
+        datasets = dict()
+
+        datasets_config = cfg.datasets_cfg
+
+        assert len(datasets_config) > 0, "At least one dataset has to be specified."
+
+        for name in datasets_config:
+            dataset_config = datasets_config[name]
+
+            builder = registry.get_builder_class(name)(dataset_config)
+            dataset = builder.build_datasets()
+
+            if "train" in dataset and "sample_ratio" in dataset_config:
+                dataset["train"].sample_ratio = float(dataset_config.sample_ratio)
+
+            datasets[name] = dataset
+
+        return datasets
+
+    def train_step(self, model, samples):
+        loss_dict = model(samples)
+        loss = loss_dict["loss"]
+        return loss, loss_dict
+
+    def valid_step(self, model, samples):
+        raise NotImplementedError
+
+    def before_evaluation(self, model, dataset, **kwargs):
+        model.before_evaluation(dataset=dataset, task_type=type(self))
+
+    def after_evaluation(self, **kwargs):
+        pass
+
+    def inference_step(self):
+        raise NotImplementedError
+
+    def evaluation(self, model, data_loader, cuda_enabled=True):
+        metric_logger = MetricLogger(delimiter="  ")
+        header = "Evaluation"
+        # TODO make it configurable
+        print_freq = 10
+
+        results = []
+
+        for samples in metric_logger.log_every(data_loader, print_freq, header):
+            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+
+            eval_output = self.valid_step(model=model, samples=samples)
+            results.extend(eval_output)
+
+        if is_dist_avail_and_initialized():
+            dist.barrier()
+
+        return results
+
+    def train_epoch(
+            self,
+            epoch,
+            model,
+            data_loader,
+            optimizer,
+            lr_scheduler,
+            scaler=None,
+            cuda_enabled=False,
+            log_freq=50,
+            accum_grad_iters=1,
+    ):
+        return self._train_inner_loop(
+            epoch=epoch,
+            iters_per_epoch=len(data_loader),
+            model=model,
+            data_loader=data_loader,
+            optimizer=optimizer,
+            scaler=scaler,
+            lr_scheduler=lr_scheduler,
+            log_freq=log_freq,
+            cuda_enabled=cuda_enabled,
+            accum_grad_iters=accum_grad_iters,
+        )
+
+    def train_iters(
+            self,
+            epoch,
+            start_iters,
+            iters_per_inner_epoch,
+            model,
+            data_loader,
+            optimizer,
+            lr_scheduler,
+            scaler=None,
+            cuda_enabled=False,
+            log_freq=50,
+            accum_grad_iters=1,
+    ):
+        return self._train_inner_loop(
+            epoch=epoch,
+            start_iters=start_iters,
+            iters_per_epoch=iters_per_inner_epoch,
+            model=model,
+            data_loader=data_loader,
+            optimizer=optimizer,
+            scaler=scaler,
+            lr_scheduler=lr_scheduler,
+            log_freq=log_freq,
+            cuda_enabled=cuda_enabled,
+            accum_grad_iters=accum_grad_iters,
+        )
+
+    def _train_inner_loop(
+            self,
+            epoch,
+            iters_per_epoch,
+            model,
+            data_loader,
+            optimizer,
+            lr_scheduler,
+            scaler=None,
+            start_iters=None,
+            log_freq=50,
+            cuda_enabled=False,
+            accum_grad_iters=1,
+    ):
+        """
+        An inner training loop compatible with both epoch-based and iter-based training.
+
+        When using epoch-based, training stops after one epoch; when using iter-based,
+        training stops after #iters_per_epoch iterations.
+        """
+        use_amp = scaler is not None
+
+        if not hasattr(data_loader, "__next__"):
+            # convert to iterator if not already
+            data_loader = iter(data_loader)
+
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
+        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
+
+        # if iter-based runner, schedule lr based on inner epoch.
+        logging.info(
+            "Start training epoch {}, {} iters per inner epoch.".format(
+                epoch, iters_per_epoch
+            )
+        )
+        header = "Train: data epoch: [{}]".format(epoch)
+        if start_iters is None:
+            # epoch-based runner
+            inner_epoch = epoch
+        else:
+            # In iter-based runner, we schedule the learning rate based on iterations.
+            inner_epoch = start_iters // iters_per_epoch
+            header = header + "; inner epoch [{}]".format(inner_epoch)
+
+        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
+            # if using iter-based runner, we stop after iters_per_epoch iterations.
+            if i >= iters_per_epoch:
+                break
+
+            samples = next(data_loader)
+
+            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
+            samples.update(
+                {
+                    "epoch": inner_epoch,
+                    "num_iters_per_epoch": iters_per_epoch,
+                    "iters": i,
+                }
+            )
+
+            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
+
+            with torch.cuda.amp.autocast(enabled=use_amp):
+                loss, loss_dict = self.train_step(model=model, samples=samples)
+                loss /= accum_grad_iters  # TODO: not affect loss_dict values for logging
+
+            # after_train_step()
+            if use_amp:
+                scaler.scale(loss).backward()
+            else:
+                loss.backward()
+
+            # update gradients every accum_grad_iters iterations
+
+            if (i + 1) % accum_grad_iters == 0:
+                if use_amp:
+                    scaler.step(optimizer)
+                    scaler.update()
+                else:
+                    optimizer.step()
+                optimizer.zero_grad()
+
+            metric_logger.update(**loss_dict)
+            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+
+        # after train_epoch()
+        # gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
+        return {
+            k: "{:.3f}".format(meter.global_avg)
+            for k, meter in metric_logger.meters.items()
+        }
+
+    @staticmethod
+    def save_result(result, result_dir, filename, remove_duplicate=""):
+        import json
+
+        result_file = os.path.join(
+            result_dir, "%s_rank%d.json" % (filename, get_rank())
+        )
+        final_result_file = os.path.join(result_dir, "%s.json" % filename)
+
+        json.dump(result, open(result_file, "w"))
+
+        if is_dist_avail_and_initialized():
+            dist.barrier()
+
+        if is_main_process():
+            logging.warning("rank %d starts merging results." % get_rank())
+            # combine results from all processes
+            result = []
+
+            for rank in range(get_world_size()):
+                result_file = os.path.join(
+                    result_dir, "%s_rank%d.json" % (filename, rank)
+                )
+                res = json.load(open(result_file, "r"))
+                result += res
+
+            if remove_duplicate:
+                result_new = []
+                id_list = []
+                for res in result:
+                    if res[remove_duplicate] not in id_list:
+                        id_list.append(res[remove_duplicate])
+                        result_new.append(res)
+                result = result_new
+
+            json.dump(result, open(final_result_file, "w"))
+            print("result file saved to %s" % final_result_file)
+
+        return final_result_file
diff --git a/unimernet/tasks/unimernet_train.py b/unimernet/tasks/unimernet_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b2b1b5da6ded9197edcea87a6a4ffded817eb5d
--- /dev/null
+++ b/unimernet/tasks/unimernet_train.py
@@ -0,0 +1,167 @@
+import torch
+import evaluate
+import random
+
+from unimernet.common.registry import registry
+from unimernet.tasks.base_task import BaseTask
+from unimernet.common.dist_utils import main_process
+import os.path as osp
+import json
+import numpy as np
+from torchtext.data import metrics
+from rapidfuzz.distance import Levenshtein
+
+
+@registry.register_task("unimernet_train")
+class UniMERNet_Train(BaseTask):
+
+    def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"):
+        super(UniMERNet_Train, self).__init__()
+        self.temperature = temperature
+        self.do_sample = do_sample
+        self.top_p = top_p
+        self.evaluate = evaluate
+        self.agg_metric = agg_metric
+
+        self.report_metric = report_metric
+
+    @classmethod
+    def setup_task(cls, cfg):
+        run_cfg = cfg.run_cfg
+        generate_cfg = run_cfg.generate_cfg
+
+        temperature = generate_cfg.get('temperature', .2)
+        do_sample = generate_cfg.get("do_sample", False)
+        top_p = generate_cfg.get("top_p", 0.95)
+
+        evaluate = run_cfg.evaluate
+        report_metric = run_cfg.get("report_metric", True)
+        agg_metric = run_cfg.get("agg_metric", "edit_distance")
+
+        return cls(
+            temperature=temperature,
+            do_sample=do_sample,
+            top_p=top_p,
+            evaluate=evaluate,
+            report_metric=report_metric,
+            agg_metric=agg_metric,
+        )
+
+    def valid_step(self, model, samples):
+        results = []
+        image, text = samples["image"], samples["text_input"]
+        preds = model.generate(
+            samples,
+            temperature=self.temperature,
+            do_sample=self.do_sample,
+            top_p=self.top_p
+        )
+        pred_tokens = preds["pred_tokens"]
+        pred_strs = preds["pred_str"]
+        pred_ids = preds["pred_ids"]  # [b, n-1]
+
+        truth_inputs = model.tokenizer.tokenize(text)
+        truth_ids = truth_inputs["input_ids"][:, 1:]
+        truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"])
+        truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"])
+
+        ids = samples["id"]
+
+        for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs,
+                                                                                        pred_ids, truth_tokens,
+                                                                                        truth_strs, truth_ids, ids):
+            pred_id = pred_id.tolist()
+            truth_id = truth_id.tolist()
+            shape_diff = len(pred_id) - len(truth_id)
+            if shape_diff < 0:
+                pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff)
+            else:
+                truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff
+            pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id)
+            mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id)
+            tok_acc = (pred_id == truth_id)[mask].float().mean().item()
+
+            this_item = {
+                "pred_token": pred_token,
+                "pred_str": pred_str,
+                "truth_str": truth_str,
+                "truth_token": truth_token,
+                "token_acc": tok_acc,
+                "id": id_
+            }
+            results.append(this_item)
+        return results
+
+    def after_evaluation(self, val_result, split_name, epoch, **kwargs):
+        eval_result_file = self.save_result(
+            result=val_result,
+            result_dir=registry.get_path("result_dir"),
+            filename="{}_epoch{}".format(split_name, epoch),
+            remove_duplicate="id",
+        )
+
+        if self.report_metric:
+            metrics = self._report_metrics(
+                eval_result_file=eval_result_file, split_name=split_name
+            )
+        else:
+            metrics = {"agg_metrics": 0.0}
+
+        return metrics
+
+    @main_process
+    def _report_metrics(self, eval_result_file, split_name):
+
+        with open(eval_result_file) as f:
+            results = json.load(f)
+
+        edit_dists = []
+        all_pred_tokens = []
+        all_truth_tokens = []
+        all_pred_strs = []
+        all_truth_strs = []
+        token_accs = []
+        for result in results:
+            pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[
+                "truth_token"], result["truth_str"], result["token_acc"]
+
+            if len(truth_str) > 0:
+                norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str)
+                edit_dists.append(norm_edit_dist)
+
+            all_pred_tokens.append(pred_token)
+            all_truth_tokens.append([truth_token])
+            all_pred_strs.append(pred_str)
+            all_truth_strs.append(truth_str)
+            token_accs.append(tok_acc)
+
+        # bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens)
+        bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8))
+        bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs)
+        bleu_score = bleu_results['bleu']
+        
+        edit_distance = np.mean(edit_dists)
+        token_accuracy = np.mean(token_accs)
+        eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy}
+
+        log_stats = {split_name: {k: v for k, v in eval_ret.items()}}
+
+        with open(
+                osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
+        ) as f:
+            f.write(json.dumps(log_stats) + "\n")
+
+        coco_res = {k: v for k, v in eval_ret.items()}
+        # agg_metrics = sum([v for v in eval_ret.values()])
+        if "edit" in self.agg_metric.lower():  # edit_distance
+            agg_metrics = (1 - edit_distance) * 100
+        elif "bleu" in self.agg_metric.lower():  # bleu_score
+            agg_metrics = bleu_score * 100
+        elif "token" in self.agg_metric.lower():  # token_accuracy
+            agg_metrics = token_accuracy * 100
+        else:
+            raise ValueError(f"Invalid metrics: '{self.agg_metric}'")
+
+        coco_res["agg_metrics"] = agg_metrics
+
+        return coco_res