diff --git a/app.py b/app.py index e01fbaea20a7f01801a371b9a7c13a1b21613cd4..1594456b213dc4d9210124becd82b1a6c2afbedc 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,8 @@ import os import subprocess -subprocess.run("pip install salesforce-lavis --no-deps", shell=True) +#subprocess.run("pip install salesforce-lavis --no-deps", shell=True) +# https://github.com/salesforce/BLIP/issues/165 from PIL import Image import gradio as gr diff --git a/lavis/__init__.py b/lavis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ab17686f819c970015351238d2db77c8c09d5243 --- /dev/null +++ b/lavis/__init__.py @@ -0,0 +1,31 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import sys + +from omegaconf import OmegaConf + +from lavis.common.registry import registry + +from lavis.datasets.builders import * +from lavis.models import * +from lavis.processors import * +from lavis.tasks import * + + +root_dir = os.path.dirname(os.path.abspath(__file__)) +default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml")) + +registry.register_path("library_root", root_dir) +repo_root = os.path.join(root_dir, "..") +registry.register_path("repo_root", repo_root) +cache_root = os.path.join(repo_root, default_cfg.env.cache_root) +registry.register_path("cache_root", cache_root) + +registry.register("MAX_INT", sys.maxsize) +registry.register("SPLIT_NAMES", ["train", "val", "test"]) diff --git a/lavis/common/config.py b/lavis/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2264b0578fd52b805f619a871ce5ff80c0310ccb --- /dev/null +++ b/lavis/common/config.py @@ -0,0 +1,468 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +from typing import Dict + +from omegaconf import OmegaConf +from lavis.common.registry import registry + + +class Config: + def __init__(self, args): + self.config = {} + + self.args = args + + # Register the config and configuration for setup + registry.register("configuration", self) + + user_config = self._build_opt_list(self.args.options) + + config = OmegaConf.load(self.args.cfg_path) + + runner_config = self.build_runner_config(config) + model_config = self.build_model_config(config, **user_config) + dataset_config = self.build_dataset_config(config) + + # Validate the user-provided runner configuration + # model and dataset configuration are supposed to be validated by the respective classes + # [TODO] validate the model/dataset configuration + # self._validate_runner_config(runner_config) + + # Override the default configuration with user options. + self.config = OmegaConf.merge( + runner_config, model_config, dataset_config, user_config + ) + + def _validate_runner_config(self, runner_config): + """ + This method validates the configuration, such that + 1) all the user specified options are valid; + 2) no type mismatches between the user specified options and the config. + """ + runner_config_validator = create_runner_config_validator() + runner_config_validator.validate(runner_config) + + def _build_opt_list(self, opts): + opts_dot_list = self._convert_to_dot_list(opts) + return OmegaConf.from_dotlist(opts_dot_list) + + @staticmethod + def build_model_config(config, **kwargs): + model = config.get("model", None) + assert model is not None, "Missing model configuration file." + + model_cls = registry.get_model_class(model.arch) + assert model_cls is not None, f"Model '{model.arch}' has not been registered." + + model_type = kwargs.get("model.model_type", None) + if not model_type: + model_type = model.get("model_type", None) + # else use the model type selected by user. + + assert model_type is not None, "Missing model_type." + + model_config_path = model_cls.default_config_path(model_type=model_type) + + model_config = OmegaConf.create() + # hiararchy override, customized config > default config + model_config = OmegaConf.merge( + model_config, + OmegaConf.load(model_config_path), + {"model": config["model"]}, + ) + + return model_config + + @staticmethod + def build_runner_config(config): + return {"run": config.run} + + @staticmethod + def build_dataset_config(config): + datasets = config.get("datasets", None) + if datasets is None: + raise KeyError( + "Expecting 'datasets' as the root key for dataset configuration." + ) + + dataset_config = OmegaConf.create() + + for dataset_name in datasets: + builder_cls = registry.get_builder_class(dataset_name) + + dataset_config_type = datasets[dataset_name].get("type", "default") + dataset_config_path = builder_cls.default_config_path( + type=dataset_config_type + ) + + # hiararchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + OmegaConf.load(dataset_config_path), + {"datasets": {dataset_name: config["datasets"][dataset_name]}}, + ) + + return dataset_config + + def _convert_to_dot_list(self, opts): + if opts is None: + opts = [] + + if len(opts) == 0: + return opts + + has_equal = opts[0].find("=") != -1 + + if has_equal: + return opts + + return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])] + + def get_config(self): + return self.config + + @property + def run_cfg(self): + return self.config.run + + @property + def datasets_cfg(self): + return self.config.datasets + + @property + def model_cfg(self): + return self.config.model + + def pretty_print(self): + logging.info("\n===== Running Parameters =====") + logging.info(self._convert_node_to_json(self.config.run)) + + logging.info("\n====== Dataset Attributes ======") + datasets = self.config.datasets + + for dataset in datasets: + if dataset in self.config.datasets: + logging.info(f"\n======== {dataset} =======") + dataset_config = self.config.datasets[dataset] + logging.info(self._convert_node_to_json(dataset_config)) + else: + logging.warning(f"No dataset named '{dataset}' in config. Skipping") + + logging.info(f"\n====== Model Attributes ======") + logging.info(self._convert_node_to_json(self.config.model)) + + def _convert_node_to_json(self, node): + container = OmegaConf.to_container(node, resolve=True) + return json.dumps(container, indent=4, sort_keys=True) + + def to_dict(self): + return OmegaConf.to_container(self.config) + + +def node_to_dict(node): + return OmegaConf.to_container(node) + + +class ConfigValidator: + """ + This is a preliminary implementation to centralize and validate the configuration. + May be altered in the future. + + A helper class to validate configurations from yaml file. + + This serves the following purposes: + 1. Ensure all the options in the yaml are defined, raise error if not. + 2. when type mismatches are found, the validator will raise an error. + 3. a central place to store and display helpful messages for supported configurations. + + """ + + class _Argument: + def __init__(self, name, choices=None, type=None, help=None): + self.name = name + self.val = None + self.choices = choices + self.type = type + self.help = help + + def __str__(self): + s = f"{self.name}={self.val}" + if self.type is not None: + s += f", ({self.type})" + if self.choices is not None: + s += f", choices: {self.choices}" + if self.help is not None: + s += f", ({self.help})" + return s + + def __init__(self, description): + self.description = description + + self.arguments = dict() + + self.parsed_args = None + + def __getitem__(self, key): + assert self.parsed_args is not None, "No arguments parsed yet." + + return self.parsed_args[key] + + def __str__(self) -> str: + return self.format_help() + + def add_argument(self, *args, **kwargs): + """ + Assume the first argument is the name of the argument. + """ + self.arguments[args[0]] = self._Argument(*args, **kwargs) + + def validate(self, config=None): + """ + Convert yaml config (dict-like) to list, required by argparse. + """ + for k, v in config.items(): + assert ( + k in self.arguments + ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}.""" + + if self.arguments[k].type is not None: + try: + self.arguments[k].val = self.arguments[k].type(v) + except ValueError: + raise ValueError(f"{k} is not a valid {self.arguments[k].type}.") + + if self.arguments[k].choices is not None: + assert ( + v in self.arguments[k].choices + ), f"""{k} must be one of {self.arguments[k].choices}.""" + + return config + + def format_arguments(self): + return str([f"{k}" for k in sorted(self.arguments.keys())]) + + def format_help(self): + # description + key-value pair string for each argument + help_msg = str(self.description) + return help_msg + ", available arguments: " + self.format_arguments() + + def print_help(self): + # display help message + print(self.format_help()) + + +def create_runner_config_validator(): + validator = ConfigValidator(description="Runner configurations") + + validator.add_argument( + "runner", + type=str, + choices=["runner_base", "runner_iter"], + help="""Runner to use. The "runner_base" uses epoch-based training while iter-based + runner runs based on iters. Default: runner_base""", + ) + # add argumetns for training dataset ratios + validator.add_argument( + "train_dataset_ratios", + type=Dict[str, float], + help="""Ratios of training dataset. This is used in iteration-based runner. + Do not support for epoch-based runner because how to define an epoch becomes tricky. + Default: None""", + ) + validator.add_argument( + "max_iters", + type=float, + help="Maximum number of iterations to run.", + ) + validator.add_argument( + "max_epoch", + type=int, + help="Maximum number of epochs to run.", + ) + # add arguments for iters_per_inner_epoch + validator.add_argument( + "iters_per_inner_epoch", + type=float, + help="Number of iterations per inner epoch. This is required when runner is runner_iter.", + ) + lr_scheds_choices = registry.list_lr_schedulers() + validator.add_argument( + "lr_sched", + type=str, + choices=lr_scheds_choices, + help="Learning rate scheduler to use, from {}".format(lr_scheds_choices), + ) + task_choices = registry.list_tasks() + validator.add_argument( + "task", + type=str, + choices=task_choices, + help="Task to use, from {}".format(task_choices), + ) + # add arguments for init_lr + validator.add_argument( + "init_lr", + type=float, + help="Initial learning rate. This will be the learning rate after warmup and before decay.", + ) + # add arguments for min_lr + validator.add_argument( + "min_lr", + type=float, + help="Minimum learning rate (after decay).", + ) + # add arguments for warmup_lr + validator.add_argument( + "warmup_lr", + type=float, + help="Starting learning rate for warmup.", + ) + # add arguments for learning rate decay rate + validator.add_argument( + "lr_decay_rate", + type=float, + help="Learning rate decay rate. Required if using a decaying learning rate scheduler.", + ) + # add arguments for weight decay + validator.add_argument( + "weight_decay", + type=float, + help="Weight decay rate.", + ) + # add arguments for training batch size + validator.add_argument( + "batch_size_train", + type=int, + help="Training batch size.", + ) + # add arguments for evaluation batch size + validator.add_argument( + "batch_size_eval", + type=int, + help="Evaluation batch size, including validation and testing.", + ) + # add arguments for number of workers for data loading + validator.add_argument( + "num_workers", + help="Number of workers for data loading.", + ) + # add arguments for warm up steps + validator.add_argument( + "warmup_steps", + type=int, + help="Number of warmup steps. Required if a warmup schedule is used.", + ) + # add arguments for random seed + validator.add_argument( + "seed", + type=int, + help="Random seed.", + ) + # add arguments for output directory + validator.add_argument( + "output_dir", + type=str, + help="Output directory to save checkpoints and logs.", + ) + # add arguments for whether only use evaluation + validator.add_argument( + "evaluate", + help="Whether to only evaluate the model. If true, training will not be performed.", + ) + # add arguments for splits used for training, e.g. ["train", "val"] + validator.add_argument( + "train_splits", + type=list, + help="Splits to use for training.", + ) + # add arguments for splits used for validation, e.g. ["val"] + validator.add_argument( + "valid_splits", + type=list, + help="Splits to use for validation. If not provided, will skip the validation.", + ) + # add arguments for splits used for testing, e.g. ["test"] + validator.add_argument( + "test_splits", + type=list, + help="Splits to use for testing. If not provided, will skip the testing.", + ) + # add arguments for accumulating gradient for iterations + validator.add_argument( + "accum_grad_iters", + type=int, + help="Number of iterations to accumulate gradient for.", + ) + + # ====== distributed training ====== + validator.add_argument( + "device", + type=str, + choices=["cpu", "cuda"], + help="Device to use. Support 'cuda' or 'cpu' as for now.", + ) + validator.add_argument( + "world_size", + type=int, + help="Number of processes participating in the job.", + ) + validator.add_argument("dist_url", type=str) + validator.add_argument("distributed", type=bool) + # add arguments to opt using distributed sampler during evaluation or not + validator.add_argument( + "use_dist_eval_sampler", + type=bool, + help="Whether to use distributed sampler during evaluation or not.", + ) + + # ====== task specific ====== + # generation task specific arguments + # add arguments for maximal length of text output + validator.add_argument( + "max_len", + type=int, + help="Maximal length of text output.", + ) + # add arguments for minimal length of text output + validator.add_argument( + "min_len", + type=int, + help="Minimal length of text output.", + ) + # add arguments number of beams + validator.add_argument( + "num_beams", + type=int, + help="Number of beams used for beam search.", + ) + + # vqa task specific arguments + # add arguments for number of answer candidates + validator.add_argument( + "num_ans_candidates", + type=int, + help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""", + ) + # add arguments for inference method + validator.add_argument( + "inference_method", + type=str, + choices=["genearte", "rank"], + help="""Inference method to use for question answering. If rank, requires a answer list.""", + ) + + # ====== model specific ====== + validator.add_argument( + "k_test", + type=int, + help="Number of top k most similar samples from ITC/VTC selection to be tested.", + ) + + return validator diff --git a/lavis/common/dist_utils.py b/lavis/common/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..296a3c86f29c6e82fa8f1108c7dd9fa7d3e9ce45 --- /dev/null +++ b/lavis/common/dist_utils.py @@ -0,0 +1,137 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import functools +import os + +import torch +import torch.distributed as dist +import timm.models.hub as timm_hub + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) + args.gpu = args.rank % torch.cuda.device_count() + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + "| distributed init (rank {}, world {}): {}".format( + args.rank, args.world_size, args.dist_url + ), + flush=True, + ) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + timeout=datetime.timedelta( + days=365 + ), # allow auto-downloading and de-compressing + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def get_dist_info(): + if torch.__version__ < "1.0": + initialized = dist._initialized + else: + initialized = dist.is_initialized() + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: # non-distributed training + rank = 0 + world_size = 1 + return rank, world_size + + +def main_process(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper + + +def download_cached_file(url, check_hash=True, progress=False): + """ + Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. + If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. + """ + + def get_cached_file_path(): + # a hack to sync the file path across processes + parts = torch.hub.urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(timm_hub.get_cache_dir(), filename) + + return cached_file + + if is_main_process(): + timm_hub.download_cached_file(url, check_hash, progress) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return get_cached_file_path() diff --git a/lavis/common/gradcam.py b/lavis/common/gradcam.py new file mode 100644 index 0000000000000000000000000000000000000000..d53a5254d4b319eaf2cbfbd081b0ca8e38c5c7a0 --- /dev/null +++ b/lavis/common/gradcam.py @@ -0,0 +1,24 @@ +import numpy as np +from matplotlib import pyplot as plt +from scipy.ndimage import filters +from skimage import transform as skimage_transform + + +def getAttMap(img, attMap, blur=True, overlap=True): + attMap -= attMap.min() + if attMap.max() > 0: + attMap /= attMap.max() + attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant") + if blur: + attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2])) + attMap -= attMap.min() + attMap /= attMap.max() + cmap = plt.get_cmap("jet") + attMapV = cmap(attMap) + attMapV = np.delete(attMapV, 3, 2) + if overlap: + attMap = ( + 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img + + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV + ) + return attMap diff --git a/lavis/common/logger.py b/lavis/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1ea0d096db7b4914c6ba2c031c06a40fd793f3 --- /dev/null +++ b/lavis/common/logger.py @@ -0,0 +1,195 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + +from lavis.common import dist_utils + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not dist_utils.is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + "'{}' object has no attribute '{}'".format(type(self).__name__, attr) + ) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {:.4f}".format(name, meter.global_avg)) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + log_msg = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_msg.append("max mem: {memory:.0f}") + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print( + "{} Total time: {} ({:.4f} s / it)".format( + header, total_time_str, total_time / len(iterable) + ) + ) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def setup_logger(): + logging.basicConfig( + level=logging.INFO if dist_utils.is_main_process() else logging.WARN, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], + ) diff --git a/lavis/common/optims.py b/lavis/common/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..28b68645d509eedbafb47e8efa213ba6e2cc63f3 --- /dev/null +++ b/lavis/common/optims.py @@ -0,0 +1,117 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math + +from lavis.common.registry import registry + + +@registry.register_lr_scheduler("linear_warmup_step_lr") +class LinearWarmupStepLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + decay_rate=1, + warmup_start_lr=-1, + warmup_steps=0, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.decay_rate = decay_rate + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + step_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + init_lr=self.init_lr, + min_lr=self.min_lr, + decay_rate=self.decay_rate, + ) + + +@registry.register_lr_scheduler("linear_warmup_cosine_lr") +class LinearWarmupCosineLRScheduler: + def __init__( + self, + optimizer, + max_epoch, + min_lr, + init_lr, + warmup_steps=0, + warmup_start_lr=-1, + **kwargs + ): + self.optimizer = optimizer + + self.max_epoch = max_epoch + self.min_lr = min_lr + + self.init_lr = init_lr + self.warmup_steps = warmup_steps + self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr + + def step(self, cur_epoch, cur_step): + # assuming the warmup iters less than one epoch + if cur_epoch == 0: + warmup_lr_schedule( + step=cur_step, + optimizer=self.optimizer, + max_step=self.warmup_steps, + init_lr=self.warmup_start_lr, + max_lr=self.init_lr, + ) + else: + cosine_lr_schedule( + epoch=cur_epoch, + optimizer=self.optimizer, + max_epoch=self.max_epoch, + init_lr=self.init_lr, + min_lr=self.min_lr, + ) + + +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * ( + 1.0 + math.cos(math.pi * epoch / max_epoch) + ) + min_lr + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group["lr"] = lr diff --git a/lavis/common/registry.py b/lavis/common/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..9039d8aaa580f19cc0d43ed9330bd90055045867 --- /dev/null +++ b/lavis/common/registry.py @@ -0,0 +1,329 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + + +class Registry: + mapping = { + "builder_name_mapping": {}, + "task_name_mapping": {}, + "processor_name_mapping": {}, + "model_name_mapping": {}, + "lr_scheduler_name_mapping": {}, + "runner_name_mapping": {}, + "state": {}, + "paths": {}, + } + + @classmethod + def register_builder(cls, name): + r"""Register a dataset builder to registry with key 'name' + + Args: + name: Key with which the builder will be registered. + + Usage: + + from lavis.common.registry import registry + from lavis.datasets.base_dataset_builder import BaseDatasetBuilder + """ + + def wrap(builder_cls): + from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder + + assert issubclass( + builder_cls, BaseDatasetBuilder + ), "All builders must inherit BaseDatasetBuilder class, found {}".format( + builder_cls + ) + if name in cls.mapping["builder_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["builder_name_mapping"][name] + ) + ) + cls.mapping["builder_name_mapping"][name] = builder_cls + return builder_cls + + return wrap + + @classmethod + def register_task(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + def wrap(task_cls): + from lavis.tasks.base_task import BaseTask + + assert issubclass( + task_cls, BaseTask + ), "All tasks must inherit BaseTask class" + if name in cls.mapping["task_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["task_name_mapping"][name] + ) + ) + cls.mapping["task_name_mapping"][name] = task_cls + return task_cls + + return wrap + + @classmethod + def register_model(cls, name): + r"""Register a task to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + def wrap(model_cls): + from lavis.models import BaseModel + + assert issubclass( + model_cls, BaseModel + ), "All models must inherit BaseModel class" + if name in cls.mapping["model_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["model_name_mapping"][name] + ) + ) + cls.mapping["model_name_mapping"][name] = model_cls + return model_cls + + return wrap + + @classmethod + def register_processor(cls, name): + r"""Register a processor to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + def wrap(processor_cls): + from lavis.processors import BaseProcessor + + assert issubclass( + processor_cls, BaseProcessor + ), "All processors must inherit BaseProcessor class" + if name in cls.mapping["processor_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["processor_name_mapping"][name] + ) + ) + cls.mapping["processor_name_mapping"][name] = processor_cls + return processor_cls + + return wrap + + @classmethod + def register_lr_scheduler(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + def wrap(lr_sched_cls): + if name in cls.mapping["lr_scheduler_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["lr_scheduler_name_mapping"][name] + ) + ) + cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls + return lr_sched_cls + + return wrap + + @classmethod + def register_runner(cls, name): + r"""Register a model to registry with key 'name' + + Args: + name: Key with which the task will be registered. + + Usage: + + from lavis.common.registry import registry + """ + + def wrap(runner_cls): + if name in cls.mapping["runner_name_mapping"]: + raise KeyError( + "Name '{}' already registered for {}.".format( + name, cls.mapping["runner_name_mapping"][name] + ) + ) + cls.mapping["runner_name_mapping"][name] = runner_cls + return runner_cls + + return wrap + + @classmethod + def register_path(cls, name, path): + r"""Register a path to registry with key 'name' + + Args: + name: Key with which the path will be registered. + + Usage: + + from lavis.common.registry import registry + """ + assert isinstance(path, str), "All path must be str." + if name in cls.mapping["paths"]: + raise KeyError("Name '{}' already registered.".format(name)) + cls.mapping["paths"][name] = path + + @classmethod + def register(cls, name, obj): + r"""Register an item to registry with key 'name' + + Args: + name: Key with which the item will be registered. + + Usage:: + + from lavis.common.registry import registry + + registry.register("config", {}) + """ + path = name.split(".") + current = cls.mapping["state"] + + for part in path[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + current[path[-1]] = obj + + # @classmethod + # def get_trainer_class(cls, name): + # return cls.mapping["trainer_name_mapping"].get(name, None) + + @classmethod + def get_builder_class(cls, name): + return cls.mapping["builder_name_mapping"].get(name, None) + + @classmethod + def get_model_class(cls, name): + return cls.mapping["model_name_mapping"].get(name, None) + + @classmethod + def get_task_class(cls, name): + return cls.mapping["task_name_mapping"].get(name, None) + + @classmethod + def get_processor_class(cls, name): + return cls.mapping["processor_name_mapping"].get(name, None) + + @classmethod + def get_lr_scheduler_class(cls, name): + return cls.mapping["lr_scheduler_name_mapping"].get(name, None) + + @classmethod + def get_runner_class(cls, name): + return cls.mapping["runner_name_mapping"].get(name, None) + + @classmethod + def list_runners(cls): + return sorted(cls.mapping["runner_name_mapping"].keys()) + + @classmethod + def list_models(cls): + return sorted(cls.mapping["model_name_mapping"].keys()) + + @classmethod + def list_tasks(cls): + return sorted(cls.mapping["task_name_mapping"].keys()) + + @classmethod + def list_processors(cls): + return sorted(cls.mapping["processor_name_mapping"].keys()) + + @classmethod + def list_lr_schedulers(cls): + return sorted(cls.mapping["lr_scheduler_name_mapping"].keys()) + + @classmethod + def list_datasets(cls): + return sorted(cls.mapping["builder_name_mapping"].keys()) + + @classmethod + def get_path(cls, name): + return cls.mapping["paths"].get(name, None) + + @classmethod + def get(cls, name, default=None, no_warning=False): + r"""Get an item from registry with key 'name' + + Args: + name (string): Key whose value needs to be retrieved. + default: If passed and key is not in registry, default value will + be returned with a warning. Default: None + no_warning (bool): If passed as True, warning when key doesn't exist + will not be generated. Useful for MMF's + internal operations. Default: False + """ + original_name = name + name = name.split(".") + value = cls.mapping["state"] + for subname in name: + value = value.get(subname, default) + if value is default: + break + + if ( + "writer" in cls.mapping["state"] + and value == default + and no_warning is False + ): + cls.mapping["state"]["writer"].warning( + "Key {} is not present in registry, returning default value " + "of {}".format(original_name, default) + ) + return value + + @classmethod + def unregister(cls, name): + r"""Remove an item from registry with key 'name' + + Args: + name: Key which needs to be removed. + Usage:: + + from mmf.common.registry import registry + + config = registry.unregister("config") + """ + return cls.mapping["state"].pop(name, None) + + +registry = Registry() diff --git a/lavis/common/utils.py b/lavis/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29b35e5159634f9b114f2ddd608da1db67893840 --- /dev/null +++ b/lavis/common/utils.py @@ -0,0 +1,424 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import json +import logging +import os +import pickle +import re +import shutil +import urllib +import urllib.error +import urllib.request +from typing import Optional +from urllib.parse import urlparse + +import numpy as np +import pandas as pd +import yaml +from iopath.common.download import download +from iopath.common.file_io import file_lock, g_pathmgr +from lavis.common.registry import registry +from torch.utils.model_zoo import tqdm +from torchvision.datasets.utils import ( + check_integrity, + download_file_from_google_drive, + extract_archive, +) + + +def now(): + from datetime import datetime + + return datetime.now().strftime("%Y%m%d%H%M")[:-1] + + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + + +def get_cache_path(rel_path): + return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path)) + + +def get_abs_path(rel_path): + return os.path.join(registry.get_path("library_root"), rel_path) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +# The following are adapted from torchvision and vissl +# torchvision: https://github.com/pytorch/vision +# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + print(f"Error creating directory: {dir_path}") + return is_success + + +def get_redirected_url(url: str): + """ + Given a URL, returns the URL it redirects to or the + original URL in case of no indirection + """ + import requests + + with requests.Session() as session: + with session.get(url, stream=True, allow_redirects=True) as response: + if response.history: + return response.url + else: + return url + + +def to_google_drive_download_url(view_url: str) -> str: + """ + Utility function to transform a view URL of google drive + to a download URL for google drive + Example input: + https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view + Example output: + https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp + """ + splits = view_url.split("/") + assert splits[-1] == "view" + file_id = splits[-2] + return f"https://drive.google.com/uc?export=download&id={file_id}" + + +def download_google_drive_url(url: str, output_path: str, output_file_name: str): + """ + Download a file from google drive + Downloading an URL from google drive requires confirmation when + the file of the size is too big (google drive notifies that + anti-viral checks cannot be performed on such files) + """ + import requests + + with requests.Session() as session: + + # First get the confirmation token and append it to the URL + with session.get(url, stream=True, allow_redirects=True) as response: + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + url = url + "&confirm=" + v + + # Then download the content of the file + with session.get(url, stream=True, verify=True) as response: + makedir(output_path) + path = os.path.join(output_path, output_file_name) + total_size = int(response.headers.get("Content-length", 0)) + with open(path, "wb") as file: + from tqdm import tqdm + + with tqdm(total=total_size) as progress_bar: + for block in response.iter_content( + chunk_size=io.DEFAULT_BUFFER_SIZE + ): + file.write(block) + progress_bar.update(len(block)) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen( + urllib.request.Request(url, headers={"User-Agent": "vissl"}) + ) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def download_url( + url: str, + root: str, + filename: Optional[str] = None, + md5: Optional[str] = None, +) -> None: + """Download a file from a url and place it in root. + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. + If None, use the basename of the URL. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir(root) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + # expand redirect chain if needed + url = get_redirected_url(url) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print( + "Failed download. Trying https -> http instead." + " Downloading " + url + " to " + fpath + ) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) + + +def cache_url(url: str, cache_dir: str) -> str: + """ + This implementation downloads the remote resource and caches it locally. + The resource will only be downloaded if not previously requested. + """ + parsed_url = urlparse(url) + dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/"))) + makedir(dirname) + filename = url.split("/")[-1] + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logging.info(f"Downloading {url} to {cached} ...") + cached = download(url, dirname, filename=filename) + logging.info(f"URL {url} cached in {cached}") + return cached + + +# TODO (prigoyal): convert this into RAII-style API +def create_file_symlink(file1, file2): + """ + Simply create the symlinks for a given file1 to file2. + Useful during model checkpointing to symlinks to the + latest successful checkpoint. + """ + try: + if g_pathmgr.exists(file2): + g_pathmgr.rm(file2) + g_pathmgr.symlink(file1, file2) + except Exception as e: + logging.info(f"Could NOT create symlink. Error: {e}") + + +def save_file(data, filename, append_to_json=True, verbose=True): + """ + Common i/o utility to handle saving data to various file formats. + Supported: + .pkl, .pickle, .npy, .json + Specifically for .json, users have the option to either append (default) + or rewrite by passing in Boolean value to append_to_json. + """ + if verbose: + logging.info(f"Saving data to file: {filename}") + file_ext = os.path.splitext(filename)[1] + if file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "wb") as fopen: + pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL) + elif file_ext == ".npy": + with g_pathmgr.open(filename, "wb") as fopen: + np.save(fopen, data) + elif file_ext == ".json": + if append_to_json: + with g_pathmgr.open(filename, "a") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + else: + with g_pathmgr.open(filename, "w") as fopen: + fopen.write(json.dumps(data, sort_keys=True) + "\n") + fopen.flush() + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "w") as fopen: + dump = yaml.dump(data) + fopen.write(dump) + fopen.flush() + else: + raise Exception(f"Saving {file_ext} is not supported yet") + + if verbose: + logging.info(f"Saved data to file: {filename}") + + +def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False): + """ + Common i/o utility to handle loading data from various file formats. + Supported: + .pkl, .pickle, .npy, .json + For the npy files, we support reading the files in mmap_mode. + If the mmap_mode of reading is not successful, we load data without the + mmap_mode. + """ + if verbose: + logging.info(f"Loading data from file: {filename}") + + file_ext = os.path.splitext(filename)[1] + if file_ext == ".txt": + with g_pathmgr.open(filename, "r") as fopen: + data = fopen.readlines() + elif file_ext in [".pkl", ".pickle"]: + with g_pathmgr.open(filename, "rb") as fopen: + data = pickle.load(fopen, encoding="latin1") + elif file_ext == ".npy": + if mmap_mode: + try: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load( + fopen, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + except ValueError as e: + logging.info( + f"Could not mmap {filename}: {e}. Trying without g_pathmgr" + ) + data = np.load( + filename, + allow_pickle=allow_pickle, + encoding="latin1", + mmap_mode=mmap_mode, + ) + logging.info("Successfully loaded without g_pathmgr") + except Exception: + logging.info("Could not mmap without g_pathmgr. Trying without mmap") + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + else: + with g_pathmgr.open(filename, "rb") as fopen: + data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1") + elif file_ext == ".json": + with g_pathmgr.open(filename, "r") as fopen: + data = json.load(fopen) + elif file_ext == ".yaml": + with g_pathmgr.open(filename, "r") as fopen: + data = yaml.load(fopen, Loader=yaml.FullLoader) + elif file_ext == ".csv": + with g_pathmgr.open(filename, "r") as fopen: + data = pd.read_csv(fopen) + else: + raise Exception(f"Reading from {file_ext} is not supported yet") + return data + + +def abspath(resource_path: str): + """ + Make a path absolute, but take into account prefixes like + "http://" or "manifold://" + """ + regex = re.compile(r"^\w+://") + if regex.match(resource_path) is None: + return os.path.abspath(resource_path) + else: + return resource_path + + +def makedir(dir_path): + """ + Create the directory if it does not exist. + """ + is_success = False + try: + if not g_pathmgr.exists(dir_path): + g_pathmgr.mkdirs(dir_path) + is_success = True + except BaseException: + logging.info(f"Error creating directory: {dir_path}") + return is_success + + +def is_url(input_url): + """ + Check if an input string is a url. look for http(s):// and ignoring the case + """ + is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None + return is_url + + +def cleanup_dir(dir): + """ + Utility for deleting a directory. Useful for cleaning the storage space + that contains various training artifacts like checkpoints, data etc. + """ + if os.path.exists(dir): + logging.info(f"Deleting directory: {dir}") + shutil.rmtree(dir) + logging.info(f"Deleted contents of directory: {dir}") + + +def get_file_size(filename): + """ + Given a file, get the size of file in MB + """ + size_in_mb = os.path.getsize(filename) / float(1024**2) + return size_in_mb diff --git a/lavis/common/vqa_tools/__init__.py b/lavis/common/vqa_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b98da85428159ad0dcfab7685c080848ecf8c7b --- /dev/null +++ b/lavis/common/vqa_tools/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" diff --git a/lavis/common/vqa_tools/vqa.py b/lavis/common/vqa_tools/vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..a386b9094b0528b33e7511aff4027f30459a7ff7 --- /dev/null +++ b/lavis/common/vqa_tools/vqa.py @@ -0,0 +1,211 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" +__version__ = "0.9" + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print("loading VQA annotations and questions into memory...") + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, "r")) + questions = json.load(open(question_file, "r")) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print("creating index...") + imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} + qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + for ann in self.dataset["annotations"]: + imgToQA[ann["image_id"]] += [ann] + qa[ann["question_id"]] = ann + for ques in self.questions["questions"]: + qqa[ques["question_id"]] = ques + print("index created!") + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + for key, value in self.datset["info"].items(): + print("%s: %s" % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(imgIds) == 0: + anns = sum( + [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], + [], + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["question_id"] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(quesIds) == 0: + anns = sum( + [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["image_id"] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann["question_id"] + print("Question: %s" % (self.qqa[quesId]["question"])) + for ans in ann["answers"]: + print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset["info"] = copy.deepcopy(self.questions["info"]) + res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) + res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) + res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) + res.dataset["license"] = copy.deepcopy(self.questions["license"]) + + print("Loading and preparing results... ") + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, "results is not an array of objects" + annsQuesIds = [ann["question_id"] for ann in anns] + assert set(annsQuesIds) == set( + self.getQuesIds() + ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." + for ann in anns: + quesId = ann["question_id"] + if res.dataset["task_type"] == "Multiple Choice": + assert ( + ann["answer"] in self.qqa[quesId]["multiple_choices"] + ), "predicted answer is not one of the multiple choices" + qaAnn = self.qa[quesId] + ann["image_id"] = qaAnn["image_id"] + ann["question_type"] = qaAnn["question_type"] + ann["answer_type"] = qaAnn["answer_type"] + print( + "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) + ) + + res.dataset["annotations"] = anns + res.createIndex() + return res diff --git a/lavis/common/vqa_tools/vqa_eval.py b/lavis/common/vqa_tools/vqa_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ee808b349bb6166c744338b02af2bc84a68650ff --- /dev/null +++ b/lavis/common/vqa_tools/vqa_eval.py @@ -0,0 +1,324 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = "aagrawal" + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys +import re + + +class VQAEval: + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa + self.vqaRes = vqaRes + if vqa is not None: + self.params = {"question_id": vqa.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print("computing accuracy") + step = 0 + for quesId in quesIds: + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]["question_type"] + ansType = gts[quesId]["answer_type"] + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print("Done computing accuracy") + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/lavis/configs/datasets/aokvqa/defaults.yaml b/lavis/configs/datasets/aokvqa/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2562db220cd9b08c0dd02c6b76dae070242e3c20 --- /dev/null +++ b/lavis/configs/datasets/aokvqa/defaults.yaml @@ -0,0 +1,35 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + aok_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json + storage: + - aokvqa/annotations/aokvqa_v1p0_train.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json + storage: + - aokvqa/annotations/aokvqa_v1p0_val.json + - aokvqa/annotations/specialized_vocab_train_lavis.json + # - aokvqa/annotations/large_vocab_train_lavis.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/specialized_vocab_train.json + storage: + - aokvqa/annotations/aokvqa_v1p0_test.json + - aokvqa/annotations/specialized_vocab_train_lavis.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/avsd/defaults_dial.yaml b/lavis/configs/datasets/avsd/defaults_dial.yaml new file mode 100644 index 0000000000000000000000000000000000000000..939ac9bcc1916c7ab09fe86692aaaeffc780dd22 --- /dev/null +++ b/lavis/configs/datasets/avsd/defaults_dial.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + avsd_dialogue: # name of the dataset builder + dataset_card: dataset_card/avsd_dialogue.md + data_type: features #extracted features of videos (I3D, VGGish) # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_train.json + storage: avsd/annotations/train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_val.json + storage: avsd/annotations/val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/avsd_dstc7_test.json + storage: avsd/annotations/test.json + features: + storage: avsd/features/ diff --git a/lavis/configs/datasets/coco/defaults_cap.yaml b/lavis/configs/datasets/coco/defaults_cap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f9ffc8d293fed0bac7d745ee8c00f53ce39565d --- /dev/null +++ b/lavis/configs/datasets/coco/defaults_cap.yaml @@ -0,0 +1,28 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_caption: # name of the dataset builder + dataset_card: dataset_card/coco_caption.md + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json + md5: aa31ac474cf6250ebb81d18348a07ed8 + storage: coco/annotations/coco_karpathy_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json + md5: b273847456ef5580e33713b1f7de52a0 + storage: coco/annotations/coco_karpathy_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json + md5: 3ff34b0ef2db02d01c37399f6a2a6cd1 + storage: coco/annotations/coco_karpathy_test.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/coco/defaults_ret.yaml b/lavis/configs/datasets/coco/defaults_ret.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4bcc8a07b23bd77e0457ff5055b5037df7c9112f --- /dev/null +++ b/lavis/configs/datasets/coco/defaults_ret.yaml @@ -0,0 +1,27 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_retrieval: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json + md5: aa31ac474cf6250ebb81d18348a07ed8 + storage: coco/annotations/coco_karpathy_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json + md5: b273847456ef5580e33713b1f7de52a0 + storage: coco/annotations/coco_karpathy_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json + md5: 3ff34b0ef2db02d01c37399f6a2a6cd1 + storage: coco/annotations/coco_karpathy_test.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/coco/defaults_vqa.yaml b/lavis/configs/datasets/coco/defaults_vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..08e036d2fd55408afd6c9a799ce8b8f7c97abd90 --- /dev/null +++ b/lavis/configs/datasets/coco/defaults_vqa.yaml @@ -0,0 +1,41 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json + storage: + - coco/annotations/vqa_train.json + - coco/annotations/vqa_val.json + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json + storage: + - coco/annotations/vqa_val_eval.json + - coco/annotations/answer_list.json + - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json + - coco/annotations/v2_mscoco_val2014_annotations.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + storage: + - coco/annotations/vqa_test.json + - coco/annotations/answer_list.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/coco/eval_vqa.yaml b/lavis/configs/datasets/coco/eval_vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bfc637955aa6d7972150d671368fa6aa7d235cfd --- /dev/null +++ b/lavis/configs/datasets/coco/eval_vqa.yaml @@ -0,0 +1,27 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json + storage: + - coco/annotations/vqa_val_eval.json + - coco/annotations/answer_list.json + - coco/annotations/v2_OpenEnded_mscoco_val2014_questions.json + - coco/annotations/v2_mscoco_val2014_annotations.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/conceptual_caption/defaults_12m.yaml b/lavis/configs/datasets/conceptual_caption/defaults_12m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7f62cd3a2e5f69cc821ddb683c5eb642700d2274 --- /dev/null +++ b/lavis/configs/datasets/conceptual_caption/defaults_12m.yaml @@ -0,0 +1,20 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + conceptual_caption_12m: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - /export/home/workspace/datasets/cc12m.json + storage: + - conceptual_caption/annotations/cc12m.json + images: + storage: conceptual_caption/images_12m diff --git a/lavis/configs/datasets/conceptual_caption/defaults_3m.yaml b/lavis/configs/datasets/conceptual_caption/defaults_3m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fcba29b0ec781b3424ef06ffe59b474c82cd14f3 --- /dev/null +++ b/lavis/configs/datasets/conceptual_caption/defaults_3m.yaml @@ -0,0 +1,20 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + conceptual_caption_3m: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - /export/home/workspace/datasets/cc3m.json + storage: + - conceptual_caption/annotations/cc3m.json + images: + storage: conceptual_caption/images diff --git a/lavis/configs/datasets/didemo/defaults_ret.yaml b/lavis/configs/datasets/didemo/defaults_ret.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7923d03ce84fea806b7605ff425e0b362506fe62 --- /dev/null +++ b/lavis/configs/datasets/didemo/defaults_ret.yaml @@ -0,0 +1,25 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + didemo_retrieval: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_train.json + storage: didemo/annotations/retrieval_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_val.json + storage: didemo/annotations/retrieval_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/retrieval_test.json + storage: didemo/annotations/retrieval_test.json + videos: + storage: didemo/videos + # storage: /export/share/dongxuli/data/didemo_retrieval/videos diff --git a/lavis/configs/datasets/flickr30k/defaults.yaml b/lavis/configs/datasets/flickr30k/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9127cab813b8caa3286a7bf79533f33babbebde3 --- /dev/null +++ b/lavis/configs/datasets/flickr30k/defaults.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + flickr30k: + # data_dir: ${env.data_dir}/datasets + data_type: images + + build_info: + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json + storage: flickr30k/annotations/train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json + storage: flickr30k/annotations/val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json + storage: flickr30k/annotations/test.json + images: + storage: flickr30k/images + # storage: /export/share/datasets/vision/flickr30k diff --git a/lavis/configs/datasets/gqa/balanced_testdev.yaml b/lavis/configs/datasets/gqa/balanced_testdev.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86114fb964cda8e58f9848c216c5f4ae5f28ca70 --- /dev/null +++ b/lavis/configs/datasets/gqa/balanced_testdev.yaml @@ -0,0 +1,30 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + gqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json + storage: + - gqa/annotations/train_balanced_questions.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json + storage: + - gqa/annotations/testdev_balanced_questions.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json + storage: + - gqa/annotations/test_balanced_questions.json + images: + storage: gqa/images/ diff --git a/lavis/configs/datasets/gqa/balanced_val.yaml b/lavis/configs/datasets/gqa/balanced_val.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ca420dfdcef381bebb261b4178b9a288bc331d5f --- /dev/null +++ b/lavis/configs/datasets/gqa/balanced_val.yaml @@ -0,0 +1,30 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + gqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json + storage: + - gqa/annotations/train_balanced_questions.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/val_balanced_questions.json + storage: + - gqa/annotations/val_balanced_questions.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json + storage: + - gqa/annotations/test_balanced_questions.json + images: + storage: gqa/images/ diff --git a/lavis/configs/datasets/gqa/defaults.yaml b/lavis/configs/datasets/gqa/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7c2d87cd7afcedd52f65c5103d5d9d697f5de7f7 --- /dev/null +++ b/lavis/configs/datasets/gqa/defaults.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + gqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - /export/share/datasets/vision/GQA/questions1.2/train_all_questions/train_all_questions_0.json + - /export/share/datasets/vision/GQA/questions1.2/val_all_questions.json + storage: + - gqa/annotations/train_all_questions_0.json + - gqa/annotations/val_all_questions.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json + storage: + - aokvqa/annotations/aokvqa_v1p0_val.json + - aokvqa/annotations/large_vocab_train_lavis.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/large_vocab_train_lavis.json + storage: + - aokvqa/annotations/aokvqa_v1p0_test.json + - aokvqa/annotations/large_vocab_train_lavis.json + images: + storage: gqa/images/ diff --git a/lavis/configs/datasets/imagenet/defaults.yaml b/lavis/configs/datasets/imagenet/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a20779b43df16508f096d2159db3357e8b3ee4d --- /dev/null +++ b/lavis/configs/datasets/imagenet/defaults.yaml @@ -0,0 +1,15 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + imagenet: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + splits: ["val"] + images: + storage: /export/share/datasets/vision/imagenet diff --git a/lavis/configs/datasets/laion/defaults_2B_multi.yaml b/lavis/configs/datasets/laion/defaults_2B_multi.yaml new file mode 100644 index 0000000000000000000000000000000000000000..279a736fa1b8a5fd79821a7732c4d3bd7c4d5214 --- /dev/null +++ b/lavis/configs/datasets/laion/defaults_2B_multi.yaml @@ -0,0 +1,13 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + laion2B_multi: + + data_type: images + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + storage: /export/laion/laion2B-multi/part-00000/{00000..01743}.tar diff --git a/lavis/configs/datasets/msrvtt/defaults_cap.yaml b/lavis/configs/datasets/msrvtt/defaults_cap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a3385b46931b640ec18113a20232010f3a11e233 --- /dev/null +++ b/lavis/configs/datasets/msrvtt/defaults_cap.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msrvtt_cap: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_train.json + storage: msrvtt/annotations/cap_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_val.json + storage: msrvtt/annotations/cap_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/cap_test.json + storage: msrvtt/annotations/cap_test.json + videos: + storage: msrvtt/videos diff --git a/lavis/configs/datasets/msrvtt/defaults_qa.yaml b/lavis/configs/datasets/msrvtt/defaults_qa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..df1c4ad0a79117604ea903d567948e15ec941382 --- /dev/null +++ b/lavis/configs/datasets/msrvtt/defaults_qa.yaml @@ -0,0 +1,27 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msrvtt_qa: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_train.json + storage: msrvtt/annotations/qa_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_val.json + storage: msrvtt/annotations/qa_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/qa_test.json + storage: msrvtt/annotations/qa_test.json + ans2label: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/train_ans2label.json + storage: msrvtt/annotations/qa_ans2label.json + videos: + storage: msrvtt/videos diff --git a/lavis/configs/datasets/msrvtt/defaults_ret.yaml b/lavis/configs/datasets/msrvtt/defaults_ret.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f0cc55b39c9ba69f3aa61f17e4c4519923db7df0 --- /dev/null +++ b/lavis/configs/datasets/msrvtt/defaults_ret.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msrvtt_retrieval: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_train.json + storage: msrvtt/annotations/retrieval_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_val.json + storage: msrvtt/annotations/retrieval_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msrvtt/retrieval_test.json + storage: msrvtt/annotations/retrieval_test.json + videos: + storage: msrvtt/videos diff --git a/lavis/configs/datasets/msvd/defaults_cap.yaml b/lavis/configs/datasets/msvd/defaults_cap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9e870bd4b6b045034aefe47dbb5f8cff9bdc45b --- /dev/null +++ b/lavis/configs/datasets/msvd/defaults_cap.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msvd_cap: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_train.json + storage: msvd/annotations/cap_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_val.json + storage: msvd/annotations/cap_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/cap_test.json + storage: msvd/annotations/cap_test.json + videos: + storage: msvd/videos diff --git a/lavis/configs/datasets/msvd/defaults_qa.yaml b/lavis/configs/datasets/msvd/defaults_qa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9b4bbbd3fb44bf0bb40d14c68bc07b825a541577 --- /dev/null +++ b/lavis/configs/datasets/msvd/defaults_qa.yaml @@ -0,0 +1,29 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msvd_qa: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_train.json + storage: msvd/annotations/qa_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_val.json + storage: msvd/annotations/qa_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/qa_test.json + storage: msvd/annotations/qa_test.json + ans2label: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/msvd/train_ans2label.json + storage: msvd/annotations/qa_ans2label.json + videos: + storage: msvd/videos + + instance_id_key: question_id diff --git a/lavis/configs/datasets/nlvr/defaults.yaml b/lavis/configs/datasets/nlvr/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96a985598259861e9d23b47d2dffcf7d06b22e69 --- /dev/null +++ b/lavis/configs/datasets/nlvr/defaults.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + nlvr: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_train.json + storage: nlvr/annotations/train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json + storage: nlvr/annotations/dev.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/nlvr/nlvr_dev.json + storage: nlvr/annotations/test.json + images: + storage: /export/share/datasets/vision/NLVR2/ diff --git a/lavis/configs/datasets/nocaps/defaults.yaml b/lavis/configs/datasets/nocaps/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..062b03c00e1bf24461e201a3dcfb9250a456b9d7 --- /dev/null +++ b/lavis/configs/datasets/nocaps/defaults.yaml @@ -0,0 +1,22 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + nocaps: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json + storage: nocaps/annotations/nocaps_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json + storage: nocaps/annotations/nocaps_test.json + images: + storage: nocaps/images + # storage: /export/share/datasets/vision/nocaps/ diff --git a/lavis/configs/datasets/okvqa/defaults.yaml b/lavis/configs/datasets/okvqa/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a442c9bb2f12951ae812b924b5b12a038bcce75d --- /dev/null +++ b/lavis/configs/datasets/okvqa/defaults.yaml @@ -0,0 +1,37 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + ok_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json + # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_train2014_questions.json + # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_train2014_annotations.json + storage: + - okvqa/annotations/okvqa_train.json + # - okvqa/annotations/OpenEnded_mscoco_train2014_questions.json + # - okvqa/annotations/mscoco_train2014_annotations.json + test: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json + storage: + - okvqa/annotations/vqa_val_eval.json + - okvqa/annotations/answer_list.json + - okvqa/annotations/OpenEnded_mscoco_val2014_questions.json + - okvqa/annotations/mscoco_val2014_annotations.json + images: + storage: coco/images/ diff --git a/lavis/configs/datasets/sbu_caption/defaults.yaml b/lavis/configs/datasets/sbu_caption/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6a5a22053831056a241556ed1a37321595f00794 --- /dev/null +++ b/lavis/configs/datasets/sbu_caption/defaults.yaml @@ -0,0 +1,22 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + sbu_caption: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/sbu/sbu.json + # - /export/share/dongxuli/data/lavis/sbu/annotation/sbu.json + storage: + - sbu_captions/annotations/sbu.json + images: + storage: sbu_captions/images + # storage: /export/share/datasets/vision_language/sbu_resize diff --git a/lavis/configs/datasets/snli_ve/defaults.yaml b/lavis/configs/datasets/snli_ve/defaults.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91b6cf7fd9b79b1d6a26ae25eed38cda61b83d01 --- /dev/null +++ b/lavis/configs/datasets/snli_ve/defaults.yaml @@ -0,0 +1,25 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + snli_ve: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: /export/share/dongxuli/data/lavis/snli/annotation/ve_train.json + storage: snli/annotations/ve_train.json + val: + url: /export/share/dongxuli/data/lavis/snli/annotation/ve_dev.json + storage: snli/annotations/ve_dev.json + test: + url: /export/share/dongxuli/data/lavis/snli/annotation/ve_test.json + storage: snli/annotations/ve_test.json + images: + storage: flickr30k/images/flickr30k-images + # storage: /export/share/datasets/vision/flickr30k/flickr30k-images diff --git a/lavis/configs/datasets/vatex/defaults_cap.yaml b/lavis/configs/datasets/vatex/defaults_cap.yaml new file mode 100644 index 0000000000000000000000000000000000000000..888f66d60b1c94dda9a314fc7193f1763c471b1d --- /dev/null +++ b/lavis/configs/datasets/vatex/defaults_cap.yaml @@ -0,0 +1,24 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + msvd_cap: # name of the dataset builder + # data_dir: ${env.data_dir}/datasets + data_type: videos # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_train.json + storage: vatex/annotations/cap_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_val.json + storage: vatex/annotations/cap_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vatex/cap_private_test.json + storage: vatex/annotations/cap_test.json + videos: + storage: /export/share/dongxuli/data/vatex diff --git a/lavis/configs/datasets/vg/defaults_caption.yaml b/lavis/configs/datasets/vg/defaults_caption.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ed303b58d8976ab5a4b1da7c234405a14d559fff --- /dev/null +++ b/lavis/configs/datasets/vg/defaults_caption.yaml @@ -0,0 +1,18 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + vg_caption: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_caption.json + storage: vg/annotations/vg_caption.json + images: + storage: vg/images/ diff --git a/lavis/configs/datasets/vg/defaults_vqa.yaml b/lavis/configs/datasets/vg/defaults_vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e12e5c860a0db616a80967f7515b47abedba519e --- /dev/null +++ b/lavis/configs/datasets/vg/defaults_vqa.yaml @@ -0,0 +1,18 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + vg_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/visual_genome/vg_qa.json + storage: vg/annotations/vg_qa.json + images: + storage: vg/images/ diff --git a/lavis/configs/default.yaml b/lavis/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f58d32e264250895ab02b3d2e78a2ba6dfd3c125 --- /dev/null +++ b/lavis/configs/default.yaml @@ -0,0 +1,10 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +env: + # For default users + # cache_root: "cache" + # For internal use with persistent storage + cache_root: "/export/home/.cache/lavis" diff --git a/lavis/configs/models/albef_classification_ve.yaml b/lavis/configs/models/albef_classification_ve.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a2accab99fad7e2a880944515baefab496b18a7 --- /dev/null +++ b/lavis/configs/models/albef_classification_ve.yaml @@ -0,0 +1,40 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_classification + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_snli_ve_lavis.pt" + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + + num_classes: 3 + + use_distill: True + momentum: 0.995 + alpha: 0.4 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_layer_norm_epsilon: 1e-6 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config_albef.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + eval: + name: "blip_image_eval" + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/albef_feature_extractor.yaml b/lavis/configs/models/albef_feature_extractor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7def58e04a7b567e0a836e54f3dffdc62e1748ee --- /dev/null +++ b/lavis/configs/models/albef_feature_extractor.yaml @@ -0,0 +1,30 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_pretrain + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + + # vit encoder + vit_type: "base" + image_size: 224 + vit_ckpt_layer: 0 + vit_drop_path_rate: 0 + vit_layer_norm_epsilon: 1e-6 + vit_grad_ckpt: False + + # bert config + med_config_path: "configs/models/med_config_albef.json" + + embed_dim: 256 + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/albef_nlvr.yaml b/lavis/configs/models/albef_nlvr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..86f17224aa0dfaa4739725e7c0516df4c679aa2d --- /dev/null +++ b/lavis/configs/models/albef_nlvr.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_nlvr + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/pretrain_model_nlvr.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_nlvr_lavis.pt" + + num_classes: 2 + + use_distill: True + momentum: 0.995 + alpha: 0.4 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_layer_norm_epsilon: 1e-6 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config_albef.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/albef_pretrain_base.yaml b/lavis/configs/models/albef_pretrain_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26e00efa423345b4a78332635d1a7c2e368fb02e --- /dev/null +++ b/lavis/configs/models/albef_pretrain_base.yaml @@ -0,0 +1,38 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_pretrain + + load_pretrained: True + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + + # vit encoder + vit_type: "base" + image_size: 224 + vit_ckpt_layer: 0 + vit_drop_path_rate: 0 + vit_layer_norm_epsilon: 1e-6 + vit_grad_ckpt: False + + # bert config + med_config_path: "configs/models/med_config_albef.json" + mlm_mask_prob: 0.15 + + embed_dim: 256 + momentum: 0.995 + alpha: 0.4 + temp: 0.07 + + max_txt_len: 30 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 256 + text_processor: + train: + name: "blip_caption" diff --git a/lavis/configs/models/albef_retrieval_coco.yaml b/lavis/configs/models/albef_retrieval_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9971e6ca5d9aa85790ee2aefd9b7251e8a8b200c --- /dev/null +++ b/lavis/configs/models/albef_retrieval_coco.yaml @@ -0,0 +1,46 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_retrieval + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_coco_retrieval_lavis.pt" + + queue_size: 65536 + + # vit encoder + vit_type: "base" + image_size: 384 + vit_ckpt_layer: 0 + vit_drop_path_rate: 0 + vit_layer_norm_epsilon: 1e-6 + vit_grad_ckpt: False + + # bert config + med_config_path: "configs/models/med_config_albef.json" + + embed_dim: 256 + momentum: 0.995 + alpha: 0.4 + temp: 0.07 + use_distill: True + + max_txt_len: 30 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/albef_retrieval_flickr.yaml b/lavis/configs/models/albef_retrieval_flickr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5f77f0f99912d0f2c501e567dd0360e5c2b9336 --- /dev/null +++ b/lavis/configs/models/albef_retrieval_flickr.yaml @@ -0,0 +1,46 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_retrieval + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + finetuned: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_flickr_retrieval_lavis.pt + + queue_size: 65536 + + # vit encoder + vit_type: "base" + image_size: 384 + vit_ckpt_layer: 0 + vit_drop_path_rate: 0 + vit_layer_norm_epsilon: 1e-6 + vit_grad_ckpt: False + + # bert config + med_config_path: "configs/models/med_config_albef.json" + + embed_dim: 256 + momentum: 0.995 + alpha: 0.4 + temp: 0.07 + use_distill: True + + max_txt_len: 30 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/albef_vqav2.yaml b/lavis/configs/models/albef_vqav2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e35559f356bd77f9eedaa76b43d393a142f40239 --- /dev/null +++ b/lavis/configs/models/albef_vqav2.yaml @@ -0,0 +1,40 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: albef_vqa + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-pcl-data-research/ALBEF/ALBEF.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALBEF/albef_vqav2_lavis.pt" + + use_distill: True + momentum: 0.995 + alpha: 0.4 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_layer_norm_epsilon: 1e-6 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config_albef.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" diff --git a/lavis/configs/models/alpro_qa_msrvtt.yaml b/lavis/configs/models/alpro_qa_msrvtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3f58a1308c0d2a2075c037f6defcd4500e29b1b --- /dev/null +++ b/lavis/configs/models/alpro_qa_msrvtt.yaml @@ -0,0 +1,44 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: alpro_qa + num_classes: 1500 + + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msrvtt_qa.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt" + + timesformer: + n_frms: 16 + image_size: 224 + + patch_size: 16 + attn_drop_rate: 0. + drop_rate: 0. + drop_path_rate: 0.1 + + use_grad_ckpt: True + ckpt_layer: 12 + + # bert config + med_config_path: "configs/models/bert_config_alpro.json" + +preprocess: + vis_processor: + train: + name: "alpro_video_train" + n_frms: 16 + image_size: 224 + eval: + name: "alpro_video_eval" + n_frms: 16 + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/alpro_qa_msvd.yaml b/lavis/configs/models/alpro_qa_msvd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17d606fcc0fd8fb8adedbb992db49f6e56e67c5f --- /dev/null +++ b/lavis/configs/models/alpro_qa_msvd.yaml @@ -0,0 +1,43 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: alpro_qa + num_classes: 2423 + + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msvd_qa.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt" + + timesformer: + n_frms: 16 + image_size: 224 + + patch_size: 16 + attn_drop_rate: 0. + drop_rate: 0. + drop_path_rate: 0.1 + use_grad_ckpt: True + ckpt_layer: 12 + + # bert config + med_config_path: "configs/models/bert_config_alpro.json" + +preprocess: + vis_processor: + train: + name: "alpro_video_train" + n_frms: 16 + image_size: 224 + eval: + name: "alpro_video_eval" + n_frms: 16 + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/alpro_retrieval_didemo.yaml b/lavis/configs/models/alpro_retrieval_didemo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd021c5a5d2e93e53e74ef4cf2a94bb921a6cd83 --- /dev/null +++ b/lavis/configs/models/alpro_retrieval_didemo.yaml @@ -0,0 +1,35 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: alpro_retrieval + + load_finetuned: True + + finetuned: https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_didemo_retrieval.pt + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt" + + timesformer: + n_frms: 8 + image_size: 224 + + patch_size: 16 + attn_drop_rate: 0. + drop_rate: 0. + drop_path_rate: 0.1 + use_grad_ckpt: False + + # bert config + med_config_path: "configs/models/bert_config_alpro.json" + +preprocess: + vis_processor: + eval: + name: "alpro_video_eval" + n_frms: 8 + image_size: 224 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/alpro_retrieval_msrvtt.yaml b/lavis/configs/models/alpro_retrieval_msrvtt.yaml new file mode 100644 index 0000000000000000000000000000000000000000..431aa3ea65f83a6213c88ae07465e0c1ff7cb3ea --- /dev/null +++ b/lavis/configs/models/alpro_retrieval_msrvtt.yaml @@ -0,0 +1,41 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: alpro_retrieval + + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_msrvtt_retrieval.pt" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/ALPRO/alpro_pretrain.pt" + + timesformer: + n_frms: 8 + image_size: 224 + + patch_size: 16 + attn_drop_rate: 0. + drop_rate: 0. + drop_path_rate: 0.1 + use_grad_ckpt: False + + # bert config + med_config_path: "configs/models/bert_config_alpro.json" + +preprocess: + vis_processor: + train: + name: "alpro_video_train" + n_frms: 8 + image_size: 224 + eval: + name: "alpro_video_eval" + n_frms: 8 + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/bert_config.json b/lavis/configs/models/bert_config.json new file mode 100644 index 0000000000000000000000000000000000000000..477a9f42513d0afb774735f07177161bdd1ae94b --- /dev/null +++ b/lavis/configs/models/bert_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "add_type_embeddings": false, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": true +} \ No newline at end of file diff --git a/lavis/configs/models/bert_config_alpro.json b/lavis/configs/models/bert_config_alpro.json new file mode 100644 index 0000000000000000000000000000000000000000..a21b3a2c9344651c1d88797338de5830ca3fc043 --- /dev/null +++ b/lavis/configs/models/bert_config_alpro.json @@ -0,0 +1,23 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "add_type_embeddings": true, + "type_vocab_size": 2, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": false, + "fusion_layer": 6 +} \ No newline at end of file diff --git a/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml b/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6591e15c1a5c9c6052a95caba26c2b635a842785 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: caption_coco_flant5xl + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_flant5xl.pth" + + # vit encoder + image_size: 364 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp32" + freeze_vit: False + + # Q-Former + num_query_token: 32 + + # T5 + t5_model: "google/flan-t5-xl" + + # generation configs + prompt: "a photo of" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 364 + eval: + name: "blip_image_eval" + image_size: 364 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml b/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5005fb72ada67d0e304483e5b98428f4be7c0236 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: caption_coco_opt2.7b + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt2.7b.pth" + + # vit encoder + image_size: 364 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp32" + freeze_vit: False + + # Q-Former + num_query_token: 32 + + # OPT + opt_model: "facebook/opt-2.7b" + + # generation configs + prompt: "a photo of" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 364 + eval: + name: "blip_image_eval" + image_size: 364 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml b/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..464da1bb28668f6aa9106b3aac44cb500f85d727 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: caption_coco_opt6.7b + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_caption_opt6.7b.pth" + + # vit encoder + image_size: 364 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp32" + freeze_vit: False + + # Q-Former + num_query_token: 32 + + # OPT + opt_model: "facebook/opt-6.7b" + + # generation configs + prompt: "a photo of" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 364 + eval: + name: "blip_image_eval" + image_size: 364 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_coco.yaml b/lavis/configs/models/blip2/blip2_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..03abc369b866db180c4e7bff8b00de637bc55cf0 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_coco.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: coco + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_finetune_coco.pth" + + # vit encoder + image_size: 364 + drop_path_rate: 0 + use_grad_checkpoint: True + vit_precision: "fp32" + freeze_vit: False + + # Q-Former + num_query_token: 32 + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 364 + eval: + name: "blip_image_eval" + image_size: 364 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain.yaml b/lavis/configs/models/blip2/blip2_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..126025ebaeb20ec88ebc2af61d16acd37843125d --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml b/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf90da225618de43a3b5fa70954b363227fcd804 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain_flant5xl + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # T5 + t5_model: "google/flan-t5-xl" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml b/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fca3e9a0aa053245d08d376594f75336ba0150b7 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml @@ -0,0 +1,43 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain_flant5xl + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xl_vitL.pth" + finetuned: "" + + # vit encoder + vit_model: "clip_L" + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # T5 + t5_model: "google/flan-t5-xl" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml b/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8240904d01dde5b1dfd74baca6bb83421d92ac3e --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain_flant5xxl + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # T5 + t5_model: "google/flan-t5-xxl" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml b/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6e0bccd3fa69814bbcc294bb0a28089f3a62e5a --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain_opt2.7b + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt2.7b.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # OPT + opt_model: "facebook/opt-2.7b" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml b/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89adbfe363272a90c5bc80fbdb8ca33f05e0033c --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain_opt6.7b + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_opt6.7b.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # OPT + opt_model: "facebook/opt-6.7b" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml b/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0a0fc6464abcfea3e08655e43e381c9456f62b5 --- /dev/null +++ b/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml @@ -0,0 +1,37 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain + load_finetuned: False + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_vitL.pth" + finetuned: "" + + # vit encoder + vit_model: "clip_L" + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_caption_base_coco.yaml b/lavis/configs/models/blip_caption_base_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ee481c234290fef7d74667c2ce3e8c66fc7a3ab --- /dev/null +++ b/lavis/configs/models/blip_caption_base_coco.yaml @@ -0,0 +1,38 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_caption + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_coco_caption_base.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" + + # generation configs + prompt: "a picture of " + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + eval: + name: "blip_image_eval" + text_processor: + train: + name: "blip_caption" + prompt: "a picture of " + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_caption_large_coco.yaml b/lavis/configs/models/blip_caption_large_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0e8ae93c3f5236aac93669c53db448d312aa5eb --- /dev/null +++ b/lavis/configs/models/blip_caption_large_coco.yaml @@ -0,0 +1,37 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_caption + load_finetuned: True + + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth" + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth" + + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + # generation configs + prompt: "a picture of " + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + eval: + name: "blip_image_eval" + text_processor: + train: + name: "blip_caption" + prompt: "a picture of " + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_classification_base.yaml b/lavis/configs/models/blip_classification_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bad38f200daeb3177dce269807ffada275e61ac3 --- /dev/null +++ b/lavis/configs/models/blip_classification_base.yaml @@ -0,0 +1,22 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_classification + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth" + + use_distill: True + momentum: 0.995 + alpha: 0.4 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" diff --git a/lavis/configs/models/blip_feature_extractor_base.yaml b/lavis/configs/models/blip_feature_extractor_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eaee381415c9eb7e0bf787ad5cf9b61bf2690489 --- /dev/null +++ b/lavis/configs/models/blip_feature_extractor_base.yaml @@ -0,0 +1,29 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_pretrain + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 224 + + # bert config + med_config_path: "configs/models/med_config.json" + + embed_dim: 256 + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_itm_base.yaml b/lavis/configs/models/blip_itm_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9c79db89d3cb55575b5f4b8aa499859c5915b183 --- /dev/null +++ b/lavis/configs/models/blip_itm_base.yaml @@ -0,0 +1,31 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_image_text_matching + + load_finetuned: True + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" + + embed_dim: 256 + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_itm_large.yaml b/lavis/configs/models/blip_itm_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9bcbf4850d2eb159c506e52a8fa88de59d3a87d7 --- /dev/null +++ b/lavis/configs/models/blip_itm_large.yaml @@ -0,0 +1,31 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_image_text_matching + + load_finetuned: True + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth" + + # vit encoder + vit_type: "large" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_nlvr.yaml b/lavis/configs/models/blip_nlvr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02ecb13f11bdd02b161633d0d8c3c74eab64ba21 --- /dev/null +++ b/lavis/configs/models/blip_nlvr.yaml @@ -0,0 +1,39 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_nlvr + model_type: nlvr + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + num_classes: 2 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_layer_norm_epsilon: 1e-6 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_pretrain_base.yaml b/lavis/configs/models/blip_pretrain_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e265b832a618304d50e17a9dbf242bfe4df720db --- /dev/null +++ b/lavis/configs/models/blip_pretrain_base.yaml @@ -0,0 +1,35 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_pretrain + + load_pretrained: True + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 224 + alpha: 0.4 + + # bert config + med_config_path: "configs/models/bert_config.json" + + embed_dim: 256 + + # generation configs + prompt: "a picture of " + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" diff --git a/lavis/configs/models/blip_pretrain_large.yaml b/lavis/configs/models/blip_pretrain_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d01cbe3baf09dd118d3e127c1ce1d8e3ea2238a6 --- /dev/null +++ b/lavis/configs/models/blip_pretrain_large.yaml @@ -0,0 +1,22 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_pretrain + + # vit encoder + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 224 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + + # generation configs + prompt: "a picture of " diff --git a/lavis/configs/models/blip_retrieval_coco.yaml b/lavis/configs/models/blip_retrieval_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..30eb79028f12266224e5286e563381ba963bd756 --- /dev/null +++ b/lavis/configs/models/blip_retrieval_coco.yaml @@ -0,0 +1,39 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_retrieval + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_coco_retrieval.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + queue_size: 57600 + + # vit encoder + vit_type: "base" + vit_grad_ckpt: True + vit_ckpt_layer: 4 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" + + embed_dim: 256 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_retrieval_flickr.yaml b/lavis/configs/models/blip_retrieval_flickr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4bf1fbc2db796a3ce0f08dfa357fe982856d8a0 --- /dev/null +++ b/lavis/configs/models/blip_retrieval_flickr.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_retrieval + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_flickr_retrieval.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + queue_size: 57600 + alpha: 0.4 + + negative_all_rank: False + + # vit encoder + vit_type: "base" + vit_grad_ckpt: True + vit_ckpt_layer: 4 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_config.json" + + embed_dim: 256 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 384 + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/lavis/configs/models/blip_vqa_aokvqa.yaml b/lavis/configs/models/blip_vqa_aokvqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b3afe3e7a2e3a55c569a8c7fce3d83d1ef3ddabe --- /dev/null +++ b/lavis/configs/models/blip_vqa_aokvqa.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_vqa + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_aokvqa.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_drop_path_rate: 0.1 + + image_size: 480 + + # bert config + med_config_path: "configs/models/med_config.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 480 + eval: + name: "blip_image_eval" + image_size: 480 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" diff --git a/lavis/configs/models/blip_vqa_okvqa.yaml b/lavis/configs/models/blip_vqa_okvqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb66ccbbf1f2faed4dfe916b042263861798d951 --- /dev/null +++ b/lavis/configs/models/blip_vqa_okvqa.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_vqa + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP/blip_okvqa.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_drop_path_rate: 0.1 + + image_size: 480 + + # bert config + med_config_path: "configs/models/med_config.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 480 + eval: + name: "blip_image_eval" + image_size: 480 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" diff --git a/lavis/configs/models/blip_vqav2.yaml b/lavis/configs/models/blip_vqav2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4f0ce8daac2d23d47d342f17630ca86f7002cc50 --- /dev/null +++ b/lavis/configs/models/blip_vqav2.yaml @@ -0,0 +1,36 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip_vqa + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth" + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + # vit encoder + vit_type: "base" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + vit_drop_path_rate: 0.1 + + image_size: 480 + + # bert config + med_config_path: "configs/models/med_config.json" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 480 + eval: + name: "blip_image_eval" + image_size: 480 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" diff --git a/lavis/configs/models/clip/RN101-quickgelu.json b/lavis/configs/models/clip/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..1dbd19be9d289887b4e41bd50acdbdc78709efd3 --- /dev/null +++ b/lavis/configs/models/clip/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/RN101.json b/lavis/configs/models/clip/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..bf5babbc5a3ef48653083f10a549f42afe14727a --- /dev/null +++ b/lavis/configs/models/clip/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/RN50-quickgelu.json b/lavis/configs/models/clip/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/lavis/configs/models/clip/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/RN50.json b/lavis/configs/models/clip/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..ad98b4b8822d72b5196ddafcb732329ecad2ce56 --- /dev/null +++ b/lavis/configs/models/clip/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/RN50x16.json b/lavis/configs/models/clip/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..66576383a0cbd2ffcdd7a050e5fcbab420c7fecb --- /dev/null +++ b/lavis/configs/models/clip/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/RN50x4.json b/lavis/configs/models/clip/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..a41cb630517cc155c1ee6aa8660f6c7948f3ee4b --- /dev/null +++ b/lavis/configs/models/clip/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-16-plus-240.json b/lavis/configs/models/clip/ViT-B-16-plus-240.json new file mode 100644 index 0000000000000000000000000000000000000000..9347280c60a2a19233ac027d810ded21c26ea867 --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-16-plus-240.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 240, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-16-plus.json b/lavis/configs/models/clip/ViT-B-16-plus.json new file mode 100644 index 0000000000000000000000000000000000000000..f9cc3e3b0084590581d1ec3e81b930a9a190e036 --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-16-plus.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 896, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-16.json b/lavis/configs/models/clip/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..9afeef0fbc807f130f2b2bc65c1dd85abc9eba72 --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-32-plus-256.json b/lavis/configs/models/clip/ViT-B-32-plus-256.json new file mode 100644 index 0000000000000000000000000000000000000000..27ae13857a0bdf0c7825ba7768de0071bda3e82e --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-32-plus-256.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 256, + "layers": 12, + "width": 896, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-32-quickgelu.json b/lavis/configs/models/clip/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..f5a063adbf96df9e169706286643ab9a261b251c --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-B-32.json b/lavis/configs/models/clip/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..abd1f7973dc856ba56004ad0538f4f74f5e08a6d --- /dev/null +++ b/lavis/configs/models/clip/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-H-14.json b/lavis/configs/models/clip/ViT-H-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d2c01733dcab1293858bf8aa200f05cdb0b6f56c --- /dev/null +++ b/lavis/configs/models/clip/ViT-H-14.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} diff --git a/lavis/configs/models/clip/ViT-H-16.json b/lavis/configs/models/clip/ViT-H-16.json new file mode 100644 index 0000000000000000000000000000000000000000..942ed56bf6e24a0c19a41fad87db304444402b4f --- /dev/null +++ b/lavis/configs/models/clip/ViT-H-16.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} diff --git a/lavis/configs/models/clip/ViT-L-14-280.json b/lavis/configs/models/clip/ViT-L-14-280.json new file mode 100644 index 0000000000000000000000000000000000000000..c8e5fbac8a14c4c66c57df166ffe5dceb188e436 --- /dev/null +++ b/lavis/configs/models/clip/ViT-L-14-280.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 280, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-L-14-336.json b/lavis/configs/models/clip/ViT-L-14-336.json new file mode 100644 index 0000000000000000000000000000000000000000..4db3a1e77c891cda4d32ea3b9da9bef2c2aade0c --- /dev/null +++ b/lavis/configs/models/clip/ViT-L-14-336.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-L-14.json b/lavis/configs/models/clip/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..98951b0cbff3776e90b0c2685ce4d04f1f874343 --- /dev/null +++ b/lavis/configs/models/clip/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-L-16-320.json b/lavis/configs/models/clip/ViT-L-16-320.json new file mode 100644 index 0000000000000000000000000000000000000000..cc09c4877d27597fb0f50332e7cbcf8028586ce2 --- /dev/null +++ b/lavis/configs/models/clip/ViT-L-16-320.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 320, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-L-16.json b/lavis/configs/models/clip/ViT-L-16.json new file mode 100644 index 0000000000000000000000000000000000000000..78601e7a6822382e3466c1c00459392ee7768024 --- /dev/null +++ b/lavis/configs/models/clip/ViT-L-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/ViT-g-14.json b/lavis/configs/models/clip/ViT-g-14.json new file mode 100644 index 0000000000000000000000000000000000000000..b5c4231a67a82d1c30b675719f3004daed84299b --- /dev/null +++ b/lavis/configs/models/clip/ViT-g-14.json @@ -0,0 +1,18 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": 40, + "width": 1408, + "head_width": 88, + "mlp_ratio": 4.3637, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} diff --git a/lavis/configs/models/clip/timm-efficientnetv2_rw_s.json b/lavis/configs/models/clip/timm-efficientnetv2_rw_s.json new file mode 100644 index 0000000000000000000000000000000000000000..fa4bfb1df0240d72552e7b09dd4d17ee48a1c0e6 --- /dev/null +++ b/lavis/configs/models/clip/timm-efficientnetv2_rw_s.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "timm_model_name": "efficientnetv2_rw_s", + "timm_model_pretrained": false, + "timm_pool": "abs_attn", + "timm_proj": "", + "image_size": 288 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-resnet50d.json b/lavis/configs/models/clip/timm-resnet50d.json new file mode 100644 index 0000000000000000000000000000000000000000..7bb0957cd23e3dd0fb461764c959a75e04cae743 --- /dev/null +++ b/lavis/configs/models/clip/timm-resnet50d.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "resnet50d", + "timm_model_pretrained": false, + "timm_pool": "abs_attn", + "timm_proj": "", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-resnetaa50d.json b/lavis/configs/models/clip/timm-resnetaa50d.json new file mode 100644 index 0000000000000000000000000000000000000000..c011e0c02b5d63b1ace51e4625d383adc6aedb50 --- /dev/null +++ b/lavis/configs/models/clip/timm-resnetaa50d.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "resnetaa50d", + "timm_model_pretrained": false, + "timm_pool": "abs_attn", + "timm_proj": "", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-resnetblur50.json b/lavis/configs/models/clip/timm-resnetblur50.json new file mode 100644 index 0000000000000000000000000000000000000000..05d0b209ac44198bd0b45c6931dee71eac9b1eab --- /dev/null +++ b/lavis/configs/models/clip/timm-resnetblur50.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "timm_model_name": "resnetblur50", + "timm_model_pretrained": false, + "timm_pool": "abs_attn", + "timm_proj": "", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-swin_base_patch4_window7_224.json b/lavis/configs/models/clip/timm-swin_base_patch4_window7_224.json new file mode 100644 index 0000000000000000000000000000000000000000..bc08f2b78543857445d22eec7d288c5fe86391a9 --- /dev/null +++ b/lavis/configs/models/clip/timm-swin_base_patch4_window7_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "swin_base_patch4_window7_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-vit_base_patch16_224.json b/lavis/configs/models/clip/timm-vit_base_patch16_224.json new file mode 100644 index 0000000000000000000000000000000000000000..133b88f2f919de44c19df8318c7297824accbdce --- /dev/null +++ b/lavis/configs/models/clip/timm-vit_base_patch16_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_base_patch16_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-vit_base_patch32_224.json b/lavis/configs/models/clip/timm-vit_base_patch32_224.json new file mode 100644 index 0000000000000000000000000000000000000000..9dcc6ffbfda4fb9d206bb693f6c3d53f2757aff8 --- /dev/null +++ b/lavis/configs/models/clip/timm-vit_base_patch32_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_base_patch32_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip/timm-vit_small_patch16_224.json b/lavis/configs/models/clip/timm-vit_small_patch16_224.json new file mode 100644 index 0000000000000000000000000000000000000000..8c3ae01ab318ce07c19b7b6326c07aaec1f321a4 --- /dev/null +++ b/lavis/configs/models/clip/timm-vit_small_patch16_224.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "timm_model_name": "vit_small_patch16_224", + "timm_model_pretrained": false, + "timm_pool": "", + "timm_proj": "linear", + "image_size": 224 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/lavis/configs/models/clip_resnet50.yaml b/lavis/configs/models/clip_resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce3a2d429646b4b58706715d07da0ecb6c0d767b --- /dev/null +++ b/lavis/configs/models/clip_resnet50.yaml @@ -0,0 +1,11 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: clip + + model_type: RN50 + + pretrained: openai diff --git a/lavis/configs/models/clip_vit_base16.yaml b/lavis/configs/models/clip_vit_base16.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a06fa180993c42e63cecee38ec01134c18de7c8 --- /dev/null +++ b/lavis/configs/models/clip_vit_base16.yaml @@ -0,0 +1,17 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: clip + + model_type: ViT-B-16 + + pretrained: openai + +preprocess: + vis_processor: + eval: + name: "clip_image_eval" + image_size: 224 diff --git a/lavis/configs/models/clip_vit_base32.yaml b/lavis/configs/models/clip_vit_base32.yaml new file mode 100644 index 0000000000000000000000000000000000000000..056e3d967853f5c01426514a9f98622bc92241b8 --- /dev/null +++ b/lavis/configs/models/clip_vit_base32.yaml @@ -0,0 +1,52 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: clip + + model_type: ViT-B-32 +# ['RN50', +# 'RN50-quickgelu', +# 'RN50x4', +# 'RN50x16', +# 'RN101', +# 'RN101-quickgelu', +# 'timm-efficientnetv2_rw_s', +# 'timm-resnet50d', +# 'timm-resnetaa50d', +# 'timm-resnetblur50', +# 'timm-swin_base_patch4_window7_224', +# 'timm-vit_base_patch16_224', +# 'timm-vit_base_patch32_224', +# 'timm-vit_small_patch16_224', +# 'ViT-B-16', +# 'ViT-B-16-plus', +# 'ViT-B-16-plus-240', +# 'ViT-B-32', +# 'ViT-B-32-plus-256', +# 'ViT-B-32-quickgelu', +# 'ViT-g-14', +# 'ViT-H-14', +# 'ViT-H-16', +# 'ViT-L-14', +# 'ViT-L-14-280', +# 'ViT-L-14-336', +# 'ViT-L-16', +# 'ViT-L-16-320'] + + pretrained: openai + # "openai" + # following not available for all models + # "yfcc15m" + # "cc12m" + # "laion400m_e31" + # "laion400m_e32" + # "laion400m_avg" + +preprocess: + vis_processor: + eval: + name: "clip_image_eval" + image_size: 224 diff --git a/lavis/configs/models/clip_vit_large14.yaml b/lavis/configs/models/clip_vit_large14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8ab9f2610f1ae9e0164f39565a8302ab33123548 --- /dev/null +++ b/lavis/configs/models/clip_vit_large14.yaml @@ -0,0 +1,52 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: clip + + model_type: ViT-L-14 +# ['RN50', +# 'RN50-quickgelu', +# 'RN50x4', +# 'RN50x16', +# 'RN101', +# 'RN101-quickgelu', +# 'timm-efficientnetv2_rw_s', +# 'timm-resnet50d', +# 'timm-resnetaa50d', +# 'timm-resnetblur50', +# 'timm-swin_base_patch4_window7_224', +# 'timm-vit_base_patch16_224', +# 'timm-vit_base_patch32_224', +# 'timm-vit_small_patch16_224', +# 'ViT-B-16', +# 'ViT-B-16-plus', +# 'ViT-B-16-plus-240', +# 'ViT-B-32', +# 'ViT-B-32-plus-256', +# 'ViT-B-32-quickgelu', +# 'ViT-g-14', +# 'ViT-H-14', +# 'ViT-H-16', +# 'ViT-L-14', +# 'ViT-L-14-280', +# 'ViT-L-14-336', +# 'ViT-L-16', +# 'ViT-L-16-320'] + + pretrained: openai + # "openai" + # following not available for all models + # "yfcc15m" + # "cc12m" + # "laion400m_e31" + # "laion400m_e32" + # "laion400m_avg" + +preprocess: + vis_processor: + eval: + name: "clip_image_eval" + image_size: 224 diff --git a/lavis/configs/models/clip_vit_large14_336.yaml b/lavis/configs/models/clip_vit_large14_336.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6510d73763fd4f0e5c6512c10c5c0ad8242499b --- /dev/null +++ b/lavis/configs/models/clip_vit_large14_336.yaml @@ -0,0 +1,52 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: clip + + model_type: ViT-L-14-336 +# ['RN50', +# 'RN50-quickgelu', +# 'RN50x4', +# 'RN50x16', +# 'RN101', +# 'RN101-quickgelu', +# 'timm-efficientnetv2_rw_s', +# 'timm-resnet50d', +# 'timm-resnetaa50d', +# 'timm-resnetblur50', +# 'timm-swin_base_patch4_window7_224', +# 'timm-vit_base_patch16_224', +# 'timm-vit_base_patch32_224', +# 'timm-vit_small_patch16_224', +# 'ViT-B-16', +# 'ViT-B-16-plus', +# 'ViT-B-16-plus-240', +# 'ViT-B-32', +# 'ViT-B-32-plus-256', +# 'ViT-B-32-quickgelu', +# 'ViT-g-14', +# 'ViT-H-14', +# 'ViT-H-16', +# 'ViT-L-14', +# 'ViT-L-14-280', +# 'ViT-L-14-336', +# 'ViT-L-16', +# 'ViT-L-16-320'] + + pretrained: openai + # "openai" + # following not available for all models + # "yfcc15m" + # "cc12m" + # "laion400m_e31" + # "laion400m_e32" + # "laion400m_avg" + +preprocess: + vis_processor: + eval: + name: "clip_image_eval" + image_size: 336 diff --git a/lavis/configs/models/gpt_dialogue_base.yaml b/lavis/configs/models/gpt_dialogue_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bbdae83fbe10b7e7d9001292eb88ba3da4e2e04 --- /dev/null +++ b/lavis/configs/models/gpt_dialogue_base.yaml @@ -0,0 +1,25 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: gpt_dialogue + # pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth" + # pretrained: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth" + + len_tokenizer: 50264 # 50257 tokens from gpt2 default tokenizer + additional special tokens + + len_video_ft: 4224 # i3d_rgb: 2048 i3d_flow: 2048 vggish: 128 + +preprocess: + vis_processor: + train: + name: "gpt_video_ft" + eval: + name: "gpt_video_ft" + text_processor: + train: + name: "gpt_dialogue" + eval: + name: "gpt_dialogue" \ No newline at end of file diff --git a/lavis/configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml b/lavis/configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fac355c4312bf54d3d87057d9bc7d665f1f03a06 --- /dev/null +++ b/lavis/configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml @@ -0,0 +1,58 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: img2prompt_vqa + model_type: base + + image_question_matching_model: + arch: blip_image_text_matching + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth" + + # vit encoder + vit_type: "large" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + + image_captioning_model: + arch: blip_caption + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth" + + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + # generation configs + prompt: "a picture of " + + question_generation_moodel: + pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/img2prompt/T5_large_QG.pth" + + + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/med_config.json b/lavis/configs/models/med_config.json new file mode 100644 index 0000000000000000000000000000000000000000..a566c17bbc185f5bf8b83c7ed7dcb02e1a0ba1f9 --- /dev/null +++ b/lavis/configs/models/med_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "add_type_embeddings": false, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true +} \ No newline at end of file diff --git a/lavis/configs/models/med_config_albef.json b/lavis/configs/models/med_config_albef.json new file mode 100644 index 0000000000000000000000000000000000000000..529636d733bf35cdb82ec4c7950ede79a5ce80fc --- /dev/null +++ b/lavis/configs/models/med_config_albef.json @@ -0,0 +1,22 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "add_type_embeddings": false, + "vocab_size": 30522, + "encoder_width": 768, + "add_cross_attention": true, + "fusion_layer": 6 +} \ No newline at end of file diff --git a/lavis/configs/models/med_large_config.json b/lavis/configs/models/med_large_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d5090b06f13c6c1e42d91e30d2cd76c2b6264d3a --- /dev/null +++ b/lavis/configs/models/med_large_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "add_type_embeddings": false, + "vocab_size": 30524, + "encoder_width": 1024, + "add_cross_attention": true +} \ No newline at end of file diff --git a/lavis/configs/models/pnp-vqa/pnp_vqa_3b.yaml b/lavis/configs/models/pnp-vqa/pnp_vqa_3b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31f43778865db534e0070249db1512f50d937238 --- /dev/null +++ b/lavis/configs/models/pnp-vqa/pnp_vqa_3b.yaml @@ -0,0 +1,60 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pnp_vqa + model_type: 3b + + image_question_matching_model: + arch: blip_image_text_matching + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth" + + # vit encoder + vit_type: "large" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + + image_captioning_model: + arch: blip_caption + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth" + + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + # generation configs + prompt: "a picture of " + + question_answering_model: + arch: pnp_unifiedqav2_fid + + pretrained: "allenai/unifiedqa-v2-t5-3b-1363200" + + t5_config_path: "configs/models/pnp-vqa/unifiedqav2_3b_config.json" + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/pnp-vqa/pnp_vqa_base.yaml b/lavis/configs/models/pnp-vqa/pnp_vqa_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5630578bbe24f4788396fbe40ae365580911d1aa --- /dev/null +++ b/lavis/configs/models/pnp-vqa/pnp_vqa_base.yaml @@ -0,0 +1,59 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pnp_vqa + model_type: base + + image_question_matching_model: + arch: blip_image_text_matching + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth" + + # vit encoder + vit_type: "large" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + + image_captioning_model: + arch: blip_caption + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth" + + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + # generation configs + prompt: "a picture of " + question_answering_model: + arch: pnp_unifiedqav2_fid + + pretrained: "allenai/unifiedqa-v2-t5-base-1363200" + + t5_config_path: "configs/models/pnp-vqa/unifiedqav2_base_config.json" + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/pnp-vqa/pnp_vqa_large.yaml b/lavis/configs/models/pnp-vqa/pnp_vqa_large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bea044c9079c33a7f7ec3a31c13f2da311d042e0 --- /dev/null +++ b/lavis/configs/models/pnp-vqa/pnp_vqa_large.yaml @@ -0,0 +1,60 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pnp_vqa + model_type: large + + image_question_matching_model: + arch: blip_image_text_matching + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco_train2014.pth" + + # vit encoder + vit_type: "large" + vit_grad_ckpt: False + vit_ckpt_layer: 0 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + embed_dim: 256 + + image_captioning_model: + arch: blip_caption + load_finetuned: True + + finetuned: "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption_coco_train2014.pth" + + vit_type: "large" + vit_grad_ckpt: True + vit_ckpt_layer: 5 + + image_size: 384 + + # bert config + med_config_path: "configs/models/med_large_config.json" + + # generation configs + prompt: "a picture of " + + question_answering_model: + arch: pnp_unifiedqav2_fid + + pretrained: "allenai/unifiedqa-v2-t5-large-1363200" + + t5_config_path: "configs/models/pnp-vqa/unifiedqav2_large_config.json" + +preprocess: + vis_processor: + eval: + name: "blip_image_eval" + image_size: 384 + text_processor: + eval: + name: "blip_caption" diff --git a/lavis/configs/models/pnp-vqa/unifiedqav2_3b_config.json b/lavis/configs/models/pnp-vqa/unifiedqav2_3b_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e5220dc592c03afd94f1a9d2077a2a87a3320856 --- /dev/null +++ b/lavis/configs/models/pnp-vqa/unifiedqav2_3b_config.json @@ -0,0 +1,60 @@ +{ + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 16384, + "d_kv": 128, + "d_model": 1024, + "decoder_start_token_id": 0, + "dense_act_fn": "relu", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 24, + "num_heads": 32, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "torch_dtype": "float32", + "transformers_version": "4.21.3", + "use_cache": true, + "vocab_size": 32128 +} \ No newline at end of file diff --git a/lavis/configs/models/pnp-vqa/unifiedqav2_base_config.json b/lavis/configs/models/pnp-vqa/unifiedqav2_base_config.json new file mode 100644 index 0000000000000000000000000000000000000000..24ffa8d18a0f317f3c18e5c67bf97ede953d6436 --- /dev/null +++ b/lavis/configs/models/pnp-vqa/unifiedqav2_base_config.json @@ -0,0 +1,59 @@ +{ + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 3072, + "d_kv": 64, + "d_model": 768, + "decoder_start_token_id": 0, + "dense_act_fn": "relu", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 12, + "num_heads": 12, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "transformers_version": "4.21.3", + "use_cache": true, + "vocab_size": 32128 +} \ No newline at end of file diff --git a/lavis/configs/models/pnp-vqa/unifiedqav2_large_config.json b/lavis/configs/models/pnp-vqa/unifiedqav2_large_config.json new file mode 100644 index 0000000000000000000000000000000000000000..4f87ec69734d35cdc0d76b1b3f11f9e80df3cdc1 --- /dev/null +++ b/lavis/configs/models/pnp-vqa/unifiedqav2_large_config.json @@ -0,0 +1,59 @@ +{ + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 4096, + "d_kv": 64, + "d_model": 1024, + "decoder_start_token_id": 0, + "dense_act_fn": "relu", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_decoder_layers": 24, + "num_heads": 16, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "transformers_version": "4.21.3", + "use_cache": true, + "vocab_size": 32128 +} \ No newline at end of file diff --git a/lavis/datasets/builders/__init__.py b/lavis/datasets/builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43d540d6217fb8487f2ba4fa6754e5250063cda6 --- /dev/null +++ b/lavis/datasets/builders/__init__.py @@ -0,0 +1,118 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.datasets.builders.base_dataset_builder import load_dataset_config +from lavis.datasets.builders.caption_builder import ( + COCOCapBuilder, + MSRVTTCapBuilder, + MSVDCapBuilder, + VATEXCapBuilder, +) +from lavis.datasets.builders.image_text_pair_builder import ( + ConceptualCaption12MBuilder, + ConceptualCaption3MBuilder, + VGCaptionBuilder, + SBUCaptionBuilder, +) +from lavis.datasets.builders.classification_builder import ( + NLVRBuilder, + SNLIVisualEntailmentBuilder, +) +from lavis.datasets.builders.imagefolder_builder import ImageNetBuilder +from lavis.datasets.builders.video_qa_builder import MSRVTTQABuilder, MSVDQABuilder +from lavis.datasets.builders.vqa_builder import ( + COCOVQABuilder, + OKVQABuilder, + VGVQABuilder, + GQABuilder, +) +from lavis.datasets.builders.retrieval_builder import ( + MSRVTTRetrievalBuilder, + DiDeMoRetrievalBuilder, + COCORetrievalBuilder, + Flickr30kBuilder, +) +from lavis.datasets.builders.dialogue_builder import AVSDDialBuilder + +from lavis.common.registry import registry + +__all__ = [ + "COCOCapBuilder", + "COCORetrievalBuilder", + "COCOVQABuilder", + "ConceptualCaption12MBuilder", + "ConceptualCaption3MBuilder", + "DiDeMoRetrievalBuilder", + "Flickr30kBuilder", + "GQABuilder", + "ImageNetBuilder", + "MSRVTTCapBuilder", + "MSRVTTQABuilder", + "MSRVTTRetrievalBuilder", + "MSVDCapBuilder", + "MSVDQABuilder", + "NLVRBuilder", + "OKVQABuilder", + "SBUCaptionBuilder", + "SNLIVisualEntailmentBuilder", + "VATEXCapBuilder", + "VGCaptionBuilder", + "VGVQABuilder", + "AVSDDialBuilder", +] + + +def load_dataset(name, cfg_path=None, vis_path=None, data_type=None): + """ + Example + + >>> dataset = load_dataset("coco_caption", cfg=None) + >>> splits = dataset.keys() + >>> print([len(dataset[split]) for split in splits]) + + """ + if cfg_path is None: + cfg = None + else: + cfg = load_dataset_config(cfg_path) + + try: + builder = registry.get_builder_class(name)(cfg) + except TypeError: + print( + f"Dataset {name} not found. Available datasets:\n" + + ", ".join([str(k) for k in dataset_zoo.get_names()]) + ) + exit(1) + + if vis_path is not None: + if data_type is None: + # use default data type in the config + data_type = builder.config.data_type + + assert ( + data_type in builder.config.build_info + ), f"Invalid data_type {data_type} for {name}." + + builder.config.build_info.get(data_type).storage = vis_path + + dataset = builder.build_datasets() + return dataset + + +class DatasetZoo: + def __init__(self) -> None: + self.dataset_zoo = { + k: list(v.DATASET_CONFIG_DICT.keys()) + for k, v in sorted(registry.mapping["builder_name_mapping"].items()) + } + + def get_names(self): + return list(self.dataset_zoo.keys()) + + +dataset_zoo = DatasetZoo() diff --git a/lavis/datasets/builders/base_dataset_builder.py b/lavis/datasets/builders/base_dataset_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..e233a4c766011e529a52c44baf68c514cbbb0388 --- /dev/null +++ b/lavis/datasets/builders/base_dataset_builder.py @@ -0,0 +1,234 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +import shutil +import warnings + +import lavis.common.utils as utils +import torch.distributed as dist +from lavis.common.dist_utils import is_dist_avail_and_initialized, is_main_process +from lavis.common.registry import registry +from lavis.datasets.data_utils import extract_archive +from lavis.processors.base_processor import BaseProcessor +from omegaconf import OmegaConf +from torchvision.datasets.utils import download_url + + +class BaseDatasetBuilder: + train_dataset_cls, eval_dataset_cls = None, None + + def __init__(self, cfg=None): + super().__init__() + + if cfg is None: + # help to create datasets from default config. + self.config = load_dataset_config(self.default_config_path()) + elif isinstance(cfg, str): + self.config = load_dataset_config(cfg) + else: + # when called from task.build_dataset() + self.config = cfg + + self.data_type = self.config.data_type + + self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()} + + def build_datasets(self): + # download, split, etc... + # only called on 1 GPU/TPU in distributed + + if is_main_process(): + self._download_data() + + if is_dist_avail_and_initialized(): + dist.barrier() + + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("Building datasets...") + datasets = self.build() # dataset['train'/'val'/'test'] + + return datasets + + def build_processors(self): + vis_proc_cfg = self.config.get("vis_processor") + txt_proc_cfg = self.config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + + self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg) + self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + + self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg) + self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg) + + @staticmethod + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else None + ) + + @classmethod + def default_config_path(cls, type="default"): + return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type]) + + def _download_data(self): + self._download_ann() + self._download_vis() + + def _download_ann(self): + """ + Download annotation files if necessary. + All the vision-language datasets should have annotations of unified format. + + storage_path can be: + (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative. + (2) basename/dirname: will be suffixed with base name of URL if dirname is provided. + + Local annotation paths should be relative. + """ + anns = self.config.build_info.annotations + + splits = anns.keys() + + cache_root = registry.get_path("cache_root") + + for split in splits: + info = anns[split] + + urls, storage_paths = info.get("url", None), info.storage + + if isinstance(urls, str): + urls = [urls] + if isinstance(storage_paths, str): + storage_paths = [storage_paths] + + assert len(urls) == len(storage_paths) + + for url_or_filename, storage_path in zip(urls, storage_paths): + # if storage_path is relative, make it full by prefixing with cache_root. + if not os.path.isabs(storage_path): + storage_path = os.path.join(cache_root, storage_path) + + dirname = os.path.dirname(storage_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + + if os.path.isfile(url_or_filename): + src, dst = url_or_filename, storage_path + if not os.path.exists(dst): + shutil.copyfile(src=src, dst=dst) + else: + logging.info("Using existing file {}.".format(dst)) + else: + if os.path.isdir(storage_path): + # if only dirname is provided, suffix with basename of URL. + raise ValueError( + "Expecting storage_path to be a file path, got directory {}".format( + storage_path + ) + ) + else: + filename = os.path.basename(storage_path) + + download_url(url=url_or_filename, root=dirname, filename=filename) + + def _download_vis(self): + + storage_path = self.config.build_info.get(self.data_type).storage + storage_path = utils.get_cache_path(storage_path) + + if not os.path.exists(storage_path): + warnings.warn( + f""" + The specified path {storage_path} for visual inputs does not exist. + Please provide a correct path to the visual inputs or + refer to datasets/download_scripts/README.md for downloading instructions. + """ + ) + + def build(self): + """ + Create by split datasets inheriting torch.utils.data.Datasets. + + # build() can be dataset-specific. Overwrite to customize. + """ + self.build_processors() + + build_info = self.config.build_info + + ann_info = build_info.annotations + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in ann_info.keys(): + if split not in ["train", "val", "test"]: + continue + + is_train = split == "train" + + # processors + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + text_processor = ( + self.text_processors["train"] + if is_train + else self.text_processors["eval"] + ) + + # annotation path + ann_paths = ann_info.get(split).storage + if isinstance(ann_paths, str): + ann_paths = [ann_paths] + + abs_ann_paths = [] + for ann_path in ann_paths: + if not os.path.isabs(ann_path): + ann_path = utils.get_cache_path(ann_path) + abs_ann_paths.append(ann_path) + ann_paths = abs_ann_paths + + # visual data storage path + vis_path = vis_info.storage + + if not os.path.isabs(vis_path): + # vis_path = os.path.join(utils.get_cache_path(), vis_path) + vis_path = utils.get_cache_path(vis_path) + + if not os.path.exists(vis_path): + warnings.warn("storage path {} does not exist.".format(vis_path)) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + text_processor=text_processor, + ann_paths=ann_paths, + vis_root=vis_path, + ) + + return datasets + + +def load_dataset_config(cfg_path): + cfg = OmegaConf.load(cfg_path).datasets + cfg = cfg[list(cfg.keys())[0]] + + return cfg diff --git a/lavis/datasets/builders/caption_builder.py b/lavis/datasets/builders/caption_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..43326115a887f61e4e3cf657c8cee785937272b8 --- /dev/null +++ b/lavis/datasets/builders/caption_builder.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.coco_caption_datasets import ( + COCOCapDataset, + COCOCapEvalDataset, + NoCapsEvalDataset, +) + +from lavis.common.registry import registry +from lavis.datasets.datasets.video_caption_datasets import ( + VideoCaptionDataset, + VideoCaptionEvalDataset, +) + + +@registry.register_builder("coco_caption") +class COCOCapBuilder(BaseDatasetBuilder): + train_dataset_cls = COCOCapDataset + eval_dataset_cls = COCOCapEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/defaults_cap.yaml", + } + + +@registry.register_builder("nocaps") +class COCOCapBuilder(BaseDatasetBuilder): + eval_dataset_cls = NoCapsEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/nocaps/defaults.yaml", + } + + +@registry.register_builder("msrvtt_caption") +class MSRVTTCapBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoCaptionDataset + eval_dataset_cls = VideoCaptionEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/msrvtt/defaults_cap.yaml", + } + + +@registry.register_builder("msvd_caption") +class MSVDCapBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoCaptionDataset + eval_dataset_cls = VideoCaptionEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/msvd/defaults_cap.yaml", + } + + +@registry.register_builder("vatex_caption") +class VATEXCapBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoCaptionDataset + eval_dataset_cls = VideoCaptionEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/vatex/defaults_cap.yaml", + } diff --git a/lavis/datasets/builders/classification_builder.py b/lavis/datasets/builders/classification_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa4787bea4eae08114f12112ada29f7105ec686 --- /dev/null +++ b/lavis/datasets/builders/classification_builder.py @@ -0,0 +1,27 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.common.registry import registry +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.nlvr_datasets import NLVRDataset, NLVREvalDataset +from lavis.datasets.datasets.snli_ve_datasets import SNLIVisualEntialmentDataset + + +@registry.register_builder("nlvr") +class NLVRBuilder(BaseDatasetBuilder): + train_dataset_cls = NLVRDataset + eval_dataset_cls = NLVREvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/nlvr/defaults.yaml"} + + +@registry.register_builder("snli_ve") +class SNLIVisualEntailmentBuilder(BaseDatasetBuilder): + train_dataset_cls = SNLIVisualEntialmentDataset + eval_dataset_cls = SNLIVisualEntialmentDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/snli_ve/defaults.yaml"} diff --git a/lavis/datasets/builders/dialogue_builder.py b/lavis/datasets/builders/dialogue_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..08a54f2aa4da710af98dc36aac36e2eec5d3dad4 --- /dev/null +++ b/lavis/datasets/builders/dialogue_builder.py @@ -0,0 +1,21 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.common.registry import registry +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.avsd_dialogue_datasets import ( + AVSDDialDataset, + AVSDDialEvalDataset, +) + + +@registry.register_builder("avsd_dialogue") +class AVSDDialBuilder(BaseDatasetBuilder): + train_dataset_cls = AVSDDialDataset + eval_dataset_cls = AVSDDialEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/avsd/defaults_dial.yaml"} diff --git a/lavis/datasets/builders/image_text_pair_builder.py b/lavis/datasets/builders/image_text_pair_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..90e411eb6d41c23c15dbf5a0c67e2b68d467b43b --- /dev/null +++ b/lavis/datasets/builders/image_text_pair_builder.py @@ -0,0 +1,77 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from lavis.common.registry import registry + +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.image_text_pair_datasets import ImageTextPairDataset +from lavis.datasets.datasets.laion_dataset import LaionDataset + + +@registry.register_builder("conceptual_caption_3m") +class ConceptualCaption3MBuilder(BaseDatasetBuilder): + train_dataset_cls = ImageTextPairDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/conceptual_caption/defaults_3m.yaml" + } + + +@registry.register_builder("conceptual_caption_12m") +class ConceptualCaption12MBuilder(BaseDatasetBuilder): + train_dataset_cls = ImageTextPairDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/conceptual_caption/defaults_12m.yaml" + } + + +@registry.register_builder("sbu_caption") +class SBUCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = ImageTextPairDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/sbu_caption/defaults.yaml"} + + +@registry.register_builder("vg_caption") +class VGCaptionBuilder(BaseDatasetBuilder): + train_dataset_cls = ImageTextPairDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_caption.yaml"} + + +@registry.register_builder("laion2B_multi") +class Laion2BMultiBuilder(BaseDatasetBuilder): + train_dataset_cls = LaionDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults_2B_multi.yaml"} + + def _download_ann(self): + pass + + def _download_vis(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + datasets = dict() + split = "train" # laion dataset only has train split + + # create datasets + # [NOTE] return inner_datasets (wds.DataPipeline) + dataset_cls = self.train_dataset_cls + datasets[split] = dataset_cls( + vis_processor=self.vis_processors[split], + text_processor=self.text_processors[split], + location=build_info.storage, + ).inner_dataset + + return datasets diff --git a/lavis/datasets/builders/imagefolder_builder.py b/lavis/datasets/builders/imagefolder_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..6c71fbe216156f7e18f3a0d49004d558508980e8 --- /dev/null +++ b/lavis/datasets/builders/imagefolder_builder.py @@ -0,0 +1,1061 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +from lavis.common.registry import registry +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.imagefolder_dataset import ImageFolderDataset + + +@registry.register_builder("imagenet") +class ImageNetBuilder(BaseDatasetBuilder): + train_dataset_cls = ImageFolderDataset + eval_dataset_cls = ImageFolderDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/imagenet/defaults.yaml"} + + def _download_ann(self): + pass + + def build(self): + self.build_processors() + + build_info = self.config.build_info + + vis_info = build_info.get(self.data_type) + + datasets = dict() + for split in build_info.splits: + assert split in [ + "train", + "val", + ], "Invalid split name {}, must be one of 'train', 'val' and 'test'." + + is_train = split == "train" + + vis_processor = ( + self.vis_processors["train"] + if is_train + else self.vis_processors["eval"] + ) + + vis_path = os.path.join(vis_info.storage, split) + + # create datasets + dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + datasets[split] = dataset_cls( + vis_processor=vis_processor, + vis_root=vis_path, + classnames=imagenet_classnames, + ) + + return datasets + + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] diff --git a/lavis/datasets/builders/retrieval_builder.py b/lavis/datasets/builders/retrieval_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..98ca3bdf572fe007ea1bd97d75aefcb8ae02fe3d --- /dev/null +++ b/lavis/datasets/builders/retrieval_builder.py @@ -0,0 +1,48 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.retrieval_datasets import ( + RetrievalDataset, + RetrievalEvalDataset, + VideoRetrievalDataset, + VideoRetrievalEvalDataset, +) + +from lavis.common.registry import registry + + +@registry.register_builder("msrvtt_retrieval") +class MSRVTTRetrievalBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoRetrievalDataset + eval_dataset_cls = VideoRetrievalEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/msrvtt/defaults_ret.yaml"} + + +@registry.register_builder("didemo_retrieval") +class DiDeMoRetrievalBuilder(BaseDatasetBuilder): + train_dataset_cls = VideoRetrievalDataset + eval_dataset_cls = VideoRetrievalEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/didemo/defaults_ret.yaml"} + + +@registry.register_builder("coco_retrieval") +class COCORetrievalBuilder(BaseDatasetBuilder): + train_dataset_cls = RetrievalDataset + eval_dataset_cls = RetrievalEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/coco/defaults_ret.yaml"} + + +@registry.register_builder("flickr30k") +class Flickr30kBuilder(BaseDatasetBuilder): + train_dataset_cls = RetrievalDataset + eval_dataset_cls = RetrievalEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/flickr30k/defaults.yaml"} diff --git a/lavis/datasets/builders/video_qa_builder.py b/lavis/datasets/builders/video_qa_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a3a1fb95eddf80943ae05bc770c3b53d42b648 --- /dev/null +++ b/lavis/datasets/builders/video_qa_builder.py @@ -0,0 +1,44 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.common.registry import registry +from lavis.common.utils import get_cache_path +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder +from lavis.datasets.datasets.video_vqa_datasets import VideoQADataset + + +class VideoQABuilder(BaseDatasetBuilder): + train_dataset_cls = VideoQADataset + eval_dataset_cls = VideoQADataset + + def build(self): + datasets = super().build() + + ans2label = self.config.build_info.annotations.get("ans2label") + if ans2label is None: + raise ValueError("ans2label is not specified in build_info.") + + ans2label = get_cache_path(ans2label.storage) + + for split in datasets: + datasets[split]._build_class_labels(ans2label) + + return datasets + + +@registry.register_builder("msrvtt_qa") +class MSRVTTQABuilder(VideoQABuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/msrvtt/defaults_qa.yaml", + } + + +@registry.register_builder("msvd_qa") +class MSVDQABuilder(VideoQABuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/msvd/defaults_qa.yaml", + } diff --git a/lavis/datasets/builders/vqa_builder.py b/lavis/datasets/builders/vqa_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c08a6d1406cf7a53ca6f8180671ca53e096bfdcc --- /dev/null +++ b/lavis/datasets/builders/vqa_builder.py @@ -0,0 +1,58 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder + +from lavis.common.registry import registry +from lavis.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset +from lavis.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset +from lavis.datasets.datasets.vg_vqa_datasets import VGVQADataset +from lavis.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset + + +@registry.register_builder("coco_vqa") +class COCOVQABuilder(BaseDatasetBuilder): + train_dataset_cls = COCOVQADataset + eval_dataset_cls = COCOVQAEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/coco/defaults_vqa.yaml", + "eval": "configs/datasets/coco/eval_vqa.yaml", + } + + +@registry.register_builder("vg_vqa") +class VGVQABuilder(BaseDatasetBuilder): + train_dataset_cls = VGVQADataset + DATASET_CONFIG_DICT = {"default": "configs/datasets/vg/defaults_vqa.yaml"} + + +@registry.register_builder("ok_vqa") +class OKVQABuilder(COCOVQABuilder): + DATASET_CONFIG_DICT = { + "default": "configs/datasets/okvqa/defaults.yaml", + } + + +@registry.register_builder("aok_vqa") +class AOKVQABuilder(BaseDatasetBuilder): + train_dataset_cls = AOKVQADataset + eval_dataset_cls = AOKVQAEvalDataset + + DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"} + + +@registry.register_builder("gqa") +class GQABuilder(BaseDatasetBuilder): + train_dataset_cls = GQADataset + eval_dataset_cls = GQAEvalDataset + + DATASET_CONFIG_DICT = { + "default": "configs/datasets/gqa/defaults.yaml", + "balanced_val": "configs/datasets/gqa/balanced_val.yaml", + "balanced_testdev": "configs/datasets/gqa/balanced_testdev.yaml", + } \ No newline at end of file diff --git a/lavis/datasets/data_utils.py b/lavis/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0ea4f44d356bf95ab6ca0621c3128bccef76054c --- /dev/null +++ b/lavis/datasets/data_utils.py @@ -0,0 +1,284 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import gzip +import logging +import os +import random as rnd +import tarfile +import zipfile + +import decord +import webdataset as wds +import numpy as np +import torch +from torch.utils.data.dataset import IterableDataset, ChainDataset +from decord import VideoReader +from lavis.common.registry import registry +from lavis.datasets.datasets.base_dataset import ConcatDataset +from tqdm import tqdm + +decord.bridge.set_bridge("torch") +MAX_INT = registry.get("MAX_INT") + + +def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"): + vr = VideoReader(uri=video_path, height=height, width=width) + + vlen = len(vr) + start, end = 0, vlen + + n_frms = min(n_frms, vlen) + + if sampling == "uniform": + indices = np.arange(start, end, vlen / n_frms).astype(int) + elif sampling == "headtail": + indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2)) + indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2)) + indices = indices_h + indices_t + else: + raise NotImplementedError + + # get_batch -> T, H, W, C + frms = vr.get_batch(indices).permute(3, 0, 1, 2).float() # (C, T, H, W) + + return frms + + +def apply_to_sample(f, sample): + if len(sample) == 0: + return {} + + def _apply(x): + if torch.is_tensor(x): + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + else: + return x + + return _apply(sample) + + +def move_to_cuda(sample): + def _move_to_cuda(tensor): + return tensor.cuda() + + return apply_to_sample(_move_to_cuda, sample) + + +def prepare_sample(samples, cuda_enabled=True): + if cuda_enabled: + samples = move_to_cuda(samples) + + # TODO fp16 support + + return samples + + +def reorg_datasets_by_split(datasets): + """ + Organizes datasets by split. + + Args: + datasets: dict of torch.utils.data.Dataset objects by name. + + Returns: + Dict of datasets by split {split_name: List[Datasets]}. + """ + # if len(datasets) == 1: + # return datasets[list(datasets.keys())[0]] + # else: + reorg_datasets = dict() + + # reorganize by split + for _, dataset in datasets.items(): + for split_name, dataset_split in dataset.items(): + if split_name not in reorg_datasets: + reorg_datasets[split_name] = [dataset_split] + else: + reorg_datasets[split_name].append(dataset_split) + + return reorg_datasets + + +def concat_datasets(datasets): + """ + Concatenates multiple datasets into a single dataset. + + It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support + generic IterableDataset because it requires creating separate samplers. + + Now only supports conctenating training datasets and assuming validation and testing + have only a single dataset. This is because metrics should not be computed on the concatenated + datasets. + + Args: + datasets: dict of torch.utils.data.Dataset objects by split. + + Returns: + Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, + "val" and "test" remain the same. + + If the input training datasets contain both map-style and DataPipeline datasets, returns + a tuple, where the first element is a concatenated map-style dataset and the second + element is a chained DataPipeline dataset. + + """ + # concatenate datasets in the same split + for split_name in datasets: + if split_name != "train": + assert ( + len(datasets[split_name]) == 1 + ), "Do not support multiple {} datasets.".format(split_name) + datasets[split_name] = datasets[split_name][0] + else: + iterable_datasets, map_datasets = [], [] + for dataset in datasets[split_name]: + if isinstance(dataset, wds.DataPipeline): + logging.info( + "Dataset {} is IterableDataset, can't be concatenated.".format( + dataset + ) + ) + iterable_datasets.append(dataset) + elif isinstance(dataset, IterableDataset): + raise NotImplementedError( + "Do not support concatenation of generic IterableDataset." + ) + else: + map_datasets.append(dataset) + + # if len(iterable_datasets) > 0: + # concatenate map-style datasets and iterable-style datasets separately + chained_datasets = ( + ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None + ) + concat_datasets = ( + ConcatDataset(map_datasets) if len(map_datasets) > 0 else None + ) + + train_datasets = concat_datasets, chained_datasets + train_datasets = tuple([x for x in train_datasets if x is not None]) + train_datasets = ( + train_datasets[0] if len(train_datasets) == 1 else train_datasets + ) + + datasets[split_name] = train_datasets + + return datasets + + +def extract_archive(from_path, to_path=None, overwrite=False): + """Extract archive. + + Args: + from_path: the path of the archive. + to_path: the root path of the extracted files (directory of from_path) + overwrite: overwrite existing files (False) + + Returns: + List of paths to extracted files even if not overwritten. + + Examples: + >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' + >>> from_path = './validation.tar.gz' + >>> to_path = './' + >>> torchtext.utils.download_from_url(url, from_path) + >>> torchtext.utils.extract_archive(from_path, to_path) + >>> ['.data/val.de', '.data/val.en'] + >>> torchtext.utils.download_from_url(url, from_path) + >>> torchtext.utils.extract_archive(from_path, to_path) + >>> ['.data/val.de', '.data/val.en'] + + """ + + if to_path is None: + to_path = os.path.dirname(from_path) + + if from_path.endswith((".tar.gz", ".tgz")): + logging.info("Opening tar file {} to {}.".format(from_path, to_path)) + with tarfile.open(from_path, "r") as tar: + files = [] + for file_ in tqdm(tar): + file_path = os.path.join(to_path, file_.name) + if file_.isfile(): + files.append(file_path) + if os.path.exists(file_path): + logging.info("{} already extracted.".format(file_path)) + if not overwrite: + continue + tar.extract(file_, to_path) + logging.info("Finished extracting tar file {}.".format(from_path)) + return files + + elif from_path.endswith(".zip"): + assert zipfile.is_zipfile(from_path), from_path + logging.info("Opening zip file {} to {}.".format(from_path, to_path)) + with zipfile.ZipFile(from_path, "r") as zfile: + files = [] + for file_ in tqdm(zfile.namelist()): + file_path = os.path.join(to_path, file_) + files.append(file_path) + if os.path.exists(file_path): + logging.info("{} already extracted.".format(file_path)) + if not overwrite: + continue + zfile.extract(file_, to_path) + files = [f for f in files if os.path.isfile(f)] + logging.info("Finished extracting zip file {}.".format(from_path)) + return files + + elif from_path.endswith(".gz"): + logging.info("Opening gz file {} to {}.".format(from_path, to_path)) + default_block_size = 65536 + filename = from_path[:-3] + files = [filename] + with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file: + while True: + block = gzfile.read(default_block_size) + if not block: + break + else: + d_file.write(block) + d_file.write(block) + logging.info("Finished extracting gz file {}.".format(from_path)) + return files + + else: + raise NotImplementedError( + "We currently only support tar.gz, .tgz, .gz and zip achives." + ) + + +def save_frames_grid(img_array, out_path): + import torch + from PIL import Image + from torchvision.utils import make_grid + + if len(img_array.shape) == 3: + img_array = img_array.unsqueeze(0) + elif len(img_array.shape) == 5: + b, t, c, h, w = img_array.shape + img_array = img_array.view(-1, c, h, w) + elif len(img_array.shape) == 4: + pass + else: + raise NotImplementedError( + "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored." + ) + + assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only." + + grid = make_grid(img_array) + ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy() + + img = Image.fromarray(ndarr) + + img.save(out_path) diff --git a/lavis/datasets/datasets/aok_vqa_datasets.py b/lavis/datasets/datasets/aok_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..a77b770936f353ceba799d53acacdbfb2809aaa4 --- /dev/null +++ b/lavis/datasets/datasets/aok_vqa_datasets.py @@ -0,0 +1,154 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from collections import OrderedDict +import json +import os +import torch + +from PIL import Image + +from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "direct_answers": "; ".join(ann["direct_answers"]), + "choices": "; ".join(ann["choices"]), + "correct_choice": ann["choices"][ann["correct_choice_idx"]], + "image": sample["image"], + } + ) + + +class AOKVQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answer_key = "direct_answers" + + answer_weight = {} + for answer in ann[answer_key]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann[answer_key]) + else: + answer_weight[answer] = 1 / len(ann[answer_key]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + return { + "image": image, + "text_input": question, + "answers": answers, + "weights": weights, + } + + +class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + answer_list_path = ann_paths[1] + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + try: + self.coco_fmt_qust_file = ann_paths[2] + self.coco_fmt_anno_file = ann_paths[3] + except IndexError: + self.coco_fmt_qust_file = None + self.coco_fmt_anno_file = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def collater(self, samples): + ( + image_list, + question_list, + question_id_list, + instance_id_list, + choices_list, + correct_choice_idx_list, + direct_answers_list, + ) = ([], [], [], [], [], [], []) + + for sample in samples: + image_list.append(sample["image"]) + question_list.append(sample["text_input"]) + question_id_list.append(sample["question_id"]) + instance_id_list.append(sample["instance_id"]) + choices_list.append(sample["choices"]) + correct_choice_idx_list.append(sample["correct_choice_idx"]) + direct_answers_list.append(sample["direct_answers"]) + + return { + "image": torch.stack(image_list, dim=0), + "text_input": question_list, + "question_id": question_id_list, + "instance_id": instance_id_list, + "choices": choices_list, + "correct_choice_idx": correct_choice_idx_list, + "direct_answers": direct_answers_list, + } + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + choices = ann["choices"] + if "correct_choice_idx" in ann: + correct_choice_idx = ann["correct_choice_idx"] + else: + correct_choice_idx = None + + if "direct_answers" in ann: + direct_answers = ann["direct_answers"] + else: + direct_answers = None + + return { + "image": image, + "text_input": question, + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + "choices": choices, + "correct_choice_idx": correct_choice_idx, + "direct_answers": direct_answers, + } diff --git a/lavis/datasets/datasets/avsd_dialogue_datasets.py b/lavis/datasets/datasets/avsd_dialogue_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2d599e6117e7abebb70a4b24a42a09b7162dd0 --- /dev/null +++ b/lavis/datasets/datasets/avsd_dialogue_datasets.py @@ -0,0 +1,166 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +from lavis.datasets.datasets.dialogue_datasets import ( + DialogueDataset, + DialogueEvalDataset, +) + + +class AVSDDialDataset(DialogueDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["image_id"] + + video = self.vis_processor(self.vis_root, vname) + + dialogue = self.text_processor(ann) + + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "video_fts": video["video_fts"], + "video_token_type_ids": video["token_type_ids"], + "input_ids": dialogue["input_ids"], + "token_type_ids": dialogue["token_type_ids"], + "labels": dialogue["labels"], + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } + + def collater(self, samples): + + input_ids, token_type_ids, labels, video_fts, video_token_type_ids = ( + [], + [], + [], + [], + [], + ) + + for i in samples: + input_ids.append(i["input_ids"]) + token_type_ids.append(i["token_type_ids"]) + labels.append(i["labels"]) + video_fts.append(i["video_fts"]) + video_token_type_ids.append(i["video_token_type_ids"]) + + input_ids = self.text_processor.padding(input_ids) + + labels = self.text_processor.padding( + labels, -1 + ) # ignore token indice -1 by default + video_fts = self.vis_processor.padding(video_fts) + + token_type_ids = self.text_processor.padding(token_type_ids) + video_token_type_ids = self.text_processor.padding(video_token_type_ids) + token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1) + + attn_mask = self.text_processor.get_attention_mask(input_ids) + video_mask = self.vis_processor.get_attention_mask(video_fts) + attn_mask = torch.cat([video_mask, attn_mask], dim=1) + + video_labels = ( + torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 + ) # ignore token indice -1 by default + labels = torch.cat([video_labels, labels], dim=1) + + samples = {} + samples["input_ids"] = input_ids + samples["token_type_ids"] = token_type_ids + samples["labels"] = labels + samples["video_fts"] = video_fts + samples["attn_mask"] = attn_mask + + return samples + + +class AVSDDialEvalDataset(DialogueEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["image_id"] + + video = self.vis_processor(self.vis_root, vname) + + dialogue = self.text_processor(ann) + + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "video_fts": video["video_fts"], + "video_token_type_ids": video["token_type_ids"], + "input_ids": dialogue["input_ids"], + "token_type_ids": dialogue["token_type_ids"], + "labels": dialogue["labels"], + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } + + def collater(self, samples): + + input_ids, token_type_ids, labels, video_fts, video_token_type_ids = ( + [], + [], + [], + [], + [], + ) + + for i in samples: + input_ids.append(i["input_ids"]) + token_type_ids.append(i["token_type_ids"]) + labels.append(i["labels"]) + video_fts.append(i["video_fts"]) + video_token_type_ids.append(i["video_token_type_ids"]) + + input_ids = self.text_processor.padding(input_ids) + + labels = self.text_processor.padding( + labels, -1 + ) # ignore token indice -1 by default + video_fts = self.vis_processor.padding(video_fts) + + token_type_ids = self.text_processor.padding(token_type_ids) + video_token_type_ids = self.text_processor.padding(video_token_type_ids) + token_type_ids = torch.cat([video_token_type_ids, token_type_ids], dim=1) + + attn_mask = self.text_processor.get_attention_mask(input_ids) + video_mask = self.vis_processor.get_attention_mask(video_fts) + attn_mask = torch.cat([video_mask, attn_mask], dim=1) + + video_labels = ( + torch.ones((video_fts.size(0), video_fts.size(1))).long() * -1 + ) # ignore token indice -1 by default + labels = torch.cat([video_labels, labels], dim=1) + + samples = {} + samples["input_ids"] = input_ids + samples["token_type_ids"] = token_type_ids + samples["labels"] = labels + samples["video_fts"] = video_fts + samples["attn_mask"] = attn_mask + + return samples diff --git a/lavis/datasets/datasets/base_dataset.py b/lavis/datasets/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a02d3dfa6aaa4bd8b73a71f778544907d1cd2f --- /dev/null +++ b/lavis/datasets/datasets/base_dataset.py @@ -0,0 +1,68 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +from typing import Iterable + +from torch.utils.data import Dataset, ConcatDataset +from torch.utils.data.dataloader import default_collate + + +class BaseDataset(Dataset): + def __init__( + self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[] + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + self.annotation.extend(json.load(open(ann_path, "r"))) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __len__(self): + return len(self.annotation) + + def collater(self, samples): + return default_collate(samples) + + def set_processors(self, vis_processor, text_processor): + self.vis_processor = vis_processor + self.text_processor = text_processor + + def _add_instance_ids(self, key="instance_id"): + for idx, ann in enumerate(self.annotation): + ann[key] = str(idx) + + +class ConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__(datasets) + + def collater(self, samples): + # TODO For now only supports datasets with same underlying collater implementations + + all_keys = set() + for s in samples: + all_keys.update(s) + + shared_keys = all_keys + for s in samples: + shared_keys = shared_keys & set(s.keys()) + + samples_shared_keys = [] + for s in samples: + samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) + + return self.datasets[0].collater(samples_shared_keys) diff --git a/lavis/datasets/datasets/caption_datasets.py b/lavis/datasets/datasets/caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebe7ae19ddf4743ada9ed7e1f53ff05cd8e80db --- /dev/null +++ b/lavis/datasets/datasets/caption_datasets.py @@ -0,0 +1,84 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from lavis.datasets.datasets.base_dataset import BaseDataset +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class CaptionDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class CaptionEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/coco_caption_datasets.py b/lavis/datasets/datasets/coco_caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..400750a75ea947ff5ae230747c5de6f5fe721e55 --- /dev/null +++ b/lavis/datasets/datasets/coco_caption_datasets.py @@ -0,0 +1,70 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json + +from PIL import Image +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from lavis.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset + +COCOCapDataset = CaptionDataset + + +class COCOCapEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] + + return { + "image": image, + "image_id": img_id, + "instance_id": ann["instance_id"], + } + + +class NoCapsEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + img_id = ann["img_id"] + + return { + "image": image, + "image_id": img_id, + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/coco_vqa_datasets.py b/lavis/datasets/datasets/coco_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e07f09a69d0c65bcffe7ae545ceaddb28ab2a0 --- /dev/null +++ b/lavis/datasets/datasets/coco_vqa_datasets.py @@ -0,0 +1,107 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json + +from PIL import Image + +from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class COCOVQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answer_weight = {} + for answer in ann["answer"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answer"]) + else: + answer_weight[answer] = 1 / len(ann["answer"]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + return { + "image": image, + "text_input": question, + "answers": answers, + "weights": weights, + } + + +class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + answer_list_path = ann_paths[1] + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + try: + self.coco_fmt_qust_file = ann_paths[2] + self.coco_fmt_anno_file = ann_paths[3] + except IndexError: + self.coco_fmt_qust_file = None + self.coco_fmt_anno_file = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + return { + "image": image, + "text_input": question, + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/dataloader_utils.py b/lavis/datasets/datasets/dataloader_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6a88094c4366fc4a1dd7f0c5c62d3a245a4a37 --- /dev/null +++ b/lavis/datasets/datasets/dataloader_utils.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +import random +import torch +from lavis.datasets.data_utils import move_to_cuda +from torch.utils.data import DataLoader + + +class MultiIterLoader: + """ + A simple wrapper for iterating over multiple iterators. + + Args: + loaders (List[Loader]): List of Iterator loaders. + ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly. + """ + + def __init__(self, loaders, ratios=None): + # assert all loaders has __next__ method + for loader in loaders: + assert hasattr( + loader, "__next__" + ), "Loader {} has no __next__ method.".format(loader) + + if ratios is None: + ratios = [1.0] * len(loaders) + else: + assert len(ratios) == len(loaders) + ratios = [float(ratio) / sum(ratios) for ratio in ratios] + + self.loaders = loaders + self.ratios = ratios + + def __next__(self): + # random sample from each loader by ratio + loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] + return next(self.loaders[loader_idx]) + + +class PrefetchLoader(object): + """ + Modified from https://github.com/ChenRocks/UNITER. + + overlap compute and cuda data transfer + (copied and then modified from nvidia apex) + """ + + def __init__(self, loader): + self.loader = loader + self.stream = torch.cuda.Stream() + + def __iter__(self): + loader_it = iter(self.loader) + self.preload(loader_it) + batch = self.next(loader_it) + while batch is not None: + is_tuple = isinstance(batch, tuple) + if is_tuple: + task, batch = batch + + if is_tuple: + yield task, batch + else: + yield batch + batch = self.next(loader_it) + + def __len__(self): + return len(self.loader) + + def preload(self, it): + try: + self.batch = next(it) + except StopIteration: + self.batch = None + return + # if record_stream() doesn't work, another option is to make sure + # device inputs are created on the main stream. + # self.next_input_gpu = torch.empty_like(self.next_input, + # device='cuda') + # self.next_target_gpu = torch.empty_like(self.next_target, + # device='cuda') + # Need to make sure the memory allocated for next_* is not still in use + # by the main stream at the time we start copying to next_*: + # self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.batch = move_to_cuda(self.batch) + # more code for the alternative if record_stream() doesn't work: + # copy_ will record the use of the pinned source tensor in this + # side stream. + # self.next_input_gpu.copy_(self.next_input, non_blocking=True) + # self.next_target_gpu.copy_(self.next_target, non_blocking=True) + # self.next_input = self.next_input_gpu + # self.next_target = self.next_target_gpu + + def next(self, it): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is not None: + record_cuda_stream(batch) + self.preload(it) + return batch + + def __getattr__(self, name): + method = self.loader.__getattribute__(name) + return method + + +def record_cuda_stream(batch): + if isinstance(batch, torch.Tensor): + batch.record_stream(torch.cuda.current_stream()) + elif isinstance(batch, list) or isinstance(batch, tuple): + for t in batch: + record_cuda_stream(t) + elif isinstance(batch, dict): + for t in batch.values(): + record_cuda_stream(t) + else: + pass + + +class IterLoader: + """ + A wrapper to convert DataLoader as an infinite iterator. + + Modified from: + https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py + """ + + def __init__(self, dataloader: DataLoader, use_distributed: bool = False): + self._dataloader = dataloader + self.iter_loader = iter(self._dataloader) + self._use_distributed = use_distributed + self._epoch = 0 + + @property + def epoch(self) -> int: + return self._epoch + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed: + self._dataloader.sampler.set_epoch(self._epoch) + time.sleep(2) # Prevent possible deadlock during epoch transition + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __iter__(self): + return self + + def __len__(self): + return len(self._dataloader) diff --git a/lavis/datasets/datasets/dialogue_datasets.py b/lavis/datasets/datasets/dialogue_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7596da65f42812d185d91c8c7bcf7776e8362444 --- /dev/null +++ b/lavis/datasets/datasets/dialogue_datasets.py @@ -0,0 +1,141 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from PIL import Image + +from lavis.datasets.datasets.base_dataset import BaseDataset + +import json +import copy + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "dialogue": ann["dialogue"], + "image": sample["image"], + } + ) + + +class DialogueDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + dialogs = json.load(open(ann_path, "r"))["dialogs"] + for dialog in dialogs: + all_turns = dialog["dialog"] + dialogue_context = [] + for turn in all_turns: + dialog_instance = copy.deepcopy(dialog) + question = turn["question"] + answer = turn["answer"] + + dialog_instance["dialog"] = copy.deepcopy(dialogue_context) + dialog_instance["question"] = question + dialog_instance["answer"] = answer + self.annotation.append(dialog_instance) + dialogue_context.append(turn) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class DialogueEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + + self.vis_root = vis_root + + self.annotation = [] + for ann_path in ann_paths: + dialogs = json.load(open(ann_path, "r"))["dialogs"] + for dialog in dialogs: + all_turns = dialog["dialog"] + dialogue_context = all_turns[:-1] + last_turn = all_turns[-1] + + question = last_turn["question"] + answer = last_turn["answer"] + + dialog["dialog"] = dialogue_context + dialog["question"] = question + dialog["answer"] = answer + + self.annotation.append(dialog) + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/gqa_datasets.py b/lavis/datasets/datasets/gqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..073c57040d7852bffc273ce6177c246a4fce1ab8 --- /dev/null +++ b/lavis/datasets/datasets/gqa_datasets.py @@ -0,0 +1,101 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json + +from PIL import Image + +from lavis.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class GQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answers = [ann["answer"]] + weights = [1] + + return { + "image": image, + "text_input": question, + "answers": answers, + "weights": weights, + } + + +class GQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. gqa/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + ## TODO: support inference method == 'ranking' + answer_list_path = ann_paths[1] if len(ann_paths) > 1 else '' + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self._add_instance_ids() + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + if "answer" in ann: + # answer is a string + answer = ann["answer"] + else: + answer = None + + return { + "image": image, + "text_input": question, + "answer": answer, + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/image_text_pair_datasets.py b/lavis/datasets/datasets/image_text_pair_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6e039eeff97a85975bb68ff5bb55fb29d67e15 --- /dev/null +++ b/lavis/datasets/datasets/image_text_pair_datasets.py @@ -0,0 +1,47 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from lavis.datasets.datasets.base_dataset import BaseDataset +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": os.path.basename(ann["image"]), + "caption": ann["caption"], + "image": sample["image"], + } + ) + + +class ImageTextPairDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return {"image": image, "text_input": caption} diff --git a/lavis/datasets/datasets/imagefolder_dataset.py b/lavis/datasets/datasets/imagefolder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8057b7e946d5df5c837499a4a92d46e8c56cf03e --- /dev/null +++ b/lavis/datasets/datasets/imagefolder_dataset.py @@ -0,0 +1,59 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from lavis.datasets.datasets.base_dataset import BaseDataset +from PIL import Image +from torchvision import datasets + + +class ImageFolderDataset(BaseDataset): + def __init__(self, vis_processor, vis_root, classnames=[], **kwargs): + super().__init__(vis_processor=vis_processor, vis_root=vis_root) + + self.inner_dataset = datasets.ImageFolder(vis_root) + + self.annotation = [ + {"image": elem[0], "label": elem[1], "image_id": elem[0]} + for elem in self.inner_dataset.imgs + ] + + self.classnames = classnames + + self._add_instance_ids() + + def __len__(self): + return len(self.inner_dataset) + + def __getitem__(self, index): + ann = self.annotation[index] + + img_fn = ann["image"] + image_path = os.path.join(self.vis_root, img_fn) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return { + "image": image, + "label": ann["label"], + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } + + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "label": self.classnames[ann["label"]], + "image": sample["image"], + } + ) diff --git a/lavis/datasets/datasets/laion_dataset.py b/lavis/datasets/datasets/laion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9f283d8d27c11d663d8032b5b2b3012fdd0ec2 --- /dev/null +++ b/lavis/datasets/datasets/laion_dataset.py @@ -0,0 +1,62 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import webdataset as wds +from lavis.datasets.datasets.base_dataset import BaseDataset + + +class LaionDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, location): + super().__init__(vis_processor=vis_processor, text_processor=text_processor) + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode("pilrgb", handler=wds.warn_and_continue), + wds.to_tuple("jpg", "json", handler=wds.warn_and_continue), + wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue), + wds.map(self.to_dict, handler=wds.warn_and_continue), + ) + + def to_dict(self, sample): + return { + "image": sample[0], + "text_input": self.text_processor(sample[1]["caption"]), + } + + +if __name__ == "__main__": + from torchvision import transforms + + def to_image_text_pair(sample): + return sample[0], sample[1]["caption"] + + normalize = transforms.Normalize( + (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) + ) + + transform_train = transforms.Compose( + [ + transforms.RandomResizedCrop(256, scale=(0.2, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ) + + dataset = LaionDataset( + vis_processor=transform_train, + text_processor=lambda x: x, + location="/export/laion/laion2B-multi/part-00000/{00000..01743}.tar", + ) + + import torch + + loader = torch.utils.data.DataLoader(dataset.inner_dataset, batch_size=2) + + print(next(iter(loader))["text_input"]) diff --git a/lavis/datasets/datasets/multimodal_classification_datasets.py b/lavis/datasets/datasets/multimodal_classification_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c1b4fe02ed39bcec396e160bda6fe43246cb4d03 --- /dev/null +++ b/lavis/datasets/datasets/multimodal_classification_datasets.py @@ -0,0 +1,20 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from abc import abstractmethod +from lavis.datasets.datasets.base_dataset import BaseDataset + + +class MultimodalClassificationDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.class_labels = None + + @abstractmethod + def _build_class_labels(self): + pass diff --git a/lavis/datasets/datasets/nlvr_datasets.py b/lavis/datasets/datasets/nlvr_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc818c6ac7592686ce104bea345bfe95d727aa0 --- /dev/null +++ b/lavis/datasets/datasets/nlvr_datasets.py @@ -0,0 +1,94 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import random + +from collections import OrderedDict + +from lavis.datasets.datasets.multimodal_classification_datasets import ( + MultimodalClassificationDataset, +) +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file_L": ann["images"][0], + "file_R": ann["images"][1], + "sentence": ann["sentence"], + "label": ann["label"], + "image": [sample["image0"], sample["image1"]], + } + ) + + +class NLVRDataset(MultimodalClassificationDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.class_labels = self._build_class_labels() + + def _build_class_labels(self): + return {"False": 0, "True": 1} + + @staticmethod + def _flip(samples): + sentence = samples["text_input"] + image0, image1 = samples["image0"], samples["image1"] + + if "left" not in sentence and "right" not in sentence: + if random.random() < 0.5: + image0, image1 = image1, image0 + else: + if random.random() < 0.5: + sentence = sentence.replace("left", "[TEMP_TOKEN]") + sentence = sentence.replace("right", "left") + sentence = sentence.replace("[TEMP_TOKEN]", "right") + + image0, image1 = image1, image0 + + samples["text_input"] = sentence + samples["image0"] = image0 + samples["image1"] = image1 + + return samples + + def __getitem__(self, index): + ann = self.annotation[index] + + image0_path = os.path.join(self.vis_root, ann["images"][0]) + image0 = Image.open(image0_path).convert("RGB") + image0 = self.vis_processor(image0) + + image1_path = os.path.join(self.vis_root, ann["images"][1]) + image1 = Image.open(image1_path).convert("RGB") + image1 = self.vis_processor(image1) + + sentence = self.text_processor(ann["sentence"]) + label = self.class_labels[ann["label"]] + + return self._flip( + { + "image0": image0, + "image1": image1, + "text_input": sentence, + "label": label, + # "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } + ) + + +class NLVREvalDataset(NLVRDataset): + @staticmethod + def _flip(samples): + return samples diff --git a/lavis/datasets/datasets/retrieval_datasets.py b/lavis/datasets/datasets/retrieval_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..9cee7a4f800c67524fffbd3ce1e4fc068fba67e1 --- /dev/null +++ b/lavis/datasets/datasets/retrieval_datasets.py @@ -0,0 +1,162 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from lavis.datasets.datasets.base_dataset import BaseDataset +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + visual_key = "image" if "image" in ann else "video" + + return OrderedDict( + { + "file": ann[visual_key], + "caption": ann["caption"], + visual_key: sample[visual_key], + } + ) + + +class RetrievalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + return { + "image": image, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + "instance_id": ann["instance_id"], + } + + +class RetrievalEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann["image"]) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann["caption"]): + self.text.append(self.text_processor(caption)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __getitem__(self, index): + + image_path = os.path.join(self.vis_root, self.annotation[index]["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + + return {"image": image, "index": index} + + +class VideoRetrievalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of videos. + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann["video"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __getitem__(self, index): + + ann = self.annotation[index] + + vpath = os.path.join(self.vis_root, ann["video"]) + + video = self.vis_processor(vpath) + caption = self.text_processor(ann["caption"]) + + # return image, caption, self.img_ids[ann['image_id']] + return { + "video": video, + "text_input": caption, + "image_id": self.img_ids[ann["video"]], + } + + +class VideoRetrievalEvalDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of videos. + ann_root (string): directory to store the annotation file + split (string): val or test + """ + + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann["video"]) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann["caption"]): + self.text.append(self.text_processor(caption)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __getitem__(self, index): + ann = self.annotation[index] + + vpath = os.path.join(self.vis_root, ann["video"]) + video = self.vis_processor(vpath) + + return {"video": video, "index": index} diff --git a/lavis/datasets/datasets/snli_ve_datasets.py b/lavis/datasets/datasets/snli_ve_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..446f3caea6221087dd60d8fb96fecfc4d2e1a2b8 --- /dev/null +++ b/lavis/datasets/datasets/snli_ve_datasets.py @@ -0,0 +1,56 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from collections import OrderedDict + +from lavis.datasets.datasets.multimodal_classification_datasets import ( + MultimodalClassificationDataset, +) +from PIL import Image + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": os.path.basename(ann["image"]), + "sentence": ann["sentence"], + "label": ann["label"], + "image": sample["image"], + } + ) + + +class SNLIVisualEntialmentDataset(MultimodalClassificationDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.class_labels = self._build_class_labels() + + def _build_class_labels(self): + return {"contradiction": 0, "neutral": 1, "entailment": 2} + + def __getitem__(self, index): + ann = self.annotation[index] + + image_id = ann["image"] + image_path = os.path.join(self.vis_root, "%s.jpg" % image_id) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + sentence = self.text_processor(ann["sentence"]) + + return { + "image": image, + "text_input": sentence, + "label": self.class_labels[ann["label"]], + "image_id": image_id, + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/vg_vqa_datasets.py b/lavis/datasets/datasets/vg_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..08bd909db553c49495d46ea60e8327d801a52bf5 --- /dev/null +++ b/lavis/datasets/datasets/vg_vqa_datasets.py @@ -0,0 +1,37 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +from PIL import Image + +from lavis.datasets.datasets.vqa_datasets import VQADataset + + +class VGVQADataset(VQADataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answers = [ann["answer"]] + # TODO this should be configured better + weights = [0.2] + + return { + "image": image, + "text_input": question, + "answers": answers, + "weights": weights, + } diff --git a/lavis/datasets/datasets/video_caption_datasets.py b/lavis/datasets/datasets/video_caption_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..584b0bcf845bba469457113a9d343a5dd7dc3d64 --- /dev/null +++ b/lavis/datasets/datasets/video_caption_datasets.py @@ -0,0 +1,63 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from lavis.datasets.datasets.base_dataset import BaseDataset + +from lavis.datasets.datasets.caption_datasets import CaptionDataset + + +class VideoCaptionDataset(CaptionDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["video"] + video_path = os.path.join(self.vis_root, vname) + + video = self.vis_processor(video_path) + caption = self.text_processor(ann["caption"]) + + # "image_id" is kept to stay compatible with the COCO evaluation format + return { + "video": video, + "text_input": caption, + "image_id": self.img_ids[ann["image_id"]], + } + + +class VideoCaptionEvalDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def __getitem__(self, index): + + ann = self.annotation[index] + + vname = ann["video"] + video_path = os.path.join(self.vis_root, vname) + + video = self.vis_processor(video_path) + + return { + "video": video, + "image_id": ann["image_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/video_vqa_datasets.py b/lavis/datasets/datasets/video_vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..7d50ee736b8104c21a3e43bacc0897f45c46a9ae --- /dev/null +++ b/lavis/datasets/datasets/video_vqa_datasets.py @@ -0,0 +1,63 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +import os +from collections import OrderedDict + +from lavis.datasets.datasets.multimodal_classification_datasets import ( + MultimodalClassificationDataset, +) + + +class __DisplMixin: + def displ_item(self, index): + ann = self.annotation[index] + + vname = ann["video"] + vpath = os.path.join(self.vis_root, vname) + + return OrderedDict( + {"file": vpath, "question": ann["question"], "answer": ann["answer"]} + ) + + +class VideoQADataset(MultimodalClassificationDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def _build_class_labels(self, ans_path): + ans2label = json.load(open(ans_path)) + + self.class_labels = ans2label + + def _get_answer_label(self, answer): + if answer in self.class_labels: + return self.class_labels[answer] + else: + return len(self.class_labels) + + def __getitem__(self, index): + assert ( + self.class_labels + ), f"class_labels of {__class__.__name__} is not built yet." + + ann = self.annotation[index] + + vname = ann["video"] + vpath = os.path.join(self.vis_root, vname) + + frms = self.vis_processor(vpath) + question = self.text_processor(ann["question"]) + + return { + "video": frms, + "text_input": question, + "answers": self._get_answer_label(ann["answer"]), + "question_id": ann["question_id"], + "instance_id": ann["instance_id"], + } diff --git a/lavis/datasets/datasets/vqa_datasets.py b/lavis/datasets/datasets/vqa_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..8803b25d7fdff8f80764d95db1f8bb0bd26e93ae --- /dev/null +++ b/lavis/datasets/datasets/vqa_datasets.py @@ -0,0 +1,44 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch + +from lavis.datasets.datasets.base_dataset import BaseDataset + + +class VQADataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + def collater(self, samples): + image_list, question_list, answer_list, weight_list = [], [], [], [] + + num_answers = [] + + for sample in samples: + image_list.append(sample["image"]) + question_list.append(sample["text_input"]) + + weight_list.extend(sample["weights"]) + + answers = sample["answers"] + + answer_list.extend(answers) + num_answers.append(len(answers)) + + return { + "image": torch.stack(image_list, dim=0), + "text_input": question_list, + "answer": answer_list, + "weight": torch.Tensor(weight_list), + "n_answers": torch.LongTensor(num_answers), + } + + +class VQAEvalDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/LICENSE b/lavis/datasets/download_scripts/DownloadConceptualCaptions/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4062f42c8e3870f993fee16032dde12965e391a3 --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/LICENSE @@ -0,0 +1,25 @@ +// Copyright 2022 Dongxu Li, Junnan Li, Hung Le, Guangsen Wang, Silvio Savarese, Steven Hoi. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +MIT License + +Copyright (c) 2019 Igor Brigadir + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/README.md b/lavis/datasets/download_scripts/DownloadConceptualCaptions/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0dd0b9d5bfe304770d06b2adc363f33a6c390ced --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/README.md @@ -0,0 +1,22 @@ + + +# Download Conceptual Captions Data + +Place data from: https://ai.google.com/research/ConceptualCaptions/download in this folder + +`Train_GCC-training.tsv / cc3m.tsv` Training Split (3,318,333) + +run `download_data_cc3m.py` or `download_data_cc12m.py`. + +Images will be in default LAVIS cache folders. You can stop and resume, the settings for splitting downloads into chunks / threads are not optimal, but it maxed out my connection so i kept them as is. + +Note: A previous version of this script used a different file naming scheme, this changed and if you are resuming a previously started download, you will get duplicates. + +A bunch of them will fail to download, and return web pages instead. These will need to be cleaned up later. See `downloaded_validation_report.tsv` after it downloads for HTTP errors. Around 8% of images are gone, based on validation set results. Setting the user agent could fix some errors too maybe - not sure if any requests are rejected by sites based on this. + +It should take about a day or two to download the training data, keep an eye on disk space. diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_12m.ipynb b/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_12m.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..7ca886ffb664a0c02b45b5b0fabc5159284ef88e --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_12m.ipynb @@ -0,0 +1,227 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from lavis.common.utils import get_abs_path, get_cache_path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "cc12m = pd.read_csv(\"downloaded_cc12m_report.tsv.gz\", compression=\"gzip\", sep=\"\\t\", names=[\"caption\", \"path\", \"dataset\", \"mimetype\", \"size\", \"status\", \"url\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "caption a very typical bus station\n", + "path /export/home/.cache/lavis/conceptual_caption/i...\n", + "dataset cc3m\n", + "mimetype image/jpeg\n", + "size 36078\n", + "status 200\n", + "url http://lh6.ggpht.com/-IvRtNLNcG8o/TpFyrudaT6I/...\n", + "Name: 0, dtype: object" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cc12m.iloc[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3318333" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(cc12m)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3130587/3130587 [17:28<00:00, 2986.08it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2759017 valid records\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "cnt = 0\n", + "\n", + "valid_records = []\n", + "\n", + "for i, path in tqdm(enumerate(cc12m.path.unique()), total=len(cc12m.path.unique())):\n", + " path = str(path)\n", + " if os.path.exists(path):\n", + " record = cc12m.iloc[i]\n", + " valid_records.append({\"image\": record[\"path\"], \"caption\": record[\"caption\"]})\n", + "\n", + " cnt += 1\n", + "\n", + "print(\"Found {} valid records\".format(cnt))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2759017" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(valid_records)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'image': '/export/home/.cache/lavis/conceptual_caption/images/1_3239086386.jpg',\n", + " 'caption': 'sierra looked stunning in this top and this skirt while performing with person at their former university'}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_records[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/export/home/.cache/lavis/conceptual_caption/annotations/cc3m.json already exists\n" + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "from omegaconf import OmegaConf\n", + "\n", + "\n", + "config_path = get_abs_path(\"configs/datasets/conceptual_caption/defaults_12m.yaml\")\n", + "\n", + "ann_path = OmegaConf.load(\n", + " config_path\n", + ").datasets.conceptual_caption_12m.build_info.annotations.train.storage[0]\n", + "\n", + "ann_path = get_cache_path(ann_path)\n", + "\n", + "if os.path.exists(ann_path):\n", + " # abort\n", + " print(\"{} already exists\".format(ann_path))\n", + "else:\n", + " # Save the valid records to a json file\n", + " with open(ann_path, \"w\") as f:\n", + " f.write(json.dumps(valid_records))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_3m.ipynb b/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_3m.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ce08209d0f16120d2b1c11be095de6482d9fb71f --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/create_annotation_3m.ipynb @@ -0,0 +1,227 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from lavis.common.utils import get_abs_path, get_cache_path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "cc3m = pd.read_csv(\"downloaded_cc3m_report.tsv.gz\", compression=\"gzip\", sep=\"\\t\", names=[\"caption\", \"path\", \"dataset\", \"mimetype\", \"size\", \"status\", \"url\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "caption a very typical bus station\n", + "path /export/home/.cache/lavis/conceptual_caption/i...\n", + "dataset cc3m\n", + "mimetype image/jpeg\n", + "size 36078\n", + "status 200\n", + "url http://lh6.ggpht.com/-IvRtNLNcG8o/TpFyrudaT6I/...\n", + "Name: 0, dtype: object" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cc3m.iloc[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3318333" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(cc3m)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 3130587/3130587 [17:28<00:00, 2986.08it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found 2759017 valid records\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "cnt = 0\n", + "\n", + "valid_records = []\n", + "\n", + "for i, path in tqdm(enumerate(cc3m.path.unique()), total=len(cc3m.path.unique())):\n", + " path = str(path)\n", + " if os.path.exists(path):\n", + " record = cc3m.iloc[i]\n", + " valid_records.append({\"image\": record[\"path\"], \"caption\": record[\"caption\"]})\n", + "\n", + " cnt += 1\n", + "\n", + "print(\"Found {} valid records\".format(cnt))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2759017" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(valid_records)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'image': '/export/home/.cache/lavis/conceptual_caption/images/1_3239086386.jpg',\n", + " 'caption': 'sierra looked stunning in this top and this skirt while performing with person at their former university'}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_records[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/export/home/.cache/lavis/conceptual_caption/annotations/cc3m.json already exists\n" + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "from omegaconf import OmegaConf\n", + "\n", + "\n", + "config_path = get_abs_path(\"configs/datasets/conceptual_caption/defaults_3m.yaml\")\n", + "\n", + "ann_path = OmegaConf.load(\n", + " config_path\n", + ").datasets.conceptual_caption_3m.build_info.annotations.train.storage[0]\n", + "\n", + "ann_path = get_cache_path(ann_path)\n", + "\n", + "if os.path.exists(ann_path):\n", + " # abort\n", + " print(\"{} already exists\".format(ann_path))\n", + "else:\n", + " # Save the valid records to a json file\n", + " with open(ann_path, \"w\") as f:\n", + " f.write(json.dumps(valid_records))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py b/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py new file mode 100644 index 0000000000000000000000000000000000000000..c60b6fb8e5ae81783f9fafa71648c147871798ec --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc12m.py @@ -0,0 +1,232 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import time +from PIL import Image +from lavis.common.utils import get_abs_path, get_cache_path +from multiprocessing import Pool +from omegaconf import OmegaConf +from pathlib import Path +from torchvision.transforms import functional as TF +from tqdm import tqdm +import glob +import io +import json +import magic # pip install python-magic +import numpy as np +import os +import pandas as pd +import requests +import shelve +import zlib + +headers = { + #'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36', + "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot + "X-Forwarded-For": "64.18.15.200", +} + + +def _df_split_apply(tup_arg): + split_ind, subset, func = tup_arg + r = subset.apply(func, axis=1) + return (split_ind, r) + + +def df_multiprocess(df, processes, chunk_size, func, dataset_name): + print("Generating parts...") + with shelve.open( + "%s_%s_%s_results.tmp" % (dataset_name, func.__name__, chunk_size) + ) as results: + + pbar = tqdm(total=len(df), position=0) + # Resume: + finished_chunks = set([int(k) for k in results.keys()]) + pbar.desc = "Resuming" + for k in results.keys(): + pbar.update(len(results[str(k)][1])) + + pool_data = ( + (index, df[i : i + chunk_size], func) + for index, i in enumerate(range(0, len(df), chunk_size)) + if index not in finished_chunks + ) + print( + int(len(df) / chunk_size), + "parts.", + chunk_size, + "per part.", + "Using", + processes, + "processes", + ) + + pbar.desc = "Downloading" + with Pool(processes) as pool: + for i, result in enumerate( + pool.imap_unordered(_df_split_apply, pool_data, 2) + ): + results[str(result[0])] = result + pbar.update(len(result[1])) + pbar.close() + + print("Finished Downloading.") + return + + +# Unique name based on url +def _file_name(row): + name = ( + "%s/%s_%s" + % ( + # row["folder"], + storage_dir, + row.name, + (zlib.crc32(row["url"].encode("utf-8")) & 0xFFFFFFFF), + ) + + ".jpg" + ) + return name + + +# For checking mimetypes separately without download +def check_mimetype(row): + if os.path.isfile(str(row["file"])): + row["mimetype"] = magic.from_file(row["file"], mime=True) + row["size"] = os.stat(row["file"]).st_size + return row + + +# Don't download image, just check with a HEAD request, can't resume. +# Can use this instead of download_image to get HTTP status codes. +def check_download(row): + fname = _file_name(row) + try: + # not all sites will support HEAD + response = requests.head( + row["url"], stream=False, timeout=5, allow_redirects=True, headers=headers + ) + row["status"] = response.status_code + row["headers"] = dict(response.headers) + except: + # log errors later, set error as 408 timeout + row["status"] = 408 + return row + if response.ok: + row["file"] = fname + return row + + +def resize_img(req): + image = Image.open(req).convert("RGB") + image = TF.resize( + # image, size=(resize_size, resize_size) + image, + size=resize_size, + ) # , interpolation=Image.LANCZOS) + return image + + +def download_image(row): + fname = _file_name(row) + # Skip Already downloaded, retry others later + if os.path.isfile(fname): + row["status"] = 200 + row["file"] = fname + row["mimetype"] = magic.from_file(row["file"], mime=True) + row["size"] = os.stat(row["file"]).st_size + return row + + try: + # use smaller timeout to skip errors, but can result in failed downloads + response = requests.get( + row["url"], stream=False, timeout=5, allow_redirects=True, headers=headers + ) + row["status"] = response.status_code + # row['headers'] = dict(response.headers) + except Exception as e: + # log errors later, set error as 408 timeout + row["status"] = 408 + return row + + if response.ok: + try: + # some sites respond with gzip transport encoding + response.raw.decode_content = True + img = resize_img(io.BytesIO(response.content)) + img.save(fname) + + row["mimetype"] = magic.from_file(fname, mime=True) + row["size"] = os.stat(fname).st_size + + except Exception as e: + # # This is if it times out during a download or decode + row["status"] = 408 + + row["file"] = fname + return row + + +def open_tsv(fname, folder): + print("Opening %s Data File..." % fname) + df = pd.read_csv( + fname, sep="\t", names=["url", "caption"] + ) # , usecols=range(1, 2)) + df["folder"] = folder + print("Processing", len(df), " Images:") + return df + + +def df_from_shelve(chunk_size, func, dataset_name): + print("Generating Dataframe from results...") + with shelve.open( + "%s_%s_%s_results.tmp" % (dataset_name, func.__name__, chunk_size) + ) as results: + keylist = sorted([int(k) for k in results.keys()]) + df = pd.concat([results[str(k)][1] for k in keylist], sort=True) + return df + + +resize_size = 384 + +config_path = get_abs_path("configs/datasets/conceptual_caption/defaults_12m.yaml") + +storage_dir = OmegaConf.load( + config_path +).datasets.conceptual_caption_12m.build_info.images.storage +storage_dir = Path(get_cache_path(storage_dir)) + +os.makedirs(storage_dir, exist_ok=True) + +# number of processes in the pool can be larger than cores +num_processes = 96 +# num_processes = 1 +# chunk_size is how many images per chunk per process - changing this resets progress when restarting. +images_per_part = 100 + +data_name = "cc12m" +# os.makedirs(data_name, exist_ok=True) + +df = open_tsv("cc12m.tsv", data_name) +df_multiprocess( + df=df, + processes=num_processes, + chunk_size=images_per_part, + func=download_image, + dataset_name=data_name, +) +df = df_from_shelve( + chunk_size=images_per_part, func=download_image, dataset_name=data_name +) +df.to_csv( + "downloaded_%s_report.tsv.gz" % data_name, + compression="gzip", + sep="\t", + header=False, + index=False, +) +print("Saved.") diff --git a/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py b/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py new file mode 100644 index 0000000000000000000000000000000000000000..2edd7a224436f7fa2d923501caadd40db040f8a1 --- /dev/null +++ b/lavis/datasets/download_scripts/DownloadConceptualCaptions/download_data_cc3m.py @@ -0,0 +1,229 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import glob +from pathlib import Path +import time +from omegaconf import OmegaConf +import pandas as pd +import numpy as np +import requests +import zlib +import os +import io +import shelve +from lavis.common.utils import get_abs_path, get_cache_path +import magic # pip install python-magic +import json +from multiprocessing import Pool +from tqdm import tqdm +from PIL import Image +from torchvision.transforms import functional as TF + +headers = { + #'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36', + "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot + "X-Forwarded-For": "64.18.15.200", +} + + +def _df_split_apply(tup_arg): + split_ind, subset, func = tup_arg + r = subset.apply(func, axis=1) + return (split_ind, r) + + +def df_multiprocess(df, processes, chunk_size, func, dataset_name): + print("Generating parts...") + with shelve.open( + "%s_%s_%s_results.tmp" % (dataset_name, func.__name__, chunk_size) + ) as results: + + pbar = tqdm(total=len(df), position=0) + # Resume: + finished_chunks = set([int(k) for k in results.keys()]) + pbar.desc = "Resuming" + for k in results.keys(): + pbar.update(len(results[str(k)][1])) + + pool_data = ( + (index, df[i : i + chunk_size], func) + for index, i in enumerate(range(0, len(df), chunk_size)) + if index not in finished_chunks + ) + print( + int(len(df) / chunk_size), + "parts.", + chunk_size, + "per part.", + "Using", + processes, + "processes", + ) + + pbar.desc = "Downloading" + with Pool(processes) as pool: + for i, result in enumerate( + pool.imap_unordered(_df_split_apply, pool_data, 2) + ): + results[str(result[0])] = result + pbar.update(len(result[1])) + pbar.close() + + print("Finished Downloading.") + return + + +# Unique name based on url +def _file_name(row): + name = ( + "%s/%s_%s" + % ( + # row["folder"], + storage_dir, + row.name, + (zlib.crc32(row["url"].encode("utf-8")) & 0xFFFFFFFF), + ) + + ".jpg" + ) + return name + + +# For checking mimetypes separately without download +def check_mimetype(row): + if os.path.isfile(str(row["file"])): + row["mimetype"] = magic.from_file(row["file"], mime=True) + row["size"] = os.stat(row["file"]).st_size + return row + + +# Don't download image, just check with a HEAD request, can't resume. +# Can use this instead of download_image to get HTTP status codes. +def check_download(row): + fname = _file_name(row) + try: + # not all sites will support HEAD + response = requests.head( + row["url"], stream=False, timeout=5, allow_redirects=True, headers=headers + ) + row["status"] = response.status_code + row["headers"] = dict(response.headers) + except: + # log errors later, set error as 408 timeout + row["status"] = 408 + return row + if response.ok: + row["file"] = fname + return row + + +def resize_img(req): + image = Image.open(req).convert("RGB") + image = TF.resize( + # image, size=(resize_size, resize_size) + image, + size=resize_size, + ) # , interpolation=Image.LANCZOS) + return image + + +def download_image(row): + fname = _file_name(row) + # Skip Already downloaded, retry others later + if os.path.isfile(fname): + row["status"] = 200 + row["file"] = fname + row["mimetype"] = magic.from_file(row["file"], mime=True) + row["size"] = os.stat(row["file"]).st_size + return row + + try: + # use smaller timeout to skip errors, but can result in failed downloads + response = requests.get( + row["url"], stream=False, timeout=5, allow_redirects=True, headers=headers + ) + row["status"] = response.status_code + # row['headers'] = dict(response.headers) + except Exception as e: + # log errors later, set error as 408 timeout + row["status"] = 408 + return row + + if response.ok: + try: + # some sites respond with gzip transport encoding + response.raw.decode_content = True + img = resize_img(io.BytesIO(response.content)) + img.save(fname) + + row["mimetype"] = magic.from_file(fname, mime=True) + row["size"] = os.stat(fname).st_size + + except Exception as e: + # # This is if it times out during a download or decode + row["status"] = 408 + + row["file"] = fname + return row + + +def open_tsv(fname, folder): + print("Opening %s Data File..." % fname) + df = pd.read_csv( + fname, sep="\t", names=["caption", "url"] + ) # , usecols=range(1, 2)) + df["folder"] = folder + print("Processing", len(df), " Images:") + return df + + +def df_from_shelve(chunk_size, func, dataset_name): + print("Generating Dataframe from results...") + with shelve.open( + "%s_%s_%s_results.tmp" % (dataset_name, func.__name__, chunk_size) + ) as results: + keylist = sorted([int(k) for k in results.keys()]) + df = pd.concat([results[str(k)][1] for k in keylist], sort=True) + return df + + +resize_size = 384 + +config_path = get_abs_path("configs/datasets/conceptual_caption/defaults_3m.yaml") + +storage_dir = OmegaConf.load( + config_path +).datasets.conceptual_caption_3m.build_info.images.storage +storage_dir = Path(get_cache_path(storage_dir)) + +os.makedirs(storage_dir, exist_ok=True) + +# number of processes in the pool can be larger than cores +num_processes = 32 +# chunk_size is how many images per chunk per process - changing this resets progress when restarting. +images_per_part = 100 + +data_name = "cc3m" +df = open_tsv("Train_GCC-training.tsv", data_name) +df_multiprocess( + df=df, + processes=num_processes, + chunk_size=images_per_part, + func=download_image, + dataset_name=data_name, +) +df = df_from_shelve( + chunk_size=images_per_part, func=download_image, dataset_name=data_name +) +df.to_csv( + "downloaded_%s_report.tsv.gz" % data_name, + compression="gzip", + sep="\t", + header=False, + index=False, +) +print("Saved.") diff --git a/lavis/datasets/download_scripts/download_coco.py b/lavis/datasets/download_scripts/download_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..283448aed1b745a975bc89b5c531a853efdd31f4 --- /dev/null +++ b/lavis/datasets/download_scripts/download_coco.py @@ -0,0 +1,57 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + + +DATA_URL = { + "train": "http://images.cocodataset.org/zips/train2014.zip", # md5: 0da8c0bd3d6becc4dcb32757491aca88 + "val": "http://images.cocodataset.org/zips/val2014.zip", # md5: a3d79f5ed8d289b7a7554ce06a5782b3 + "test": "http://images.cocodataset.org/zips/test2014.zip", # md5: 04127eef689ceac55e3a572c2c92f264 + "test2015": "http://images.cocodataset.org/zips/test2015.zip", # md5: 04127eef689ceac55e3a572c2c92f264 +} + + +def download_datasets(root, url): + download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/coco/defaults_cap.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.coco_caption.build_info.images.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + for k, v in DATA_URL.items(): + print("Downloading {} to {}".format(v, k)) + download_datasets(download_dir, v) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_didemo.py b/lavis/datasets/download_scripts/download_didemo.py new file mode 100644 index 0000000000000000000000000000000000000000..376b71c4de1e83442a0209796c95f55da6b3e71a --- /dev/null +++ b/lavis/datasets/download_scripts/download_didemo.py @@ -0,0 +1,70 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + +DATA_URL = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/didemo/didemo_videos.tar.gz" + + +def download_datasets(root, url): + """ + Download the Imagenet-R dataset archives and expand them + in the folder provided as parameter + """ + download_and_extract_archive(url=url, download_root=root) + + +def move_files(download_path, storage_path): + """ + Move files from download_path to storage_path + """ + print("Moving to {}".format(storage_path)) + + os.makedirs(storage_path, exist_ok=True) + + for file_name in os.listdir(download_path): + os.rename( + os.path.join(download_path, file_name), + os.path.join(storage_path, file_name), + ) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/didemo/defaults_ret.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.didemo_retrieval.build_info.videos.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + print("Downloading {} to {}".format(DATA_URL, download_dir)) + download_datasets(download_dir, DATA_URL) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + move_files(download_dir / "videos", storage_dir) + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_flickr.py b/lavis/datasets/download_scripts/download_flickr.py new file mode 100644 index 0000000000000000000000000000000000000000..3075f02299110b729ccb0f4b34f7b9cf23046b6c --- /dev/null +++ b/lavis/datasets/download_scripts/download_flickr.py @@ -0,0 +1,78 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + get_abs_path, + get_cache_path, +) + +import opendatasets as od + + +DATA_URL = "https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset" + +print( + """ + To download the dataset, you need to have a Kaggle account and the associated key. + See https://www.kaggle.com/docs/api to create account and a new API token. + """ +) + + +def move_directory(src_dir, dst_dir): + """ + Move files from download_path to storage_path + """ + print("Moving to {}".format(dst_dir)) + + os.makedirs(dst_dir, exist_ok=True) + + for file_name in os.listdir(src_dir): + os.rename( + os.path.join(src_dir, file_name), + os.path.join(dst_dir, file_name), + ) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/flickr30k/defaults.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.flickr30k.build_info.images.storage + + storage_dir = Path(get_cache_path(storage_dir)) + download_dir = storage_dir.parent / "download" + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + os.makedirs(download_dir) + + try: + print("Downloading {} to {}".format(DATA_URL, download_dir)) + od.download(DATA_URL, download_dir) + except Exception as e: + print(e) + # remove download dir if failed + cleanup_dir(download_dir) + exit(1) + + move_directory( + download_dir / "flickr-image-dataset" / "flickr30k_images" / "flickr30k_images", + storage_dir / "flickr30k-images", + ) + + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_gqa.py b/lavis/datasets/download_scripts/download_gqa.py new file mode 100644 index 0000000000000000000000000000000000000000..0bce71408c9f8d8973ef8f7fa9419d328127978e --- /dev/null +++ b/lavis/datasets/download_scripts/download_gqa.py @@ -0,0 +1,51 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + + +DATA_URL = "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip" + + +def download_datasets(root, url): + download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir.parent) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/gqa/defaults.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.gqa.build_info.images.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + print("Downloading {}".format(DATA_URL)) + download_datasets(download_dir, DATA_URL) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_msrvtt.py b/lavis/datasets/download_scripts/download_msrvtt.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9dc1cd942ad3a17d0debe0c2b94e6edbc56c61 --- /dev/null +++ b/lavis/datasets/download_scripts/download_msrvtt.py @@ -0,0 +1,105 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + + +# TODO +# 1. Go to https://www.mediafire.com/file/czh8sezbo9s4692/test_videos.zip/file +# and https://www.mediafire.com/file/x3rrbe4hwp04e6w/train_val_videos.zip/file +# 2. Right-click the Download button and copy the link address +# e.g. +# DATA_URL = { +# "train": "https://download1602.mediafire.com/xxxxxxxxxxxx/x3rrbe4hwp04e6w/train_val_videos.zip", +# "test": "https://download2390.mediafire.com/xxxxxxxxxxxx/czh8sezbo9s4692/test_videos.zip", +# } +# 3. Paste the link address to DATA_URL + +DATA_URL = { + "train": "https://download2295.mediafire.com/4bb7p74xrbgg/x3rrbe4hwp04e6w/train_val_videos.zip", + "test": "https://download2390.mediafire.com/79hfq3592lqg/czh8sezbo9s4692/test_videos.zip", +} + + +def download_datasets(root, url): + """ + Download the Imagenet-R dataset archives and expand them + in the folder provided as parameter + """ + download_and_extract_archive(url=url, download_root=root) + + +def merge_datasets(download_path, storage_path): + """ + Merge datasets in download_path to storage_path + """ + + # Merge train and test datasets + train_path = os.path.join(download_path, "TrainValVideo") + test_path = os.path.join(download_path, "TestVideo") + train_test_path = storage_path + + print("Merging to {}".format(train_test_path)) + + os.makedirs(train_test_path, exist_ok=True) + + for file_name in os.listdir(train_path): + os.rename( + os.path.join(train_path, file_name), + os.path.join(train_test_path, file_name), + ) + + for file_name in os.listdir(test_path): + os.rename( + os.path.join(test_path, file_name), + os.path.join(train_test_path, file_name), + ) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/msrvtt/defaults_cap.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.msrvtt_cap.build_info.videos.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + for k, v in DATA_URL.items(): + print("Downloading {} to {}".format(v, k)) + download_datasets(download_dir, v) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + try: + merge_datasets(download_dir, storage_dir) + except Exception as e: + # remove storage dir if failed + cleanup_dir(download_dir) + cleanup_dir(storage_dir) + print("Failed to merging datasets. Aborting.") + + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_msvd.py b/lavis/datasets/download_scripts/download_msvd.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bf5467f3af7acdde7f7a25a38d28c599525771 --- /dev/null +++ b/lavis/datasets/download_scripts/download_msvd.py @@ -0,0 +1,67 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + + +DATA_URL = "https://www.cs.utexas.edu/users/ml/clamp/videoDescription/YouTubeClips.tar" + + +def download_datasets(root, url): + download_and_extract_archive(url=url, download_root=root) + + +def move_files(download_path, storage_path): + """ + Move files from download_path to storage_path + """ + print("Moving to {}".format(storage_path)) + + os.makedirs(storage_path, exist_ok=True) + + for file_name in os.listdir(download_path): + os.rename( + os.path.join(download_path, file_name), + os.path.join(storage_path, file_name), + ) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/msvd/defaults_cap.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.msvd_cap.build_info.videos.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + print("Downloading {}".format(DATA_URL)) + download_datasets(download_dir, DATA_URL) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + move_files(download_dir / "YouTubeClips", storage_dir) + cleanup_dir(download_dir) diff --git a/lavis/datasets/download_scripts/download_nocaps.py b/lavis/datasets/download_scripts/download_nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..ab56a7c10d958e6debb3968ca1c4def3da3beb0a --- /dev/null +++ b/lavis/datasets/download_scripts/download_nocaps.py @@ -0,0 +1,134 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import json +import logging +import os +import time +from multiprocessing import Pool + +import numpy as np +import requests +import tqdm +from lavis.common.utils import cleanup_dir, get_abs_path, get_cache_path +from omegaconf import OmegaConf + +header_mzl = { + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", + # "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot + # "X-Forwarded-For": "64.18.15.200", +} + +header_gbot = { + "User-Agent": "Googlebot-Image/1.0", # Pretend to be googlebot +} + +headers = [header_mzl, header_gbot] + +# Setup +logging.basicConfig(filename="download_nocaps.log", filemode="w", level=logging.INFO) +requests.packages.urllib3.disable_warnings( + requests.packages.urllib3.exceptions.InsecureRequestWarning +) + + +def download_file(url, filename): + max_retries = 20 + cur_retries = 0 + + header = headers[0] + + while cur_retries < max_retries: + try: + r = requests.get(url, headers=header, timeout=10) + with open(filename, "wb") as f: + f.write(r.content) + + break + except Exception as e: + logging.info(" ".join(repr(e).splitlines())) + logging.error(url) + cur_retries += 1 + + # random sample a header from headers + header = headers[np.random.randint(0, len(headers))] + + time.sleep(3 + cur_retries * 2) + + +def download_image_from_url_val(url): + basename = os.path.basename(url) + filename = os.path.join(storage_dir, "val", basename) + + download_file(url, filename) + + +def download_image_from_url_test(url): + basename = os.path.basename(url) + filename = os.path.join(storage_dir, "test", basename) + + download_file(url, filename) + + +if __name__ == "__main__": + os.makedirs("tmp", exist_ok=True) + + # storage dir + config_path = get_abs_path("configs/datasets/nocaps/defaults.yaml") + + storage_dir = OmegaConf.load(config_path).datasets.nocaps.build_info.images.storage + storage_dir = get_cache_path(storage_dir) + # make sure the storage dir exists + os.makedirs(storage_dir, exist_ok=True) + print("Storage dir:", storage_dir) + + # make sure the storage dir for val and test exists + os.makedirs(os.path.join(storage_dir, "val"), exist_ok=True) + os.makedirs(os.path.join(storage_dir, "test"), exist_ok=True) + + # download annotations + val_url = "https://nocaps.s3.amazonaws.com/nocaps_val_4500_captions.json" + tst_url = "https://s3.amazonaws.com/nocaps/nocaps_test_image_info.json" + + print("Downloading validation annotations from %s" % val_url) + download_file(val_url, "tmp/nocaps_val_ann.json") + print("Downloading testing annotations from %s" % tst_url) + download_file(tst_url, "tmp/nocaps_tst_ann.json") + + # open annotations + val_ann = json.load(open("tmp/nocaps_val_ann.json")) + tst_ann = json.load(open("tmp/nocaps_tst_ann.json")) + + # collect image urls + val_info = val_ann["images"] + tst_info = tst_ann["images"] + + val_urls = [info["coco_url"] for info in val_info] + tst_urls = [info["coco_url"] for info in tst_info] + + # setup multiprocessing + # large n_procs possibly causes server to reject requests + n_procs = 16 + + with Pool(n_procs) as pool: + print("Downloading validation images...") + list( + tqdm.tqdm( + pool.imap(download_image_from_url_val, val_urls), total=len(val_urls) + ) + ) + + with Pool(n_procs) as pool: + print("Downloading test images...") + list( + tqdm.tqdm( + pool.imap(download_image_from_url_test, tst_urls), total=len(tst_urls) + ) + ) + + # clean tmp + cleanup_dir("tmp") diff --git a/lavis/datasets/download_scripts/download_sbu.py b/lavis/datasets/download_scripts/download_sbu.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffbf43c670d471f7eb160bcb8a9b6bd887aaf65 --- /dev/null +++ b/lavis/datasets/download_scripts/download_sbu.py @@ -0,0 +1,82 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import io +import os +import pathlib +import urllib +import tqdm + +from concurrent.futures import ThreadPoolExecutor + +from lavis.common.utils import get_abs_path, get_cache_path +from lavis.datasets.builders import load_dataset +from omegaconf import OmegaConf +from PIL import Image + +# DATA_URL = {"train": "http://www.cs.rice.edu/~vo9/sbucaptions/sbu_images.tar"} + +USER_AGENT = ( + "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:15.0) Gecko/20100101 Firefox/15.0.1" +) + + +def fetch_single_image(image_url, timeout=None, retries=0): + for _ in range(retries + 1): + try: + request = urllib.request.Request( + image_url, + data=None, + headers={"user-agent": USER_AGENT}, + ) + with urllib.request.urlopen(request, timeout=timeout) as req: + image = Image.open(io.BytesIO(req.read())) + break + except Exception: + image = None + return image + + +def download_and_save_image(ann, save_dir, timeout=None, retries=0): + image = fetch_single_image(ann["url"], timeout=timeout, retries=retries) + + if image is not None: + image_path = os.path.join(save_dir, ann["image"]) + print(image_path) + image.save(image_path) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/sbu_caption/defaults.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.sbu_caption.build_info.images.storage + + storage_dir = pathlib.Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + storage_dir.mkdir(parents=True, exist_ok=True) + + num_threads = 20 + dset = load_dataset("sbu_caption")["train"].annotation + + print("Downloading dataset...") + # multiprocessing + with ThreadPoolExecutor(max_workers=num_threads) as executor: + for ann in tqdm.tqdm(dset): + executor.submit( + download_and_save_image, + ann, + storage_dir, + timeout=30, + retries=10, + ) diff --git a/lavis/datasets/download_scripts/download_vg.py b/lavis/datasets/download_scripts/download_vg.py new file mode 100644 index 0000000000000000000000000000000000000000..7fbb7828f035f2cc9b32471129f0d2ec0f916f8e --- /dev/null +++ b/lavis/datasets/download_scripts/download_vg.py @@ -0,0 +1,55 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +from pathlib import Path + +from omegaconf import OmegaConf + +from lavis.common.utils import ( + cleanup_dir, + download_and_extract_archive, + get_abs_path, + get_cache_path, +) + + +DATA_URL = { + "train": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip", + "train2": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip", +} + + +def download_datasets(root, url): + download_and_extract_archive(url=url, download_root=root, extract_root=storage_dir) + + +if __name__ == "__main__": + + config_path = get_abs_path("configs/datasets/vg/defaults_caption.yaml") + + storage_dir = OmegaConf.load( + config_path + ).datasets.vg_caption.build_info.images.storage + + download_dir = Path(get_cache_path(storage_dir)).parent / "download" + storage_dir = Path(get_cache_path(storage_dir)) + + if storage_dir.exists(): + print(f"Dataset already exists at {storage_dir}. Aborting.") + exit(0) + + try: + for k, v in DATA_URL.items(): + print("Downloading {} to {}".format(v, k)) + download_datasets(download_dir, v) + except Exception as e: + # remove download dir if failed + cleanup_dir(download_dir) + print("Failed to download or extracting datasets. Aborting.") + + cleanup_dir(download_dir) diff --git a/lavis/models/__init__.py b/lavis/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c658eb8f6ce3b77c07218fda1151b47162666ec6 --- /dev/null +++ b/lavis/models/__init__.py @@ -0,0 +1,260 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import torch +from omegaconf import OmegaConf +from lavis.common.registry import registry + +from lavis.models.base_model import BaseModel + +from lavis.models.albef_models.albef_classification import AlbefClassification +from lavis.models.albef_models.albef_feature_extractor import AlbefFeatureExtractor +from lavis.models.albef_models.albef_nlvr import AlbefNLVR +from lavis.models.albef_models.albef_pretrain import AlbefPretrain +from lavis.models.albef_models.albef_retrieval import AlbefRetrieval +from lavis.models.albef_models.albef_vqa import AlbefVQA +from lavis.models.alpro_models.alpro_qa import AlproQA +from lavis.models.alpro_models.alpro_retrieval import AlproRetrieval + +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_caption import BlipCaption +from lavis.models.blip_models.blip_classification import BlipClassification +from lavis.models.blip_models.blip_feature_extractor import BlipFeatureExtractor +from lavis.models.blip_models.blip_image_text_matching import BlipITM +from lavis.models.blip_models.blip_nlvr import BlipNLVR +from lavis.models.blip_models.blip_pretrain import BlipPretrain +from lavis.models.blip_models.blip_retrieval import BlipRetrieval +from lavis.models.blip_models.blip_vqa import BlipVQA + +from lavis.models.blip2_models.blip2 import Blip2Base +from lavis.models.blip2_models.blip2_opt import Blip2OPT +from lavis.models.blip2_models.blip2_t5 import Blip2T5 +from lavis.models.blip2_models.blip2_qformer import Blip2Qformer +from lavis.models.blip2_models.blip2_image_text_matching import Blip2ITM + +from lavis.models.pnp_vqa_models.pnp_vqa import PNPVQA +from lavis.models.pnp_vqa_models.pnp_unifiedqav2_fid import PNPUnifiedQAv2FiD +from lavis.models.img2prompt_models.img2prompt_vqa import Img2PromptVQA +from lavis.models.med import XBertLMHeadDecoder +from lavis.models.vit import VisionTransformerEncoder +from lavis.models.clip_models.model import CLIP + +from lavis.models.gpt_models.gpt_dialogue import GPTDialogue + +from lavis.processors.base_processor import BaseProcessor + + +__all__ = [ + "load_model", + "AlbefClassification", + "AlbefFeatureExtractor", + "AlbefNLVR", + "AlbefVQA", + "AlbefPretrain", + "AlbefRetrieval", + "AlproQA", + "AlproRetrieval", + "BaseModel", + "BlipBase", + "BlipFeatureExtractor", + "BlipCaption", + "BlipClassification", + "BlipITM", + "BlipNLVR", + "BlipPretrain", + "BlipRetrieval", + "BlipVQA", + "Blip2Qformer", + "Blip2Base", + "Blip2ITM", + "Blip2OPT", + "Blip2T5", + "PNPVQA", + "Img2PromptVQA", + "PNPUnifiedQAv2FiD", + "CLIP", + "VisionTransformerEncoder", + "XBertLMHeadDecoder", + "GPTDialogue", +] + + +def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None): + """ + Load supported models. + + To list all available models and types in registry: + >>> from lavis.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + checkpoint (str): path or to checkpoint. Default: None. + Note that expecting the checkpoint to have the same keys in state_dict as the model. + + Returns: + model (torch.nn.Module): model. + """ + + model = registry.get_model_class(name).from_pretrained(model_type=model_type) + + if checkpoint is not None: + model.load_checkpoint(checkpoint) + + if is_eval: + model.eval() + + if device == "cpu": + model = model.float() + + return model.to(device) + + +def load_preprocess(config): + """ + Load preprocessor configs and construct preprocessors. + + If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing. + + Args: + config (dict): preprocessor configs. + + Returns: + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + + Key is "train" or "eval" for processors used in training and evaluation respectively. + """ + + def _build_proc_from_cfg(cfg): + return ( + registry.get_processor_class(cfg.name).from_config(cfg) + if cfg is not None + else BaseProcessor() + ) + + vis_processors = dict() + txt_processors = dict() + + vis_proc_cfg = config.get("vis_processor") + txt_proc_cfg = config.get("text_processor") + + if vis_proc_cfg is not None: + vis_train_cfg = vis_proc_cfg.get("train") + vis_eval_cfg = vis_proc_cfg.get("eval") + else: + vis_train_cfg = None + vis_eval_cfg = None + + vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg) + vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg) + + if txt_proc_cfg is not None: + txt_train_cfg = txt_proc_cfg.get("train") + txt_eval_cfg = txt_proc_cfg.get("eval") + else: + txt_train_cfg = None + txt_eval_cfg = None + + txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg) + txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg) + + return vis_processors, txt_processors + + +def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"): + """ + Load model and its related preprocessors. + + List all available models and types in registry: + >>> from lavis.models import model_zoo + >>> print(model_zoo) + + Args: + name (str): name of the model. + model_type (str): type of the model. + is_eval (bool): whether the model is in eval mode. Default: False. + device (str): device to use. Default: "cpu". + + Returns: + model (torch.nn.Module): model. + vis_processors (dict): preprocessors for visual inputs. + txt_processors (dict): preprocessors for text inputs. + """ + model_cls = registry.get_model_class(name) + + # load model + model = model_cls.from_pretrained(model_type=model_type) + + if is_eval: + model.eval() + + # load preprocess + cfg = OmegaConf.load(model_cls.default_config_path(model_type)) + if cfg is not None: + preprocess_cfg = cfg.preprocess + + vis_processors, txt_processors = load_preprocess(preprocess_cfg) + else: + vis_processors, txt_processors = None, None + logging.info( + f"""No default preprocess for model {name} ({model_type}). + This can happen if the model is not finetuned on downstream datasets, + or it is not intended for direct use without finetuning. + """ + ) + + if device == "cpu" or device == torch.device("cpu"): + model = model.float() + + return model.to(device), vis_processors, txt_processors + + +class ModelZoo: + """ + A utility class to create string representation of available model architectures and types. + + >>> from lavis.models import model_zoo + >>> # list all available models + >>> print(model_zoo) + >>> # show total number of models + >>> print(len(model_zoo)) + """ + + def __init__(self) -> None: + self.model_zoo = { + k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys()) + for k, v in registry.mapping["model_name_mapping"].items() + } + + def __str__(self) -> str: + return ( + "=" * 50 + + "\n" + + f"{'Architectures':<30} {'Types'}\n" + + "=" * 50 + + "\n" + + "\n".join( + [ + f"{name:<30} {', '.join(types)}" + for name, types in self.model_zoo.items() + ] + ) + ) + + def __iter__(self): + return iter(self.model_zoo.items()) + + def __len__(self): + return sum([len(v) for v in self.model_zoo.values()]) + + +model_zoo = ModelZoo() diff --git a/lavis/models/albef_models/__init__.py b/lavis/models/albef_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..512729237e32354ec0aca598343320afaf7d4acd --- /dev/null +++ b/lavis/models/albef_models/__init__.py @@ -0,0 +1,202 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import os +import time + +import lavis.common.dist_utils as dist_utils +import torch +import torch.distributed as dist +import torch.nn.functional as F +from lavis.common.dist_utils import download_cached_file +from lavis.common.logger import MetricLogger +from lavis.common.utils import is_url +from lavis.models.base_model import BaseModel +from lavis.models.vit import interpolate_pos_embed +from transformers import BertTokenizer + + +class AlbefBase(BaseModel): + @classmethod + def init_tokenizer(cls): + return BertTokenizer.from_pretrained("bert-base-uncased") + + def load_from_pretrained(self, url_or_filename, rename_text_keys=True): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( + state_dict["visual_encoder.pos_embed"], self.visual_encoder + ) + if ( + "visual_encoder_m.pos_embed" in self.state_dict().keys() + and "visual_encoder_m.pos_embed" in state_dict + ): + state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( + state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m + ) + + if rename_text_keys: + for key in list(state_dict.keys()): + if "bert" in key: + new_key = key.replace("bert.", "") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + for key in self.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape != self.state_dict()[key].shape: + del state_dict[key] + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + return msg + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_output = model.text_encoder.forward_text(text_input) + text_embed = F.normalize( + model.text_proj(text_output.last_hidden_state[:, 0, :]) + ) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + if hasattr(model.tokenizer, "enc_token_id"): + text_ids[:, 0] = model.tokenizer.enc_token_id + + image_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(model.device) + image_feat = model.visual_encoder.forward_features(image) + image_embed = model.vision_proj(image_feat[:, 0, :]) + image_embed = F.normalize(image_embed, dim=-1) + + image_feats.append(image_feat.cpu()) + image_embeds.append(image_embed) + + image_feats = torch.cat(image_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = image_embeds @ text_embeds.t() + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + # topk_sim, topk_idx = sims.topk(k=config["k_test"], dim=0) + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + + encoder_output = image_feats[start + i].repeat(k_test, 1, 1).to(model.device) + encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( + model.device + ) + output = model.text_encoder( + text_ids[topk_idx], + attention_mask=text_atts[topk_idx], + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + encoder_output = image_feats[topk_idx.cpu()].to(model.device) + encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to( + model.device + ) + output = model.text_encoder( + text_ids[start + i].repeat(k_test, 1), + attention_mask=text_atts[start + i].repeat(k_test, 1), + encoder_hidden_states=encoder_output, + encoder_attention_mask=encoder_att, + return_dict=True, + ) + score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/lavis/models/albef_models/albef_classification.py b/lavis/models/albef_models/albef_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..2d82de9a40ab53f443bf67fc6cfe24c6b6ed81cd --- /dev/null +++ b/lavis/models/albef_models/albef_classification.py @@ -0,0 +1,182 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import warnings +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.albef_models import AlbefBase +from lavis.models.albef_models.albef_outputs import ( + AlbefIntermediateOutput, + AlbefOutputWithLogits, +) +from lavis.models.base_model import MomentumDistilationMixin +from lavis.models.med import XBertEncoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("albef_classification") +class AlbefClassification(AlbefBase, MomentumDistilationMixin): + PRETRAINED_MODEL_CONFIG_DICT = { + "ve": "configs/models/albef_classification_ve.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + num_classes, + momentum=0.995, + alpha=0.4, + use_distill=True, + max_txt_len=40, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.max_txt_len = max_txt_len + + self.use_distill = use_distill + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + hidden_size = text_encoder.config.hidden_size + + if num_classes > 0: + self.cls_head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, num_classes), + ) + else: + warnings.warn( + f"Found num_classes=0, initializing {type(self)} without classifier." + ) + + if self.use_distill: + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + self.cls_head_m = deepcopy(self.cls_head) + + self.momentum = momentum + self.alpha = alpha + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.cls_head, self.cls_head_m], + ] + + self.copy_params() + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch) + + def forward(self, samples, is_train=True): + sentences = samples["text_input"] + sentences = self.tokenizer( + sentences, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + samples.update({"tokenized_text": sentences}) + + targets = samples["label"] + + image_embeds = self.visual_encoder.forward_features(samples["image"]) + encoder_output = self.text_encoder.forward_automask( + samples["tokenized_text"], image_embeds + ) + + prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) + + if is_train: + if self.use_distill: + with torch.no_grad(): + self._momentum_update() + + image_embeds_m = self.visual_encoder_m(samples["image"]) + encoder_output_m = self.text_encoder_m.forward_automask( + samples["tokenized_text"], image_embeds_m + ) + + prediction_m = self.cls_head_m( + encoder_output_m.last_hidden_state[:, 0, :] + ) + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + loss = (1 - alpha) * F.cross_entropy( + prediction, targets + ) - alpha * torch.sum( + F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1), + dim=1, + ).mean() + else: + loss = F.cross_entropy(prediction, targets) + + image_embeds_m, encoder_output_m, prediction_m = None, None, None + + # return {"loss": loss} + return AlbefOutputWithLogits( + loss=loss, + intermediate_output=AlbefIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + encoder_output=encoder_output, + encoder_output_m=encoder_output_m, + ), + logits=prediction, + logits_m=prediction_m, + ) + else: + return {"predictions": prediction, "targets": targets} + + def predict(self, samples): + output = self.forward(samples, is_train=False) + return output + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + # text encoder + multimodal encoder + text_encoder = XBertEncoder.from_config(cfg) + + alpha = cfg.get("alpha", 0.4) + momentum = cfg.get("momentum", 0.995) + use_distill = cfg.get("use_distill", True) + num_classes = cfg.get("num_classes", -1) + max_txt_len = cfg.get("max_txt_len", 40) + + assert num_classes > 1, "Invalid number of classes provided, found {}".format( + num_classes + ) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + use_distill=use_distill, + alpha=alpha, + num_classes=num_classes, + momentum=momentum, + max_txt_len=max_txt_len, + ) + + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/albef_models/albef_feature_extractor.py b/lavis/models/albef_models/albef_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..df7632c6d8e7eac7e6ae019379e53febd3f7ef0c --- /dev/null +++ b/lavis/models/albef_models/albef_feature_extractor.py @@ -0,0 +1,204 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import warnings + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path +from lavis.models.albef_models import AlbefBase +from lavis.models.albef_models.albef_outputs import AlbefOutputFeatures +from lavis.models.med import BertForMaskedLM +from lavis.models.vit import VisionTransformerEncoder +from torch import nn +from transformers import BertConfig + + +@registry.register_model("albef_feature_extractor") +class AlbefFeatureExtractor(AlbefBase): + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/albef_feature_extractor.yaml", + } + + def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=30): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.embed_dim = embed_dim + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.max_txt_len = max_txt_len + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + @torch.no_grad() + def extract_features(self, samples, mode="multimodal"): + """ + Extract features for multimodal or unimodal samples. + + Args: + samples (dict): A dictionary of samples, containing the following keys: + - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. + Raw images should be preprocessed before being passed to feature extractor. + - text_input (list): A list of strings containing the text, length B. + mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". + If "multimodal", return image features and multimodal features; + if "text", return text features; + if "image", return image features. + Default: "multimodal". + + Returns: + An AlbefOutputFeatures object, see lavis/models/albef_models/albef_outputs.py for details. + + Examples: + ```python + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> caption = "a large fountain spewing water into the air" + >>> model, vis_processors, txt_processors = load_model_and_preprocess("albef_feature_extractor", is_eval=True) + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> text_input = txt_processors["eval"](caption) + + >>> sample = {"image": image, "text_input": [text_input]} + + >>> features_multimodal = model.extract_features(sample) + >>> features_multimodal.keys() + odict_keys(['image_embeds', 'multimodal_embeds']) + >>> features_multimodal.image_embeds.shape + torch.Size([1, 197, 768]) + >>> features_multimodal.multimodal_embeds.shape + torch.Size([1, 12, 768]) + + >>> features_text = model.extract_features(sample, mode="text") + >>> features_text.keys() + odict_keys(['text_embeds', 'text_features']) + >>> features_text.text_embeds.shape + torch.Size([1, 12, 768]) + >>> features_text.text_features.shape + torch.Size([1, 12, 256]) + + >>> features_image = model.extract_features(sample, mode="image") + >>> features_image.keys() + odict_keys(['image_embeds', 'image_features']) + >>> features_image.image_embeds.shape + torch.Size([1, 197, 768]) + >>> features_image.image_features.shape + torch.Size([1, 197, 256]) + ``` + """ + image = samples["image"] + caption = samples["text_input"] + + if isinstance(mode, str): + mode = [mode] + + for m in mode: + assert m in [ + "multimodal", + "image", + "text", + ], "mode must be one of [multimodal, image, text], but got {}".format(m) + + # initalize output + image_embeds, text_embeds, multimodal_embeds = None, None, None + image_features, text_features = None, None + + if "image" in mode or "multimodal" in mode: + assert ( + image is not None + ), "image must be provided if mode is 'image' or 'multimodal'" + + image_embeds = self.visual_encoder.forward_features(image) + image_features = F.normalize(self.vision_proj(image_embeds), dim=-1) + + if "text" in mode or "multimodal" in mode: + assert ( + caption is not None + ), "text must be provided if mode is 'text' or 'multimodal'" + + text = self.tokenizer( + caption, + padding=True, + return_tensors="pt", + ).to(self.device) + + text_output = self.text_encoder.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + text_embeds = text_output.last_hidden_state + text_features = F.normalize(self.text_proj(text_embeds), dim=-1) + + if "multimodal" in mode: + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + # forward the positve image-text pair + output = self.text_encoder.bert( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + mode="fusion", + ) + + multimodal_embeds = output.last_hidden_state + + return AlbefOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + multimodal_embeds=multimodal_embeds, + ) + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True) + config_text_encoder = BertConfig.from_json_file( + get_abs_path(cfg["med_config_path"]) + ) + config_text_encoder.fusion_layer = 6 + text_encoder = BertForMaskedLM.from_pretrained( + "bert-base-uncased", config=config_text_encoder + ) + + embed_dim = cfg.get("embed_dim", 256) + max_txt_len = cfg.get("max_txt_len", 30) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + embed_dim=embed_dim, + max_txt_len=max_txt_len, + ) + + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + if pretrain_path is not None: + msg = model.load_from_pretrained( + url_or_filename=pretrain_path, rename_text_keys=False + ) + else: + warnings.warn("No pretrained weights are loaded.") + + return model diff --git a/lavis/models/albef_models/albef_nlvr.py b/lavis/models/albef_models/albef_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..5df836b18479600f0c1dedd7d56200b2b6b054d9 --- /dev/null +++ b/lavis/models/albef_models/albef_nlvr.py @@ -0,0 +1,260 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path +from lavis.models.albef_models import AlbefBase +from lavis.models.albef_models.albef_outputs import AlbefIntermediateOutput, AlbefOutput +from lavis.models.base_model import MomentumDistilationMixin +from lavis.models.med import BertModel +from lavis.models.vit import VisionTransformerEncoder +from torch import nn +from transformers import BertConfig + + +@registry.register_model("albef_nlvr") +class AlbefNLVR(AlbefBase, MomentumDistilationMixin): + PRETRAINED_MODEL_CONFIG_DICT = { + "nlvr": "configs/models/albef_nlvr.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + num_classes, + momentum=0.995, + alpha=0.4, + use_distill=True, + max_txt_len=40, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.max_txt_len = max_txt_len + + self.use_distill = use_distill + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + hidden_size = text_encoder.config.hidden_size + self.cls_head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, num_classes), + ) + + self.share_cross_attention(self.text_encoder.encoder) + + if self.use_distill: + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + self.cls_head_m = deepcopy(self.cls_head) + + self.share_cross_attention(self.text_encoder_m.encoder) + + self.momentum = momentum + self.alpha = alpha + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.cls_head, self.cls_head_m], + ] + + self.copy_params() + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) + + def forward(self, samples, is_train=True): + """ + Forward function for training and evaluation. + + Args: + samples (dict): a dict of input samples, which contains the following keys: + - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384. + - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384. + - text_input (list): list of strings, each string is a natural language sentence. + - label (torch.LongTensor): ground truth label with shape (batch_size,). + is_train (bool): whether the model is in training mode. + If True, the model will return the loss; + If False, the model will return the prediction. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("albef_nlvr") + >>> samples = { + ... "image0": torch.randn(2, 3, 384, 384), + ... "image1": torch.randn(2, 3, 384, 384), + ... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"], + ... "label": torch.tensor([0, 1]), + ... } + >>> output = model(samples) + >>> output.keys() + odict_keys(['intermediate_output', 'loss']) + """ + text = samples["text_input"] + text = self.tokenizer( + text, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + + targets = samples["label"] + + image0 = samples["image0"] + image1 = samples["image1"] + images = torch.cat([image0, image1], dim=0) + + image_embeds = self.visual_encoder.forward_features(images) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) + + encoder_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[: image0_embeds.size(0)], + image_atts[image0_embeds.size(0) :], + ], + return_dict=True, + ) + + prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) + + if is_train: + if self.use_distill: + with torch.no_grad(): + self._momentum_update() + + image_embeds_m = self.visual_encoder_m(images) + image0_embeds_m, image1_embeds_m = torch.split( + image_embeds_m, targets.size(0) + ) + encoder_output_m = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=[image0_embeds_m, image1_embeds_m], + encoder_attention_mask=[ + image_atts[: image0_embeds_m.size(0)], + image_atts[image0_embeds_m.size(0) :], + ], + return_dict=True, + ) + + prediction_m = self.cls_head_m( + encoder_output_m.last_hidden_state[:, 0, :] + ) + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + loss = (1 - alpha) * F.cross_entropy( + prediction, targets + ) - alpha * torch.sum( + F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1), + dim=1, + ).mean() + else: + loss = F.cross_entropy(prediction, targets) + + encoder_output_m = None + image0_embeds_m, image1_embeds_m = None, None + + # return {"loss": loss} + return AlbefOutput( + loss=loss, + intermediate_output=AlbefIntermediateOutput( + image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0), + image_embeds_m=torch.stack( + [image0_embeds_m, image1_embeds_m], dim=0 + ), + encoder_output=encoder_output, + encoder_output_m=encoder_output_m, + ), + ) + else: + return {"predictions": prediction, "targets": targets} + + def share_cross_attention(self, model): + for i in range(6): + layer_num = 6 + i * 2 + modules_0 = model.layer[layer_num].crossattention.self._modules + modules_1 = model.layer[layer_num + 1].crossattention.self._modules + + for name in modules_0.keys(): + if "key" in name or "value" in name: + module_0 = modules_0[name] + module_1 = modules_1[name] + if hasattr(module_0, "weight"): + module_0.weight = module_1.weight + if hasattr(module_0, "bias"): + module_0.bias = module_1.bias + + def predict(self, samples): + output = self.forward(samples, is_train=False) + return output + + def load_from_pretrained(self, url_or_filename, use_distill=True): + _, msg = super().load_from_pretrained(url_or_filename) + + if use_distill and any(["_m" in k for k in msg.missing_keys]): + # this is required when initializing the model from TA pre-trained weights + self.copy_params() + + return msg + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + # text encoder + multimodal encoder + bert_config = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"])) + bert_config.num_hidden_layers = 18 + + text_encoder = BertModel.from_pretrained( + "bert-base-uncased", config=bert_config, add_pooling_layer=False + ) + + alpha = cfg.get("alpha", 0.4) + momentum = cfg.get("momentum", 0.995) + use_distill = cfg.get("use_distill", True) + num_classes = cfg.get("num_classes", -1) + max_txt_len = cfg.get("max_txt_len", 40) + + assert num_classes > 1, "Invalid number of classes provided, found {}".format( + num_classes + ) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + use_distill=use_distill, + alpha=alpha, + num_classes=num_classes, + momentum=momentum, + max_txt_len=max_txt_len, + ) + + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/albef_models/albef_outputs.py b/lavis/models/albef_models/albef_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f73f39cf175319aa095cb24f30e9496f305a74 --- /dev/null +++ b/lavis/models/albef_models/albef_outputs.py @@ -0,0 +1,97 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + ModelOutput, +) + + +@dataclass +class AlbefSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class AlbefIntermediateOutput(ModelOutput): + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_m: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class AlbefOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[AlbefSimilarity] = None + + intermediate_output: AlbefIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_mlm: Optional[torch.FloatTensor] = None + + +@dataclass +class AlbefOutputWithLogits(AlbefOutput): + logits: torch.FloatTensor = None + logits_m: torch.FloatTensor = None + + +@dataclass +class AlbefOutputFeatures(ModelOutput): + """ + Data class of features from AlbefFeatureExtractor. + + Args: + image_embeds: `torch.FloatTensor` of shape `(batch_size, num_patches+1, embed_dim)`, `optional` + image_features: `torch.FloatTensor` of shape `(batch_size, num_patches+1, feature_dim)`, `optional` + text_embeds: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, embed_dim)`, `optional` + text_features: `torch.FloatTensor` of shape `(batch_size, sequence_length+1, feature_dim)`, `optional` + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/lavis/models/albef_models/albef_pretrain.py b/lavis/models/albef_models/albef_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..e25baf30a65f3218bb7b9ab8ebed6b01f74c773b --- /dev/null +++ b/lavis/models/albef_models/albef_pretrain.py @@ -0,0 +1,416 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path +from lavis.models.albef_models import AlbefBase +from lavis.models.albef_models.albef_outputs import ( + AlbefIntermediateOutput, + AlbefOutput, + AlbefSimilarity, +) +from lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin +from lavis.models.med import BertForMaskedLM +from lavis.models.vit import VisionTransformerEncoder +from torch import nn +from transformers import BertConfig + + +@registry.register_model("albef_pretrain") +class AlbefPretrain(AlbefBase, MomentumDistilationMixin, SharedQueueMixin): + """ + ALBEF pretrain model. + + Supported model types: + - base: ALBEF base model used for pretraining. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/albef_pretrain_base.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + queue_size, + embed_dim=256, + mlm_mask_prob=0.15, + temp=0.07, + momentum=0.995, + alpha=0.4, + max_txt_len=30, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.embed_dim = embed_dim + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create the momentum encoder + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + + self.vision_proj_m = deepcopy(self.vision_proj) + self.text_proj_m = deepcopy(self.text_proj) + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.vision_proj, self.vision_proj_m], + [self.text_proj, self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(temp * torch.ones([])) + + self.alpha = alpha + self.max_txt_len = max_txt_len + + self.mlm_probability = mlm_mask_prob + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) + + def forward(self, samples): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. Default: H=224, W=224. + - text_input (list): A list of length batch_size, each element is a string of text/caption. + - epoch (int): The current epoch. + - iters (int): The current iteration. + - num_iters_per_epoch (int): The number of iterations per epoch. + + Returns: + BlipOutput: A BlipOutput object containing loss and intermediate output. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("albef_pretrain") + >>> images = torch.randn(4, 3, 224, 224) + >>> text_input = ["caption of image 1", "another caption of image 1", "caption of image 2", "caption of image 3"] + >>> samples = {"image": images, "text_input": text_input, "epoch": 0, "iters": 0, "num_iters_per_epoch": 100} + >>> output = model(samples) + >>> output.keys() + odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm', 'loss_mlm']) + """ + image = samples["image"] + caption = samples["text_input"] + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + with torch.no_grad(): + self.temp.clamp_(0.001, 0.5) + + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + text = self.tokenizer( + caption, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + + text_output = self.text_encoder.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1 + ) + image_feat_all = torch.cat( + [image_feat_m.t(), self.image_queue.clone().detach()], dim=1 + ) + text_output_m = self.text_encoder_m.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1) + text_feat_all = torch.cat( + [text_feat_m.t(), self.text_queue.clone().detach()], dim=1 + ) + + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) + sim_targets.fill_diagonal_(1) + + sim_i2t_targets = ( + alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + ) + sim_t2i_targets = ( + alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + ) + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1 + ).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1 + ).mean() + + loss_itc = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m) + + # forward the positve image-text pair + encoder_output_pos = self.text_encoder.bert( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + mode="fusion", + ) + with torch.no_grad(): + bs = image.size(0) + + weights_i2t = sim_i2t[:, :bs].clone() + weights_t2i = sim_t2i[:, :bs].clone() + + weights_i2t.fill_diagonal_(-np.Inf) + weights_t2i.fill_diagonal_(-np.Inf) + + weights_i2t = F.softmax(weights_i2t, dim=1) + weights_t2i = F.softmax(weights_t2i, dim=1) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + encoder_output_neg = self.text_encoder.bert( + encoder_embeds=text_embeds_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + mode="fusion", + ) + + vl_embeddings = torch.cat( + [ + encoder_output_pos.last_hidden_state[:, 0, :], + encoder_output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + itm_logits = self.itm_head(vl_embeddings) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(self.device) + loss_itm = F.cross_entropy(itm_logits, itm_labels) + + # MLM + input_ids = text.input_ids.clone() + labels = input_ids.clone() + + probability_matrix = torch.full(labels.shape, self.mlm_probability) + input_ids, labels = self.mask( + input_ids, + self.text_encoder.config.vocab_size, + self.device, + targets=labels, + probability_matrix=probability_matrix, + ) + + with torch.no_grad(): + logits_m = self.text_encoder_m( + input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds_m, + encoder_attention_mask=image_atts, + return_dict=True, + return_logits=True, + ) + mlm_output = self.text_encoder( + input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + labels=labels, + soft_labels=F.softmax(logits_m, dim=-1), + alpha=alpha, + ) + loss_mlm = mlm_output.loss + + return AlbefOutput( + loss=loss_itc + loss_itm + loss_mlm, + loss_itc=loss_itc, + loss_itm=loss_itm, + loss_mlm=loss_mlm, + sims=AlbefSimilarity( + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + sim_i2t_m=sim_i2t_m, + sim_t2i_m=sim_t2i_m, + sim_i2t_targets=sim_i2t_targets, + sim_t2i_targets=sim_t2i_targets, + ), + intermediate_output=AlbefIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + text_embeds=text_embeds, + text_embeds_m=text_embeds_m, + encoder_output=encoder_output_pos, + encoder_output_neg=encoder_output_neg, + itm_logits=itm_logits, + itm_labels=itm_labels, + ), + ) + + def mask( + self, + input_ids, + vocab_size, + device, + targets=None, + masked_indices=None, + probability_matrix=None, + ): + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + """ + if masked_indices is None: + masked_indices = torch.bernoulli(probability_matrix).bool() + + masked_indices[input_ids == self.tokenizer.pad_token_id] = False + masked_indices[input_ids == self.tokenizer.cls_token_id] = False + + if targets is not None: + targets[~masked_indices] = -100 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = ( + torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices + ) + input_ids[indices_replaced] = self.tokenizer.mask_token_id + + # 10% of the time, we replace masked input tokens with random word + indices_random = ( + torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() + & masked_indices + & ~indices_replaced + ) + random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to( + device + ) + input_ids[indices_random] = random_words[indices_random] + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + + if targets is not None: + return input_ids, targets + else: + return input_ids + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True) + config_text_encoder = BertConfig.from_json_file( + get_abs_path(cfg["med_config_path"]) + ) + config_text_encoder.fusion_layer = 6 + text_encoder = BertForMaskedLM.from_pretrained( + "bert-base-uncased", config=config_text_encoder + ) + + embed_dim = cfg.get("embed_dim", 256) + momentum = cfg.get("momentum", 0.995) + alpha = cfg.get("alpha", 0.4) + mlm_mask_prob = cfg.get("mlm_mask_prob", 0.15) + temp = cfg.get("temp", 0.07) + max_txt_len = cfg.get("max_txt_len", 30) + queue_size = cfg.get("queue_size", 65536) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + queue_size=queue_size, + embed_dim=embed_dim, + mlm_mask_prob=mlm_mask_prob, + temp=temp, + momentum=momentum, + alpha=alpha, + max_txt_len=max_txt_len, + ) + + return model diff --git a/lavis/models/albef_models/albef_retrieval.py b/lavis/models/albef_models/albef_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..dafea6d806445bb851dc6b4d47281d65d81508cf --- /dev/null +++ b/lavis/models/albef_models/albef_retrieval.py @@ -0,0 +1,344 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.albef_models import AlbefBase, compute_sim_matrix +from lavis.models.albef_models.albef_outputs import ( + AlbefIntermediateOutput, + AlbefOutput, + AlbefSimilarity, +) +from lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin +from lavis.models.med import XBertEncoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("albef_retrieval") +class AlbefRetrieval(AlbefBase, MomentumDistilationMixin, SharedQueueMixin): + """ + ALBEF retrieval model. + + Supported model types: + - coco: fine-tuned ALBEF base model on COCO dataset (Karparthy split). + - flickr: fine-tuned ALBEF base model on Flickr30k dataset. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("albef_retrieval", "coco") + >>> model = load_model("albef_retrieval", "flickr") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "coco": "configs/models/albef_retrieval_coco.yaml", + "flickr": "configs/models/albef_retrieval_flickr.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + queue_size, + embed_dim=256, + temp=0.07, + use_distill=True, + momentum=0.995, + alpha=0.4, + max_txt_len=30, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create the momentum encoder + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + + self.vision_proj_m = deepcopy(self.vision_proj) + self.text_proj_m = deepcopy(self.text_proj) + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.vision_proj, self.vision_proj_m], + [self.text_proj, self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("idx_queue", torch.full((1, queue_size), -100)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(temp * torch.ones([])) + + self.alpha = alpha + self.max_txt_len = max_txt_len + self.use_distill = use_distill + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) + + def forward(self, samples): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. + - text_input (list): A list of length batch_size, each element is a string of text/caption. + - image_id (torch.Tensor): A tensor of shape (batch_size, ). The image ids, used to identify same images in batch. + - epoch (int): The current epoch. + - iters (int): The current iteration. + - num_iters_per_epoch (int): The number of iterations per epoch. + + Returns: + BlipOutput: A BlipOutput object. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("albef_retrieval", "coco") + >>> images = torch.randn(4, 3, 384, 384) + >>> text_input = ["caption of image 1", "another caption of image 1", "caption of image 2", "caption of image 3"] + >>> image_id = torch.tensor([1, 1, 2, 3]) + >>> samples = {"image": images, "text_input": text_input, "image_id": image_id, "epoch": 0, "iters": 0, "num_iters_per_epoch": 100} + >>> output = model(samples) + >>> output.keys() + odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm']) + """ + image = samples["image"] + caption = samples["text_input"] + idx = samples["image_id"] + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + with torch.no_grad(): + self.temp.clamp_(0.001, 0.5) + + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + + text = self.tokenizer( + caption, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + + text_output = self.text_encoder.forward_text(text) + + text_embeds = text_output.last_hidden_state + text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + idx = idx.view(-1, 1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1 + ) + image_feat_all = torch.cat( + [image_feat_m.t(), self.image_queue.clone().detach()], dim=1 + ) + text_output_m = self.text_encoder_m.forward_text(text) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1) + text_feat_all = torch.cat( + [text_feat_m.t(), self.text_queue.clone().detach()], dim=1 + ) + + if self.use_distill: + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_i2t_targets = ( + alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + ) + sim_t2i_targets = ( + alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + ) + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + if self.use_distill: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1 + ).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1 + ).mean() + else: + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_targets, dim=1 + ).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_targets, dim=1 + ).mean() + + loss_itc = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) + + encoder_output_pos = self.text_encoder( + encoder_embeds=text_embeds, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + mode="fusion", + ) + + with torch.no_grad(): + bs = image.size(0) + weights_i2t = F.softmax(sim_i2t[:, :bs] + 1e-4, dim=1) + weights_t2i = F.softmax(sim_t2i[:, :bs] + 1e-4, dim=1) + + mask = torch.eq(idx, idx.T) + weights_i2t.masked_fill_(mask, 0) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + encoder_output_neg = self.text_encoder( + encoder_embeds=text_embeds_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + mode="fusion", + ) + + vl_embeddings = torch.cat( + [ + encoder_output_pos.last_hidden_state[:, 0, :], + encoder_output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + itm_logits = self.itm_head(vl_embeddings) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(self.device) + loss_itm = F.cross_entropy(itm_logits, itm_labels) + + return AlbefOutput( + loss=loss_itc + loss_itm, + loss_itc=loss_itc, + loss_itm=loss_itm, + sims=AlbefSimilarity( + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + sim_i2t_m=sim_i2t_m, + sim_t2i_m=sim_t2i_m, + sim_i2t_targets=sim_i2t_targets, + sim_t2i_targets=sim_t2i_targets, + ), + intermediate_output=AlbefIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + text_embeds=text_embeds, + text_embeds_m=text_embeds_m, + encoder_output=encoder_output_pos, + encoder_output_neg=encoder_output_neg, + itm_logits=itm_logits, + itm_labels=itm_labels, + ), + ) + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=False) + text_encoder = XBertEncoder.from_config(cfg) + + embed_dim = cfg.get("embed_dim", 256) + momentum = cfg.get("momentum", 0.995) + alpha = cfg.get("alpha", 0.4) + temp = cfg.get("temp", 0.07) + max_txt_len = cfg.get("max_txt_len", 30) + queue_size = cfg.get("queue_size", 0) + use_distill = cfg.get("use_distill", True) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + queue_size=queue_size, + embed_dim=embed_dim, + temp=temp, + momentum=momentum, + alpha=alpha, + max_txt_len=max_txt_len, + use_distill=use_distill, + ) + + model.load_checkpoint_from_config(cfg) + + return model + + def compute_sim_matrix(self, data_loader, task_cfg): + """ + Compute similarity i2t, t2i matrix for the given data loader. + """ + k_test = task_cfg.k_test + + return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test) diff --git a/lavis/models/albef_models/albef_vqa.py b/lavis/models/albef_models/albef_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4dcbb9cd34a28637d3420ef4bdad5be47563b3 --- /dev/null +++ b/lavis/models/albef_models/albef_vqa.py @@ -0,0 +1,442 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path, is_url +from lavis.models.albef_models import AlbefBase +from lavis.models.albef_models.albef_outputs import AlbefIntermediateOutput, AlbefOutput +from lavis.models.base_model import MomentumDistilationMixin, tile +from lavis.models.med import BertConfig, BertLMHeadModel, XBertEncoder +from lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed +from lavis.common.dist_utils import download_cached_file + + +@registry.register_model("albef_vqa") +class AlbefVQA(AlbefBase, MomentumDistilationMixin): + """ + ALBEF VQA models. + + Supported model types: + - base: vqa model initialized with pre-trained ALBEF base model on 115M image-text pairs after CapFilt; not fine-tuned. + - vqav2: fine-tuned ALBEF base model on VQA v2.0 dataset. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("albef_vqa", "vqav2") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "vqav2": "configs/models/albef_vqav2.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + text_decoder, + use_distill=True, + momentum=0.995, + alpha=0.4, + max_txt_len=35, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.max_txt_len = max_txt_len + + self.use_distill = use_distill + + self.visual_encoder = image_encoder + + self.text_encoder = text_encoder + self.text_decoder = text_decoder + + if self.use_distill: + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + self.text_decoder_m = deepcopy(self.text_decoder) + + self.momentum = momentum + self.alpha = alpha + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + + self.copy_params() + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch) + + def forward(self, samples): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (list): A list of strings, each string is a question + - answer (list): A list of strings, each string is an answer + - weight (torch.Tensor): A tensor used to weigh each answer in the loss computation. + The shape of the tensor is (sum(n_answers),) + - n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers + for each question in the batch. + + Returns: + An AlbefOutput object containing loss and intermediate outputs; + see lavis/models/albef_models/albef_outputs.py for more details. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("albef_vqa") + >>> samples = { + ... "image": torch.rand(2, 3, 384, 384), + ... "text_input": ["What is this?", "What is that?"], + ... "answer": ["cat", "cat", "dog"], + ... "weight": torch.tensor([1.0, 1.0, 1.0]), + ... "n_answers": torch.tensor([2, 1]), + ... "epoch": 0, "iters": 0, "num_iters_per_epoch": 1000, + ... } + >>> output = model(samples) + >>> output.keys() + odict_keys(['intermediate_output', 'loss']) + """ + ( + encoder_output, + encoder_output_m, + image_embeds, + image_embeds_m, + ) = self.forward_encoder(samples) + loss, decoder_output, decoder_targets = self.forward_decoder( + samples, encoder_out=(encoder_output, encoder_output_m) + ) + + return AlbefOutput( + loss=loss, + intermediate_output=AlbefIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + encoder_output=encoder_output, + encoder_output_m=encoder_output_m, + decoder_output=decoder_output, + decoder_labels=decoder_targets, + ), + ) + + def forward_encoder(self, samples): + questions = samples["text_input"] + questions = self.tokenizer( + questions, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + samples.update({"tokenized_text": questions}) + + image_embeds = self.visual_encoder.forward_features(samples["image"]) + encoder_output = self.text_encoder.forward_automask( + tokenized_text=samples["tokenized_text"], visual_embeds=image_embeds + ) + + if self.use_distill: + self._momentum_update() + with torch.no_grad(): + image_embeds_m = self.visual_encoder_m(samples["image"]) + encoder_output_m = self.text_encoder_m.forward_automask( + tokenized_text=samples["tokenized_text"], + visual_embeds=image_embeds_m, + ) + else: + encoder_output_m = None + image_embeds_m = None + + return encoder_output, encoder_output_m, image_embeds, image_embeds_m + + def forward_decoder(self, samples, encoder_out, **kwargs): + answers = self.tokenizer( + samples["answer"], padding="longest", return_tensors="pt" + ).to(self.device) + answer_targets = answers.input_ids.masked_fill( + answers.input_ids == self.tokenizer.pad_token_id, -100 + ) + + question_states = [] + question_atts = [] + + question = samples["tokenized_text"] + question_output, question_output_m = encoder_out + + for b, n in enumerate(samples["n_answers"]): + question_states += [question_output.last_hidden_state[b]] * n + question_atts += [question.attention_mask[b]] * n + + question_states = torch.stack(question_states, dim=0) + question_atts = torch.stack(question_atts, dim=0) + + if self.use_distill: + with torch.no_grad(): + question_states_m = [] + for b, n in enumerate(samples["n_answers"]): + question_states_m += [question_output_m.last_hidden_state[b]] * n + question_states_m = torch.stack(question_states_m, 0) + + logits_m = self.text_decoder_m( + answers.input_ids, + attention_mask=answers.attention_mask, + encoder_hidden_states=question_states_m, + encoder_attention_mask=question_atts, + return_logits=True, + ) + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + answer_output = self.text_decoder( + answers.input_ids, + attention_mask=answers.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + soft_labels=F.softmax(logits_m, dim=-1), + alpha=alpha, + return_dict=True, + reduction="none", + ) + + loss = samples["weight"] * answer_output.loss + bsz = samples["image"].size(0) + + loss = loss.sum() / bsz + + return loss, answer_output, answer_targets + + def predict_answers(self, samples, answer_list, num_ans_candidates=128, **kwargs): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (str or [str]): String or a list of strings, each string is a question. + The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first. + num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability. + answer_list (list): A list of strings, each string is an answer. + + Returns: + List: A list of strings, each string is an answer. + + Examples: + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> model, vis_processors, txt_processors = load_model_and_preprocess("albef_vqa", "vqav2") + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> question = "Which city is this photo taken?" + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> question = txt_processors["eval"](question) + >>> samples = {"image": image, "text_input": [question]} + >>> answer_list = ["Singapore", "London", "Palo Alto", "Tokyo"] + >>> answers = model.predict_answers(samples, answer_list=answer_list) + >>> answers + ['Singapore'] + """ + + if isinstance(samples["text_input"], str): + samples["text_input"] = [samples["text_input"]] + + assert len(samples["text_input"]) == samples["image"].size( + 0 + ), "The number of questions must be equal to the batch size." + + num_ans_candidates = min(num_ans_candidates, len(answer_list)) + + return self.rank_answers( + samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates + ) + + def rank_answers(self, samples, answer_list, num_ans_candidates): + """ + Generate the first token of answers using decoder and select ${num_ans_candidates} + most probable ones. Then select answers from answer list, which start with the probable tokens. + Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss. + Return the answers that minimize the losses as result. + + """ + answer_candidates = self.tokenizer( + answer_list, padding="longest", return_tensors="pt" + ).to(self.device) + # answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id + + answer_ids = answer_candidates.input_ids + answer_atts = answer_candidates.attention_mask + + question_output, _, _, _ = self.forward_encoder(samples) + question_states = question_output.last_hidden_state + + tokenized_question = samples["tokenized_text"] + question_atts = tokenized_question.attention_mask + + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.text_decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction="none", + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax(logits, dim=1).index_select( + dim=1, index=answer_first_token + ) + topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100 + ) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, num_ans_candidates) + question_atts = tile(question_atts, 0, num_ans_candidates) + + output = self.text_decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction="none", + ) + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] + + answers = [answer_list[max_id] for max_id in max_ids] + + return answers + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + text_encoder = XBertEncoder.from_config(cfg) + + config_decoder = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"])) + config_decoder.fusion_layer = 0 + config_decoder.num_hidden_layers = 6 + text_decoder = BertLMHeadModel.from_pretrained( + "bert-base-uncased", config=config_decoder + ) + + alpha = cfg.get("alpha", 0.4) + momentum = cfg.get("momentum", 0.995) + use_distill = cfg.get("use_distill", True) + max_txt_len = cfg.get("max_txt_len", 25) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + text_decoder=text_decoder, + use_distill=use_distill, + momentum=momentum, + alpha=alpha, + max_txt_len=max_txt_len, + ) + + # load pre-trained weights + model.load_checkpoint_from_config(cfg) + + return model + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + # reshape positional embedding to accomodate for image resolution change + pos_embed_reshaped = interpolate_pos_embed( + state_dict["visual_encoder.pos_embed"], self.visual_encoder + ) + state_dict["visual_encoder.pos_embed"] = pos_embed_reshaped + + m_pos_embed_reshaped = interpolate_pos_embed( + state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m + ) + state_dict["visual_encoder_m.pos_embed"] = m_pos_embed_reshaped + + for key in list(state_dict.keys()): + if "bert" in key: + encoder_key = key.replace("bert.", "") + state_dict[encoder_key] = state_dict[key] + + # intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder) + if "text_encoder" in key: + if "layer" in key: + encoder_keys = key.split(".") + layer_num = int(encoder_keys[4]) + + if layer_num < 6: + del state_dict[key] + continue + else: + decoder_layer_num = layer_num - 6 + encoder_keys[4] = str(decoder_layer_num) + encoder_key = ".".join(encoder_keys) + else: + encoder_key = key + decoder_key = encoder_key.replace("text_encoder", "text_decoder") + state_dict[decoder_key] = state_dict[key] + + del state_dict[key] + + for key in self.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape != self.state_dict()[key].shape: + del state_dict[key] + + msg = self.load_state_dict(state_dict, strict=False) + logging.info("load checkpoint from %s" % url_or_filename) + logging.info(f"missing keys: {msg.missing_keys}") + + return msg diff --git a/lavis/models/alpro_models/__init__.py b/lavis/models/alpro_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1dfd29514a54a24857c775b461a8937243c06784 --- /dev/null +++ b/lavis/models/alpro_models/__init__.py @@ -0,0 +1,103 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +import torch.nn.functional as F +from lavis.common.dist_utils import download_cached_file +from lavis.common.utils import is_url +from lavis.models.base_model import BaseModel +from transformers import BertTokenizer + + +class AlproBase(BaseModel): + @classmethod + def init_tokenizer(cls): + return BertTokenizer.from_pretrained("bert-base-uncased") + + def load_from_pretrained(self, url_or_filename, num_frames, num_patches): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + for key in list(state_dict.keys()): + if "bert" in key: + new_key = key.replace("bert.", "") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + spatial_embed_key = "visual_encoder.model.pos_embed" + temporal_embed_key = "visual_encoder.model.time_embed" + + ## Resizing spatial embeddings in case they don't match + if num_patches + 1 != state_dict[spatial_embed_key].size(1): + state_dict[spatial_embed_key] = resize_spatial_embedding( + state_dict, spatial_embed_key, num_patches + ) + else: + logging.info( + "The length of spatial position embedding matches. No need to resize." + ) + + ## Resizing time embeddings in case they don't match + if temporal_embed_key in state_dict and num_frames != state_dict[ + temporal_embed_key + ].size(1): + state_dict[temporal_embed_key] = resize_temporal_embedding( + state_dict, temporal_embed_key, num_frames + ) + else: + logging.info( + "No temporal encoding found. Or the length of temporal position embedding matches. No need to resize." + ) + + msg = self.load_state_dict(state_dict, strict=False) + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +def resize_spatial_embedding(state_dict, key, num_patches): + logging.info( + f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" + ) + + pos_embed = state_dict[key] + + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + + new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + + return new_pos_embed + + +def resize_temporal_embedding(state_dict, key, num_frames): + logging.info( + f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" + ) + + time_embed = state_dict[key].transpose(1, 2) + new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") + + return new_time_embed.transpose(1, 2) diff --git a/lavis/models/alpro_models/alpro_outputs.py b/lavis/models/alpro_models/alpro_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..68a11a9cfbd95c866597cf0e8d5a126134587de6 --- /dev/null +++ b/lavis/models/alpro_models/alpro_outputs.py @@ -0,0 +1,59 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + BaseModelOutputWithPoolingAndCrossAttentions, + ModelOutput, +) + + +@dataclass +class AlproSimilarity(ModelOutput): + sim_v2t: torch.FloatTensor = None + sim_t2v: torch.FloatTensor = None + + sim_v2t_targets: Optional[torch.FloatTensor] = None + sim_t2v_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class AlproIntermediateOutput(ModelOutput): + # uni-modal features + video_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + vtm_logits: Optional[torch.FloatTensor] = None + vtm_labels: Optional[torch.LongTensor] = None + + +@dataclass +class AlproOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[AlproSimilarity] = None + + intermediate_output: AlproIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_vtc: Optional[torch.FloatTensor] = None + + loss_vtm: Optional[torch.FloatTensor] = None + + loss_mlm: Optional[torch.FloatTensor] = None + + +@dataclass +class AlproOutputWithLogits(AlproOutput): + logits: torch.FloatTensor = None diff --git a/lavis/models/alpro_models/alpro_qa.py b/lavis/models/alpro_models/alpro_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..2a931be0e23f2c218431288b8390f7a3304702c8 --- /dev/null +++ b/lavis/models/alpro_models/alpro_qa.py @@ -0,0 +1,141 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from warnings import warn + +import torch +import torch.nn.functional as F +from lavis.common.config import node_to_dict +from lavis.common.registry import registry +from lavis.models.alpro_models import AlproBase +from lavis.models.alpro_models.alpro_outputs import ( + AlproIntermediateOutput, + AlproOutputWithLogits, +) +from lavis.models.med import XBertEncoder +from lavis.models.timesformer.vit import TimeSformer +from torch import nn + + +@registry.register_model("alpro_qa") +class AlproQA(AlproBase): + PRETRAINED_MODEL_CONFIG_DICT = { + "msrvtt": "configs/models/alpro_qa_msrvtt.yaml", + "msvd": "configs/models/alpro_qa_msvd.yaml", + } + + def __init__( + self, visual_encoder, text_encoder, hidden_size, num_classes, max_txt_len=40 + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = visual_encoder + + self.text_encoder = text_encoder + + if num_classes > 0: + self.classifier = nn.Sequential( + nn.Linear(hidden_size, hidden_size * 2), + nn.ReLU(True), + nn.Linear(hidden_size * 2, num_classes), + ) + else: + warn(f"num_classes is 0. Initialized {type(self)} without classifier.") + + self.max_txt_len = max_txt_len + + def forward(self, samples, is_train=True): + visual_inputs = samples["video"] + question = samples["text_input"] + targets = samples["answers"] + + # forward text + text = self.tokenizer( + question, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + + text_output = self.text_encoder.forward_text( + text, + token_type_ids=torch.zeros( + text.input_ids.shape, dtype=torch.long, device=self.device + ), + ) + text_embeds = text_output.last_hidden_state + + # forward visual + # timeSformer asks for (b, c, t, h, w) as input. + video_embeds = self.visual_encoder.forward_features(visual_inputs) + video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + # forward cross-encoder + attention_mask = torch.cat([text.attention_mask, video_atts], dim=1) + embedding_output = torch.cat([text_embeds, video_embeds], dim=1) + + encoder_output = self.text_encoder( + encoder_embeds=embedding_output, + attention_mask=attention_mask, + return_dict=True, + mode="fusion", + ) + + prediction = self.classifier(encoder_output.last_hidden_state[:, 0, :]) + if is_train: + loss = F.cross_entropy(prediction, targets) + # return {"loss": loss} + return AlproOutputWithLogits( + loss=loss, + intermediate_output=AlproIntermediateOutput( + video_embeds=video_embeds, + text_embeds=text_embeds, + encoder_output=encoder_output, + ), + logits=prediction, + ) + else: + return {"predictions": prediction, "targets": targets} + + def predict(self, samples): + output = self.forward(samples, is_train=False) + return output + + @classmethod + def from_config(cls, cfg): + # vision encoder + visual_encoder_config = node_to_dict(cfg.timesformer) + visual_encoder = TimeSformer(**visual_encoder_config) + + # text encoder + text_encoder = XBertEncoder.from_config(cfg) + + num_classes = cfg.get("num_classes", -1) + hidden_size = cfg.get("hidden_size", 768) + + model = cls( + visual_encoder=visual_encoder, + text_encoder=text_encoder, + hidden_size=hidden_size, + num_classes=num_classes, + ) + + num_patches = ( + visual_encoder_config["image_size"] // visual_encoder_config["patch_size"] + ) ** 2 + num_frames = visual_encoder_config["n_frms"] + + model.load_checkpoint_from_config( + cfg, num_frames=num_frames, num_patches=num_patches + ) + + return model diff --git a/lavis/models/alpro_models/alpro_retrieval.py b/lavis/models/alpro_models/alpro_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..f574ad42bdfd2f13b21dc30430b72f9c278d0ced --- /dev/null +++ b/lavis/models/alpro_models/alpro_retrieval.py @@ -0,0 +1,422 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import time + +import lavis.common.dist_utils as dist_utils +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +from lavis.common.config import node_to_dict +from lavis.common.dist_utils import get_rank +from lavis.common.logger import MetricLogger +from lavis.common.registry import registry +from lavis.models.alpro_models import AlproBase +from lavis.models.alpro_models.alpro_outputs import AlproIntermediateOutput, AlproOutput +from lavis.models.base_model import all_gather_with_grad +from lavis.models.med import XBertEncoder +from lavis.models.timesformer.vit import TimeSformer +from torch import nn + + +@registry.register_model("alpro_retrieval") +class AlproRetrieval(AlproBase): + PRETRAINED_MODEL_CONFIG_DICT = { + "msrvtt": "configs/models/alpro_retrieval_msrvtt.yaml", + "didemo": "configs/models/alpro_retrieval_didemo.yaml", + } + + def __init__( + self, + visual_encoder, + text_encoder, + vision_width=768, + text_width=768, + embed_dim=256, + max_txt_len=35, + temp=0.07, + ): + super().__init__() + + self.temp = nn.Parameter(torch.ones([]) * temp) + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = visual_encoder + self.text_encoder = text_encoder + + vision_width = vision_width + text_width = text_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + self.max_txt_len = max_txt_len + + def forward(self, samples): + with torch.no_grad(): + self.temp.clamp_(0.001, 0.5) + + visual_inputs = samples["video"] + caption = samples["text_input"] + + b, t, c, h, w = visual_inputs.shape + + # forward text + text = self.tokenizer( + caption, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + + text_output = self.text_encoder.forward_text( + text, + token_type_ids=torch.zeros( + text.input_ids.shape, dtype=torch.long, device=self.device + ), + ) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + # forward visual + # timeSformer asks for (b, c, t, h, w) as input. + video_embeds = self.visual_encoder.forward_features(visual_inputs) + video_feat = F.normalize(self.vision_proj(video_embeds[:, 0, :]), dim=-1) + video_atts = torch.ones(video_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + # ========== (in-batch) ITC loss ========== + gathered_video_feats = all_gather_with_grad(video_feat) + gathered_text_feats = all_gather_with_grad(text_feat) + + sim_v2t = video_feat @ gathered_text_feats.t() / self.temp + sim_t2v = text_feat @ gathered_video_feats.t() / self.temp + + sim_targets = torch.zeros_like(sim_v2t) + + local_rank = get_rank() + b_start, b_end = b * local_rank, b * (local_rank + 1) + sim_targets[:, b_start:b_end] = torch.eye(b) + + loss_v2t = -torch.sum(F.log_softmax(sim_v2t, dim=1) * sim_targets, dim=1).mean() + loss_t2v = -torch.sum(F.log_softmax(sim_t2v, dim=1) * sim_targets, dim=1).mean() + + vtc_loss = (loss_v2t + loss_t2v) / 2 + + ( + vtm_loss, + vtm_logits, + vtm_labels, + encoder_output, + encoder_output_neg, + ) = self.compute_vtm( + text_embeds=text_embeds, + text_atts=text.attention_mask, + image_embeds=video_embeds, + image_atts=video_atts, + sim_i2t=sim_v2t.clone(), # for hard mining + sim_t2i=sim_t2v.clone(), # for hard mining + ) + + loss = vtc_loss + vtm_loss + + # return {"loss": loss} + return AlproOutput( + loss=loss, + loss_vtc=vtc_loss, + loss_vtm=vtm_loss, + intermediate_output=AlproIntermediateOutput( + video_embeds=video_embeds, + text_embeds=text_embeds, + encoder_output=encoder_output, + encoder_output_neg=encoder_output_neg, + vtm_logits=vtm_logits, + vtm_labels=vtm_labels, + ), + ) + + def compute_vtm( + self, text_embeds, text_atts, image_embeds, image_atts, sim_i2t, sim_t2i + ): + device = self.device + + # ====== positive pairs ======= + attention_mask = torch.cat([text_atts, image_atts], dim=1) + embedding_output_pos = torch.cat([text_embeds, image_embeds], dim=1) + + encoder_outputs_pos = self.text_encoder( + encoder_embeds=embedding_output_pos, + attention_mask=attention_mask, + return_dict=True, + mode="fusion", + ) + + # ====== negative pairs ======= + bs = text_embeds.shape[0] + + local_rank = get_rank() + b_start, b_end = bs * local_rank, bs * (local_rank + 1) + + with torch.no_grad(): + weights_v2t = sim_i2t[:, b_start:b_end] + weights_t2v = sim_t2i[:, b_start:b_end] + + # never select self as negative + weights_v2t.fill_diagonal_(-np.Inf) + weights_t2v.fill_diagonal_(-np.Inf) + + weights_v2t = F.softmax(weights_v2t, dim=1) + weights_t2v = F.softmax(weights_t2v, dim=1) + + # select a negative image for each text + # FIXME to optimize using indexing operations + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2v[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_embeds_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_v2t[b], 1).item() + text_embeds_neg.append(text_embeds[neg_idx]) + text_atts_neg.append(text_atts[neg_idx]) + + text_embeds_neg = torch.stack(text_embeds_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_embeds_all = torch.cat([text_embeds, text_embeds_neg], dim=0) + text_atts_all = torch.cat([text_atts, text_atts_neg], dim=0) + + video_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + video_atts_all = torch.cat([image_atts, image_atts], dim=0) + + attention_mask_all = torch.cat([text_atts_all, video_atts_all], dim=1) + embedding_output_all = torch.cat([text_embeds_all, video_embeds_all], dim=1) + + # forward negative pairs via cross encoder + encoder_outputs_neg = self.text_encoder( + encoder_embeds=embedding_output_all, + attention_mask=attention_mask_all, + return_dict=True, + mode="fusion", + ) + + vl_embeddings = torch.cat( + [ + encoder_outputs_pos.last_hidden_state[:, 0, :], + encoder_outputs_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + vtm_logits = self.itm_head(vl_embeddings) + + vtm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(device) + vtm_loss = F.cross_entropy(vtm_logits, vtm_labels) + + return ( + vtm_loss, + vtm_logits, + vtm_labels, + encoder_outputs_pos, + encoder_outputs_neg, + ) + + def compute_sim_matrix(self, data_loader, task_cfg): + k_test = task_cfg.get("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_feats = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + text_output = self.text_encoder.forward_text( + text_input, + token_type_ids=torch.zeros( + text_input.input_ids.shape, dtype=torch.long, device=self.device + ), + ) + text_feats.append(text_output.last_hidden_state.cpu()) + text_embed = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]) + ) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + text_feats = torch.cat(text_feats, dim=0) + + video_feats = [] + video_embeds = [] + for samples in data_loader: + video = samples["video"] + + video = video.to(self.device) + video_feat = self.visual_encoder.forward_features(video) + video_embed = self.vision_proj(video_feat[:, 0, :]) + video_embed = F.normalize(video_embed, dim=-1) + + video_feats.append(video_feat.cpu()) + video_embeds.append(video_embed) + + video_feats = torch.cat(video_feats, dim=0) + video_embeds = torch.cat(video_embeds, dim=0) + + sims_matrix = video_embeds @ text_embeds.t() + score_matrix_v2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(self.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + # video-to-text + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + + video_feats_repeat = ( + video_feats[start + i].repeat(k_test, 1, 1).to(self.device) + ) + video_atts_repeat = torch.ones( + video_feats_repeat.size()[:-1], dtype=torch.long + ).to(self.device) + + attention_mask = torch.cat([text_atts[topk_idx], video_atts_repeat], dim=1) + embedding_output = torch.cat( + [text_feats[topk_idx].to(self.device), video_feats_repeat], dim=1 + ) + + output = self.text_encoder( + encoder_embeds=embedding_output, + attention_mask=attention_mask, + return_dict=True, + mode="fusion", + ) + + score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_v2t[start + i, topk_idx] = score + topk_sim + + # text-to-video + sims_matrix = sims_matrix.t() + score_matrix_t2v = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(self.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + + text_feats_repeat = ( + text_feats[start + i].repeat(k_test, 1, 1).to(self.device) + ) + text_atts_repeat = text_atts[start + i].repeat(k_test, 1).to(self.device) + + video_atts = torch.ones( + video_feats[topk_idx].size()[:-1], dtype=torch.long + ).to(self.device) + + embedding_output = torch.cat( + [text_feats_repeat, video_feats[topk_idx].to(self.device)], dim=1 + ) + attention_mask = torch.cat([text_atts_repeat, video_atts], dim=1) + + output = self.text_encoder( + encoder_embeds=embedding_output, + attention_mask=attention_mask, + return_dict=True, + mode="fusion", + ) + + score = self.itm_head(output.last_hidden_state[:, 0, :])[:, 1] + score_matrix_t2v[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_v2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2v, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() + + @classmethod + def from_config(cls, cfg): + # vision encoder + visual_encoder_config = node_to_dict(cfg.timesformer) + visual_encoder = TimeSformer(**visual_encoder_config) + + # text encoder + text_encoder = XBertEncoder.from_config(cfg) + + max_txt_len = cfg.get("max_txt_len", 35) + + model = cls( + visual_encoder=visual_encoder, + text_encoder=text_encoder, + max_txt_len=max_txt_len, + ) + + num_patches = ( + visual_encoder_config["image_size"] // visual_encoder_config["patch_size"] + ) ** 2 + num_frames = visual_encoder_config["n_frms"] + + model.load_checkpoint_from_config( + cfg, num_frames=num_frames, num_patches=num_patches + ) + + return model diff --git a/lavis/models/base_model.py b/lavis/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1a3b3b1e6290c15a634251d118dab37adea30c --- /dev/null +++ b/lavis/models/base_model.py @@ -0,0 +1,247 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +from lavis.common.dist_utils import download_cached_file, is_dist_avail_and_initialized +from lavis.common.utils import get_abs_path, is_url +from omegaconf import OmegaConf + + +class BaseModel(nn.Module): + """Base class for models.""" + + def __init__(self): + super().__init__() + + @property + def device(self): + return list(self.parameters())[0].device + + def load_checkpoint(self, url_or_filename): + """ + Load from a finetuned checkpoint. + + This should expect no mismatch in the model keys and the checkpoint keys. + """ + + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + if "model" in checkpoint.keys(): + state_dict = checkpoint["model"] + else: + state_dict = checkpoint + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + @classmethod + def from_pretrained(cls, model_type): + """ + Build a pretrained model from default configuration file, specified by model_type. + + Args: + - model_type (str): model type, specifying architecture and checkpoints. + + Returns: + - model (nn.Module): pretrained or finetuned model, depending on the configuration. + """ + model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model + model = cls.from_config(model_cfg) + + return model + + @classmethod + def default_config_path(cls, model_type): + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}".format(model_type) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + def load_checkpoint_from_config(self, cfg, **kwargs): + """ + Load checkpoint as specified in the config file. + + If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model. + When loading the pretrained model, each task-specific architecture may define their + own load_from_pretrained() method. + """ + load_finetuned = cfg.get("load_finetuned", True) + if load_finetuned: + finetune_path = cfg.get("finetuned", None) + assert ( + finetune_path is not None + ), "Found load_finetuned is True, but finetune_path is None." + self.load_checkpoint(url_or_filename=finetune_path) + else: + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + assert "Found load_finetuned is False, but pretrain_path is None." + self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs) + + def before_evaluation(self, **kwargs): + pass + + def show_n_params(self, return_str=True): + tot = 0 + for p in self.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return "{:.1f}M".format(tot / 1e6) + else: + return "{:.1f}K".format(tot / 1e3) + else: + return tot + + +class BaseEncoder(nn.Module): + """ + Base class for primitive encoders, such as ViT, TimeSformer, etc. + """ + + def __init__(self): + super().__init__() + + def forward_features(self, samples, **kwargs): + raise NotImplementedError + + @property + def device(self): + return list(self.parameters())[0].device + + +class SharedQueueMixin: + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr : ptr + batch_size] = image_feats.T + self.text_queue[:, ptr : ptr + batch_size] = text_feats.T + + if idxs is not None: + idxs = concat_all_gather(idxs) + self.idx_queue[:, ptr : ptr + batch_size] = idxs.T + + ptr = (ptr + batch_size) % self.queue_size # move pointer + self.queue_ptr[0] = ptr + + +class MomentumDistilationMixin: + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip( + model_pair[0].parameters(), model_pair[1].parameters() + ): + param_m.data = param_m.data * self.momentum + param.data * ( + 1.0 - self.momentum + ) + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [ + torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/lavis/models/blip2_models/Qformer.py b/lavis/models/blip2_models/Qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e71b12375e10511858a9c505dc795181e6ce5603 --- /dev/null +++ b/lavis/models/blip2_models/Qformer.py @@ -0,0 +1,1216 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/lavis/models/blip2_models/__init__.py b/lavis/models/blip2_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lavis/models/blip2_models/blip2.py b/lavis/models/blip2_models/blip2.py new file mode 100644 index 0000000000000000000000000000000000000000..2259e1ac8be46f5523ae02ff5100540681fa5e5a --- /dev/null +++ b/lavis/models/blip2_models/blip2.py @@ -0,0 +1,229 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import contextlib +import logging +import os +import time +import datetime + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +import lavis.common.dist_utils as dist_utils +from lavis.common.dist_utils import download_cached_file +from lavis.common.utils import is_url +from lavis.common.logger import MetricLogger +from lavis.models.base_model import BaseModel +from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel +from lavis.models.eva_vit import create_eva_vit_g +from lavis.models.clip_vit import create_clip_vit_L +from transformers import BertTokenizer + + +class Blip2Base(BaseModel): + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel.from_pretrained( + "bert-base-uncased", config=encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_vision_encoder( + cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision + ): + assert model_name in [ + "eva_clip_g", + "clip_L", + ], "vit model must be eva_clip_g or clip_L" + if model_name == "eva_clip_g": + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) + elif model_name == "clip_L": + visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision) + ln_vision = LayerNorm(visual_encoder.num_features) + return visual_encoder, ln_vision + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_feat = model.forward_text(text_input) + text_embed = F.normalize(model.text_proj(text_feat)) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + + vit_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(model.device) + image_feat, vit_feat = model.forward_image(image) + image_embed = model.vision_proj(image_feat) + image_embed = F.normalize(image_embed, dim=-1) + + vit_feats.append(vit_feat.cpu()) + image_embeds.append(image_embed) + + vit_feats = torch.cat(vit_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = [] + for image_embed in image_embeds: + sim_q2t = image_embed @ text_embeds.t() + sim_i2t, _ = sim_q2t.max(0) + sims_matrix.append(sim_i2t) + sims_matrix = torch.stack(sims_matrix, dim=0) + + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[topk_idx], + text_atts=text_atts[topk_idx], + ).float() + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[topk_idx.cpu()].to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[start + i].repeat(k_test, 1), + text_atts=text_atts[start + i].repeat(k_test, 1), + ).float() + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/lavis/models/blip2_models/blip2_image_text_matching.py b/lavis/models/blip2_models/blip2_image_text_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..f32db24d09e23c92d61453e569d58f1d7da18969 --- /dev/null +++ b/lavis/models/blip2_models/blip2_image_text_matching.py @@ -0,0 +1,116 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.blip2_models.blip2_qformer import Blip2Qformer + + +@registry.register_model("blip2_image_text_matching") +class Blip2ITM(Blip2Qformer): + """ + BLIP Image-Text Matching (ITM) model. + Supported model types: + - pretrained: pretrained model + - coco: fintuned model on coco + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip2_image_text_matching", "pretrained") + >>> model = load_model("blip2_image_text_matching", "coco") + """ + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + num_query_token=32, + cross_attention_freq=2, + embed_dim=256, + max_txt_len=32, + ): + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + cross_attention_freq=cross_attention_freq, + embed_dim=embed_dim, + max_txt_len=max_txt_len, + ) + + def forward(self, samples, match_head="itm"): + image = samples["image"] + caption = samples["text_input"] + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_embeds = image_embeds.float() + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + text = self.tokenizer( + caption, + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + if match_head == "itm": + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + image.device + ) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + output_itm = self.Qformer.bert( + text.input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] + itm_logit = self.itm_head(itm_embeddings) + itm_logit = itm_logit.mean(dim=1) + + return itm_logit + + elif match_head == "itc": + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + image_feats = F.normalize( + self.vision_proj(query_output.last_hidden_state), dim=-1 + ) + + text_output = self.Qformer.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + ) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 + ) + + sims = torch.bmm(image_feats, text_feat.unsqueeze(-1)) + sim, _ = torch.max(sims, dim=1) + + return sim diff --git a/lavis/models/blip2_models/blip2_opt.py b/lavis/models/blip2_models/blip2_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..14cb4ea5d6be7e85374f806111140e3f7bb02e23 --- /dev/null +++ b/lavis/models/blip2_models/blip2_opt.py @@ -0,0 +1,272 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +from lavis.common.registry import registry +from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train +from lavis.models.blip2_models.modeling_opt import OPTForCausalLM, OPTConfig +from transformers import AutoTokenizer + + +@registry.register_model("blip2_opt") +class Blip2OPT(Blip2Base): + """ + BLIP2 OPT model. + Supported model types: + - pretrained_opt2.7b: pretrained model with OPT2.7b + - pretrained_opt6.7b: pretrained model with OPT6.7b + - caption_coco_opt2.7b: fintuned image captioning model with OPT2.7b + - caption_coco_opt6.7b: fintuned image captioning model with OPT6.7b + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip2_opt", "caption_coco_opt2.7b") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_opt2.7b": "configs/models/blip2/blip2_pretrain_opt2.7b.yaml", + "pretrain_opt6.7b": "configs/models/blip2/blip2_pretrain_opt6.7b.yaml", + "caption_coco_opt2.7b": "configs/models/blip2/blip2_caption_opt2.7b.yaml", + "caption_coco_opt6.7b": "configs/models/blip2/blip2_caption_opt6.7b.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + num_query_token=32, + opt_model="facebook/opt-2.7b", + prompt="", + max_txt_len=32, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + logging.info("freeze vision encoder") + + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.opt_tokenizer = AutoTokenizer.from_pretrained(opt_model, use_fast=False) + self.opt_model = OPTForCausalLM.from_pretrained( + opt_model, torch_dtype=torch.float16 + ) + for name, param in self.opt_model.named_parameters(): + param.requires_grad = False + self.eos_token_id = self.opt_tokenizer( + "\n", add_special_tokens=False + ).input_ids[0] + + self.opt_proj = nn.Linear( + self.Qformer.config.hidden_size, self.opt_model.config.hidden_size + ) + + self.max_txt_len = max_txt_len + self.prompt = prompt + prompt_tokens = self.opt_tokenizer(self.prompt, return_tensors="pt") + self.prompt_length = prompt_tokens.attention_mask.sum(1) + + def forward(self, samples): + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_opt = self.opt_proj(query_output.last_hidden_state) + atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(image.device) + + self.opt_tokenizer.padding_side = "right" + + text = [t + "\n" for t in samples["text_input"]] + + opt_tokens = self.opt_tokenizer( + text, + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + ).to(image.device) + + targets = opt_tokens.input_ids.masked_fill( + opt_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100 + ) + if self.prompt: + targets[:, : self.prompt_length] = -100 # do not apply loss to the prompt + + empty_targets = ( + torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100) + ) + targets = torch.cat([empty_targets, targets], dim=1) + + inputs_embeds = self.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids) + inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1) + attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) + + with self.maybe_autocast(): + outputs = self.opt_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=30, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + num_captions=1, + temperature=1, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_length (int): The maximum length of the sequence to be generated. + min_length (int): The minimum length of the sequence to be generated. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions to be generated for each image. + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_opt = self.opt_proj(query_output.last_hidden_state) + atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to( + image.device + ) + + if "prompt" in samples.keys(): + prompt = samples["prompt"] + else: + prompt = self.prompt + + prompt = [prompt] * image.size(0) + + opt_tokens = self.opt_tokenizer(prompt, return_tensors="pt").to( + image.device + ) + input_ids = opt_tokens.input_ids + attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) + + if use_nucleus_sampling: + query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0) + num_beams = 1 + else: + query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0) + + outputs = self.opt_model.generate( + input_ids=input_ids, + query_embeds=query_embeds, + attention_mask=attention_mask, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + eos_token_id=self.eos_token_id, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + + prompt_length = opt_tokens.input_ids.shape[1] + output_text = self.opt_tokenizer.batch_decode( + outputs[:, prompt_length:], skip_special_tokens=True + ) + output_text = [text.strip() for text in output_text] + return output_text + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + opt_model = cfg.get("opt_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 32) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + opt_model=opt_model, + prompt=prompt, + max_txt_len=max_txt_len, + ) + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/blip2_models/blip2_qformer.py b/lavis/models/blip2_models/blip2_qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb078042384fa6db72c0312d06094e04be51173 --- /dev/null +++ b/lavis/models/blip2_models/blip2_qformer.py @@ -0,0 +1,518 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.cuda.amp import autocast as autocast +from torch.nn import functional as F + +from lavis.common.registry import registry +from lavis.models.base_model import all_gather_with_grad, concat_all_gather +from lavis.models.blip2_models.blip2 import ( + Blip2Base, + compute_sim_matrix, + disabled_train, +) +from lavis.models.blip_models.blip_outputs import BlipOutput, BlipOutputFeatures + + +@registry.register_model("blip2") +@registry.register_model("blip2_feature_extractor") +class Blip2Qformer(Blip2Base): + """ + BLIP2 first-stage model with Q-former and ViT. + Supported model types: + - pretrained: pretrained model with vit-g + - pretrain_vitL: pretrained model with vit-large + - coco: fintuned model on coco + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip2", "pretrain") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain": "configs/models/blip2/blip2_pretrain.yaml", + "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml", + "coco": "configs/models/blip2/blip2_coco.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + num_query_token=32, + cross_attention_freq=2, + embed_dim=256, + max_txt_len=32, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + logging.info("freeze vision encoder") + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features, cross_attention_freq + ) + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + state_dict = self.Qformer.state_dict() + for name, param in self.Qformer.named_parameters(): + if "_query" in name: + key_orig = name.replace("_query", "") + param.data.copy_(state_dict[key_orig]) + + self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) + self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) + + self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2) + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + self.max_txt_len = max_txt_len + + def forward(self, samples): + image = samples["image"] + text = samples["text_input"] + + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + use_cache=True, + return_dict=True, + ) + + image_feats = F.normalize( + self.vision_proj(query_output.last_hidden_state), dim=-1 + ) + + text_tokens = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + text_output = self.Qformer.bert( + text_tokens.input_ids, + attention_mask=text_tokens.attention_mask, + return_dict=True, + ) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 + ) + + ###============== Image-text Contrastive ===================### + image_feats_all = concat_all_gather( + image_feats + ) # [batch_size*num_gpu, num_query_tokens, embed_dim] + text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim] + + sim_q2t = torch.matmul( + image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # image-text similarity: aggregate across all query tokens + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) + ).squeeze() + + # text-image similarity: aggregate across all query tokens + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu] + + rank = dist.get_rank() + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + + ###============== Image-text Matching ===================### + text_input_ids_world = concat_all_gather(text_tokens.input_ids) + text_attention_mask_world = concat_all_gather(text_tokens.attention_mask) + image_embeds_world = all_gather_with_grad(image_embeds) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i, dim=1) + 1e-4 + weights_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t, dim=1) + 1e-4 + weights_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(text_input_ids_world[neg_idx]) + text_atts_neg.append(text_attention_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat( + [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0 + ) # pos, pos, neg + text_atts_all = torch.cat( + [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) + query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to( + image.device + ) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], dim=0 + ) # pos, neg, pos + image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to( + image.device + ) + + output_itm = self.Qformer.bert( + text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] + vl_output = self.itm_head(vl_embeddings) + logits = vl_output.mean(dim=1) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(image.device) + loss_itm = F.cross_entropy(logits, itm_labels) + + ##================= Image Captioning ========================## + decoder_input_ids = text_tokens.input_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + labels = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100 + ) + + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + image.device + ) + attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1) + lm_output = self.Qformer( + decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output.past_key_values, + return_dict=True, + labels=labels, + ) + + loss_lm = lm_output.loss + + return BlipOutput( + loss=loss_itc + loss_itm + loss_lm, + loss_itc=loss_itc, + loss_itm=loss_itm, + loss_lm=loss_lm, + ) + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_length (int): The maximum length of the sequence to be generated. + min_length (int): The minimum length of the sequence to be generated. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions to be generated for each image. + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + image = samples["image"] + image_embeds = self.ln_vision(self.visual_encoder(image)) + + if not use_nucleus_sampling: + image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) + else: + num_beams = 1 + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + model_kwargs = { + "encoder_hidden_states": image_embeds, + "encoder_attention_mask": image_atts, + } + + input_ids = ( + torch.LongTensor(image.size(0), 1) + .fill_(self.tokenizer.bos_token_id) + .to(image.device) + ) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + outputs = self.Qformer.generate( + input_ids=input_ids, + query_embeds=query_tokens, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + do_sample=use_nucleus_sampling, + top_p=top_p, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs + ) + captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return captions + + def forward_image(self, image): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + return query_output.last_hidden_state, image_embeds + + def forward_text(self, text_tokens): + text_output = self.Qformer.bert( + text_tokens.input_ids, + attention_mask=text_tokens.attention_mask, + return_dict=True, + ) + return text_output.last_hidden_state[:, 0, :] + + def compute_itm(self, image_inputs, text_ids, text_atts): + image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to( + image_inputs.device + ) + query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + image_inputs.device + ) + attention_mask = torch.cat([query_atts, text_atts], dim=1) + output_itm = self.Qformer.bert( + text_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_inputs, + encoder_attention_mask=image_atts, + return_dict=True, + ) + vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] + itm_logit = self.itm_head(vl_embeddings) + itm_logit = itm_logit[:, :, 1].mean(dim=1) + return itm_logit + + @torch.no_grad() + def extract_features(self, samples, mode="multimodal"): + """ + Extract features for multimodal or unimodal samples. + Args: + samples (dict): A dictionary of samples, containing the following keys: + - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. + Raw images should be preprocessed before being passed to feature extractor. + - text_input (list): A list of strings containing the text, length B. + mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". + If "multimodal", return image features and multimodal features; + if "text", return text features; + if "image", return image features. + Default: "multimodal". + Returns: + BlipOutputFeatures: A BlipOutputFeatures object containing the features. + See lavis/models/blip_models/blip_outputs.py for more details. + """ + image = samples.get("image") + caption = samples.get("text_input") + + # assert mode is one of "image", "text", "multimodal" + assert mode in [ + "image", + "text", + "multimodal", + ], "mode must be one of 'image', 'text', 'multimodal'" + + # initalize output + image_embeds, text_embeds, multimodal_embeds = None, None, None + image_features, text_features = None, None + + if mode == "image": + assert ( + image is not None + ), "Image is not provided for mode 'image' or 'multimodal'" + # return query features + with self.maybe_autocast(): + image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) + image_embeds_frozen = image_embeds_frozen.float() + image_atts = torch.ones( + image_embeds_frozen.size()[:-1], dtype=torch.long + ).to(self.device) + query_tokens = self.query_tokens.expand( + image_embeds_frozen.shape[0], -1, -1 + ) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds_frozen, + encoder_attention_mask=image_atts, + return_dict=True, + ) + image_embeds = query_output.last_hidden_state + image_features = F.normalize(self.vision_proj(image_embeds), dim=-1) + + elif mode == "text": + assert ( + caption is not None + ), "text input is None for mode 'text' or 'multimodal'" + + # return text features + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + + text_output = self.Qformer.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + ) + text_embeds = text_output.last_hidden_state + text_features = self.text_proj(text_embeds) + text_features = F.normalize(text_features, dim=-1) + + elif mode == "multimodal": + # return multimodel query features + with self.maybe_autocast(): + image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) + image_embeds_frozen = image_embeds_frozen.float() + image_atts = torch.ones( + image_embeds_frozen.size()[:-1], dtype=torch.long + ).to(self.device) + query_tokens = self.query_tokens.expand( + image_embeds_frozen.shape[0], -1, -1 + ) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + self.device + ) + + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output = self.Qformer.bert( + text.input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds_frozen, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + multimodal_embeds = output.last_hidden_state[:, : query_tokens.size(1), :] + + return BlipOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + multimodal_embeds=multimodal_embeds, + ) + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + cross_attention_freq = cfg.get("cross_attention_freq", 2) + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + + max_txt_len = cfg.get("max_txt_len", 32) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + cross_attention_freq=cross_attention_freq, + max_txt_len=max_txt_len, + ) + model.load_checkpoint_from_config(cfg) + + return model + + def compute_sim_matrix(self, data_loader, task_cfg): + """ + Compute similarity i2t, t2i matrix for the given data loader. + """ + k_test = task_cfg.k_test + + return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test) diff --git a/lavis/models/blip2_models/blip2_t5.py b/lavis/models/blip2_models/blip2_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..ba98e431854674ef92d6616a3b0daad432e4801e --- /dev/null +++ b/lavis/models/blip2_models/blip2_t5.py @@ -0,0 +1,383 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast as autocast +from transformers import T5TokenizerFast + +from lavis.common.registry import registry +from lavis.models.blip2_models.blip2 import Blip2Base, disabled_train +from lavis.models.blip2_models.modeling_t5 import T5Config, T5ForConditionalGeneration + + +@registry.register_model("blip2_t5") +class Blip2T5(Blip2Base): + """ + BLIP2 T5 model. + Supported model types: + - pretrain_flant5xl: pretrained model with FlanT5-XL + - pretrain_flant5xl_vitL: pretrained model with FlanT5-XL + - pretrain_flant5xxl: pretrained model with FlanT5-XXL + - caption_coco_flant5xl: fintuned image captioning model with FlanT5-XL + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip2_t5", "pretrain_flant5xl") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_flant5xl": "configs/models/blip2/blip2_pretrain_flant5xl.yaml", + "pretrain_flant5xl_vitL": "configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml", + "pretrain_flant5xxl": "configs/models/blip2/blip2_pretrain_flant5xxl.yaml", + "caption_coco_flant5xl": "configs/models/blip2/blip2_caption_flant5xl.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + num_query_token=32, + t5_model="google/flan-t5-xl", + prompt="", + max_txt_len=32, + apply_lemmatizer=False, + ): + """ + apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas. + """ + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + logging.info("freeze vision encoder") + + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + self.Qformer.cls = None + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model) + t5_config = T5Config.from_pretrained(t5_model) + t5_config.dense_act_fn = "gelu" + self.t5_model = T5ForConditionalGeneration.from_pretrained( + t5_model, config=t5_config + ) + + for name, param in self.t5_model.named_parameters(): + param.requires_grad = False + param.data = param.data.bfloat16() + + self.t5_proj = nn.Linear( + self.Qformer.config.hidden_size, self.t5_model.config.hidden_size + ) + + self.max_txt_len = max_txt_len + self.prompt = prompt + + self._apply_lemmatizer = apply_lemmatizer + self._lemmatizer = None + + def forward(self, samples): + image = samples["image"] + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_t5 = self.t5_proj(query_output.last_hidden_state) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + with self.maybe_autocast(dtype=torch.bfloat16): + input_tokens = self.t5_tokenizer( + samples["text_input"], + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + output_tokens = self.t5_tokenizer( + samples["text_output"], + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + targets = output_tokens.input_ids.masked_fill( + output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 + ) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + decoder_attention_mask=output_tokens.attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=30, + min_length=1, + top_p=0.9, + repetition_penalty=1.0, + length_penalty=1.0, + num_captions=1, + temperature=1, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_length (int): The maximum length of the sequence to be generated. + min_length (int): The minimum length of the sequence to be generated. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions to be generated for each image. + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + image = samples["image"] + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_embeds = image_embeds.float() + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_t5 = self.t5_proj(query_output.last_hidden_state) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + if "prompt" in samples.keys(): + prompt = samples["prompt"] + else: + prompt = self.prompt + + if isinstance(prompt, str): + prompt = [prompt] * image.size(0) + else: + assert len(prompt) == image.size( + 0 + ), "The number of prompts must be equal to the batch size." + + input_tokens = self.t5_tokenizer( + prompt, padding="longest", return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + output_text = self.t5_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + + return output_text + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=-1, + **kwargs + ): + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_embeds = image_embeds.float() + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_t5 = self.t5_proj(query_output.last_hidden_state) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + if isinstance(samples["text_input"], str): + samples["text_input"] = [samples["text_input"]] + if prompt: + text_input = [prompt.format(question) for question in samples["text_input"]] + else: + text_input = samples["text_input"] + + input_tokens = self.t5_tokenizer( + text_input, padding="longest", return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + do_sample=False, + num_beams=num_beams, + max_new_tokens=max_len, + min_length=min_len, + length_penalty=length_penalty, + ) + output_text = self.t5_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + + if self._apply_lemmatizer: + output_text = self._lemmatize(output_text) + + return output_text + + def _lemmatize(self, answers): + def apply(answer): + doc = self.lemmatizer(answer) + + words = [] + for token in doc: + if token.pos_ in ["NOUN", "VERB"]: + words.append(token.lemma_) + else: + words.append(token.text) + answer = " ".join(words) + + return answer + + return [apply(answer) for answer in answers] + + @property + def lemmatizer(self): + if self._lemmatizer is None: + try: + import spacy + + self._lemmatizer = spacy.load("en_core_web_sm") + except ImportError: + logging.error( + """ + Please install spacy and en_core_web_sm model to apply lemmatization. + python -m spacy download en_core_web_sm + OR + import spacy.cli + spacy.cli.download("en_core_web_sm") + """ + ) + exit(1) + + return self._lemmatizer + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + t5_model = cfg.get("t5_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 32) + + apply_lemmatizer = cfg.get("apply_lemmatizer", False) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + t5_model=t5_model, + prompt=prompt, + max_txt_len=max_txt_len, + apply_lemmatizer=apply_lemmatizer, + ) + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/blip2_models/modeling_opt.py b/lavis/models/blip2_models/modeling_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4077c83a706825131be82702deba5e344b87e0 --- /dev/null +++ b/lavis/models/blip2_models/modeling_opt.py @@ -0,0 +1,1113 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch OPT model.""" +import random +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.opt.configuration_opt import OPTConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/opt-350m" +_CONFIG_FOR_DOC = "OPTConfig" +_TOKENIZER_FOR_DOC = "GPT2Tokenizer" + +# Base model docstring +_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024] + +# SequenceClassification docstring +_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc" +_SEQ_CLASS_EXPECTED_LOSS = 1.71 +_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'" + +# QuestionAnswering docstring +_QA_EXPECTED_OUTPUT = "'a nice puppet'" +_QA_EXPECTED_LOSS = 7.41 +_QA_TARGET_START_INDEX = 14 +_QA_TARGET_END_INDEX = 15 + +OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/opt-125m", + "facebook/opt-350m", + "facebook/opt-1.3b", + "facebook/opt-2.7b", + "facebook/opt-6.7b", + "facebook/opt-13b", + "facebook/opt-30b", + # See all OPT models at https://huggingface.co/models?filter=opt +] + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1 + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, attention_mask: torch.LongTensor, past_key_values_length: int = 0 + ): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = ( + torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask + ).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 + if attn_weights.dtype == torch.float16: + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(torch.float16) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +class OPTDecoderLayer(nn.Module): + def __init__(self, config: OPTConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.do_layer_norm_before = config.do_layer_norm_before + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim) + self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + hidden_states_shape = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) + residual = hidden_states + + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + + hidden_states = (residual + hidden_states).view(hidden_states_shape) + + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +OPT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`OPTConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTPreTrainedModel(PreTrainedModel): + + config_class = OPTConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["OPTDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (OPTDecoder)): + module.gradient_checkpointing = value + + +OPT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class OPTDecoder(OPTPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`] + + Args: + config: OPTConfig + """ + + def __init__(self, config: OPTConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.word_embed_proj_dim, self.padding_idx + ) + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size + ) + + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = nn.Linear( + config.hidden_size, config.word_embed_proj_dim, bias=False + ) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = nn.Linear( + config.word_embed_proj_dim, config.hidden_size, bias=False + ) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility + # with checkpoints that have been fine-tuned before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm(config.hidden_size) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList( + [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + past_key_values_length=past_key_values_length, + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) + combined_attention_mask = ( + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if query_embeds is not None: + inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1) + input_shape = inputs_embeds.size()[:-1] + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if self.project_in is not None: + inputs_embeds = self.project_in(inputs_embeds) + + hidden_states = inputs_embeds + pos_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # check if head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask], ["head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + dropout_probability = random.uniform(0, 1) + if self.training and (dropout_probability < self.layerdrop): + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + ) + else: + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + "The bare OPT Model outputting raw hidden-states without any specific head on top.", + OPT_START_DOCSTRING, +) +class OPTModel(OPTPreTrainedModel): + def __init__(self, config: OPTConfig): + super().__init__(config) + self.decoder = OPTDecoder(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.decoder.embed_tokens = value + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + + return BaseModelOutputWithPast( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + ) + + +class OPTForCausalLM(OPTPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + + # the lm_head weight is automatically tied to the embed tokens weight + self.lm_head = nn.Linear( + config.word_embed_proj_dim, config.vocab_size, bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + query_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import GPT2Tokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + query_embeds=query_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]).contiguous() + + loss = None + if labels is not None: + logits = logits[:, -labels.size(1) :, :] + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + loss = loss_fct( + shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) + if reduction == "none": + loss = loss.view(shift_logits.size(0), -1).sum(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids=None, + query_embeds=None, + past=None, + attention_mask=None, + use_cache=None, + **kwargs, + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + if input_ids is not None: + attention_mask = input_ids.new_ones(input_ids.shape) + if past: + input_ids = input_ids[:, -1:] + query_embeds = None + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "use_cache": use_cache, + } + + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past diff --git a/lavis/models/blip2_models/modeling_t5.py b/lavis/models/blip2_models/modeling_t5.py new file mode 100644 index 0000000000000000000000000000000000000000..10e4d56f2c21b0cbe639e0f568bd352a6cb76351 --- /dev/null +++ b/lavis/models/blip2_models/modeling_t5.py @@ -0,0 +1,2063 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if ( + hidden_states.dtype == torch.float16 + and torch.isinf(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["T5Block"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING +) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + if reduction == "none": + loss = loss.view(lm_logits.size(0), -1).sum(1) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder.embed_tokens.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5EncoderModel + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5EncoderModel.from_pretrained("t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/lavis/models/blip_models/__init__.py b/lavis/models/blip_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b88146b9eb3d60dd10ee2aed8e0a33cba924746 --- /dev/null +++ b/lavis/models/blip_models/__init__.py @@ -0,0 +1,90 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +from typing import List + +from torch import nn + + +def tie_encoder_decoder_weights( + encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key: str +): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logging.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" + if hasattr(decoder_pointer, "weight") and skip_key not in module_name: + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + print(module_name + " is tied") + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = set( + [module_name + "/" + sub_name for sub_name in encoder_modules.keys()] + ) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance( + decoder_modules[decoder_name], + type(encoder_modules[encoder_name]), + ) and len(encoder_modules) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively( + decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key + ) diff --git a/lavis/models/blip_models/blip.py b/lavis/models/blip_models/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..44bf7aca81bcc7c039e1aa52442148986d5cd824 --- /dev/null +++ b/lavis/models/blip_models/blip.py @@ -0,0 +1,59 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import os + +import torch +from lavis.common.dist_utils import download_cached_file +from lavis.common.utils import is_url +from lavis.models.base_model import BaseModel +from lavis.models.vit import interpolate_pos_embed +from transformers import BertTokenizer + + +class BlipBase(BaseModel): + @classmethod + def init_tokenizer(cls): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + + state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( + state_dict["visual_encoder.pos_embed"], self.visual_encoder + ) + if "visual_encoder_m.pos_embed" in self.state_dict().keys(): + state_dict["visual_encoder_m.pos_embed"] = interpolate_pos_embed( + state_dict["visual_encoder_m.pos_embed"], self.visual_encoder_m + ) + + for key in self.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape != self.state_dict()[key].shape: + del state_dict[key] + + msg = self.load_state_dict(state_dict, strict=False) + + logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg diff --git a/lavis/models/blip_models/blip_caption.py b/lavis/models/blip_models/blip_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..26f0690a596039a33edfb90b34b3fc0a62ef28ce --- /dev/null +++ b/lavis/models/blip_models/blip_caption.py @@ -0,0 +1,219 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +from lavis.common.registry import registry + +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import ( + BlipOutput, + BlipIntermediateOutput, +) +from lavis.models.med import XBertLMHeadDecoder +from lavis.models.vit import VisionTransformerEncoder + + +@registry.register_model("blip_caption") +class BlipCaption(BlipBase): + """ + BLIP captioning model. + + Supported model types: + - base_coco: fine-tuned BLIP base model on COCO caption dataset (Karparthy split). + - large_coco: fine-tuned BLIP large model on COCO caption dataset (Karparthy split). + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_caption", "base_coco") + >>> model = load_model("blip_caption", "large_coco") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base_coco": "configs/models/blip_caption_base_coco.yaml", + "large_coco": "configs/models/blip_caption_large_coco.yaml", + } + + def __init__(self, image_encoder, text_decoder, prompt=None, max_txt_len=40): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + self.text_decoder = text_decoder + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 + + self.max_txt_len = max_txt_len + + def forward_encoder(self, samples): + image_embeds = self.visual_encoder.forward_features(samples["image"]) + return image_embeds + + def forward_decoder(self, samples, image_embeds): + # prepare inputs for forwarding decoder + raw_text = samples["text_input"] + text = self.tokenizer( + raw_text, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + text.input_ids[:, 0] = self.tokenizer.bos_token_id + + # prepare targets for forwarding decoder + decoder_targets = text.input_ids.masked_fill( + text.input_ids == self.tokenizer.pad_token_id, -100 + ) + decoder_targets[:, : self.prompt_length] = -100 + + # forward decoder + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + decoder_output = self.text_decoder( + input_ids=text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=decoder_targets, + return_dict=True, + ) + + return decoder_output, decoder_targets + + def forward(self, samples): + r""" + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size. + Returns: + output (BlipOutput): A BlipOutput object containing the following + attributes: + - loss (torch.Tensor): A scalar tensor containing the total loss. For BlipCaption, this is the same as the LM loss. + - loss_lm (torch.Tensor): A scalar tensor containing the LM loss. + - intermediate_outputs (BlipIntermediateOutput): A BlipIntermediateOutput object containing intermediate outputs. + see :class:`lavis.models.blip_models.blip_outputs.BlipOutput` for more details. + + Example: + ```python + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption") + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> text_input = ["a large statue of a person spraying water from a fountain"] + >>> samples = {"image": image, "text_input": text_input} + >>> output = model(samples) + >>> output.keys() + odict_keys(['intermediate_output', 'loss', 'loss_lm']) + >>> output.intermediate_output.image_embeds.shape + torch.Size([1, 577, 768]) + >>> output.intermediate_output.decoder_labels.shape + torch.Size([1, 13]) + ```""" + + image_embeds = self.forward_encoder(samples) + decoder_output, decoder_targets = self.forward_decoder(samples, image_embeds) + + # return decoder_out + return BlipOutput( + loss=decoder_output.loss, + loss_lm=decoder_output.loss, + intermediate_output=BlipIntermediateOutput( + image_embeds=image_embeds, + decoder_output=decoder_output, + decoder_labels=decoder_targets, + ), + ) + + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + num_captions=1, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_length (int): The maximum length of the sequence to be generated. + min_length (int): The minimum length of the sequence to be generated. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions to be generated for each image. + Returns: + captions (list): A list of strings of length batch_size * num_captions. + + Example: + ```python + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_caption") + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> samples = {"image": image} + >>> captions = model.generate(samples) + >>> captions + ['a large statue of a person spraying water from a fountain'] + >>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3) + >>> captions # example output, results may vary due to randomness + ['singapore showing the view of some building', + 'the singapore harbor in twilight, as the weather is going down', + 'the famous singapore fountain at sunset'] + """ + # prepare inputs for decoder generation. + encoder_out = self.forward_encoder(samples) + image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0) + + prompt = [self.prompt] * image_embeds.size(0) + prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device) + prompt.input_ids[:, 0] = self.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + # get decoded text + decoder_out = self.text_decoder.generate_from_encoder( + tokenized_prompt=prompt, + visual_embeds=image_embeds, + sep_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + use_nucleus_sampling=use_nucleus_sampling, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + ) + + outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True) + captions = [output[len(self.prompt) :] for output in outputs] + + return captions + + @classmethod + def from_config(cls, cfg): + # vision encoder + image_encoder = VisionTransformerEncoder.from_config(cfg) + # text encoder + multimodal decoder + text_decoder = XBertLMHeadDecoder.from_config(cfg) + + prompt = cfg.get("prompt", None) + max_txt_len = cfg.get("max_txt_len", 40) + + model = cls(image_encoder, text_decoder, prompt=prompt, max_txt_len=max_txt_len) + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/blip_models/blip_classification.py b/lavis/models/blip_models/blip_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..46c49099d6170fd74b8cbdfae8b1925707e493b6 --- /dev/null +++ b/lavis/models/blip_models/blip_classification.py @@ -0,0 +1,177 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.base_model import MomentumDistilationMixin +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import ( + BlipIntermediateOutput, + BlipOutputWithLogits, +) +from lavis.models.med import XBertEncoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("blip_classification") +class BlipClassification(BlipBase, MomentumDistilationMixin): + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/blip_classification_base.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + num_classes, + momentum=0.995, + alpha=0.4, + max_txt_len=40, + use_distill=True, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.use_distill = use_distill + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + hidden_size = text_encoder.config.hidden_size + self.cls_head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, num_classes), + ) + + if self.use_distill: + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + self.cls_head_m = deepcopy(self.cls_head) + + self.momentum = momentum + self.alpha = alpha + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.cls_head, self.cls_head_m], + ] + + self.copy_params() + + self.max_txt_len = max_txt_len + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / num_iters_per_epoch) + + def forward(self, samples, is_train=True): + sentences = samples["text_input"] + sentences = self.tokenizer( + sentences, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + samples.update({"tokenized_text": sentences}) + + targets = samples["label"] + + image_embeds = self.visual_encoder.forward_features(samples["image"]) + encoder_output = self.text_encoder.forward_automask( + samples["tokenized_text"], image_embeds + ) + + prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) + + if is_train: + if self.use_distill: + with torch.no_grad(): + self._momentum_update() + + image_embeds_m = self.visual_encoder_m(samples["image"]) + encoder_output_m = self.text_encoder_m.forward_automask( + samples["tokenized_text"], image_embeds_m + ) + + prediction_m = self.cls_head_m( + encoder_output_m.last_hidden_state[:, 0, :] + ) + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + loss = (1 - alpha) * F.cross_entropy( + prediction, targets + ) - alpha * torch.sum( + F.log_softmax(prediction, dim=1) * F.softmax(prediction_m, dim=1), + dim=1, + ).mean() + else: + loss = F.cross_entropy(prediction, targets) + + # return {"loss": loss} + return BlipOutputWithLogits( + loss=loss, + intermediate_output=BlipIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + encoder_output=encoder_output, + encoder_output_m=encoder_output_m, + ), + logits=prediction, + logits_m=prediction_m, + ) + + else: + return {"predictions": prediction, "targets": targets} + + def predict(self, samples): + output = self.forward(samples, is_train=False) + return output + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + # text encoder + multimodal encoder + text_encoder = XBertEncoder.from_config(cfg) + use_distill = cfg.get("use_distill", True) + momentum = cfg.get("momentum", 0.995) + num_classes = cfg.get("num_classes", -1) + alpha = cfg.get("alpha", 0.4) + max_txt_len = cfg.get("max_txt_len", 40) + + assert num_classes > 1, "Invalid number of classes provided, found {}".format( + num_classes + ) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + use_distill=use_distill, + alpha=alpha, + num_classes=num_classes, + momentum=momentum, + max_txt_len=max_txt_len, + ) + + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + if pretrain_path is not None: + msg = model.load_from_pretrained(url_or_filename=pretrain_path) + + return model diff --git a/lavis/models/blip_models/blip_feature_extractor.py b/lavis/models/blip_models/blip_feature_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..31df635b2e36b57dc2c5d211b76f3735a3e273df --- /dev/null +++ b/lavis/models/blip_models/blip_feature_extractor.py @@ -0,0 +1,212 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import warnings + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import BlipOutputFeatures +from lavis.models.med import XBertEncoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("blip_feature_extractor") +class BlipFeatureExtractor(BlipBase): + """ + Class for BLIP feature extractor. + + Supported model types: + - base: BLIP base model with pre-trained weights from capfilt by BLIP large model. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_feature_extractor", "base") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/blip_feature_extractor_base.yaml", + # "large": "configs/models/blip_feature_extractor_large.yaml", + } + + def __init__(self, image_encoder, text_encoder, embed_dim, max_txt_len=40): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + # creating projection layers for ITC + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.max_txt_len = max_txt_len + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + @torch.no_grad() + def extract_features(self, samples, mode="multimodal"): + """ + Extract features for multimodal or unimodal samples. + + Args: + samples (dict): A dictionary of samples, containing the following keys: + - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. + Raw images should be preprocessed before being passed to feature extractor. + - text_input (list): A list of strings containing the text, length B. + mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". + If "multimodal", return image features and multimodal features; + if "text", return text features; + if "image", return image features. + Default: "multimodal". + + Returns: + BlipOutputFeatures: A BlipOutputFeatures object containing the features. + See lavis/models/blip_models/blip_outputs.py for more details. + + Examples: + ```python + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> caption = "a large fountain spewing water into the air" + >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_feature_extractor", is_eval=True) + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> text_input = txt_processors["eval"](caption) + + >>> sample = {"image": image, "text_input": [text_input]} + + >>> features_multimodal = model.extract_features(sample) + >>> features_multimodal.keys() + odict_keys(['image_embeds', 'multimodal_embeds']) + >>> features_multimodal.image_embeds.shape + torch.Size([1, 197, 768]) + >>> features_multimodal.multimodal_embeds.shape + torch.Size([1, 12, 768]) + + >>> features_text = model.extract_features(sample, mode="text") + >>> features_text.keys() + odict_keys(['text_embeds', 'text_features']) + >>> features_text.text_embeds.shape + torch.Size([1, 12, 768]) + >>> features_text.text_features.shape + torch.Size([1, 12, 256]) + + >>> features_image = model.extract_features(sample, mode="image") + >>> features_image.keys() + odict_keys(['image_embeds', 'image_features']) + >>> features_image.image_embeds.shape + torch.Size([1, 197, 768]) + >>> features_image.image_features.shape + torch.Size([1, 197, 256]) + ``` + """ + image = samples.get("image") + caption = samples.get("text_input") + + # assert mode is one of "image", "text", "multimodal" + assert mode in [ + "image", + "text", + "multimodal", + ], "mode must be one of 'image', 'text', 'multimodal'" + + # initalize output + image_embeds, text_embeds, multimodal_embeds = None, None, None + image_features, text_features = None, None + + if mode == "image": + assert ( + image is not None + ), "Image is not provided for mode 'image' or 'multimodal'" + # return image features + image_embeds = self.visual_encoder.forward_features(image) + + image_features = self.vision_proj(image_embeds) + image_features = F.normalize(image_features, dim=-1) + + elif mode == "text": + assert ( + caption is not None + ), "text input is None for mode 'text' or 'multimodal'" + + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + + # return text features + text_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + text_embeds = text_output.last_hidden_state + + text_features = self.text_proj(text_embeds) + text_features = F.normalize(text_features, dim=-1) + + elif mode == "multimodal": + # return multimodel features + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + text.input_ids[:, 0] = self.tokenizer.enc_token_id + + output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + multimodal_embeds = output.last_hidden_state + + return BlipOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + multimodal_embeds=multimodal_embeds, + ) + + @classmethod + def from_config(cls, cfg=None): + # set from_pretrained=True to load weights for 'bert-base-uncased' + image_encoder = VisionTransformerEncoder.from_config(cfg) + text_encoder = XBertEncoder.from_config(cfg) + + embed_dim = cfg.get("embed_dim", 256) + max_txt_len = cfg.get("max_txt_len", 30) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + embed_dim=embed_dim, + max_txt_len=max_txt_len, + ) + + # load pre-trained weights + pretrain_path = cfg.get("pretrained", None) + if pretrain_path is not None: + msg = model.load_from_pretrained(url_or_filename=pretrain_path) + else: + warnings.warn("No pretrained weights are loaded.") + + return model diff --git a/lavis/models/blip_models/blip_image_text_matching.py b/lavis/models/blip_models/blip_image_text_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a691984f8eec5540e476f7c188e36c1fccab5ea7 --- /dev/null +++ b/lavis/models/blip_models/blip_image_text_matching.py @@ -0,0 +1,199 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.blip_models.blip import BlipBase +from torch import nn +from lavis.models.med import XBertEncoder + +from lavis.models.vit import VisionTransformerEncoder + + +@registry.register_model("blip_image_text_matching") +class BlipITM(BlipBase): + """ + BLIP Image-Text Matching (ITM) model. + + Supported model types: + - base: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split). + - large: fine-tuned BLIP retrieval weights on COCO dataset (Karpathy split). + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_image_text_matching", "base") + >>> model = load_model("blip_image_text_matching", "large") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/blip_itm_base.yaml", + "large": "configs/models/blip_itm_large.yaml", + } + + def __init__(self, image_encoder, text_encoder, embed_dim=256, max_txt_len=35): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.text_encoder = text_encoder + + self.visual_encoder = image_encoder + + self.max_txt_len = max_txt_len + + # creating projection layers for ITC + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + def forward(self, samples, match_head="itm"): + image = samples["image"] + caption = samples["text_input"] + + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + text = self.tokenizer( + caption, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + if match_head == "itm": + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id # extra code + output = self.text_encoder( + encoder_input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) + return itm_output + + elif match_head == "itc": + text_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + mode="text", + ) + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 + ) + + sim = image_feat @ text_feat.t() + return sim + def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head='itm'): + # breakpoint() + encoder_input_ids = encoder_input_ids.clone() + encoder_input_ids = encoder_input_ids[:, 3:] + text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long() + + if match_head == 'itm': + # encoder_input_ids = encoder_input_ids.clone() + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + output = self.text_encoder(encoder_input_ids, + attention_mask=text_attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + # print(output.last_hidden_state.shape) + itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) + itm_output = F.softmax(itm_output, dim=1)[:,1] + return itm_output #, mask, token_length + + elif match_head == 'itc': + encoder_input_ids[:, 0] = self.tokenizer.cls_token_id + text_output = self.text_encoder(encoder_input_ids, attention_mask=text_attention_mask, + return_dict=True, mode='text') + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1) + + sim = image_feat @ text_feat.t() + return sim + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + text_encoder = XBertEncoder.from_config(cfg) + + embed_dim = cfg.get("embed_dim", 256) + max_txt_len = cfg.get("max_txt_len", 35) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + embed_dim=embed_dim, + max_txt_len=max_txt_len, + ) + + model.load_checkpoint_from_config(cfg) + + return model + + +def compute_gradcam(model, visual_input, text_input, tokenized_text, block_num=6): + model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.save_attention = True + + output = model({"image": visual_input, "text_input": text_input}, match_head="itm") + loss = output[:, 1].sum() + + model.zero_grad() + loss.backward() + with torch.no_grad(): + mask = tokenized_text.attention_mask.view( + tokenized_text.attention_mask.size(0), 1, -1, 1, 1 + ) # (bsz,1,token_len, 1,1) + token_length = tokenized_text.attention_mask.sum(dim=-1) - 2 + token_length = token_length.cpu() + # grads and cams [bsz, num_head, seq_len, image_patch] + grads = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attn_gradients() + cams = model.text_encoder.base_model.base_model.encoder.layer[ + block_num + ].crossattention.self.get_attention_map() + + # assume using vit with 576 num image patch + cams = cams[:, :, :, 1:].reshape(visual_input.size(0), 12, -1, 24, 24) * mask + grads = ( + grads[:, :, :, 1:].clamp(0).reshape(visual_input.size(0), 12, -1, 24, 24) + * mask + ) + + gradcams = cams * grads + gradcam_list = [] + + for ind in range(visual_input.size(0)): + token_length_ = token_length[ind] + gradcam = gradcams[ind].mean(0).cpu().detach() + # [enc token gradcam, average gradcam across token, gradcam for individual token] + gradcam = torch.cat( + ( + gradcam[0:1, :], + gradcam[1 : token_length_ + 1, :].sum(dim=0, keepdim=True) + / token_length_, + gradcam[1:, :], + ) + ) + gradcam_list.append(gradcam) + + return gradcam_list, output diff --git a/lavis/models/blip_models/blip_nlvr.py b/lavis/models/blip_models/blip_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..a67d7a1b2c27a200efaae5dda5da1c5fc9ca78e8 --- /dev/null +++ b/lavis/models/blip_models/blip_nlvr.py @@ -0,0 +1,187 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os + +import torch +import torch.nn.functional as F +from lavis.common.dist_utils import download_cached_file +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path, is_url +from lavis.models.base_model import MomentumDistilationMixin +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import BlipIntermediateOutput, BlipOutput +from lavis.models.blip_models.nlvr_encoder import BertModel +from lavis.models.vit import VisionTransformerEncoder, interpolate_pos_embed +from torch import nn +from transformers import BertConfig + + +@registry.register_model("blip_nlvr") +class BlipNLVR(BlipBase, MomentumDistilationMixin): + """ + Class for BLIP NLVR model. + + Supported model types: + - base: model with pre-trained BLIP weights, used as initialization for fine-tuning. + - nlvr: finetuned model on NLVR2 dataset. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_nlvr", "nlvr") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "nlvr": "configs/models/blip_nlvr.yaml", + } + + def __init__(self, image_encoder, text_encoder, num_classes): + super().__init__() + + self.tokenizer = self.init_tokenizer() + self.visual_encoder = image_encoder + self.text_encoder = text_encoder + + hidden_size = text_encoder.config.hidden_size + self.cls_head = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, num_classes), + ) + + def forward(self, samples, is_train=True): + """ + Forward function for training and evaluation. + + Args: + samples (dict): a dict of input samples, which contains the following keys: + - image0 (torch.Tensor): input image 0, shape (batch_size, 3, H, W), default H=384, W=384. + - image1 (torch.Tensor): input image 1, shape (batch_size, 3, H, W), default H=384, W=384. + - text_input (list): list of strings, each string is a natural language sentence. + - label (torch.LongTensor): ground truth label with shape (batch_size,). + is_train (bool): whether the model is in training mode. + If True, the model will return the loss; + If False, the model will return the prediction. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("blip_nlvr", "nlvr") + >>> samples = { + ... "image0": torch.randn(2, 3, 384, 384), + ... "image1": torch.randn(2, 3, 384, 384), + ... "text_input": ["there is a ferret in tall grass", "there are lips in one of the images"], + ... "label": torch.tensor([0, 1]), + ... } + >>> output = model(samples) + >>> output.keys() + odict_keys(['intermediate_output', 'loss']) + """ + text = samples["text_input"] + text = self.tokenizer(text, padding="longest", return_tensors="pt").to( + self.device + ) + text.input_ids[:, 0] = self.tokenizer.enc_token_id + + targets = samples["label"] + + image0 = samples["image0"] + image1 = samples["image1"] + images = torch.cat([image0, image1], dim=0) + + image_embeds = self.visual_encoder.forward_features(images) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + image0_embeds, image1_embeds = torch.split(image_embeds, targets.size(0)) + + encoder_output = self.text_encoder( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=[image0_embeds, image1_embeds], + encoder_attention_mask=[ + image_atts[: image0_embeds.size(0)], + image_atts[image0_embeds.size(0) :], + ], + return_dict=True, + ) + + prediction = self.cls_head(encoder_output.last_hidden_state[:, 0, :]) + + if is_train: + loss = F.cross_entropy(prediction, targets) + # return {"loss": loss} + return BlipOutput( + loss=loss, + intermediate_output=BlipIntermediateOutput( + image_embeds=torch.stack([image0_embeds, image1_embeds], dim=0), + encoder_output=encoder_output, + ), + ) + else: + return {"predictions": prediction, "targets": targets} + + def predict(self, samples): + output = self.forward(samples, is_train=False) + return output + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + # text encoder + multimodal encoder + bert_config = BertConfig.from_json_file(get_abs_path(cfg["med_config_path"])) + text_encoder = BertModel(config=bert_config, add_pooling_layer=False) + + num_classes = cfg.get("num_classes", 3) + + assert num_classes > 1, "Invalid number of classes provided, found {}".format( + num_classes + ) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + num_classes=num_classes, + ) + + model.load_checkpoint_from_config(cfg) + + return model + + def load_from_pretrained(self, url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + state_dict = checkpoint["model"] + + state_dict["visual_encoder.pos_embed"] = interpolate_pos_embed( + state_dict["visual_encoder.pos_embed"], self.visual_encoder + ) + + for key in list(state_dict.keys()): + if "crossattention.self." in key: + new_key0 = key.replace("self", "self0") + new_key1 = key.replace("self", "self1") + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + elif "crossattention.output.dense." in key: + new_key0 = key.replace("dense", "dense0") + new_key1 = key.replace("dense", "dense1") + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + + msg = self.load_state_dict(state_dict, strict=False) + print("load checkpoint from %s" % url_or_filename) + print(f"missing keys {msg.missing_keys}") + return msg diff --git a/lavis/models/blip_models/blip_outputs.py b/lavis/models/blip_models/blip_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..9d18ddcabb68f09e1b4952d337d0206efcd8e8ad --- /dev/null +++ b/lavis/models/blip_models/blip_outputs.py @@ -0,0 +1,116 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) + + +@dataclass +class BlipSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipIntermediateOutput(ModelOutput): + """ + Data class for intermediate outputs of BLIP models. + + image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). + text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). + + image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). + text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). + + encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. + encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. + + decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. + decoder_labels (torch.LongTensor): labels for the captioning loss. + + itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). + itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) + + """ + + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class BlipOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[BlipSimilarity] = None + + intermediate_output: BlipIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_lm: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipOutputWithLogits(BlipOutput): + logits: torch.FloatTensor = None + logits_m: torch.FloatTensor = None + + +@dataclass +class BlipOutputFeatures(ModelOutput): + """ + Data class of features from BlipFeatureExtractor. + + Args: + image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional + image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional + text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional + text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/lavis/models/blip_models/blip_pretrain.py b/lavis/models/blip_models/blip_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..a8077cae11eb35b8e09d4fdfd77ea3c58ef6ea0f --- /dev/null +++ b/lavis/models/blip_models/blip_pretrain.py @@ -0,0 +1,394 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin +from lavis.models.blip_models import tie_encoder_decoder_weights +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import ( + BlipOutput, + BlipSimilarity, + BlipIntermediateOutput, +) +from lavis.models.med import XBertEncoder, XBertLMHeadDecoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("blip_pretrain") +class BlipPretrain(BlipBase, SharedQueueMixin, MomentumDistilationMixin): + """ + BLIP pretrain model. + + Supported model types: + - base: BLIP base model before pretraining. + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/blip_pretrain_base.yaml", + # "large": "configs/models/blip_pretrain_large.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + text_decoder, + queue_size, + alpha=0.4, + embed_dim=256, + momentum=0.995, + tie_enc_dec_weights=True, + max_txt_len=30, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + text_encoder.resize_token_embeddings(len(self.tokenizer)) + text_decoder.resize_token_embeddings(len(self.tokenizer)) + + if tie_enc_dec_weights: + tie_encoder_decoder_weights( + encoder=text_encoder, + decoder=text_decoder.bert, + base_model_prefix="", + skip_key="/attention", + ) + + self.visual_encoder = image_encoder + + self.text_encoder = text_encoder + self.text_decoder = text_decoder + + # creating projection layers for ITC + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create the momentum encoder + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + + self.vision_proj_m = deepcopy(self.vision_proj) + self.text_proj_m = deepcopy(self.text_proj) + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.vision_proj, self.vision_proj_m], + [self.text_proj, self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07 * torch.ones([])) + + self.alpha = alpha + self.max_txt_len = max_txt_len + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) + + def forward(self, samples): + + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. Default: H=224, W=224. + - text_input (list): A list of length batch_size, each element is a string of text/caption. + - epoch (int): The current epoch. + - iters (int): The current iteration. + - num_iters_per_epoch (int): The number of iterations per epoch. + + Returns: + BlipOutput: A BlipOutput object containing loss and intermediate output. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("blip_pretrain", "base") + >>> images = torch.randn(4, 3, 224, 224) + >>> text_input = ["caption of image 1", "another caption of image 1", "caption of image 2", "caption of image 3"] + >>> samples = {"image": images, "text_input": text_input, "epoch": 0, "iters": 0, "num_iters_per_epoch": 100} + >>> output = model(samples) + >>> output.keys() + odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm', 'loss_lm']) + + >>> output.intermediate_output.keys() + odict_keys(['image_embeds', 'text_embeds', 'image_embeds_m', 'text_embeds_m', 'encoder_output', 'encoder_output_neg', 'itm_logits', 'itm_labels', 'decoder_output', 'decoder_labels']) + >>> output.intermediate_output.image_embeds.shape + >>> # shape: (batch_size, num_patches, embed_dim) + torch.Size([4, 197, 768]) + >>> output.intermediate_output.text_embeds.shape + >>> # shape: (batch_size, max_txt_len, embed_dim) + torch.Size([4, 30, 768]) + >>> output.intermediate_output.image_embeds_m.shape + >>> # shape: (batch_size, num_patches, embed_dim) + torch.Size([4, 197, 768]) + >>> output.intermediate_output.text_embeds_m.shape + >>> # shape: (batch_size, max_txt_len, embed_dim) + torch.Size([4, 30, 768]) + >>> output.intermediate_output.itm_logits.shape + >>> # shape: (batch_size * 3, 2) + torch.Size([12, 2]) + >>> output.intermediate_output.itm_labels.shape + >>> # shape: (batch_size * 3,) + torch.Size([12]) + >>> output.intermediate_output.encoder_output.last_hidden_state.shape + >>> # shape: (batch_size, max_txt_len, embed_dim) + torch.Size([4, 30, 768]) + >>> output.intermediate_output.encoder_output_m.last_hidden_state.shape + >>> # shape: (batch_size, max_txt_len, embed_dim) + torch.Size([4, 30, 768]) + >>> output.intermediate_output.decoder_output.logits.shape + >>> # shape: (batch_size, max_txt_len, vocab_size) + torch.Size([4, 30, 30524]) + >>> output.intermediate_output.decoder_labels.shape + >>> # shape: (batch_size, max_txt_len) + torch.Size([4, 30]) + """ + + image = samples["image"] + caption = samples["text_input"] + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + with torch.no_grad(): + self.temp.clamp_(0.001, 0.5) + + # image embeddings and features + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + + text = self.tokenizer( + caption, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + # text embeddings and features + text_output = self.text_encoder.forward_text(text) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1 + ) + image_feat_all = torch.cat( + [image_feat_m.t(), self.image_queue.clone().detach()], dim=1 + ) + + text_output_m = self.text_encoder_m.forward_text(text) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1) + text_feat_all = torch.cat( + [text_feat_m.t(), self.text_queue.clone().detach()], dim=1 + ) + + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) + sim_targets.fill_diagonal_(1) + + sim_i2t_targets = ( + alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + ) + sim_t2i_targets = ( + alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + ) + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1 + ).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1 + ).mean() + + loss_itc = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m) + + # Image-text Matching + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder( + encoder_input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i[:, :bs], dim=1) + 1e-4 + weights_t2i.fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t[:, :bs], dim=1) + 1e-4 + weights_i2t.fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_encoder( + text_ids_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + itm_logits = self.itm_head(vl_embeddings) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(image.device) + loss_itm = F.cross_entropy(itm_logits, itm_labels) + + # LM + decoder_input_ids = text.input_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + decoder_targets = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100 + ) + + decoder_output = self.text_decoder( + decoder_input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + labels=decoder_targets, + return_dict=True, + ) + + loss_lm = decoder_output.loss + + return BlipOutput( + loss=loss_itc + loss_itm + loss_lm, + loss_itc=loss_itc, + loss_itm=loss_itm, + loss_lm=loss_lm, + sims=BlipSimilarity( + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + sim_i2t_m=sim_i2t_m, + sim_t2i_m=sim_t2i_m, + sim_i2t_targets=sim_i2t_targets, + sim_t2i_targets=sim_t2i_targets, + ), + intermediate_output=BlipIntermediateOutput( + image_embeds=image_embeds, + text_embeds=text_embeds, + image_embeds_m=image_embeds_m, + text_embeds_m=text_embeds_m, + encoder_output=output_pos, + encoder_output_neg=output_neg, + itm_logits=itm_logits, + itm_labels=itm_labels, + decoder_output=decoder_output, + decoder_labels=decoder_targets, + ), + ) + + def reset_queue_ptr(self): + self.queue_ptr = torch.zeros(1, dtype=torch.long) + + @classmethod + def from_config(cls, cfg=None): + # set from_pretrained=True to load weights for 'bert-base-uncased' + image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True) + text_encoder = XBertEncoder.from_config(cfg, from_pretrained=True) + text_decoder = XBertLMHeadDecoder.from_config(cfg, from_pretrained=True) + + embed_dim = cfg.get("embed_dim", 256) + momentum = cfg.get("momentum", 0.995) + alpha = cfg.get("alpha", 0.4) + max_txt_len = cfg.get("max_txt_len", 30) + queue_size = cfg.get("queue_size", 57600) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + text_decoder=text_decoder, + embed_dim=embed_dim, + queue_size=queue_size, + momentum=momentum, + alpha=alpha, + tie_enc_dec_weights=True, + max_txt_len=max_txt_len, + ) + + # [IMPORTANT] to reset queue pointer to 0. + # Otherwise when updating last batch in the queue, the batch size and remaining queue length may be un-equal. + model.reset_queue_ptr() + + return model diff --git a/lavis/models/blip_models/blip_retrieval.py b/lavis/models/blip_models/blip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..44e9c5c998d60400c2443112f69f4be5ad415048 --- /dev/null +++ b/lavis/models/blip_models/blip_retrieval.py @@ -0,0 +1,396 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.albef_models import compute_sim_matrix +from lavis.models.base_model import ( + MomentumDistilationMixin, + SharedQueueMixin, + all_gather_with_grad, + concat_all_gather, +) +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import ( + BlipOutput, + BlipSimilarity, + BlipIntermediateOutput, +) +from lavis.models.med import XBertEncoder +from lavis.models.vit import VisionTransformerEncoder +from torch import nn + + +@registry.register_model("blip_retrieval") +class BlipRetrieval(BlipBase, MomentumDistilationMixin, SharedQueueMixin): + """ + BLIP retrieval model. + + Supported model types: + - coco: fine-tuned BLIP base model on COCO dataset (Karpathy split). + - flickr: fine-tuned BLIP base model on Flickr30k dataset. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_retrieval", "coco") + >>> model = load_model("blip_retrieval", "flickr") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "coco": "configs/models/blip_retrieval_coco.yaml", + "flickr": "configs/models/blip_retrieval_flickr.yaml", + } + + def __init__( + self, + image_encoder, + text_encoder, + queue_size, + alpha=0.4, + embed_dim=256, + momentum=0.995, + negative_all_rank=False, + max_txt_len=35, + ): + """ """ + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + + self.text_encoder = text_encoder + + # creating projection layers for ITC + text_width = text_encoder.config.hidden_size + vision_width = image_encoder.vision_width + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create the momentum encoder + self.visual_encoder_m = deepcopy(self.visual_encoder) + self.text_encoder_m = deepcopy(self.text_encoder) + + self.vision_proj_m = deepcopy(self.vision_proj) + self.text_proj_m = deepcopy(self.text_proj) + + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.vision_proj, self.vision_proj_m], + [self.text_proj, self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("idx_queue", torch.full((1, queue_size), -100)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07 * torch.ones([])) + + self.alpha = alpha + self.max_txt_len = max_txt_len + + self.negative_all_rank = negative_all_rank + + def _rampup_factor(self, epoch, iters, num_iters_per_epoch): + return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch)) + + def forward(self, samples): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). The input images. + - text_input (list): A list of length batch_size, each element is a string of text/caption. + - image_id (torch.Tensor): A tensor of shape (batch_size, ). The image ids, used to identify same images in batch. + - epoch (int): The current epoch. + - iters (int): The current iteration. + - num_iters_per_epoch (int): The number of iterations per epoch. + + Returns: + BlipOutput: A BlipOutput object. See ``lavis.models.blip_models.blip_outputs.BlipOutput`` for more details. + + Examples: + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("blip_retrieval", "coco") + >>> images = torch.randn(4, 3, 384, 384) + >>> text_input = ["caption of image 1", "another caption of image 1", "caption of image 2", "caption of image 3"] + >>> image_id = torch.tensor([1, 1, 2, 3]) + >>> samples = {"image": images, "text_input": text_input, "image_id": image_id, "epoch": 0, "iters": 0, "num_iters_per_epoch": 100} + >>> output = model(samples) + >>> output.keys() + odict_keys(['sims', 'intermediate_output', 'loss', 'loss_itc', 'loss_itm']) + """ + image = samples["image"] + caption = samples["text_input"] + idx = samples["image_id"] + + alpha = self.alpha * self._rampup_factor( + epoch=samples["epoch"], + iters=samples["iters"], + num_iters_per_epoch=samples["num_iters_per_epoch"], + ) + + with torch.no_grad(): + self.temp.clamp_(0.001, 0.5) + + image_embeds = self.visual_encoder.forward_features(image) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + + text = self.tokenizer( + caption, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + text_output = self.text_encoder.forward_text(text) + text_embeds = text_output.last_hidden_state + text_feat = F.normalize(self.text_proj(text_embeds[:, 0, :]), dim=-1) + + # Image-text Contrastive Learning + idx = idx.view(-1, 1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()], dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize( + self.vision_proj_m(image_embeds_m[:, 0, :]), dim=-1 + ) + image_feat_m_all = torch.cat( + [image_feat_m.t(), self.image_queue.clone().detach()], dim=1 + ) + + text_output_m = self.text_encoder_m.forward_text(text) + text_embeds_m = text_output_m.last_hidden_state + text_feat_m = F.normalize(self.text_proj_m(text_embeds_m[:, 0, :]), dim=-1) + text_feat_m_all = torch.cat( + [text_feat_m.t(), self.text_queue.clone().detach()], dim=1 + ) + + sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp + + sim_i2t_targets = ( + alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + ) + sim_t2i_targets = ( + alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + ) + + sim_i2t = image_feat @ text_feat_m_all / self.temp + sim_t2i = text_feat @ image_feat_m_all / self.temp + + loss_i2t = -torch.sum( + F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1 + ).mean() + loss_t2i = -torch.sum( + F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1 + ).mean() + + loss_itc = (loss_i2t + loss_t2i) / 2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) + + # Image-text Matching + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder( + encoder_input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + idxs = concat_all_gather(idx) + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()) + + image_feat_world = concat_all_gather(image_feat) + text_feat_world = concat_all_gather(text_feat) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_t2i.masked_fill_(mask, 0) + + image_embeds_world = all_gather_with_grad(image_embeds) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = concat_all_gather(encoder_input_ids) + att_mask_world = concat_all_gather(text.attention_mask) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + else: + with torch.no_grad(): + mask = torch.eq(idx, idx.t()) + + sim_i2t = image_feat @ text_feat.t() / self.temp + sim_t2i = text_feat @ image_feat.t() / self.temp + + weights_i2t = F.softmax(sim_i2t, dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image (from same rank) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text (from same rank) for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg], dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg], dim=0) + + image_embeds_all = torch.cat([image_embeds_neg, image_embeds], dim=0) + image_atts_all = torch.cat([image_atts, image_atts], dim=0) + + output_neg = self.text_encoder( + text_ids_all, + attention_mask=text_atts_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = torch.cat( + [ + output_pos.last_hidden_state[:, 0, :], + output_neg.last_hidden_state[:, 0, :], + ], + dim=0, + ) + itm_logits = self.itm_head(vl_embeddings) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(self.device) + loss_itm = F.cross_entropy(itm_logits, itm_labels) + + return BlipOutput( + loss=loss_itc + loss_itm, + loss_itc=loss_itc, + loss_itm=loss_itm, + sims=BlipSimilarity( + sim_i2t=sim_i2t, + sim_t2i=sim_t2i, + sim_i2t_m=sim_i2t_m, + sim_t2i_m=sim_t2i_m, + sim_i2t_targets=sim_i2t_targets, + sim_t2i_targets=sim_t2i_targets, + ), + intermediate_output=BlipIntermediateOutput( + image_embeds=image_embeds, + image_embeds_m=image_embeds_m, + text_embeds=text_embeds, + text_embeds_m=text_embeds_m, + encoder_output=output_pos, + encoder_output_neg=output_neg, + itm_logits=itm_logits, + itm_labels=itm_labels, + ), + ) + + def reset_queue_ptr(self): + self.queue_ptr = torch.zeros(1, dtype=torch.long) + + @classmethod + def from_config(cls, cfg=None): + # set from_pretrained=True to load weights for 'bert-base-uncased' + image_encoder = VisionTransformerEncoder.from_config(cfg) + text_encoder = XBertEncoder.from_config(cfg) + + embed_dim = cfg.get("embed_dim", 256) + momentum = cfg.get("momentum", 0.995) + alpha = cfg.get("alpha", 0.4) + negative_all_rank = cfg.get("negative_all_rank", False) + + queue_size = cfg.get("queue_size", 0) + max_txt_len = cfg.get("max_txt_len", 35) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + queue_size=queue_size, + alpha=alpha, + embed_dim=embed_dim, + momentum=momentum, + negative_all_rank=negative_all_rank, + max_txt_len=max_txt_len, + ) + + model.load_checkpoint_from_config(cfg) + model.reset_queue_ptr() + + return model + + def compute_sim_matrix(self, data_loader, task_cfg): + """ + Compute similarity i2t, t2i matrix for the given data loader. + """ + k_test = task_cfg.k_test + + return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test) diff --git a/lavis/models/blip_models/blip_vqa.py b/lavis/models/blip_models/blip_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6e4144b8243e251d4c1c6451f88f97ef641a8b --- /dev/null +++ b/lavis/models/blip_models/blip_vqa.py @@ -0,0 +1,375 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.models.base_model import tile +from lavis.models.blip_models.blip import BlipBase +from lavis.models.blip_models.blip_outputs import ( + BlipOutput, + BlipIntermediateOutput, +) +from lavis.models.med import XBertEncoder, XBertLMHeadDecoder +from lavis.models.vit import VisionTransformerEncoder + + +@registry.register_model("blip_vqa") +class BlipVQA(BlipBase): + """ + BLIP VQA models. + + Supported model types: + - base: vqa model initialized with pre-trained BLIP base model on 115M image-text pairs after CapFilt; not fine-tuned. + - vqav2: fine-tuned BLIP base model on VQA v2.0 dataset. + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("blip_vqa", "vqav2") + >>> model = load_model("blip_vqa", "okvqa") + >>> model = load_model("blip_vqa", "aokvqa") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "vqav2": "configs/models/blip_vqav2.yaml", + "okvqa": "configs/models/blip_vqa_okvqa.yaml", + "aokvqa": "configs/models/blip_vqa_aokvqa.yaml", + } + + def __init__(self, image_encoder, text_encoder, text_decoder, max_txt_len=35): + super().__init__() + self.tokenizer = self.init_tokenizer() + + self.visual_encoder = image_encoder + + self.text_encoder = text_encoder + self.text_decoder = text_decoder + + self.max_txt_len = max_txt_len + + def forward(self, samples): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (list): A list of strings, each string is a question + - answer (list): A list of strings, each string is an answer + - weight (torch.Tensor): A tensor used to weigh each answer in the loss computation. + The shape of the tensor is (sum(n_answers),) + - n_answers (torch.Tensor): A tensor shape (batch_size,) containing the number of answers + for each question in the batch. + + Returns: + A BlipOutput object containing loss and intermediate outputs, + see :class:`lavis.models.blip_outputs.BlipOutput` for more details. + + Examples: + ```python + >>> import torch + >>> from lavis.models import load_model + >>> model = load_model("blip_vqa") + >>> samples = { + ... "image": torch.rand(2, 3, 480, 480), + ... "text_input": ["What is this?", "What is that?"], + ... "answer": ["cat", "cat", "dog"], + ... "weight": torch.tensor([1.0, 1.0, 1.0]), + ... "n_answers": torch.tensor([2, 1]), + ... } + >>> output = model(samples) + >>> output.keys() + odict_keys(['intermediate_output', 'loss']) + >>> output.intermediate_output.keys() + odict_keys(['image_embeds', 'encoder_output', 'decoder_output', 'decoder_labels']) + ``` + """ + encoder_output, image_embeds = self.forward_encoder(samples) + loss, decoder_output, decoder_targets = self.forward_decoder( + samples=samples, encoder_out=encoder_output + ) + + return BlipOutput( + loss=loss, + intermediate_output=BlipIntermediateOutput( + image_embeds=image_embeds, + encoder_output=encoder_output, + decoder_output=decoder_output, + decoder_labels=decoder_targets, + ), + ) + + def forward_encoder(self, samples): + questions = samples["text_input"] + questions = self.tokenizer( + questions, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(self.device) + questions.input_ids[:, 0] = self.tokenizer.enc_token_id + samples.update({"tokenized_text": questions}) + + image_embeds = self.visual_encoder.forward_features(samples["image"]) + encoder_output = self.text_encoder.forward_automask( + tokenized_text=samples["tokenized_text"], visual_embeds=image_embeds + ) + + return encoder_output, image_embeds + + def forward_decoder(self, samples, encoder_out, **kwargs): + answers = self.tokenizer( + samples["answer"], padding="longest", return_tensors="pt" + ).to(self.device) + answers.input_ids[:, 0] = self.tokenizer.bos_token_id + answer_targets = answers.input_ids.masked_fill( + answers.input_ids == self.tokenizer.pad_token_id, -100 + ) + + question_states = [] + question_atts = [] + + question = samples["tokenized_text"] + question_output = encoder_out + + for b, n in enumerate(samples["n_answers"]): + question_states += [question_output.last_hidden_state[b]] * n + question_atts += [question.attention_mask[b]] * n + + question_states = torch.stack(question_states, dim=0) + question_atts = torch.stack(question_atts, dim=0) + + answer_output = self.text_decoder( + answers.input_ids, + attention_mask=answers.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + reduction="none", + ) + + loss = samples["weight"] * answer_output.loss + bsz = samples["image"].size(0) + + loss = loss.sum() / bsz + + return loss, answer_output, answer_targets + + def predict_answers( + self, + samples, + num_beams=3, + inference_method="rank", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + **kwargs + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (str or [str]): String or a list of strings, each string is a question. + The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first. + num_beams (int): Number of beams for beam search. 1 means no beam search. + inference_method (str): Inference method. One of "rank", "generate". + - If "rank", the model will return answers with the highest probability from the answer list. + - If "generate", the model will generate answers. + max_len (int): Maximum length of generated answers. + min_len (int): Minimum length of generated answers. + num_ans_candidates (int): Number of answer candidates, used to filter out answers with low probability. + answer_list (list): A list of strings, each string is an answer. + + Returns: + List: A list of strings, each string is an answer. + + Examples: + ```python + >>> from PIL import Image + >>> from lavis.models import load_model_and_preprocess + >>> model, vis_processors, txt_processors = load_model_and_preprocess("blip_vqa", "vqav2") + >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB") + >>> question = "Which city is this photo taken?" + >>> image = vis_processors["eval"](raw_image).unsqueeze(0) + >>> question = txt_processors["eval"](question) + >>> samples = {"image": image, "text_input": [question]} + >>> answers = model.predict_answers(samples) + >>> answers + ['singapore'] + >>> answer_list = ["Singapore", "London", "Palo Alto", "Tokyo"] + >>> answers = model.predict_answers(samples, answer_list=answer_list) + >>> answers + ['Singapore'] + ``` + """ + assert inference_method in [ + "rank", + "generate", + ], "Inference method must be one of 'rank' or 'generate', got {}.".format( + inference_method + ) + + if isinstance(samples["text_input"], str): + samples["text_input"] = [samples["text_input"]] + + assert len(samples["text_input"]) == samples["image"].size( + 0 + ), "The number of questions must be equal to the batch size." + + if inference_method == "generate": + return self._generate_answers( + samples, num_beams=num_beams, max_length=max_len, min_length=min_len + ) + elif inference_method == "rank": + assert answer_list is not None, "answer_list must be provided for ranking" + + num_ans_candidates = min(num_ans_candidates, len(answer_list)) + + return self._rank_answers( + samples, answer_list=answer_list, num_ans_candidates=num_ans_candidates + ) + + def _generate_answers(self, samples, num_beams=3, max_length=10, min_length=1): + encoder_out, _ = self.forward_encoder(samples) + + question_output = encoder_out + + question_states = question_output.last_hidden_state.repeat_interleave( + num_beams, dim=0 + ) + question_atts = torch.ones(question_states.size()[:-1], dtype=torch.long).to( + self.device + ) + + model_kwargs = { + "encoder_hidden_states": question_states, + "encoder_attention_mask": question_atts, + } + + bsz = samples["image"].size(0) + bos_ids = torch.full( + (bsz, 1), fill_value=self.tokenizer.bos_token_id, device=self.device + ) + + outputs = self.text_decoder.generate( + input_ids=bos_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs + ) + + # collect answers + answers = [] + for output in outputs: + answer = self.tokenizer.decode(output, skip_special_tokens=True) + answers.append(answer) + + return answers + + def _rank_answers(self, samples, answer_list, num_ans_candidates): + """ + Generate the first token of answers using decoder and select ${num_ans_candidates} + most probable ones. Then select answers from answer list, which start with the probable tokens. + Lastly, use the selected answers as the ground-truth labels for decoding and calculating LM loss. + Return the answers that minimize the losses as result. + + """ + answer_candidates = self.tokenizer( + answer_list, padding="longest", return_tensors="pt" + ).to(self.device) + answer_candidates.input_ids[:, 0] = self.tokenizer.bos_token_id + + answer_ids = answer_candidates.input_ids + answer_atts = answer_candidates.attention_mask + + question_output, _ = self.forward_encoder(samples) + question_states = question_output.last_hidden_state + + tokenized_question = samples["tokenized_text"] + question_atts = tokenized_question.attention_mask + + num_ques = question_states.size(0) + start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token + + start_output = self.text_decoder( + start_ids, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + return_dict=True, + reduction="none", + ) + logits = start_output.logits[:, 0, :] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:, 1] + prob_first_token = F.softmax(logits, dim=1).index_select( + dim=1, index=answer_first_token + ) + topk_probs, topk_ids = prob_first_token.topk(num_ans_candidates, dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids, dim=0) + input_atts = torch.cat(input_atts, dim=0) + + targets_ids = input_ids.masked_fill( + input_ids == self.tokenizer.pad_token_id, -100 + ) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, num_ans_candidates) + question_atts = tile(question_atts, 0, num_ans_candidates) + + output = self.text_decoder( + input_ids, + attention_mask=input_atts, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=targets_ids, + return_dict=True, + reduction="none", + ) + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques, num_ans_candidates) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] + + answers = [answer_list[max_id] for max_id in max_ids] + + return answers + + @classmethod + def from_config(cls, cfg=None): + image_encoder = VisionTransformerEncoder.from_config(cfg) + + # text encoder + multimodal encoder + text_encoder = XBertEncoder.from_config(cfg) + text_decoder = XBertLMHeadDecoder.from_config(cfg) + + max_txt_len = cfg.get("max_txt_len", 35) + + model = cls( + image_encoder=image_encoder, + text_encoder=text_encoder, + text_decoder=text_decoder, + max_txt_len=max_txt_len, + ) + + model.load_checkpoint_from_config(cfg) + + return model diff --git a/lavis/models/blip_models/nlvr_encoder.py b/lavis/models/blip_models/nlvr_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2b12b1c34c1c5d5a5acc43b9dc3f26ef876515c2 --- /dev/null +++ b/lavis/models/blip_models/nlvr_encoder.py @@ -0,0 +1,960 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import math +from typing import Tuple + +import torch +import torch.utils.checkpoint +from torch import Tensor, device, nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.models.bert.configuration_bert import BertConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + # hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1))) + hidden_states = self.merge_layer( + torch.cat([hidden_states0, hidden_states1], dim=-1) + ) + else: + hidden_states = (hidden_states0 + hidden_states1) / 2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + if is_cross_attention: + self.self0 = BertSelfAttention(config, is_cross_attention) + self.self1 = BertSelfAttention(config, is_cross_attention) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput( + config, + twin=is_cross_attention, + merge=(is_cross_attention and layer_num >= 6), + ) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states) == list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output( + [self_outputs0[0], self_outputs1[0]], hidden_states + ) + + outputs = (attention_output,) + self_outputs0[ + 1: + ] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention( + config, + is_cross_attention=self.config.add_cross_attention, + layer_num=layer_num, + ) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode == "multimodal": + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode="multimodal", + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode="multimodal", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds or encoder_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) diff --git a/lavis/models/clip_models/__init__.py b/lavis/models/clip_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..325e25255550a00fdd082deb82a8a0da567cadb0 --- /dev/null +++ b/lavis/models/clip_models/__init__.py @@ -0,0 +1,14 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +""" OpenAI pretrained model functions +Adapted from https://github.com/mlfoundations/open_clip and https://github.com/openai/CLIP. + +Originally MIT License, Copyright (c) 2021 OpenAI. +""" diff --git a/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz b/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/lavis/models/clip_models/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/lavis/models/clip_models/clip_outputs.py b/lavis/models/clip_models/clip_outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7bb032e01189d923c4e78b63bec94138d481f7 --- /dev/null +++ b/lavis/models/clip_models/clip_outputs.py @@ -0,0 +1,43 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +from dataclasses import dataclass + +from typing import Optional + +import torch +from transformers.modeling_outputs import ModelOutput + + +@dataclass +class ClipOutputFeatures(ModelOutput): + """ + Data class of features from AlbefFeatureExtractor. + + Args: + image_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional` + image_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional` + text_embeds: `torch.FloatTensor` of shape `(batch_size, 1, embed_dim)`, `optional` + text_features: `torch.FloatTensor` of shape `(batch_size, 1, feature_dim)`, `optional` + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + +@dataclass +class ClipOutput(ModelOutput): + intermediate_output: Optional[ClipOutputFeatures] = None + + logit_scale_exp: Optional[torch.FloatTensor] = None + + loss: Optional[torch.FloatTensor] = None diff --git a/lavis/models/clip_models/loss.py b/lavis/models/clip_models/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..da92413b1a26df994eb48c714a4c03be6c409fcf --- /dev/null +++ b/lavis/models/clip_models/loss.py @@ -0,0 +1,141 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + image_features, + text_features, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + else: + with torch.no_grad(): + all_image_features = hvd.allgather(image_features) + all_text_features = hvd.allgather(text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features = list( + all_image_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_image_features = torch.cat( + torch.distributed.nn.all_gather(image_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + else: + gathered_image_features = [ + torch.zeros_like(image_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_image_features, image_features) + dist.all_gather(gathered_text_features, text_features) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_image_features[rank] = image_features + gathered_text_features[rank] = text_features + all_image_features = torch.cat(gathered_image_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + + return all_image_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward(self, image_features, text_features, logit_scale): + device = image_features.device + if self.world_size > 1: + all_image_features, all_text_features = gather_features( + image_features, + text_features, + self.local_loss, + self.gather_with_grad, + self.rank, + self.world_size, + self.use_horovod, + ) + + if self.local_loss: + logits_per_image = logit_scale * image_features @ all_text_features.T + logits_per_text = logit_scale * text_features @ all_image_features.T + else: + logits_per_image = ( + logit_scale * all_image_features @ all_text_features.T + ) + logits_per_text = logits_per_image.T + else: + logits_per_image = logit_scale * image_features @ text_features.T + logits_per_text = logit_scale * text_features @ image_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_image.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + total_loss = ( + F.cross_entropy(logits_per_image, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + return total_loss diff --git a/lavis/models/clip_models/model.py b/lavis/models/clip_models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3d5651848c6935e584abab1f9ecaad873b5392 --- /dev/null +++ b/lavis/models/clip_models/model.py @@ -0,0 +1,1254 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +""" CLIP Model +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import datetime +import json +import logging +import os +import re +import time +import warnings +from collections import OrderedDict +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from lavis.common.registry import registry +from lavis.common.utils import get_abs_path +from lavis.models.base_model import BaseModel +from lavis.models.clip_models.clip_outputs import ClipOutput, ClipOutputFeatures +from lavis.models.clip_models.timm_model import TimmModel +from lavis.models.clip_models.transform import image_transform +from lavis.models.clip_models.utils import freeze_batch_norm_2d +from lavis.tasks.multimodal_classification import MultimodalClassificationTask +from torch import nn + +from .pretrained import ( + download_pretrained, + get_pretrained_url, + list_pretrained_tag_models, +) + +_MODEL_CONFIG_PATHS = [Path(__file__).parent.parent.parent / f"configs/models/clip/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +@dataclass +class CLIPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + + +@registry.register_model("clip") +@registry.register_model("clip_feature_extractor") +class CLIP(BaseModel): + PRETRAINED_MODEL_CONFIG_DICT = { + "ViT-B-32": "configs/models/clip_vit_base32.yaml", + "ViT-B-16": "configs/models/clip_vit_base16.yaml", + "ViT-L-14": "configs/models/clip_vit_large14.yaml", + "ViT-L-14-336": "configs/models/clip_vit_large14_336.yaml", + "RN50": "configs/models/clip_resnet50.yaml", + } + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + ): + from .tokenizer import tokenize + + super().__init__() + + self.tokenizer = tokenize + self._loss = None + + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + self.visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + act_layer = ( + nn.GELU + ) # so that text transformer doesn't use QuickGELU w/ timm models + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // 64 + self.visual = VisualTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + output_dim=embed_dim, + act_layer=act_layer, + ) + + self.transformer = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + + self.text_projection = nn.Parameter(torch.empty(text_cfg.width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.prompt_templates = openai_imagenet_template + self.classifier = None + + self.init_parameters() + + @property + def loss(self): + if self._loss is None: + from lavis.models.clip_models.loss import ClipLoss + from torch import distributed as dist + + self._loss = ClipLoss( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + local_loss=False, + gather_with_grad=False, + use_horovod=False, + ) + + return self._loss + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) + + if hasattr(self.visual, "init_parameters"): + self.visual.init_parameters() + + proj_std = (self.transformer.width**-0.5) * ( + (2 * self.transformer.layers) ** -0.5 + ) + attn_std = self.transformer.width**-0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock( + unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats + ) + + def encode_image(self, image): + return self.visual(image) + + def encode_text(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + # def forward(self, image, text): + def forward(self, samples): + image = samples.get("image") + text = samples.get("text_input") + + if text is not None: + text = self.tokenizer(text).to(self.device) + + if image is None: + return self.encode_text(text) + elif text is None: + return self.encode_image(image) + image_embeds = self.encode_image(image) + image_features = F.normalize(image_embeds, dim=-1) + + text_embeds = self.encode_text(text) + text_features = F.normalize(text_embeds, dim=-1) + + loss = self.loss(image_features, text_features, self.logit_scale.exp()) + + # return image_features, text_features, self.logit_scale.exp() + # return {"loss": loss} + return ClipOutput( + intermediate_output=ClipOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + ), + loss=loss, + logit_scale_exp=self.logit_scale.exp(), + ) + + def extract_features(self, samples): + """ + Extract features from the model for samples. + + Keys allowed are "image" and "text_input" in samples. + If either key is missing, the corresponding features are not extracted. + + Args: + samples: dict of samples to extract features from. + + Returns: + ClipOutputFeatures object with features for the samples. + """ + image = samples.get("image") + text = samples.get("text_input") + + if text is not None: + text = self.tokenizer(text).to(self.device) + + if image is None: + return self.encode_text(text) + elif text is None: + return self.encode_image(image) + + image_embeds = self.encode_image(image) + image_features = F.normalize(image_embeds, dim=-1) + + text_embeds = self.encode_text(text) + text_features = F.normalize(text_embeds, dim=-1) + + return ClipOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + ) + + def predict(self, samples): + image = samples["image"] + targets = samples["label"] + + image_features = self.encode_image(image) + image_features = F.normalize(image_features, dim=-1) + + logits = 100.0 * image_features @ self.classifier + + return {"predictions": logits, "targets": targets} + + def before_evaluation(self, dataset, task_type, **kwargs): + if task_type == MultimodalClassificationTask: + self.classifier = self.zero_shot_classifier( + classnames=dataset.classnames, + templates=self.prompt_templates, + ) + + def zero_shot_classifier(self, classnames, templates): + with torch.no_grad(): + zeroshot_weights = [] + for classname in classnames: + texts = [ + template(classname) for template in templates + ] # format with class + texts = self.tokenizer(texts).to(self.device) # tokenize + + class_embeddings = self.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device) + return zeroshot_weights + + @classmethod + def default_config_path(cls, model_type="base"): + model_type = "ViT-B-32" if model_type == "base" else model_type + + assert ( + model_type in cls.PRETRAINED_MODEL_CONFIG_DICT + ), "Unknown model type {}. \n Available types: {}".format( + model_type, cls.PRETRAINED_MODEL_CONFIG_DICT.keys() + ) + return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type]) + + @classmethod + def from_config(cls, cfg=None): + model_name = cfg.model_type + pretrained = cfg.pretrained + + precision = cfg.get("precision", "fp32") + + return create_model( + model_name=model_name, pretrained=pretrained, precision=precision + ) + + def zero_shot_predict(self, image_path, categories): + assert isinstance( + categories, list + ), f"categories must be a list, got {type(categories)}." + assert os.path.exists(image_path), f"File {image_path} does not exist." + + from lavis.processors.clip_processors import ClipImageEvalProcessor + from PIL import Image + + image_preprocess = ClipImageEvalProcessor() + image = image_preprocess(Image.open(image_path)).unsqueeze(0) + + text = self.tokenizer(categories) + + with torch.no_grad(): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + + text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) + + print("Label probs:", text_probs) # prints: [[1., 0., 0.]] + + def compute_sim_matrix(self, data_loader, **kwargs): + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_features = [] + + for i in range(0, num_text, text_bs): + + text = texts[i : min(num_text, i + text_bs)] + text_input = self.tokenizer(text).to(self.device) + + text_feat = self.encode_text(text_input) + text_feat = F.normalize(text_feat, dim=-1) + + text_features.append(text_feat) + + text_features = torch.cat(text_features, dim=0) + + image_features = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(self.device) + image_feat = self.encode_image(image) + image_feat = F.normalize(image_feat, dim=-1) + + image_features.append(image_feat) + + image_features = torch.cat(image_features, dim=0) + + sims_matrix_i2t = image_features @ text_features.t() + sims_matrix_t2i = sims_matrix_i2t.t() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return sims_matrix_i2t.cpu().numpy(), sims_matrix_t2i.cpu().numpy() + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model_from_openai_state_dict(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width**2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + convert_weights_to_fp16(model) + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,), + ), + ) + model.visual.image_size = image_size + return + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu"): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + + +def create_model( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + pretrained_image: bool = False, +): + model_name = model_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + + if pretrained.lower() == "openai": + logging.info(f"Loading pretrained {model_name} from OpenAI.") + model = load_openai_model(model_name, device=device, jit=jit) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + logging.info(f"No pretrained weights loaded for {model_name} model.") + if model_name in _MODEL_CONFIGS: + logging.info(f"Loading {model_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[model_name]) + else: + logging.error( + f"Model config for {model_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {model_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if pretrained_image: + if "timm_model_name" in model_cfg.get("vision_cfg", {}): + # pretrained weight loading for timm models set via vision_cfg + model_cfg["vision_cfg"]["timm_model_pretrained"] = True + else: + assert ( + False + ), "pretrained image towers currently only supported for timm models" + + model = CLIP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(model_name, pretrained) + if url: + checkpoint_path = download_pretrained(url) + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f"Loading pretrained {model_name} weights ({pretrained}).") + model.load_state_dict(load_state_dict(checkpoint_path)) + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {model_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {model_name}." + ) + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + pretrained_image=pretrained_image, + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, +): + """Load a CLIP model + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + Returns + ------- + model : torch.nn.Module + The CLIP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained(get_pretrained_url(name, "openai")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict() + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict(sd).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + model.float() + + # ensure image_size attr available at consistent location for both jit and non-jit + model.visual.image_size = model.input_resolution.item() + return model + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/lavis/models/clip_models/pics/CLIP.png b/lavis/models/clip_models/pics/CLIP.png new file mode 100644 index 0000000000000000000000000000000000000000..a1b5ec9171fd7a51e36e845a02304eb837142ba1 Binary files /dev/null and b/lavis/models/clip_models/pics/CLIP.png differ diff --git a/lavis/models/clip_models/pretrained.py b/lavis/models/clip_models/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d9834952263a0cd19c775d2576628e4ee580cd --- /dev/null +++ b/lavis/models/clip_models/pretrained.py @@ -0,0 +1,182 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_VITL14_336 = dict( + openai="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + tag = tag.lower() + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/lavis/models/clip_models/timm_model.py b/lavis/models/clip_models/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13bf04fc23e2691902f8b7da67ac99d19a696116 --- /dev/null +++ b/lavis/models/clip_models/timm_model.py @@ -0,0 +1,561 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +""" timm model adapter +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +import math +import warnings +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + + # from timm.models.layers.attention_pool2d import RotAttentionPool2d + # from timm.models.layers.attention_pool2d import ( + # AttentionPool2d as AbsAttentionPool2d, + # ) + +except ImportError as e: + timm = None + +from lavis.models.clip_models.utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_modules, group_parameters + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x + + +class RotAttentionPool2d(nn.Module): + """Attention based 2D feature pooling w/ rotary (relative) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW + """ + + def __init__( + self, + in_features: int, + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + embed_dim = embed_dim or in_features + out_features = out_features or in_features + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + assert embed_dim % num_heads == 0 + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim**-0.5 + self.pos_embed = RotaryEmbedding(self.head_dim) + + trunc_normal_(self.qkv.weight, std=in_features**-0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + x = x.reshape(B, -1, N).permute(0, 2, 1) + + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + + x = ( + self.qkv(x) + .reshape(B, N + 1, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = x[0], x[1], x[2] + + qc, q = q[:, :, :1], q[:, :, 1:] + sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) + q = apply_rot_embed(q, sin_emb, cos_emb) + q = torch.cat([qc, q], dim=2) + + kc, k = k[:, :, :1], k[:, :, 1:] + k = apply_rot_embed(k, sin_emb, cos_emb) + k = torch.cat([kc, k], dim=2) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +class AttentionPool2d(nn.Module): + """Attention based 2D feature pooling w/ learned (absolute) pos embedding. + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. + It was based on impl in CLIP by OpenAI + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. + """ + + def __init__( + self, + in_features: int, + feat_size: Union[int, Tuple[int, int]], + out_features: int = None, + embed_dim: int = None, + num_heads: int = 4, + qkv_bias: bool = True, + ): + super().__init__() + + embed_dim = embed_dim or in_features + out_features = out_features or in_features + assert embed_dim % num_heads == 0 + self.feat_size = to_2tuple(feat_size) + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, out_features) + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.scale = self.head_dim**-0.5 + + spatial_dim = self.feat_size[0] * self.feat_size[1] + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + trunc_normal_(self.pos_embed, std=in_features**-0.5) + trunc_normal_(self.qkv.weight, std=in_features**-0.5) + nn.init.zeros_(self.qkv.bias) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + assert self.feat_size[0] == H + assert self.feat_size[1] == W + x = x.reshape(B, -1, N).permute(0, 2, 1) + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) + + x = ( + self.qkv(x) + .reshape(B, N + 1, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = x[0], x[1], x[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + x = self.proj(x) + return x[:, 0] + + +def pixel_freq_bands( + num_bands: int, + max_freq: float = 224.0, + linear_bands: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + if linear_bands: + bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device) + else: + bands = 2 ** torch.linspace( + 0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device + ) + return bands * torch.pi + + +def inv_freq_bands( + num_bands: int, + temperature: float = 100000.0, + step: int = 2, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + inv_freq = 1.0 / ( + temperature + ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands) + ) + return inv_freq + + +def build_sincos2d_pos_embed( + feat_shape: List[int], + dim: int = 64, + temperature: float = 10000.0, + reverse_coord: bool = False, + interleave_sin_cos: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Args: + feat_shape: + dim: + temperature: + reverse_coord: stack grid order W, H instead of H, W + interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos + dtype: + device: + Returns: + """ + assert ( + dim % 4 == 0 + ), "Embed dimension must be divisible by 4 for sin-cos 2D position embedding" + pos_dim = dim // 4 + bands = inv_freq_bands( + pos_dim, temperature=temperature, step=1, dtype=dtype, device=device + ) + + if reverse_coord: + feat_shape = feat_shape[::-1] # stack W, H instead of H, W + grid = ( + torch.stack( + torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] + ) + ) + .flatten(1) + .transpose(0, 1) + ) + pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0) + # FIXME add support for unflattened spatial dim? + + stack_dim = ( + 2 if interleave_sin_cos else 1 + ) # stack sin, cos, sin, cos instead of sin sin cos cos + pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1) + return pos_emb + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + linear_bands: bool = False, + include_grid: bool = False, + concat_out: bool = True, + in_pixels: bool = True, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + dtype=dtype, + device=device, + ) + else: + bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + if in_pixels: + grid = torch.stack( + torch.meshgrid( + [ + torch.linspace(-1.0, 1.0, steps=s, device=device, dtype=dtype) + for s in feat_shape + ] + ), + dim=-1, + ) + else: + grid = torch.stack( + torch.meshgrid( + [torch.arange(s, device=device, dtype=dtype) for s in feat_shape] + ), + dim=-1, + ) + grid = grid.unsqueeze(-1) + pos = grid * bands + + pos_sin, pos_cos = pos.sin(), pos.cos() + out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos) + # FIXME torchscript doesn't like multiple return types, probably need to always cat? + if concat_out: + out = torch.cat(out, dim=-1) + return out + + +class FourierEmbed(nn.Module): + def __init__( + self, + max_res: int = 224, + num_bands: int = 64, + concat_grid=True, + keep_spatial=False, + ): + super().__init__() + self.max_res = max_res + self.num_bands = num_bands + self.concat_grid = concat_grid + self.keep_spatial = keep_spatial + self.register_buffer( + "bands", pixel_freq_bands(max_res, num_bands), persistent=False + ) + + def forward(self, x): + B, C = x.shape[:2] + feat_shape = x.shape[2:] + emb = build_fourier_pos_embed( + feat_shape, + self.bands, + include_grid=self.concat_grid, + dtype=x.dtype, + device=x.device, + ) + emb = emb.transpose(-1, -2).flatten(len(feat_shape)) + batch_expand = (B,) + (-1,) * (x.ndim - 1) + + # FIXME support nD + if self.keep_spatial: + x = torch.cat( + [x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1 + ) + else: + x = torch.cat( + [x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1 + ) + x = x.reshape(B, feat_shape.numel(), -1) + + return x + + +def rot(x): + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) + + +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): + if isinstance(x, torch.Tensor): + x = [x] + return [t * cos_emb + rot(t) * sin_emb for t in x] + + +def apply_rot_embed_split(x: torch.Tensor, emb): + split = emb.shape[-1] // 2 + return x * emb[:, :split] + rot(x) * emb[:, split:] + + +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_freq: float = 224, + linear_bands: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +): + """ + NOTE: shape arg should include spatial dim only + """ + feat_shape = torch.Size(feat_shape) + + sin_emb, cos_emb = build_fourier_pos_embed( + feat_shape, + bands=bands, + num_bands=dim // 4, + max_res=max_freq, + linear_bands=linear_bands, + concat_out=False, + device=device, + dtype=dtype, + ) + N = feat_shape.numel() + sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + +class RotaryEmbedding(nn.Module): + """Rotary position embedding + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not + been well tested, and will likely change. It will be moved to its own file. + The following impl/resources were referenced for this impl: + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py + * https://blog.eleuther.ai/rotary-embeddings/ + """ + + def __init__(self, dim, max_res=224, linear_bands: bool = False): + super().__init__() + self.dim = dim + self.register_buffer( + "bands", + pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), + persistent=False, + ) + + def get_embed(self, shape: List[int]): + return build_rotary_pos_embed(shape, self.bands) + + def forward(self, x): + # assuming channel-first tensor where spatial dim are >= 2 + sin_emb, cos_emb = self.get_embed(x.shape[2:]) + return apply_rot_embed(x, sin_emb, cos_emb) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/lavis/models/clip_models/tokenizer.py b/lavis/models/clip_models/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7e19124df29ace4b7e0599d1082e80d38aca0748 --- /dev/null +++ b/lavis/models/clip_models/tokenizer.py @@ -0,0 +1,203 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +""" CLIP tokenizer +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/lavis/models/clip_models/transform.py b/lavis/models/clip_models/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..e1104418cf7fb3f9bf169d52a0f8a051b9200c42 --- /dev/null +++ b/lavis/models/clip_models/transform.py @@ -0,0 +1,111 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torchvision.transforms.functional as F + + +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +class ResizeMaxSize(nn.Module): + def __init__( + self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0 + ): + super().__init__() + if not isinstance(max_size, int): + raise TypeError(f"Size should be int. Got {type(max_size)}") + self.max_size = max_size + self.interpolation = interpolation + self.fn = min if fn == "min" else min + self.fill = fill + + def forward(self, img): + if isinstance(img, torch.Tensor): + height, width = img.shape[:2] + else: + width, height = img.size + scale = self.max_size / float(max(height, width)) + if scale != 1.0: + new_size = tuple(round(dim * scale) for dim in (height, width)) + img = F.resize(img, new_size, self.interpolation) + pad_h = self.max_size - new_size[0] + pad_w = self.max_size - new_size[1] + img = F.pad( + img, + padding=[ + pad_w // 2, + pad_h // 2, + pad_w - pad_w // 2, + pad_h - pad_h // 2, + ], + fill=self.fill, + ) + return img + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_longest_max: bool = False, + fill_color: int = 0, +): + mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean + std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std + if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: + # for square size, pass size as int so that Resize() uses aspect preserving shortest edge + image_size = image_size[0] + + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + if resize_longest_max: + transforms = [ResizeMaxSize(image_size, fill=fill_color)] + else: + transforms = [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + ] + transforms.extend( + [ + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + return Compose(transforms) diff --git a/lavis/models/clip_models/utils.py b/lavis/models/clip_models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba9191a8c8043ed07d96144b4c10fcffb08cc9c --- /dev/null +++ b/lavis/models/clip_models/utils.py @@ -0,0 +1,49 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/mlfoundations/open_clip +""" + +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + Returns: + torch.nn.Module: Resulting module + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res diff --git a/lavis/models/clip_vit.py b/lavis/models/clip_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..48a5b804c559bb70d8d34c64d8f2ee6bfa4bdb1f --- /dev/null +++ b/lavis/models/clip_vit.py @@ -0,0 +1,254 @@ +from collections import OrderedDict +from itertools import repeat +import collections.abc +import math + +import torch +import torch.nn.functional as F +from torch import nn + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +from lavis.models.eva_vit import convert_weights_to_fp16 +from lavis.common.dist_utils import download_cached_file + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, use_grad_checkpointing=False): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, use_grad_checkpointing and i>12) for i in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, use_grad_checkpointing: bool): + super().__init__() + self.input_resolution = input_resolution + self.num_features = width + self.num_heads = heads + self.num_patches = (input_resolution // patch_size) ** 2 + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers-1, heads, use_grad_checkpointing=use_grad_checkpointing) + +# self.ln_final = LayerNorm(width) + + def forward(self, x: torch.Tensor): + + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + +# x = self.ln_final(x) + return x + + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse +to_2tuple = _ntuple(2) +def interpolate_pos_embed(model, state_dict, interpolation: str = 'bicubic', seq_dim=1): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('positional_embedding', None) + + grid_size = round((model.positional_embedding.shape[0] - 1) ** 0.5) + if old_pos_embed is None: + return + grid_size = to_2tuple(grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + print('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + align_corners=True, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['positional_embedding'] = new_pos_embed + + +def create_clip_vit_L(img_size=224,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + input_resolution=img_size, + patch_size=14, + width=1024, + layers=22, + heads=16, + use_grad_checkpointing=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/clip_vit_L.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + # print(incompatible_keys) + + if precision == "fp16": + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/lavis/models/eva_vit.py b/lavis/models/eva_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..082892b26c64405335f682288a0b145548ed43e4 --- /dev/null +++ b/lavis/models/eva_vit.py @@ -0,0 +1,442 @@ +# Based on EVA, BEIT, timm and DeiT code bases +# https://github.com/baaivision/EVA +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/microsoft/unilm/tree/master/beit +# https://github.com/facebookresearch/deit/ +# https://github.com/facebookresearch/dino +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from lavis.common.dist_utils import download_cached_file + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the orignal BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., window_size=None, attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None and init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): + super().__init__() + self.image_size = img_size + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + self.use_checkpoint = use_checkpoint + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) +# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) +# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None +# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) +# if isinstance(self.head, nn.Linear): +# trunc_normal_(self.head.weight, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() +# if isinstance(self.head, nn.Linear): +# self.head.weight.data.mul_(init_scale) +# self.head.bias.data.mul_(init_scale) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + return x +# x = self.norm(x) + +# if self.fc_norm is not None: +# t = x[:, 1:, :] +# return self.fc_norm(t.mean(1)) +# else: +# return x[:, 0] + + def forward(self, x): + x = self.forward_features(x) +# x = self.head(x) + return x + + def get_intermediate_layers(self, x): + x = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + features = [] + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + for blk in self.blocks: + x = blk(x, rel_pos_bias) + features.append(x) + + return features + + +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'].float() + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +# if isinstance(l, (nn.MultiheadAttention, Attention)): +# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: +# tensor = getattr(l, attr) +# if tensor is not None: +# tensor.data = tensor.data.half() + + model.apply(_convert_weights_to_fp16) + + +def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"): + model = VisionTransformer( + img_size=img_size, + patch_size=14, + use_mean_pooling=False, + embed_dim=1408, + depth=39, + num_heads=1408//88, + mlp_ratio=4.3637, + qkv_bias=True, + drop_path_rate=drop_path_rate, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + use_checkpoint=use_checkpoint, + ) + url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + cached_file = download_cached_file( + url, check_hash=False, progress=True + ) + state_dict = torch.load(cached_file, map_location="cpu") + interpolate_pos_embed(model,state_dict) + + incompatible_keys = model.load_state_dict(state_dict, strict=False) +# print(incompatible_keys) + + if precision == "fp16": +# model.to("cuda") + convert_weights_to_fp16(model) + return model \ No newline at end of file diff --git a/lavis/models/gpt_models/gpt_dialogue.py b/lavis/models/gpt_models/gpt_dialogue.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea769701c7119a3a11b43627519cf51b8f66adf --- /dev/null +++ b/lavis/models/gpt_models/gpt_dialogue.py @@ -0,0 +1,110 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +import torch.nn as nn +from lavis.common.registry import registry +from lavis.models.base_model import BaseModel +from torch.nn import CrossEntropyLoss, MSELoss +from transformers import GPT2LMHeadModel +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +@registry.register_model("gpt_dialogue") +class GPTDialogue(BaseModel, GPT2LMHeadModel): + + PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/gpt_dialogue_base.yaml"} + + def __init__(self, config, len_video_ft=4224): + + super().__init__(config) + + self.video_ff = nn.Linear(len_video_ft, config.n_embd) + self.video_ff_out = nn.Linear(config.n_embd, len_video_ft) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + samples, + past_key_values=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + input_embs = self.transformer.wte(samples["input_ids"]) + video_embs = self.video_ff(samples["video_fts"]) + input_embs = torch.cat([video_embs, input_embs], dim=1) + + transformer_outputs = self.transformer( + attention_mask=samples["attn_mask"], + token_type_ids=samples["token_type_ids"], + inputs_embeds=input_embs, + position_ids=position_ids, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if samples["labels"] is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = samples["labels"][..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-1) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if samples["video_fts"] is not None: + len_video_fts = samples["video_fts"].shape[1] + video_logits = self.video_ff_out(hidden_states[:, :len_video_fts, :]) + # Shift so that tokens < n predict n + shift_logits = video_logits[..., :-1, :].contiguous() + shift_labels = samples["video_fts"][..., 1:, :].contiguous() + # Flatten the tokens + loss_fct = MSELoss(reduction="mean") + video_loss = loss_fct(shift_logits, shift_labels) + + if loss is not None: + loss = loss + video_loss + else: + loss = video_loss + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @classmethod + def from_config(cls, cfg): + model = cls.__bases__[1].from_pretrained("gpt2") + model.resize_token_embeddings(cfg["len_tokenizer"]) + return model diff --git a/lavis/models/img2prompt_models/__init__.py b/lavis/models/img2prompt_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9d1ec8674c89a26e8a1d374c6ea80a16bc6c5b --- /dev/null +++ b/lavis/models/img2prompt_models/__init__.py @@ -0,0 +1,11 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch + + + diff --git a/lavis/models/img2prompt_models/img2prompt_vqa.py b/lavis/models/img2prompt_models/img2prompt_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..00cda00a8f029841771ef041c5321e45441fdfbd --- /dev/null +++ b/lavis/models/img2prompt_models/img2prompt_vqa.py @@ -0,0 +1,582 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import random + +import spacy +import torch +import torch.nn.functional as F +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from lavis.common.dist_utils import download_cached_file +from lavis.common.registry import registry +from lavis.models.base_model import BaseModel +from lavis.models.blip_models.blip_image_text_matching import compute_gradcam + +open_pos = ["NOUN", "VERB", "ADJ", "ADV", "NUM"] + + + +@registry.register_model("img2prompt_vqa") +class Img2PromptVQA(BaseModel): + """ + Img2Prompt_VQA model consists of three submodels for zero-shot VQA: + 1. Image-questioning matching model + 2. Image captioning model + 3. Large Language model + + Supported model types: + - base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base) + - large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large) + - 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b) + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("img2prompt_vqa", "base", is_eval=True) + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "base": "configs/models/img2prompt-vqa/img2prompt_vqa_base.yaml", + } + + def __init__( + self, + image_question_matching_model, + image_captioning_model, + question_generation_model, + question_generation_tokenizer, + offload_model=False, + ): + super().__init__() + + self.image_question_matching_model = image_question_matching_model + self.image_captioning_model = image_captioning_model + self.question_generation_model = question_generation_model + self.question_generation_tokenizer = question_generation_tokenizer + self.offload_model = offload_model + self.nlp = spacy.load("en_core_web_sm") + + def forward_itm(self, samples, block_num=7): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + block_num (int): The index of cross-attention block for gradcam computation. + + Returns: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + """ + image = samples["image"] + question = [text.strip("?") for text in samples["text_input"]] + tokenized_text = self.image_question_matching_model.tokenizer( + question, padding="longest", truncation=True, return_tensors="pt" + ).to(self.image_question_matching_model.device) + with torch.set_grad_enabled(True): + gradcams, _ = compute_gradcam( + model=self.image_question_matching_model, + visual_input=image, + text_input=question, + tokenized_text=tokenized_text, + block_num=block_num, + ) + + gradcams = [gradcam_[1] for gradcam_ in gradcams] + samples["gradcams"] = torch.stack(gradcams).reshape( + samples["image"].size(0), -1 + ) + + return samples + + def itm_rank(self, image_embeds, image_atts, encoder_input_ids, match_head="itm"): + # breakpoint() + encoder_input_ids = encoder_input_ids.clone() + encoder_input_ids = encoder_input_ids[:, self.prompt_length - 1 :] + text_attention_mask = (encoder_input_ids != self.tokenizer.pad_token_id).long() + + if match_head == "itm": + # encoder_input_ids = encoder_input_ids.clone() + encoder_input_ids[:, 0] = self.tokenizer.enc_token_id + output = self.text_encoder( + encoder_input_ids, + attention_mask=text_attention_mask, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + itm_output = self.itm_head(output.last_hidden_state[:, 0, :]) + return itm_output # , mask, token_length + + elif match_head == "itc": + encoder_input_ids[:, 0] = self.tokenizer.cls_token_id + text_output = self.text_encoder( + encoder_input_ids, + attention_mask=text_attention_mask, + return_dict=True, + mode="text", + ) + image_feat = F.normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 + ) + + sim = image_feat @ text_feat.t() + return sim + + def forward_cap( + self, + samples, + cap_max_length=20, + cap_min_length=0, + top_p=1, + top_k=50, + repetition_penalty=1.0, + num_captions=100, + num_patches=20, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + cap_max_length (int): The maximum length of the caption to be generated. + cap_min_length (int): The minimum length of the caption to be generated. + top_p (float): The cumulative probability for nucleus sampling. + top_k (float): The number of the highest probability tokens for top-k sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions generated for each image. + num_patches (int): Number of patches sampled for each image. + + Returns: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + - captions (nested list): A nested list of strings of total length batch_size * num_captions + """ + encoder_out = self.image_captioning_model.forward_encoder(samples) + captions = [[] for _ in range(encoder_out.size(0))] + + min_num_captions = 0 + + while min_num_captions < num_captions: + encoder_out_samples = [] + for i in range(num_captions): + patch_id = ( + torch.multinomial( + samples["gradcams"].to(self.image_captioning_model.device), + num_patches, + ).reshape(encoder_out.size(0), -1) + + 1 + ) + patch_id = ( + patch_id.sort(dim=1) + .values.unsqueeze(-1) + .expand(-1, -1, encoder_out.size(2)) + ) + encoder_out_sample = torch.gather(encoder_out, 1, patch_id) + encoder_out_samples.append(encoder_out_sample) + + stacked = torch.stack(encoder_out_samples, dim=1) + image_embeds = torch.flatten( + stacked, start_dim=0, end_dim=1 + ) # (bsz*num_seq, num_patch, dim) + + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + self.image_captioning_model.device + ) + model_kwargs = { + "encoder_hidden_states": image_embeds, + "encoder_attention_mask": image_atts, + } + + prompt = [self.image_captioning_model.prompt] * image_embeds.size(0) + prompt = self.image_captioning_model.tokenizer( + prompt, return_tensors="pt" + ).to(self.image_captioning_model.device) + prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + decoder_out = self.image_captioning_model.text_decoder.generate( + input_ids=prompt.input_ids, + max_length=cap_max_length, + min_length=cap_min_length, + do_sample=True, + top_p=top_p, + top_k=top_k, + num_return_sequences=1, + eos_token_id=self.image_captioning_model.tokenizer.sep_token_id, + pad_token_id=self.image_captioning_model.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs + ) + + itm_outputs = self.image_question_matching_model.itm_rank( + image_embeds, image_atts, encoder_input_ids=decoder_out + ) # caption filter + + outputs = self.image_captioning_model.tokenizer.batch_decode( + decoder_out, skip_special_tokens=True + ) + + for counter, output in enumerate(outputs): + ind = counter // num_captions + if len(captions[ind]) < num_captions: + caption = output[len(self.image_captioning_model.prompt) :] + overlap_caption = [1 for caps in captions[ind] if caption in caps] + # print(itm_outputs) + if ( + len(overlap_caption) == 0 and itm_outputs[counter] >= 0.5 + ): # image filter + captions[ind].append(caption) + + min_num_captions = min([len(i) for i in captions]) + + samples["captions"] = captions + + return samples + + def answer_extraction(self, caption, num_question_generation=30): + cap_use = "" + # print(caption) + caption = caption + ans_to_cap_dict = {} + answers = [] + for cap_idx, cap in enumerate(caption): + # print(cap) + cap_use += cap + cap = cap.strip().strip(".") + # print(cap) + cap = self.nlp(cap) + for token in cap: # Noun /Verb/Adj//NUM + if token.pos_ in open_pos: + if token.text.lower() not in ans_to_cap_dict: + ans_to_cap_dict[token.text.lower()] = [cap_idx] + else: + if cap_idx not in ans_to_cap_dict[token.text.lower()]: + ans_to_cap_dict[token.text.lower()].append(cap_idx) + answers.append(token.text) + for ent in cap.ents: + + if ent.text not in answers: + if ent.text.lower() not in ans_to_cap_dict: + ans_to_cap_dict[ent.text.lower()] = [cap_idx] + else: + if cap_idx not in ans_to_cap_dict[ent.text.lower()]: + ans_to_cap_dict[ent.text.lower()].append(cap_idx) + answers.append(ent.text) + for chunk in cap.noun_chunks: + if len(chunk.text.split()) < 4: + if chunk.text.lower() not in ans_to_cap_dict: + ans_to_cap_dict[chunk.text.lower()] = [cap_idx] + else: + if cap_idx not in ans_to_cap_dict[chunk.text.lower()]: + ans_to_cap_dict[chunk.text.lower()].append(cap_idx) + # print(chunk.text) + answers.append(chunk.text) + answers = sorted(answers, key=answers.count, reverse=True) + real_answers = [] + for i in answers: + i = i + "." + if i not in real_answers: + real_answers.append(i) + + contexts_for_question_generation = [] + answers = [] + for ans in real_answers[ + :num_question_generation + ]: # Generate questions for 30 answers with max frequencies. + contexts_for_question_generation.append( + "answer: %s context: %s." % (ans, cap_use) + ) + answers.append(ans) + contexts_for_question_generation.append( + "answer: %s context: %s." % ("yes.", cap_use) + ) + answers.append("yes.") + return contexts_for_question_generation, answers, ans_to_cap_dict + + def forward_qa_generation(self, samples): + caption = samples["captions"][0] + ( + contexts_for_question_generation, + answers, + ans_to_cap_dict, + ) = self.answer_extraction(caption) + inputs = self.question_generation_tokenizer( + contexts_for_question_generation, + padding="longest", + truncation=True, + max_length=2048, + return_tensors="pt", + ).to(self.device) + question_size = inputs.input_ids.shape[0] + cur_b = 0 + true_input_size = 10 + outputs_list = [] + while cur_b < question_size: + outputs = self.question_generation_model.generate( + input_ids=inputs.input_ids[cur_b : cur_b + true_input_size], + attention_mask=inputs.attention_mask[cur_b : cur_b + true_input_size], + num_beams=3, + max_length=30, + ) + questions = self.question_generation_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + outputs_list += questions + cur_b += true_input_size + questions = outputs_list + samples["questions"] = questions + samples["answers"] = answers + samples["ans_to_cap_dict"] = ans_to_cap_dict + # results.append({"question_id": ques_id, "question":questions,"answer":answers}) + return samples + + def create_context_prompt(self, samples, num_caps_per_img=30): + ans_dict_queid = samples["ans_to_cap_dict"] + # print(ans_dict_queid) + caption = samples["captions"][0] + answers = samples["answers"] + Context_Prompt = "" + mycontexts_id = [] + for idx in range(num_caps_per_img): + cap_id_list = ans_dict_queid.get( + answers[(len(answers) - 1 - idx) % len(answers)][:-1].lower(), [0] + ) + for cap_id in cap_id_list: + if cap_id not in mycontexts_id: + Context_Prompt += caption[cap_id] + mycontexts_id.append(cap_id) + break # We just take one cap for each answer + samples["Context_Prompt"] = Context_Prompt + return Context_Prompt + + def create_task_prompt( + self, samples, question_type="neural", num_question_per_img=30 + ): + syn_question_queid = samples["questions"] + syn_ans_queid = samples["answers"] + Task_Prompt = "" + for idx in range(num_question_per_img): + # if config['random_question']: + # qa_idx = random.randint(0, len(syn_question_queid) - 1) + # else: + qa_idx = idx + if ( + question_type != "rule" and num_question_per_img > 0 and idx < 1 + ): ## yes and no questions for vqav2 + # Task_Prompt += "Question:" + # Task_Prompt += syn_question_queid_next[-1] + # Task_Prompt += '\n' + # Task_Prompt += "Answer:no\n" + Task_Prompt += "Question:" + Task_Prompt += syn_question_queid[-1] + Task_Prompt += "\n" + Task_Prompt += "Answer:" + Task_Prompt += "yes\n" + Task_Prompt += "Question:Is this a toilet?\n" + Task_Prompt += "Answer:no\n" + if "question_type" == "rule": # Rule-Based Question Generation + Noun_Questions = [ + "What item is this in this picture?", + "What item is that in this picture?", + ] + + Verb_Questions = [ + "What action is being done in this picture?", + "Why is this item doing in this picture?", + "Which action is being taken in this picture?", + "What action is item doing in this picture?", + "What action is item performing in this picture?", + ] + + Adj_Questions = [ + "How to describe one item in this picture?", + "What is item's ADJ TYPE in this picture?", + "What is the ADJ TYPE in this picture?", + ] + + Task_Prompt += "Question:" + doc = self.nlp(syn_ans_queid[(qa_idx) % len(syn_ans_queid)][:-1].lower()) + if doc[-1].pos_ == "NOUN": + Task_Prompt += Noun_Questions[ + random.randint(0, len(Noun_Questions) - 1) + ] + elif doc[-1].pos_ == "VERB": + Task_Prompt += Verb_Questions[ + random.randint(0, len(Verb_Questions) - 1) + ] + elif doc[-1].pos_ == "ADJ": + Task_Prompt += Adj_Questions[ + random.randint(0, len(Adj_Questions) - 1) + ] + + Task_Prompt += "\n" + + Task_Prompt += "Answer:" + Task_Prompt += syn_ans_queid[(qa_idx) % len(syn_ans_queid)][:-1].lower() + Task_Prompt += "\n" + samples["Task_Prompt"] = Task_Prompt + # print(Task_Prompt) + return Task_Prompt + + def prompts_construction( + self, + samples, + question_type="neural", + num_caps_per_img=30, + num_question_per_img=30, + ): + Prompt = "Please reason the answer of the questions according to the given contexts.\n" + + Context_Prompt = self.create_context_prompt(samples, num_caps_per_img) + + Task_Prompt = self.create_task_prompt( + samples, question_type, num_question_per_img + ) + + Img2Prompt = ( + Prompt + + "Contexts:" + + Context_Prompt + + "\n" + + Task_Prompt + + "Question:" + + samples["text_input"][0] + + "\nAnswer:" + ) + return Img2Prompt + + def prepare_LLM_input( + self, + samples, + num_beams=1, + inference_method="generate", + max_len=20, + min_len=0, + internal_bsz_fid=1, + num_captions=50, + num_captions_fid=1, + cap_max_length=20, + cap_min_length=10, + top_k=50, + top_p=1, + repetition_penalty=1, + num_patches=20, + block_num=7, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (str or [str]): String or a list of strings, each string is a question. + The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first. + num_beams (int): Number of beams for beam search. 1 means no beam search. + inference_method (str): Inference method. Must be "generate". The model will generate answers. + max_len (int): Maximum length of generated answers. + min_len (int): Minimum length of generated answers. + internal_bsz_fid (int): Internal batch size when using FiD decoding. + num_captions (int): Number of captions generated for each image. + num_captions_fid (int): Number of captions concatenated with a question during FiD decoding. + cap_max_length (int): The maximum length of the caption to be generated. + cap_min_length (int): The minimum length of the caption to be generated. + top_k (float): The number of the highest probability tokens for top-k sampling. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_patches (int): Number of patches sampled for each image. + block_num (int): The index of cross-attention block for gradcam computation. + + Returns: + List: A list of strings, each string is an answer. + gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + captions (nested list): A nested list of strings of total length batch_size * num_captions + """ + assert inference_method in [ + "generate", + ], "Inference method must be 'generate', got {}.".format(inference_method) + + if isinstance(samples["text_input"], str): + samples["text_input"] = [samples["text_input"]] + + assert len(samples["text_input"]) == samples["image"].size( + 0 + ), "The number of questions must be equal to the batch size." + + samples = self.forward_itm(samples, block_num=block_num) + + samples = self.forward_cap( + samples, + cap_max_length=cap_max_length, + cap_min_length=cap_min_length, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + num_captions=num_captions, + num_patches=num_patches, + ) + + if self.offload_model: + samples["image"] = samples["image"].to("cpu") + self.image_question_matching_model.to("cpu") + self.image_captioning_model.to("cpu") + torch.cuda.empty_cache() + + pred_answers = self.forward_qa( + samples, + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + internal_bsz_fid=internal_bsz_fid, + num_captions=num_captions, + num_captions_fid=num_captions_fid, + ) + + if self.offload_model: + self.image_question_matching_model.to(self.question_answering_model.device) + self.image_captioning_model.to(self.question_answering_model.device) + + return pred_answers, samples["captions"], samples["gradcams"] + + @classmethod + def from_config(cls, model_config): + itm_config = model_config.image_question_matching_model + cap_config = model_config.image_captioning_model + + itm_cls = registry.get_model_class(itm_config.arch) + cap_cls = registry.get_model_class(cap_config.arch) + + image_question_matching_model = itm_cls.from_config(itm_config) + image_captioning_model = cap_cls.from_config(cap_config) + + question_generation_tokenizer = T5Tokenizer.from_pretrained( + "google/t5-large-lm-adapt" + ) + question_generation_model = T5ForConditionalGeneration.from_pretrained( + "google/t5-large-lm-adapt" + ) + cached_file = download_cached_file( + "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/projects/img2prompt/T5_large_QG.pth", + check_hash=False, + progress=True, + ) + checkpoint = torch.load(cached_file, map_location="cpu") + state_dict = checkpoint["model"] + question_generation_model.load_state_dict(state_dict) + model = cls( + image_question_matching_model=image_question_matching_model, + image_captioning_model=image_captioning_model, + question_generation_model=question_generation_model, + question_generation_tokenizer=question_generation_tokenizer, + offload_model=False, + ) + + return model diff --git a/lavis/models/med.py b/lavis/models/med.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3b326fa85e07175462e49c9e8e9da8423e4fed --- /dev/null +++ b/lavis/models/med.py @@ -0,0 +1,1416 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on huggingface code base + https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +from transformers import BatchEncoding, PreTrainedTokenizer + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig +from lavis.common.utils import get_abs_path + +from lavis.models.base_model import BaseEncoder + +logging.set_verbosity_error() +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + if config.add_type_embeddings: + self.token_type_embeddings = nn.Embedding( + config.type_vocab_size, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is not None: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + else: + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + + # compatibility for ALBEF and BLIP + try: + # ALBEF & ALPRO + fusion_layer = self.config.fusion_layer + add_cross_attention = ( + fusion_layer <= layer_num and self.config.add_cross_attention + ) + + self.fusion_layer = fusion_layer + except AttributeError: + # BLIP + self.fusion_layer = self.config.num_hidden_layers + add_cross_attention = self.config.add_cross_attention + + # if self.config.add_cross_attention: + if add_cross_attention: + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + # TODO line 482 in albef/models/xbert.py + # compatibility for ALBEF and BLIP + if mode in ["multimodal", "fusion"] and hasattr(self, "crossattention"): + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + if isinstance(encoder_hidden_states, list): + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states[ + (self.layer_num - self.fusion_layer) + % len(encoder_hidden_states) + ], + encoder_attention_mask[ + (self.layer_num - self.fusion_layer) + % len(encoder_hidden_states) + ], + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] + + else: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode="multimodal", + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + try: + # ALBEF + fusion_layer = self.config.fusion_layer + except AttributeError: + # BLIP + fusion_layer = self.config.num_hidden_layers + + if mode == "text": + start_layer = 0 + # output_layer = self.config.fusion_layer + output_layer = fusion_layer + + elif mode == "fusion": + # start_layer = self.config.fusion_layer + start_layer = fusion_layer + output_layer = self.config.num_hidden_layers + + elif mode == "multimodal": + start_layer = 0 + output_layer = self.config.num_hidden_layers + + # compatibility for ALBEF and BLIP + # for i in range(self.config.num_hidden_layers): + for i in range(start_layer, output_layer): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + # TODO pay attention to this. + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, seq_length, prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode="multimodal", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError( + "You have to specify either input_ids or inputs_embeds or encoder_embeds" + ) + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + # token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode="multimodal", + soft_labels=None, + alpha=0, + return_logits=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_embeds=encoder_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(prediction_scores, dim=-1) * soft_labels, dim=-1 + ) + loss_distill = loss_distill[labels != -100].mean() + masked_lm_loss = (1 - alpha) * masked_lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, attention_mask=None, **model_kwargs + ): + input_shape = input_ids.shape + effective_batch_size = input_shape[0] + + # add a dummy token + assert ( + self.config.pad_token_id is not None + ), "The PAD token should be defined for generation" + attention_mask = torch.cat( + [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], + dim=-1, + ) + dummy_token = torch.full( + (effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, dummy_token], dim=1) + + return {"input_ids": input_ids, "attention_mask": attention_mask} + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + mode="multimodal", + soft_labels=None, + alpha=0, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if soft_labels is not None: + loss_distill = -torch.sum( + F.log_softmax(shifted_prediction_scores, dim=-1) * soft_labels, dim=-1 + ) + loss_distill = (loss_distill * (labels != -100)).sum(1) + lm_loss = (1 - alpha) * lm_loss + alpha * loss_distill + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past=None, attention_mask=None, **model_kwargs + ): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class XBertLMHeadDecoder(BertLMHeadModel): + """ + This class decouples the decoder forward logic from the VL model. + In this way, different VL models can share this decoder as long as + they feed encoder_embeds as required. + """ + + @classmethod + def from_config(cls, cfg, from_pretrained=False): + + med_config_path = get_abs_path(cfg.get("med_config_path")) + med_config = BertConfig.from_json_file(med_config_path) + + if from_pretrained: + return cls.from_pretrained("bert-base-uncased", config=med_config) + else: + return cls(config=med_config) + + def generate_from_encoder( + self, + tokenized_prompt, + visual_embeds, + sep_token_id, + pad_token_id, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + **kwargs + ): + + #if not use_nucleus_sampling: + # num_beams = num_beams + # visual_embeds = visual_embeds.repeat_interleave(num_beams, dim=0) + + image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + model_kwargs = { + "encoder_hidden_states": visual_embeds, + "encoder_attention_mask": image_atts, + } + + if use_nucleus_sampling: + # nucleus sampling + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=1.1, + **model_kwargs + ) + else: + # beam search + outputs = self.generate( + input_ids=tokenized_prompt.input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=sep_token_id, + pad_token_id=pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs + ) + + return outputs + + +class XBertEncoder(BertModel, BaseEncoder): + @classmethod + def from_config(cls, cfg, from_pretrained=False): + + med_config_path = get_abs_path(cfg.get("med_config_path")) + med_config = BertConfig.from_json_file(med_config_path) + + if from_pretrained: + return cls.from_pretrained( + "bert-base-uncased", config=med_config, add_pooling_layer=False + ) + else: + return cls(config=med_config, add_pooling_layer=False) + + def forward_automask(self, tokenized_text, visual_embeds, **kwargs): + image_atts = torch.ones(visual_embeds.size()[:-1], dtype=torch.long).to( + self.device + ) + + text = tokenized_text + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + encoder_hidden_states=visual_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return text_output + + def forward_text(self, tokenized_text, **kwargs): + text = tokenized_text + token_type_ids = kwargs.get("token_type_ids", None) + + text_output = super().forward( + text.input_ids, + attention_mask=text.attention_mask, + token_type_ids=token_type_ids, + return_dict=True, + mode="text", + ) + + return text_output diff --git a/lavis/models/pnp_vqa_models/__init__.py b/lavis/models/pnp_vqa_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44178e5503d448c954785201b5261eaa0df71ec5 --- /dev/null +++ b/lavis/models/pnp_vqa_models/__init__.py @@ -0,0 +1,29 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch + + +def prepare_qa_input(sample, num_captions, num_captions_fid): + sample_question_captions = [] + + for question, captions in zip(sample['text_input'], sample['captions']): + assert isinstance(captions, list) + question_captions = [] + question_caption = '' + for cap_id, cap_ in enumerate(captions[0:num_captions]): + question_caption += (cap_.strip() + '. ') + if (cap_id + 1) != num_captions and ((cap_id + 1) % num_captions_fid == 0): + question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() + question_captions.append(question_caption) + question_caption = '' + if (cap_id + 1) == num_captions: + question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() + question_captions.append(question_caption) + sample_question_captions.append(question_captions) + + sample['question_captions'] = sample_question_captions diff --git a/lavis/models/pnp_vqa_models/pnp_unifiedqav2_fid.py b/lavis/models/pnp_vqa_models/pnp_unifiedqav2_fid.py new file mode 100644 index 0000000000000000000000000000000000000000..43da9ac1452aa2aa4d5de48409ced8628b34b093 --- /dev/null +++ b/lavis/models/pnp_vqa_models/pnp_unifiedqav2_fid.py @@ -0,0 +1,87 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on facebookresearch code base + https://github.com/facebookresearch/FiD +""" + +import torch +import torch.nn as nn +from lavis.common.registry import registry +from lavis.models.base_model import BaseModel +from lavis.common.utils import get_abs_path +from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration + + +@registry.register_model("pnp_unifiedqav2_fid") +class PNPUnifiedQAv2FiD(T5ForConditionalGeneration, BaseModel): + + PRETRAINED_MODEL_CONFIG_DICT = {} + + def __init__(self, config, model_path): + super().__init__(config) + + self.tokenizer = T5Tokenizer.from_pretrained(model_path) + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + if input_ids != None: + if input_ids.dim() == 3: + self.encoder.num_contexts = input_ids.size(1) + input_ids = input_ids.view(input_ids.size(0), -1) + if attention_mask != None: + attention_mask = attention_mask.view(attention_mask.size(0), -1) + + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs + ) + + def generate(self, input_ids, attention_mask, num_beams=1, min_length=0, max_length=20): + self.encoder.num_contexts = input_ids.size(1) + + return super().generate( + input_ids=input_ids.view(input_ids.size(0), -1), + attention_mask=attention_mask.view(attention_mask.size(0), -1), + num_beams=num_beams, + min_length=min_length, + max_length=max_length + ) + + def load_unifiedqa(self, state_dict): + self.load_state_dict(state_dict) + self.encoder = T5EncoderWrapper(self.encoder) + + @classmethod + def from_config(cls, cfg): + model_path = cfg.get('pretrained') + t5_config_path = get_abs_path(cfg.get("t5_config_path")) + t5_config = T5Config.from_json_file(t5_config_path) + model = cls(t5_config, model_path) + model.load_unifiedqa(T5ForConditionalGeneration.from_pretrained(model_path).state_dict()) + + return model + + +class T5EncoderWrapper(torch.nn.Module): + + def __init__(self, encoder): + super().__init__() + + self.encoder = encoder + self.block = self.encoder.block + self.parallelize = self.encoder.parallelize + self.main_input_name = encoder.main_input_name + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + bsz, total_length = input_ids.shape + context_length = total_length // self.num_contexts + input_ids = input_ids.view(bsz*self.num_contexts, context_length) + attention_mask = attention_mask.view(bsz*self.num_contexts, context_length) + outputs = self.encoder(input_ids, attention_mask, **kwargs) + outputs = (outputs[0].view(bsz, self.num_contexts*context_length, -1), ) + outputs[1:] + + return outputs \ No newline at end of file diff --git a/lavis/models/pnp_vqa_models/pnp_vqa.py b/lavis/models/pnp_vqa_models/pnp_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..59b9d888bdcb999ca65eabfda7c457b7041524c4 --- /dev/null +++ b/lavis/models/pnp_vqa_models/pnp_vqa.py @@ -0,0 +1,340 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +import torch.nn as nn +from itertools import chain +from lavis.common.registry import registry +from lavis.models.base_model import BaseModel +from torch.nn import CrossEntropyLoss, MSELoss +from transformers import T5ForConditionalGeneration +from lavis.models.pnp_vqa_models import prepare_qa_input +from lavis.models.blip_models.blip_image_text_matching import compute_gradcam +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions + + +@registry.register_model("pnp_vqa") +class PNPVQA(BaseModel): + """ + PNPVQA model consists of three submodels for zero-shot VQA: + 1. Image-questioning matching model + 2. Image captioning model + 3. Question answering model + + Supported model types: + - base: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-base) + - large: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-large) + - 3b: BLIPITM, BLIPCaption, PNPUnifiedQAv2FiD (t5-3b) + + Usage: + >>> from lavis.models import load_model + >>> model = load_model("pnp_vqa", "base", is_eval=True) + >>> model = load_model("pnp_vqa", "large", is_eval=True) + >>> model = load_model("pnp_vqa", "3b", is_eval=True) + """ + + PRETRAINED_MODEL_CONFIG_DICT = {"base": "configs/models/pnp-vqa/pnp_vqa_base.yaml", + "large": "configs/models/pnp-vqa/pnp_vqa_large.yaml", + "3b": "configs/models/pnp-vqa/pnp_vqa_3b.yaml", + } + + def __init__(self, image_question_matching_model, image_captioning_model, + question_answering_model, offload_model=False): + super().__init__() + + self.image_question_matching_model = image_question_matching_model + self.image_captioning_model = image_captioning_model + self.question_answering_model = question_answering_model + self.offload_model = offload_model + + def forward_itm(self, samples, block_num=7): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + block_num (int): The index of cross-attention block for gradcam computation. + + Returns: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + """ + image = samples['image'] + question = [text.strip('?') for text in samples['text_input']] + tokenized_text = self.image_question_matching_model.tokenizer(question, padding='longest', truncation=True, + return_tensors="pt").to(self.image_question_matching_model.device) + with torch.set_grad_enabled(True): + gradcams, _ = compute_gradcam(model=self.image_question_matching_model, + visual_input=image, + text_input=question, + tokenized_text=tokenized_text, + block_num=block_num) + + gradcams = [gradcam_[1] for gradcam_ in gradcams] + samples['gradcams'] = torch.stack(gradcams).reshape(samples['image'].size(0), -1) + + return samples + + def forward_cap( + self, + samples, + cap_max_length=20, + cap_min_length=0, + top_p=1, + top_k=50, + repetition_penalty=1.0, + num_captions=100, + num_patches=20, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + cap_max_length (int): The maximum length of the caption to be generated. + cap_min_length (int): The minimum length of the caption to be generated. + top_p (float): The cumulative probability for nucleus sampling. + top_k (float): The number of the highest probability tokens for top-k sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions generated for each image. + num_patches (int): Number of patches sampled for each image. + + Returns: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + - captions (nested list): A nested list of strings of total length batch_size * num_captions + """ + encoder_out = self.image_captioning_model.forward_encoder(samples) + captions = [[] for _ in range(encoder_out.size(0))] + + min_num_captions = 0 + + while min_num_captions < num_captions: + encoder_out_samples = [] + for i in range(num_captions): + patch_id = torch.multinomial(samples['gradcams'].to(self.image_captioning_model.device), + num_patches).reshape(encoder_out.size(0), -1) + 1 + patch_id = patch_id.sort(dim=1).values.unsqueeze(-1).expand(-1, -1, encoder_out.size(2)) + encoder_out_sample = torch.gather(encoder_out, 1, patch_id) + encoder_out_samples.append(encoder_out_sample) + + stacked = torch.stack(encoder_out_samples, dim=1) + image_embeds = torch.flatten(stacked, start_dim=0, end_dim=1) #(bsz*num_seq, num_patch, dim) + + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(self.image_captioning_model.device) + model_kwargs = { + "encoder_hidden_states": image_embeds, + "encoder_attention_mask": image_atts, + } + + prompt = [self.image_captioning_model.prompt] * image_embeds.size(0) + prompt = self.image_captioning_model.tokenizer(prompt, + return_tensors="pt").to(self.image_captioning_model.device) + prompt.input_ids[:, 0] = self.image_captioning_model.tokenizer.bos_token_id + prompt.input_ids = prompt.input_ids[:, :-1] + + decoder_out = self.image_captioning_model.text_decoder.generate( + input_ids=prompt.input_ids, + max_length=cap_max_length, + min_length=cap_min_length, + do_sample=True, + top_p=top_p, + top_k=top_k, + num_return_sequences=1, + eos_token_id=self.image_captioning_model.tokenizer.sep_token_id, + pad_token_id=self.image_captioning_model.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + outputs = self.image_captioning_model.tokenizer.batch_decode(decoder_out, skip_special_tokens=True) + + for counter, output in enumerate(outputs): + ind = counter//num_captions + if len(captions[ind]) < num_captions: + caption = output[len(self.image_captioning_model.prompt):] + overlap_caption = [1 for caps in captions[ind] if caption in caps] + if len(overlap_caption) == 0: + captions[ind].append(caption) + + min_num_captions = min([len(i) for i in captions]) + + samples['captions'] = captions + + return samples + + def forward_qa( + self, + samples, + num_beams=1, + max_len=20, + min_len=0, + internal_bsz_fid=1, + num_captions=100, + num_captions_fid=1, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - text_input (list): A list of strings of length batch_size + - gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + - captions (nested list): A nested list of strings of total length batch_size * num_captions + - question_captions (nested list): A nested list of concatenated strings of questions and captions + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_len (int): Maximum length of generated answers. + min_len (int): Minimum length of generated answers. + internal_bsz_fid (int): Internal batch size when using FiD decoding. + num_captions (int): Number of captions generated for each image. + num_captions_fid (int): Number of captions concatenated with a question during FiD decoding. + + Returns: + List: A list of strings, each string is an answer. + """ + prepare_qa_input(samples, num_captions=num_captions, num_captions_fid=num_captions_fid) + + pred_answers = [] + question_captions = samples['question_captions'] + question_captions_chunk = [question_captions[i:i + internal_bsz_fid] + for i in range(0, len(question_captions), internal_bsz_fid)] + question_captions_chunk = list(chain(*question_captions_chunk)) + + for question_caption in question_captions_chunk: + question_caption_input = self.question_answering_model.tokenizer(question_caption, padding='longest', + truncation=True, return_tensors="pt").to(self.question_answering_model.device) + + question_caption_input.input_ids = question_caption_input.input_ids.reshape( + internal_bsz_fid, -1, question_caption_input.input_ids.size(1)) + question_caption_input.attention_mask = question_caption_input.attention_mask.reshape( + internal_bsz_fid, -1, question_caption_input.attention_mask.size(1)) + + outputs = self.question_answering_model.generate(input_ids=question_caption_input.input_ids, + attention_mask=question_caption_input.attention_mask, + num_beams=num_beams, + min_length=min_len, + max_length=max_len, + ) + + for output in outputs: + pred_answer = self.question_answering_model.tokenizer.decode(output, skip_special_tokens=True) + pred_answers.append(pred_answer) + + return pred_answers + + def predict_answers( + self, + samples, + num_beams=1, + inference_method="generate", + max_len=20, + min_len=0, + internal_bsz_fid=1, + num_captions=50, + num_captions_fid=1, + cap_max_length=20, + cap_min_length=10, + top_k=50, + top_p=1, + repetition_penalty=1, + num_patches=50, + block_num=7, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W). Default H=480, W=480. + - text_input (str or [str]): String or a list of strings, each string is a question. + The number of questions must be equal to the batch size. If a single string, will be converted to a list of string, with length 1 first. + num_beams (int): Number of beams for beam search. 1 means no beam search. + inference_method (str): Inference method. Must be "generate". The model will generate answers. + max_len (int): Maximum length of generated answers. + min_len (int): Minimum length of generated answers. + internal_bsz_fid (int): Internal batch size when using FiD decoding. + num_captions (int): Number of captions generated for each image. + num_captions_fid (int): Number of captions concatenated with a question during FiD decoding. + cap_max_length (int): The maximum length of the caption to be generated. + cap_min_length (int): The minimum length of the caption to be generated. + top_k (float): The number of the highest probability tokens for top-k sampling. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_patches (int): Number of patches sampled for each image. + block_num (int): The index of cross-attention block for gradcam computation. + + Returns: + List: A list of strings, each string is an answer. + gradcams (torch.Tensor): A tensor of shape (batch_size, H*W) + captions (nested list): A nested list of strings of total length batch_size * num_captions + """ + assert inference_method in [ + "generate", + ], "Inference method must be 'generate', got {}.".format( + inference_method + ) + + if isinstance(samples["text_input"], str): + samples["text_input"] = [samples["text_input"]] + + assert len(samples["text_input"]) == samples["image"].size( + 0 + ), "The number of questions must be equal to the batch size." + + samples = self.forward_itm(samples, block_num=block_num) + + samples = self.forward_cap(samples, + cap_max_length=cap_max_length, + cap_min_length=cap_min_length, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + num_captions=num_captions, + num_patches=num_patches) + + if self.offload_model: + samples['image'] = samples['image'].to('cpu') + self.image_question_matching_model.to('cpu') + self.image_captioning_model.to('cpu') + torch.cuda.empty_cache() + + pred_answers = self.forward_qa(samples, + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + internal_bsz_fid=internal_bsz_fid, + num_captions=num_captions, + num_captions_fid=num_captions_fid) + + if self.offload_model: + self.image_question_matching_model.to(self.question_answering_model.device) + self.image_captioning_model.to(self.question_answering_model.device) + + return pred_answers, samples['captions'], samples['gradcams'] + + @classmethod + def from_config(cls, model_config): + itm_config = model_config.image_question_matching_model + cap_config = model_config.image_captioning_model + qa_config = model_config.question_answering_model + + itm_cls = registry.get_model_class(itm_config.arch) + cap_cls = registry.get_model_class(cap_config.arch) + qa_cls = registry.get_model_class(qa_config.arch) + + image_question_matching_model = itm_cls.from_config(itm_config) + image_captioning_model = cap_cls.from_config(cap_config) + question_answering_model = qa_cls.from_config(qa_config) + + model = cls(image_question_matching_model=image_question_matching_model, + image_captioning_model=image_captioning_model, + question_answering_model=question_answering_model, + offload_model= True if model_config.model_type == '3b' else False, + ) + + return model \ No newline at end of file diff --git a/lavis/models/timesformer/__init__.py b/lavis/models/timesformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1da75fc4c6577d8629ccb82f7a2b97b116c5b2bc --- /dev/null +++ b/lavis/models/timesformer/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" diff --git a/lavis/models/timesformer/conv2d_same.py b/lavis/models/timesformer/conv2d_same.py new file mode 100644 index 0000000000000000000000000000000000000000..ad23cc5a75e48d08137053c5e481a2feb8356b50 --- /dev/null +++ b/lavis/models/timesformer/conv2d_same.py @@ -0,0 +1,116 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" + +# Copyright 2020 Ross Wightman +# Conv2d w/ Same Padding + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + +import math +from typing import List, Tuple + +from .vit_utils import is_static_pad, get_padding + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( + iw, k[1], s[1], d[1] + ) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + value=value, + ) + return x + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == "same": + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == "valid": + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def conv2d_same( + x, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + dilation: Tuple[int, int] = (1, 1), + groups: int = 1, +): + x = pad_same(x, weight.shape[-2:], stride, dilation) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """Tensorflow like 'SAME' convolution wrapper for 2D convolutions""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias + ) + + def forward(self, x): + return conv2d_same( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop("padding", "") + kwargs.setdefault("bias", False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) diff --git a/lavis/models/timesformer/features.py b/lavis/models/timesformer/features.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ef6bb31fae6253a1e3f23a2570c290d5cdf432 --- /dev/null +++ b/lavis/models/timesformer/features.py @@ -0,0 +1,308 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" + +# Copyright 2020 Ross Wightman + +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn + + +class FeatureInfo: + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert "num_chs" in fi and fi["num_chs"] > 0 + assert "reduction" in fi and fi["reduction"] >= prev_reduction + prev_reduction = fi["reduction"] + assert "module" in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)""" + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [ + self.info[i] if keys is None else {k: self.info[i][k] for k in keys} + for i in idx + ] + else: + return ( + self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + ) + + def channels(self, idx=None): + """feature channels accessor""" + return self.get("num_chs", idx) + + def reduction(self, idx=None): + """feature reduction (output stride) accessor""" + return self.get("reduction", idx) + + def module_name(self, idx=None): + """feature module name accessor""" + return self.get("module", idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """Feature Hook Helper + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. This works quite well in eager Python but needs + redesign for torcscript. + """ + + def __init__(self, hooks, named_modules, out_map=None, default_hook_type="forward"): + # setup feature hooks + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h["module"] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h["hook_type"] if "hook_type" in h else default_hook_type + if hook_type == "forward_pre": + m.register_forward_pre_hook(hook_fn) + elif hook_type == "forward": + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + self._feature_outputs = defaultdict(OrderedDict) + + def _collect_output_hook(self, hook_id, *args): + x = args[ + -1 + ] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(("_".join(combined), ".".join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, "feature_info") + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = ( + out_map[i] if out_map is not None else feature_info.out_indices[i] + ) + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """Feature extractor with OrderedDict return + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + Arguments: + model (nn.Module): model from which we will extract the features + out_indices (tuple[int]): model output indices to extract features for + out_map (sequence): list or tuple specifying desired return id for each out index, + otherwise str(index) is used + feature_concat (bool): whether to concatenate intermediate features that are lists or tuples + vs select element [0] + flatten_sequential (bool): whether to flatten sequential modules assigned to model + """ + + def __init__( + self, + model, + out_indices=(0, 1, 2, 3, 4), + out_map=None, + feature_concat=False, + flatten_sequential=False, + ): + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.concat = feature_concat + self.return_layers = {} + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len( + return_layers + ), f"Return layers ({remaining}) are not present in model" + self.update(layers) + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """Feature extractor with list return + See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. + In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool. + """ + + def __init__( + self, + model, + out_indices=(0, 1, 2, 3, 4), + out_map=None, + feature_concat=False, + flatten_sequential=False, + ): + super(FeatureListNet, self).__init__( + model, + out_indices=out_indices, + out_map=out_map, + feature_concat=feature_concat, + flatten_sequential=flatten_sequential, + ) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """FeatureHookNet + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + + def __init__( + self, + model, + out_indices=(0, 1, 2, 3, 4), + out_map=None, + out_as_dict=False, + no_rewrite=False, + feature_concat=False, + flatten_sequential=False, + default_hook_type="forward", + ): + super(FeatureHookNet, self).__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.out_as_dict = out_as_dict + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, "reset_classifier"): # make sure classifier is removed? + model.reset_classifier(0) + layers["body"] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = { + f["module"]: f["hook_type"] if "hook_type" in f else default_hook_type + for f in self.feature_info.get_dicts() + } + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert ( + not remaining + ), f"Return layers ({remaining}) are not present in model" + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def forward(self, x): + for name, module in self.items(): + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.out_as_dict else list(out.values()) diff --git a/lavis/models/timesformer/helpers.py b/lavis/models/timesformer/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8ebd1415fff35cd0f1e365a6f666dcb2f04fee --- /dev/null +++ b/lavis/models/timesformer/helpers.py @@ -0,0 +1,400 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified model creation / weight loading / state_dict helpers + +import logging, warnings +import os +import math +from collections import OrderedDict + +import torch +import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F + + +def load_state_dict(checkpoint_path, use_ema=False): + if checkpoint_path and os.path.isfile(checkpoint_path): + checkpoint = torch.load(checkpoint_path, map_location="cpu") + state_dict_key = "state_dict" + if isinstance(checkpoint, dict): + if use_ema and "state_dict_ema" in checkpoint: + state_dict_key = "state_dict_ema" + if state_dict_key and state_dict_key in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + # strip `module.` prefix + name = k[7:] if k.startswith("module") else k + new_state_dict[name] = v + state_dict = new_state_dict + elif "model_state" in checkpoint: + state_dict_key = "model_state" + new_state_dict = OrderedDict() + for k, v in checkpoint[state_dict_key].items(): + # strip `model.` prefix + name = k[6:] if k.startswith("model") else k + new_state_dict[name] = v + state_dict = new_state_dict + else: + state_dict = checkpoint + logging.info( + "Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path) + ) + return state_dict + else: + logging.error("No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + state_dict = load_state_dict(checkpoint_path, use_ema) + model.load_state_dict(state_dict, strict=strict) + + +# def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): +# resume_epoch = None +# if os.path.isfile(checkpoint_path): +# checkpoint = torch.load(checkpoint_path, map_location='cpu') +# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: +# if log_info: +# _logger.info('Restoring model state from checkpoint...') +# new_state_dict = OrderedDict() +# for k, v in checkpoint['state_dict'].items(): +# name = k[7:] if k.startswith('module') else k +# new_state_dict[name] = v +# model.load_state_dict(new_state_dict) + +# if optimizer is not None and 'optimizer' in checkpoint: +# if log_info: +# _logger.info('Restoring optimizer state from checkpoint...') +# optimizer.load_state_dict(checkpoint['optimizer']) + +# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: +# if log_info: +# _logger.info('Restoring AMP loss scaler state from checkpoint...') +# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) + +# if 'epoch' in checkpoint: +# resume_epoch = checkpoint['epoch'] +# if 'version' in checkpoint and checkpoint['version'] > 1: +# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save + +# if log_info: +# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) +# else: +# model.load_state_dict(checkpoint) +# if log_info: +# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) +# return resume_epoch +# else: +# _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) +# raise FileNotFoundError() + + +def load_pretrained( + model, + cfg=None, + num_classes=1000, + in_chans=3, + filter_fn=None, + img_size=224, + num_frames=8, + num_patches=196, + attention_type="divided_space_time", + pretrained_model="", + strict=True, +): + if cfg is None: + cfg = getattr(model, "default_cfg") + if cfg is None or "url" not in cfg or not cfg["url"]: + logging.warning("Pretrained model URL is invalid, using random initialization.") + return + + if len(pretrained_model) == 0: + if cfg is None: + logging.info(f"loading from default config {model.default_cfg}.") + state_dict = model_zoo.load_url(cfg["url"], progress=False, map_location="cpu") + else: + try: + state_dict = load_state_dict(pretrained_model)["model"] + except: + state_dict = load_state_dict(pretrained_model) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + if in_chans == 1: + conv1_name = cfg["first_conv"] + logging.info( + "Converting first conv (%s) pretrained weights from 3 to 1 channel" + % conv1_name + ) + conv1_weight = state_dict[conv1_name + ".weight"] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I > 3: + assert conv1_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) + conv1_weight = conv1_weight.sum(dim=2, keepdim=False) + else: + conv1_weight = conv1_weight.sum(dim=1, keepdim=True) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + ".weight"] = conv1_weight + elif in_chans != 3: + conv1_name = cfg["first_conv"] + conv1_weight = state_dict[conv1_name + ".weight"] + conv1_type = conv1_weight.dtype + conv1_weight = conv1_weight.float() + O, I, J, K = conv1_weight.shape + if I != 3: + logging.warning( + "Deleting first conv (%s) from pretrained weights." % conv1_name + ) + del state_dict[conv1_name + ".weight"] + strict = False + else: + logging.info( + "Repeating first conv (%s) weights in channel dim." % conv1_name + ) + repeat = int(math.ceil(in_chans / 3)) + conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv1_weight *= 3 / float(in_chans) + conv1_weight = conv1_weight.to(conv1_type) + state_dict[conv1_name + ".weight"] = conv1_weight + + classifier_name = cfg["classifier"] + if num_classes == 1000 and cfg["num_classes"] == 1001: + # special case for imagenet trained models with extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + ".weight"] + state_dict[classifier_name + ".weight"] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + ".bias"] + state_dict[classifier_name + ".bias"] = classifier_bias[1:] + elif num_classes != state_dict[classifier_name + ".weight"].size(0): + # print('Removing the last fully connected layer due to dimensions mismatch ('+str(num_classes)+ ' != '+str(state_dict[classifier_name + '.weight'].size(0))+').', flush=True) + # completely discard fully connected for all other differences between pretrained and created model + del state_dict[classifier_name + ".weight"] + del state_dict[classifier_name + ".bias"] + strict = False + + ## Resizing the positional embeddings in case they don't match + logging.info( + f"Resizing spatial position embedding from {state_dict['pos_embed'].size(1)} to {num_patches + 1}" + ) + if num_patches + 1 != state_dict["pos_embed"].size(1): + pos_embed = state_dict["pos_embed"] + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + new_pos_embed = F.interpolate( + other_pos_embed, size=(num_patches), mode="nearest" + ) + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + state_dict["pos_embed"] = new_pos_embed + + ## Resizing time embeddings in case they don't match + if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): + logging.info( + f"Resizing temporal position embedding from {state_dict['time_embed'].size(1)} to {num_frames}" + ) + time_embed = state_dict["time_embed"].transpose(1, 2) + new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") + state_dict["time_embed"] = new_time_embed.transpose(1, 2) + + ## Initializing temporal attention + if attention_type == "divided_space_time": + new_state_dict = state_dict.copy() + for key in state_dict: + if "blocks" in key and "attn" in key: + new_key = key.replace("attn", "temporal_attn") + if not new_key in state_dict: + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[new_key] = state_dict[new_key] + if "blocks" in key and "norm1" in key: + new_key = key.replace("norm1", "temporal_norm1") + if not new_key in state_dict: + new_state_dict[new_key] = state_dict[key] + else: + new_state_dict[new_key] = state_dict[new_key] + state_dict = new_state_dict + + ## Loading the weights + model.load_state_dict(state_dict, strict=False) + + +def load_pretrained_imagenet( + model, + pretrained_model, + cfg=None, + ignore_classifier=True, + num_frames=8, + num_patches=196, + **kwargs, +): + import timm + + logging.info(f"Loading vit_base_patch16_224 checkpoints.") + loaded_state_dict = timm.models.vision_transformer.vit_base_patch16_224( + pretrained=True + ).state_dict() + + del loaded_state_dict["head.weight"] + del loaded_state_dict["head.bias"] + + ## Initializing temporal attention + new_state_dict = loaded_state_dict.copy() + for key in loaded_state_dict: + if "blocks" in key and "attn" in key: + new_key = key.replace("attn", "temporal_attn") + if not new_key in loaded_state_dict: + new_state_dict[new_key] = loaded_state_dict[key] + else: + new_state_dict[new_key] = loaded_state_dict[new_key] + if "blocks" in key and "norm1" in key: + new_key = key.replace("norm1", "temporal_norm1") + if not new_key in loaded_state_dict: + new_state_dict[new_key] = loaded_state_dict[key] + else: + new_state_dict[new_key] = loaded_state_dict[new_key] + + loaded_state_dict = new_state_dict + + loaded_keys = loaded_state_dict.keys() + model_keys = model.state_dict().keys() + + load_not_in_model = [k for k in loaded_keys if k not in model_keys] + model_not_in_load = [k for k in model_keys if k not in loaded_keys] + + toload = dict() + mismatched_shape_keys = [] + for k in model_keys: + if k in loaded_keys: + if model.state_dict()[k].shape != loaded_state_dict[k].shape: + mismatched_shape_keys.append(k) + else: + toload[k] = loaded_state_dict[k] + + logging.info("Keys in loaded but not in model:") + logging.info(f"In total {len(load_not_in_model)}, {sorted(load_not_in_model)}") + logging.info("Keys in model but not in loaded:") + logging.info(f"In total {len(model_not_in_load)}, {sorted(model_not_in_load)}") + logging.info("Keys in model and loaded, but shape mismatched:") + logging.info( + f"In total {len(mismatched_shape_keys)}, {sorted(mismatched_shape_keys)}" + ) + + model.load_state_dict(toload, strict=False) + + +def load_pretrained_kinetics( + model, + pretrained_model, + cfg=None, + ignore_classifier=True, + num_frames=8, + num_patches=196, + **kwargs, +): + if cfg is None: + cfg = getattr(model, "default_cfg") + if cfg is None or "url" not in cfg or not cfg["url"]: + logging.warning("Pretrained model URL is invalid, using random initialization.") + return + + assert ( + len(pretrained_model) > 0 + ), "Path to pre-trained Kinetics weights not provided." + + state_dict = load_state_dict(pretrained_model) + + classifier_name = cfg["classifier"] + if ignore_classifier: + + classifier_weight_key = classifier_name + ".weight" + classifier_bias_key = classifier_name + ".bias" + + state_dict[classifier_weight_key] = model.state_dict()[classifier_weight_key] + state_dict[classifier_bias_key] = model.state_dict()[classifier_bias_key] + + else: + raise NotImplementedError( + "[dxli] Not supporting loading Kinetics-pretrained ckpt with classifier." + ) + + ## Resizing the positional embeddings in case they don't match + if num_patches + 1 != state_dict["pos_embed"].size(1): + new_pos_embed = resize_spatial_embedding(state_dict, "pos_embed", num_patches) + state_dict["pos_embed"] = new_pos_embed + + ## Resizing time embeddings in case they don't match + if "time_embed" in state_dict and num_frames != state_dict["time_embed"].size(1): + state_dict["time_embed"] = resize_temporal_embedding( + state_dict, "time_embed", num_frames + ) + + ## Loading the weights + try: + model.load_state_dict(state_dict, strict=True) + logging.info("Succeeded in loading Kinetics pre-trained weights.") + except: + logging.error("Error in loading Kinetics pre-trained weights.") + + +def resize_spatial_embedding(state_dict, key, num_patches): + logging.info( + f"Resizing spatial position embedding from {state_dict[key].size(1)} to {num_patches + 1}" + ) + + pos_embed = state_dict[key] + + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + + new_pos_embed = F.interpolate(other_pos_embed, size=(num_patches), mode="nearest") + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + + return new_pos_embed + + +def resize_temporal_embedding(state_dict, key, num_frames): + logging.info( + f"Resizing temporal position embedding from {state_dict[key].size(1)} to {num_frames}" + ) + + time_embed = state_dict[key].transpose(1, 2) + new_time_embed = F.interpolate(time_embed, size=(num_frames), mode="nearest") + + return new_time_embed.transpose(1, 2) + + +def detach_variable(inputs): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__, + ) + + +def check_backward_validity(inputs): + if not any(inp.requires_grad for inp in inputs): + warnings.warn( + "None of the inputs have requires_grad=True. Gradients will be None" + ) diff --git a/lavis/models/timesformer/linear.py b/lavis/models/timesformer/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa849b93479d796c8cf0c4fde999ed028f9ae45 --- /dev/null +++ b/lavis/models/timesformer/linear.py @@ -0,0 +1,21 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +""" Linear layer (alternate definition) +""" +import torch +import torch.nn.functional as F +from torch import nn as nn + + +class Linear(nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None + return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) + else: + return F.linear(input, self.weight, self.bias) diff --git a/lavis/models/timesformer/vit.py b/lavis/models/timesformer/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..a6202b40636684b4a9c4feb8af5be227180ca966 --- /dev/null +++ b/lavis/models/timesformer/vit.py @@ -0,0 +1,634 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" + +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition + +import logging +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +from einops import rearrange +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +from .helpers import load_pretrained, load_pretrained_imagenet, load_pretrained_kinetics +from .vit_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + DropPath, + to_2tuple, + trunc_normal_, +) + + +def _cfg(url="", **kwargs): + return { + "url": url, + "num_classes": 1000, + "input_size": (3, 224, 224), + "pool_size": None, + "crop_pct": 0.9, + "interpolation": "bicubic", + "mean": IMAGENET_DEFAULT_MEAN, + "std": IMAGENET_DEFAULT_STD, + "first_conv": "patch_embed.proj", + "classifier": "head", + **kwargs, + } + + +default_cfgs = { + "vit_base_patch16_224": _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth", + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), +} + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + with_qkv=True, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.with_qkv = with_qkv + if self.with_qkv: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_drop = nn.Dropout(attn_drop) + + def forward(self, x): + B, N, C = x.shape + if self.with_qkv: + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + else: + qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute( + 0, 2, 1, 3 + ) + q, k, v = qkv, qkv, qkv + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + if self.with_qkv: + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + layer_num, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.1, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + attention_type="divided_space_time", + use_grad_checkpointing=False, + ): + super().__init__() + self.attention_type = attention_type + assert attention_type in [ + "divided_space_time", + "space_only", + "joint_space_time", + ] + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + # Temporal Attention Parameters + if self.attention_type == "divided_space_time": + self.temporal_norm1 = norm_layer(dim) + self.temporal_attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.temporal_fc = nn.Linear(dim, dim) + + # drop path + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # [dxli] + self.layer_num = layer_num + self.use_grad_checkpointing = use_grad_checkpointing + + if use_grad_checkpointing: + self.temporal_attn = checkpoint_wrapper(self.temporal_attn) + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, B, T, W): + num_spatial_tokens = (x.size(1) - 1) // T + H = num_spatial_tokens // W + + if self.attention_type in ["space_only", "joint_space_time"]: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + elif self.attention_type == "divided_space_time": + # Temporal + xt = x[:, 1:, :] + xt = rearrange(xt, "b (h w t) m -> (b h w) t m", b=B, h=H, w=W, t=T) + + temporal_attn_out = self.temporal_attn(self.temporal_norm1(xt)) + + res_temporal = self.drop_path(temporal_attn_out) + + res_temporal = rearrange( + res_temporal, "(b h w) t m -> b (h w t) m", b=B, h=H, w=W, t=T + ) + res_temporal = self.temporal_fc(res_temporal) + xt = x[:, 1:, :] + res_temporal + + # Spatial + init_cls_token = x[:, 0, :].unsqueeze(1) + cls_token = init_cls_token.repeat(1, T, 1) + cls_token = rearrange(cls_token, "b t m -> (b t) m", b=B, t=T).unsqueeze(1) + xs = xt + xs = rearrange(xs, "b (h w t) m -> (b t) (h w) m", b=B, h=H, w=W, t=T) + xs = torch.cat((cls_token, xs), 1) + + spatial_attn_out = self.attn(self.norm1(xs)) + res_spatial = self.drop_path(spatial_attn_out) + + # Taking care of CLS token + cls_token = res_spatial[:, 0, :] + cls_token = rearrange(cls_token, "(b t) m -> b t m", b=B, t=T) + # averaging for every frame + cls_token = torch.mean(cls_token, 1, True) + res_spatial = res_spatial[:, 1:, :] + res_spatial = rearrange( + res_spatial, "(b t) (h w) m -> b (h w t) m", b=B, h=H, w=W, t=T + ) + res = res_spatial + x = xt + + # Mlp + x = torch.cat((init_cls_token, x), 1) + torch.cat((cls_token, res), 1) + + x_res = x + + x = self.norm2(x) + # x = x + self.drop_path(self.mlp(self.norm2(x))) + + # MLP + mlp_out = self.mlp(x) + + x = x_res + self.drop_path(mlp_out) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + B, C, T, H, W = x.shape + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.proj(x) + W = x.size(-1) + x = x.flatten(2).transpose(1, 2) + return x, T, W + + +class VisionTransformer(nn.Module): + """Vision Transformere""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + hybrid_backbone=None, + norm_layer=nn.LayerNorm, + num_frames=8, + attention_type="divided_space_time", + dropout=0.0, + use_grad_checkpointing=False, + ckpt_layer=0, + ): + super().__init__() + + self.attention_type = attention_type + self.depth = depth + self.dropout = nn.Dropout(dropout) + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + # Positional Embeddings + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + if self.attention_type != "space_only": + self.time_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim)) + self.time_drop = nn.Dropout(p=drop_rate) + + # Attention Blocks + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, self.depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + Block( + layer_num=i, + use_grad_checkpointing=( + use_grad_checkpointing and i >= self.depth - ckpt_layer + ), + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + attention_type=self.attention_type, + ) + for i in range(self.depth) + ] + ) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = ( + nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + # initialization of temporal attention weights + if self.attention_type == "divided_space_time": + i = 0 + for m in self.blocks.modules(): + m_str = str(m) + if "Block" in m_str: + if i > 0: + nn.init.constant_(m.temporal_fc.weight, 0) + nn.init.constant_(m.temporal_fc.bias, 0) + i += 1 + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token", "time_embed"} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=""): + self.num_classes = num_classes + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def remove_classifier(self): + self.num_classes = 0 + self.head = None + + def forward_features(self, x): + B = x.shape[0] + x, T, W = self.patch_embed(x) + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # resizing the positional embeddings in case they don't match the input at inference + if x.size(1) != self.pos_embed.size(1): + pos_embed = self.pos_embed + cls_pos_embed = pos_embed[0, 0, :].unsqueeze(0).unsqueeze(1) + other_pos_embed = pos_embed[0, 1:, :].unsqueeze(0).transpose(1, 2) + P = int(other_pos_embed.size(2) ** 0.5) + H = x.size(1) // W + other_pos_embed = other_pos_embed.reshape(1, x.size(2), P, P) + new_pos_embed = F.interpolate(other_pos_embed, size=(H, W), mode="nearest") + new_pos_embed = new_pos_embed.flatten(2) + new_pos_embed = new_pos_embed.transpose(1, 2) + new_pos_embed = torch.cat((cls_pos_embed, new_pos_embed), 1) + x = x + new_pos_embed + else: + x = x + self.pos_embed + x = self.pos_drop(x) + + # Time Embeddings + if self.attention_type != "space_only": + cls_tokens = x[:B, 0, :].unsqueeze(1) + x = x[:, 1:] + x = rearrange(x, "(b t) n m -> (b n) t m", b=B, t=T) + # Resizing time embeddings in case they don't match + if T != self.time_embed.size(1): + time_embed = self.time_embed.transpose(1, 2) + new_time_embed = F.interpolate(time_embed, size=(T), mode="nearest") + new_time_embed = new_time_embed.transpose(1, 2) + x = x + new_time_embed + else: + x = x + self.time_embed + x = self.time_drop(x) + x = rearrange(x, "(b n) t m -> b (n t) m", b=B, t=T) + x = torch.cat((cls_tokens, x), dim=1) + + # Attention blocks + for blk in self.blocks: + x = blk(x, B, T, W) + + # Predictions for space-only baseline + if self.attention_type == "space_only": + x = rearrange(x, "(b t) n m -> b t n m", b=B, t=T) + x = torch.mean(x, 1) # averaging predictions for every frame + + x = self.norm(x) + + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _conv_filter(state_dict, patch_size=16): + """convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if "patch_embed.proj.weight" in k: + if v.shape[-1] != patch_size: + patch_size = v.shape[-1] + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +class vit_base_patch16_224(nn.Module): + def __init__(self, cfg, **kwargs): + super(vit_base_patch16_224, self).__init__() + self.pretrained = True + patch_size = 16 + self.model = VisionTransformer( + img_size=cfg.DATA.TRAIN_CROP_SIZE, + num_classes=cfg.MODEL.NUM_CLASSES, + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + num_frames=cfg.DATA.NUM_FRAMES, + attention_type=cfg.TIMESFORMER.ATTENTION_TYPE, + **kwargs, + ) + + self.attention_type = cfg.TIMESFORMER.ATTENTION_TYPE + self.model.default_cfg = default_cfgs["vit_base_patch16_224"] + self.num_patches = (cfg.DATA.TRAIN_CROP_SIZE // patch_size) * ( + cfg.DATA.TRAIN_CROP_SIZE // patch_size + ) + pretrained_model = cfg.TIMESFORMER.PRETRAINED_MODEL + if self.pretrained: + load_pretrained( + self.model, + num_classes=self.model.num_classes, + in_chans=kwargs.get("in_chans", 3), + filter_fn=_conv_filter, + img_size=cfg.DATA.TRAIN_CROP_SIZE, + num_patches=self.num_patches, + attention_type=self.attention_type, + pretrained_model=pretrained_model, + ) + + def forward(self, x): + x = self.model(x) + return x + + +class TimeSformer(nn.Module): + def __init__( + self, + image_size=224, + patch_size=16, + n_frms=8, + attn_drop_rate=0.0, + drop_path_rate=0.1, + drop_rate=0, + use_grad_ckpt=False, + ckpt_layer=0, + remove_classifier=True, + **kwargs, + ): + super(TimeSformer, self).__init__() + + self.img_size = image_size + self.patch_size = patch_size + self.num_frames = n_frms + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + self.use_grad_ckpt = use_grad_ckpt + self.ckpt_layer = ckpt_layer + + self.attention_type = "divided_space_time" + + logging.info( + f"Initializing TimeSformer with img_size={self.img_size}, patch_size={self.patch_size}, num_frames={self.num_frames}" + ) + + # will be ignored when loading official pretrained ckpt + self.num_classes = 400 + + self.model = VisionTransformer( + img_size=self.img_size, + num_classes=self.num_classes, + patch_size=self.patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop_rate=self.drop_rate, + attn_drop_rate=self.attn_drop_rate, + drop_path_rate=self.drop_path_rate, + num_frames=self.num_frames, + attention_type=self.attention_type, + use_grad_checkpointing=self.use_grad_ckpt, + ckpt_layer=self.ckpt_layer, + **kwargs, + ) + + if remove_classifier: + self.model.remove_classifier() + + self.model.default_cfg = default_cfgs[ + "vit_base_patch" + str(self.patch_size) + "_224" + ] + self.num_patches = (self.img_size // self.patch_size) * ( + self.img_size // self.patch_size + ) + + def forward(self, x): + x = self.model(x) + return x + + def forward_features(self, x): + # b, c, t, h, w = x.shape + x = self.model.forward_features(x) + + ## apply pooling + W = H = self.img_size // self.patch_size + T = self.num_frames + + cls_tokens = x[:, 0, :].unsqueeze(1) + other_tokens = x[:, 1:, :] + + x = rearrange(other_tokens, "b (h w t) m -> b t (h w) m", h=H, w=W, t=T) + + x = torch.mean(x, dim=1) + x = torch.cat((cls_tokens, x), dim=1) + + return x + + def load_state_dict(self, pretrained_ckpt_path): + logging.info( + "Loading TimeSformer checkpoints from {}".format(pretrained_ckpt_path) + ) + + if pretrained_ckpt_path == "vit_base_patch16_224": + load_ckpt_func = load_pretrained_imagenet + else: + load_ckpt_func = load_pretrained_kinetics + + load_ckpt_func( + self.model, + num_classes=self.model.num_classes, + in_chans=3, + filter_fn=_conv_filter, + img_size=self.img_size, + num_frames=self.num_frames, + num_patches=self.num_patches, + attention_type=self.attention_type, + pretrained_model=pretrained_ckpt_path, + ) diff --git a/lavis/models/timesformer/vit_utils.py b/lavis/models/timesformer/vit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5045d586495ca8ddab3f52d5f0a1b207fe263762 --- /dev/null +++ b/lavis/models/timesformer/vit_utils.py @@ -0,0 +1,189 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on https://github.com/facebookresearch/TimeSformer +""" + +# Copyright 2020 Ross Wightman +# Various utility functions + +import torch +import torch.nn as nn +import math +import warnings +import torch.nn.functional as F + +from itertools import repeat +import collections.abc as container_abcs + +DEFAULT_CROP_PCT = 0.875 +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (0.0167 * 255)] * 3) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == "same": + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == "valid": + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((int(math.ceil(x // s)) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +# def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): +def pad_same(x, k, s, d=(1, 1), value=0): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding( + iw, k[1], s[1], d[1] + ) + if pad_h > 0 or pad_w > 0: + x = F.pad( + x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + value=value, + ) + return x + + +def adaptive_pool_feat_mult(pool_type="avg"): + if pool_type == "catavgmax": + return 2 + else: + return 1 + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/lavis/models/vit.py b/lavis/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..f35b7bb6886f8e4455330cf7c330a18e57f11db7 --- /dev/null +++ b/lavis/models/vit.py @@ -0,0 +1,527 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + + Based on timm code base + https://github.com/rwightman/pytorch-image-models/tree/master/timm +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper +from lavis.models.base_model import BaseEncoder + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_grad_checkpointing=False, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + representation_size=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=None, + use_grad_checkpointing=False, + ckpt_layer=0, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_grad_checkpointing=( + use_grad_checkpointing and i >= depth - ckpt_layer + ), + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=0.02) + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:, : x.size(1), :] + x = self.pos_drop(x) + + for i, blk in enumerate(self.blocks): + x = blk(x, register_blk == i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=""): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""): + """Load weights from .npz checkpoints for official Google Brain Flax implementation""" + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and "opt/target/embedding/kernel" in w: + prefix = "opt/target/" + + if hasattr(model.patch_embed, "backbone"): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, "stem") + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_( + adapt_input_conv( + stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"]) + ) + ) + stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"])) + stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f"{prefix}block{i + 1}/unit{j + 1}/" + for r in range(3): + getattr(block, f"conv{r + 1}").weight.copy_( + _n2p(w[f"{bp}conv{r + 1}/kernel"]) + ) + getattr(block, f"norm{r + 1}").weight.copy_( + _n2p(w[f"{bp}gn{r + 1}/scale"]) + ) + getattr(block, f"norm{r + 1}").bias.copy_( + _n2p(w[f"{bp}gn{r + 1}/bias"]) + ) + if block.downsample is not None: + block.downsample.conv.weight.copy_( + _n2p(w[f"{bp}conv_proj/kernel"]) + ) + block.downsample.norm.weight.copy_( + _n2p(w[f"{bp}gn_proj/scale"]) + ) + block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"])) + embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"]) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"]) + ) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"])) + model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False)) + pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + model.pos_embed, + getattr(model, "num_tokens", 1), + model.patch_embed.grid_size, + ) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"])) + model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"])) + # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f"{prefix}Transformer/encoderblock_{i}/" + mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/" + block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"])) + block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"])) + block.attn.qkv.weight.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T + for n in ("query", "key", "value") + ] + ) + ) + block.attn.qkv.bias.copy_( + torch.cat( + [ + _n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) + for n in ("query", "key", "value") + ] + ) + ) + block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"])) + for r in range(2): + getattr(block.mlp, f"fc{r + 1}").weight.copy_( + _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"]) + ) + getattr(block.mlp, f"fc{r + 1}").bias.copy_( + _n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"]) + ) + block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"])) + block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"])) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate( + posemb_grid, size=gs_new, mode="bicubic", align_corners=False + ) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + + if orig_size != new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print( + "reshape position embedding from %d to %d" % (orig_size**2, new_size**2) + ) + + return new_pos_embed + else: + return pos_embed_checkpoint + + +class VisionTransformerEncoder(VisionTransformer, BaseEncoder): + @classmethod + def from_config(cls, cfg, from_pretrained=False): + + vit_type = cfg.get("vit_type", "base") + image_size = cfg.get("image_size", 384) + ckpt_layer = cfg.get("vit_ckpt_layer", 0) + drop_path_rate = cfg.get("vit_drop_path_rate", 0) + norm_layer_eps = cfg.get("vit_layer_norm_epsilon", -1) + use_grad_checkpointing = cfg.get("vit_grad_ckpt", False) + + if norm_layer_eps == -1: + norm_layer = None + else: + norm_layer = partial(nn.LayerNorm, eps=norm_layer_eps) + + # norm_layer=partial(nn.LayerNorm, eps=1e-6), + assert vit_type in ["base", "large"], "vit parameter must be base or large" + if vit_type == "base": + vision_width = 768 + visual_encoder = cls( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=12, + num_heads=12, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate, + norm_layer=norm_layer, + ) + + if from_pretrained: + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", + map_location="cpu", + check_hash=True, + ) + state_dict = checkpoint["model"] + state_dict["pos_embed"] = interpolate_pos_embed( + state_dict["pos_embed"], visual_encoder + ) + msg = visual_encoder.load_state_dict(state_dict, strict=False) + + elif vit_type == "large": + vision_width = 1024 + visual_encoder = cls( + img_size=image_size, + patch_size=16, + embed_dim=vision_width, + depth=24, + num_heads=16, + use_grad_checkpointing=use_grad_checkpointing, + ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate, + norm_layer=norm_layer, + ) + if from_pretrained: + from timm.models.helpers import load_custom_pretrained + from timm.models.vision_transformer import default_cfgs + + load_custom_pretrained( + visual_encoder, default_cfgs["vit_large_patch16_224_in21k"] + ) + + visual_encoder.vision_width = vision_width + return visual_encoder + + def forward_features(self, x, register_blk=-1): + return super().forward(x, register_blk) diff --git a/lavis/processors/__init__.py b/lavis/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6278b1ca247f148fafea080aa8aa0c3abc8c37d4 --- /dev/null +++ b/lavis/processors/__init__.py @@ -0,0 +1,53 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.processors.base_processor import BaseProcessor + +from lavis.processors.alpro_processors import ( + AlproVideoTrainProcessor, + AlproVideoEvalProcessor, +) +from lavis.processors.blip_processors import ( + BlipImageTrainProcessor, + Blip2ImageTrainProcessor, + BlipImageEvalProcessor, + BlipCaptionProcessor, +) +from lavis.processors.gpt_processors import ( + GPTVideoFeatureProcessor, + GPTDialogueProcessor, +) +from lavis.processors.clip_processors import ClipImageTrainProcessor + +from lavis.common.registry import registry + +__all__ = [ + "BaseProcessor", + # ALPRO + "AlproVideoTrainProcessor", + "AlproVideoEvalProcessor", + # BLIP + "BlipImageTrainProcessor", + "Blip2ImageTrainProcessor", + "BlipImageEvalProcessor", + "BlipCaptionProcessor", + "ClipImageTrainProcessor", + # GPT + "GPTVideoFeatureProcessor", + "GPTDialogueProcessor", +] + + +def load_processor(name, cfg=None): + """ + Example + + >>> processor = load_processor("alpro_video_train", cfg=None) + """ + processor = registry.get_processor_class(name).from_config(cfg) + + return processor diff --git a/lavis/processors/alpro_processors.py b/lavis/processors/alpro_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..fbcc03211f176c5f63529659674200eed759ec70 --- /dev/null +++ b/lavis/processors/alpro_processors.py @@ -0,0 +1,216 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import torch +from lavis.common.registry import registry +from lavis.datasets.data_utils import load_video +from lavis.processors import transforms_video +from lavis.processors.base_processor import BaseProcessor +from lavis.processors.randaugment import VideoRandomAugment +from lavis.processors import functional_video as F +from omegaconf import OmegaConf +from torchvision import transforms + +MAX_INT = registry.get("MAX_INT") + + +class AlproVideoBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None, n_frms=MAX_INT): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms_video.NormalizeVideo(mean, std) + + self.n_frms = n_frms + + +class ToUint8(object): + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.to(torch.uint8) + + def __repr__(self): + return self.__class__.__name__ + + +class ToTHWC(object): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (C, T, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, H, W, C) + """ + + def __init__(self): + pass + + def __call__(self, tensor): + return tensor.permute(1, 2, 3, 0) + + def __repr__(self): + return self.__class__.__name__ + + +class ResizeVideo(object): + def __init__(self, target_size, interpolation_mode="bilinear"): + self.target_size = target_size + self.interpolation_mode = interpolation_mode + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return F.resize(clip, self.target_size, self.interpolation_mode) + + def __repr__(self): + return self.__class__.__name__ + "(resize_size={0})".format(self.target_size) + + +@registry.register_processor("alpro_video_train") +class AlproVideoTrainProcessor(AlproVideoBaseProcessor): + def __init__( + self, + image_size=384, + mean=None, + std=None, + min_scale=0.5, + max_scale=1.0, + n_frms=MAX_INT, + ): + super().__init__(mean=mean, std=std, n_frms=n_frms) + + self.image_size = image_size + + self.transform = transforms.Compose( + [ + # Video size is (C, T, H, W) + transforms_video.RandomResizedCropVideo( + image_size, + scale=(min_scale, max_scale), + interpolation_mode="bicubic", + ), + transforms_video.RandomHorizontalFlipVideo(), + ToTHWC(), # C, T, H, W -> T, H, W, C + VideoRandomAugment( + 2, + 5, + augs=[ + "Identity", + "AutoContrast", + "Brightness", + "Sharpness", + "Equalize", + "ShearX", + "ShearY", + "TranslateX", + "TranslateY", + "Rotate", + ], + ), + ToUint8(), + transforms_video.ToTensorVideo(), # T, H, W, C -> C, T, H, W + self.normalize, + ] + ) + + def __call__(self, vpath): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: video clip after transforms. Size is (C, T, size, size). + """ + clip = load_video( + video_path=vpath, + n_frms=self.n_frms, + height=self.image_size, + width=self.image_size, + sampling="headtail", + ) + + return self.transform(clip) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 256) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + n_frms = cfg.get("n_frms", MAX_INT) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + n_frms=n_frms, + ) + + +@registry.register_processor("alpro_video_eval") +class AlproVideoEvalProcessor(AlproVideoBaseProcessor): + def __init__(self, image_size=256, mean=None, std=None, n_frms=MAX_INT): + super().__init__(mean=mean, std=std, n_frms=n_frms) + + self.image_size = image_size + + # Input video size is (C, T, H, W) + self.transform = transforms.Compose( + [ + # frames will be resized during decord loading. + ToUint8(), # C, T, H, W + ToTHWC(), # T, H, W, C + transforms_video.ToTensorVideo(), # C, T, H, W + self.normalize, # C, T, H, W + ] + ) + + def __call__(self, vpath): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: video clip after transforms. Size is (C, T, size, size). + """ + clip = load_video( + video_path=vpath, + n_frms=self.n_frms, + height=self.image_size, + width=self.image_size, + ) + + return self.transform(clip) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 256) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + n_frms = cfg.get("n_frms", MAX_INT) + + return cls(image_size=image_size, mean=mean, std=std, n_frms=n_frms) diff --git a/lavis/processors/base_processor.py b/lavis/processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c9d86859270a046623661a632587f2b3136b46 --- /dev/null +++ b/lavis/processors/base_processor.py @@ -0,0 +1,26 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from omegaconf import OmegaConf + + +class BaseProcessor: + def __init__(self): + self.transform = lambda x: x + return + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + return cls() + + def build(self, **kwargs): + cfg = OmegaConf.create(kwargs) + + return self.from_config(cfg) diff --git a/lavis/processors/blip_processors.py b/lavis/processors/blip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..abaafda9041167cfa0e11a08e0f70cca3c8eea56 --- /dev/null +++ b/lavis/processors/blip_processors.py @@ -0,0 +1,239 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from lavis.common.registry import registry +from lavis.processors.base_processor import BaseProcessor +from lavis.processors.randaugment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +class BlipImageBaseProcessor(BaseProcessor): + def __init__(self, mean=None, std=None): + if mean is None: + mean = (0.48145466, 0.4578275, 0.40821073) + if std is None: + std = (0.26862954, 0.26130258, 0.27577711) + + self.normalize = transforms.Normalize(mean, std) + + +@registry.register_processor("blip_caption") +class BlipCaptionProcessor(BaseProcessor): + def __init__(self, prompt="", max_words=50): + self.prompt = prompt + self.max_words = max_words + + def __call__(self, caption): + caption = self.prompt + self.pre_caption(caption) + + return caption + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + prompt = cfg.get("prompt", "") + max_words = cfg.get("max_words", 50) + + return cls(prompt=prompt, max_words=max_words) + + def pre_caption(self, caption): + caption = re.sub( + r"([.!\"()*#:;~])", + " ", + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + " ", + caption, + ) + caption = caption.rstrip("\n") + caption = caption.strip(" ") + + # truncate caption + caption_words = caption.split(" ") + if len(caption_words) > self.max_words: + caption = " ".join(caption_words[: self.max_words]) + + return caption + + +@registry.register_processor("blip_question") +class BlipQuestionProcessor(BaseProcessor): + def __init__(self, max_words=50): + self.max_words = max_words + + def __call__(self, question): + return self.pre_question(question) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + max_words = cfg.get("max_words", 50) + + return cls(max_words=max_words) + + def pre_question(self, question): + question = re.sub( + r"([.!\"()*#:;~])", + "", + question.lower(), + ) + question = question.rstrip(" ") + + # truncate question + question_words = question.split(" ") + if len(question_words) > self.max_words: + question = " ".join(question_words[: self.max_words]) + + return question + + +@registry.register_processor("blip_image_train") +class BlipImageTrainProcessor(BlipImageBaseProcessor): + def __init__( + self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 + ): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.RandomHorizontalFlip(), + RandomAugment( + 2, + 5, + isPIL=True, + augs=[ + "Identity", + "AutoContrast", + "Brightness", + "Sharpness", + "Equalize", + "ShearX", + "ShearY", + "TranslateX", + "TranslateY", + "Rotate", + ], + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 384) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("blip_image_eval") +class BlipImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=384, mean=None, std=None): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 384) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls(image_size=image_size, mean=mean, std=std) + + +@registry.register_processor("blip2_image_train") +class Blip2ImageTrainProcessor(BlipImageBaseProcessor): + def __init__( + self, image_size=364, mean=None, std=None, min_scale=0.5, max_scale=1.0 + ): + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + self.normalize, + ] + ) + + def __call__(self, item): + return self.transform(item) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 364) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.5) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) \ No newline at end of file diff --git a/lavis/processors/clip_processors.py b/lavis/processors/clip_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..08bd066de69e01c8a90ca9f8546ab046ae08cd78 --- /dev/null +++ b/lavis/processors/clip_processors.py @@ -0,0 +1,92 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from lavis.common.registry import registry +from lavis.processors.blip_processors import BlipImageBaseProcessor +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +@registry.register_processor("clip_image_train") +class ClipImageTrainProcessor(BlipImageBaseProcessor): + def __init__( + self, image_size=224, mean=None, std=None, min_scale=0.9, max_scale=1.0 + ): + + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + image_size, + scale=(min_scale, max_scale), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + transforms.ToTensor(), + self.normalize, + ] + ) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + min_scale = cfg.get("min_scale", 0.9) + max_scale = cfg.get("max_scale", 1.0) + + return cls( + image_size=image_size, + mean=mean, + std=std, + min_scale=min_scale, + max_scale=max_scale, + ) + + +@registry.register_processor("clip_image_eval") +class ClipImageEvalProcessor(BlipImageBaseProcessor): + def __init__(self, image_size=224, mean=None, std=None): + + super().__init__(mean=mean, std=std) + + self.transform = transforms.Compose( + [ + transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC), + transforms.CenterCrop(image_size), + _convert_to_rgb, + transforms.ToTensor(), + self.normalize, + ] + ) + + @classmethod + def from_config(cls, cfg=None): + if cfg is None: + cfg = OmegaConf.create() + + image_size = cfg.get("image_size", 224) + + mean = cfg.get("mean", None) + std = cfg.get("std", None) + + return cls( + image_size=image_size, + mean=mean, + std=std, + ) diff --git a/lavis/processors/functional_video.py b/lavis/processors/functional_video.py new file mode 100644 index 0000000000000000000000000000000000000000..597a29315d4e1a575e7209edb0618eeaf4fc024a --- /dev/null +++ b/lavis/processors/functional_video.py @@ -0,0 +1,121 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import warnings + +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + """ + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") + return clip[..., i : i + h, j : j + w] + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError( + f"target size should be tuple (height, width), instead got {target_size}" + ) + return torch.nn.functional.interpolate( + clip, size=target_size, mode=interpolation_mode, align_corners=False + ) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError( + "clip tensor should have data type uint8. Got %s" % str(clip.dtype) + ) + return clip.float().permute(3, 0, 1, 2) / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + Returns: + flipped clip (torch.tensor): Size is (C, T, H, W) + """ + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") + return clip.flip(-1) diff --git a/lavis/processors/gpt_processors.py b/lavis/processors/gpt_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..2fe6204c674a2f4b500a0b0ef79a9f02068dbb66 --- /dev/null +++ b/lavis/processors/gpt_processors.py @@ -0,0 +1,171 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import re + +from lavis.common.registry import registry +from lavis.processors.base_processor import BaseProcessor +from lavis.processors.randaugment import RandomAugment +from omegaconf import OmegaConf +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode +import os +from itertools import chain +import numpy as np +import torch +from transformers import GPT2Tokenizer + +SPECIAL_TOKENS_DICT = { + "bos_token": "", + "eos_token": "", + "additional_special_tokens": ["", "", "