import re import collections.abc import torch from torch.nn import functional as F np_str_obj_array_pattern = re.compile(r"[SaUO]") # map arg string of written list to list def str2list(config, list_args): for k, v in vars(config).items(): if k in list_args and v is not None and isinstance(v, str): v = v.replace("[", "") v = v.replace("]", "") config.__setattr__(k, list(map(int, v.split(",")))) return config def pad_tensor(x, l, pad_value=0): padlen = l - x.shape[0] pad = [0 for _ in range(2 * len(x.shape[1:]))] + [0, padlen] return F.pad(x, pad=pad, value=pad_value) def pad_collate(batch, pad_value=0): # modified default_collate from the official pytorch repo # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if len(elem.shape) > 0: sizes = [e.shape[0] for e in batch] m = max(sizes) if not all(s == m for s in sizes): # pad tensors which have a temporal dimension batch = [pad_tensor(e, m, pad_value=pad_value) for e in batch] if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError("Format not managed : {}".format(elem.dtype)) return pad_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, collections.abc.Mapping): return {key: pad_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type(*(pad_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = zip(*batch) return [pad_collate(samples) for samples in transposed] raise TypeError("Format not managed : {}".format(elem_type)) def get_ntrainparams(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)