XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import re
import collections.abc
import torch
from torch.nn import functional as F
np_str_obj_array_pattern = re.compile(r"[SaUO]")
# map arg string of written list to list
def str2list(config, list_args):
for k, v in vars(config).items():
if k in list_args and v is not None and isinstance(v, str):
v = v.replace("[", "")
v = v.replace("]", "")
config.__setattr__(k, list(map(int, v.split(","))))
return config
def pad_tensor(x, l, pad_value=0):
padlen = l - x.shape[0]
pad = [0 for _ in range(2 * len(x.shape[1:]))] + [0, padlen]
return F.pad(x, pad=pad, value=pad_value)
def pad_collate(batch, pad_value=0):
# modified default_collate from the official pytorch repo
# https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if len(elem.shape) > 0:
sizes = [e.shape[0] for e in batch]
m = max(sizes)
if not all(s == m for s in sizes):
# pad tensors which have a temporal dimension
batch = [pad_tensor(e, m, pad_value=pad_value) for e in batch]
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError("Format not managed : {}".format(elem.dtype))
return pad_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, collections.abc.Mapping):
return {key: pad_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(pad_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = zip(*batch)
return [pad_collate(samples) for samples in transposed]
raise TypeError("Format not managed : {}".format(elem_type))
def get_ntrainparams(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)