jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
raw
history blame contribute delete
499 Bytes
import platform
import torch
def to_shared_memory(tensors: tuple[torch.Tensor]):
return [tensor.cpu() for tensor in tensors if tensor is not None]
""" if platform.system() == "Windows":
return [tensor.cpu() for tensor in tensors if tensor is not None]
return [tensor.share_memory_() for tensor in tensors if tensor is not None] """
def to_device(tensors: tuple[torch.Tensor], device: torch.device):
return [tensor.to(device) for tensor in tensors if tensor is not None]