Spaces:
Running
Running
File size: 1,059 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from typing import Dict, Optional, Union
import torch
from accelerate import Accelerator
from diffusers.utils.torch_utils import is_compiled_module
def unwrap_model(accelerator: Accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
def align_device_and_dtype(
x: Union[torch.Tensor, Dict[str, torch.Tensor]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if isinstance(x, torch.Tensor):
if device is not None:
x = x.to(device)
if dtype is not None:
x = x.to(dtype)
elif isinstance(x, dict):
if device is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
if dtype is not None:
x = {k: align_device_and_dtype(v, device, dtype) for k, v in x.items()}
return x
def expand_tensor_dims(tensor, ndim):
while len(tensor.shape) < ndim:
tensor = tensor.unsqueeze(-1)
return tensor
|