Spaces:
Running
Running
import dataclasses | |
import warnings | |
import numpy as np | |
import torch | |
def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): | |
"""Change the device of object recursively""" | |
if isinstance(data, dict): | |
return { | |
k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() | |
} | |
elif dataclasses.is_dataclass(data) and not isinstance(data, type): | |
return type(data)( | |
*[ | |
to_device(v, device, dtype, non_blocking, copy) | |
for v in dataclasses.astuple(data) | |
] | |
) | |
# maybe namedtuple. I don't know the correct way to judge namedtuple. | |
elif isinstance(data, tuple) and type(data) is not tuple: | |
return type(data)( | |
*[to_device(o, device, dtype, non_blocking, copy) for o in data] | |
) | |
elif isinstance(data, (list, tuple)): | |
return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) | |
elif isinstance(data, np.ndarray): | |
return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) | |
elif isinstance(data, torch.Tensor): | |
return data.to(device, dtype, non_blocking, copy) | |
else: | |
return data | |
def force_gatherable(data, device): | |
"""Change object to gatherable in torch.nn.DataParallel recursively | |
The difference from to_device() is changing to torch.Tensor if float or int | |
value is found. | |
The restriction to the returned value in DataParallel: | |
The object must be | |
- torch.cuda.Tensor | |
- 1 or more dimension. 0-dimension-tensor sends warning. | |
or a list, tuple, dict. | |
""" | |
if isinstance(data, dict): | |
return {k: force_gatherable(v, device) for k, v in data.items()} | |
# DataParallel can't handle NamedTuple well | |
elif isinstance(data, tuple) and type(data) is not tuple: | |
return type(data)(*[force_gatherable(o, device) for o in data]) | |
elif isinstance(data, (list, tuple, set)): | |
return type(data)(force_gatherable(v, device) for v in data) | |
elif isinstance(data, np.ndarray): | |
return force_gatherable(torch.from_numpy(data), device) | |
elif isinstance(data, torch.Tensor): | |
if data.dim() == 0: | |
# To 1-dim array | |
data = data[None] | |
return data.to(device) | |
elif isinstance(data, float): | |
return torch.tensor([data], dtype=torch.float, device=device) | |
elif isinstance(data, int): | |
return torch.tensor([data], dtype=torch.long, device=device) | |
elif data is None: | |
return None | |
else: | |
warnings.warn(f"{type(data)} may not be gatherable by DataParallel") | |
return data | |