Spaces:
Running
Running
File size: 552 Bytes
80ebcb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
class DisableTensorToDtype:
def __enter__(self):
self.original_to = torch.Tensor.to
def modified_to(tensor, *args, **kwargs):
# remove dtype from args if present
args = [arg if not isinstance(arg, torch.dtype) else None for arg in args]
if "dtype" in kwargs:
kwargs.pop("dtype")
return self.original_to(tensor, *args, **kwargs)
torch.Tensor.to = modified_to
def __exit__(self, *args, **kwargs):
torch.Tensor.to = self.original_to
|