Spaces:
Running
Running
import torch | |
class DisableTensorToDtype: | |
def __enter__(self): | |
self.original_to = torch.Tensor.to | |
def modified_to(tensor, *args, **kwargs): | |
# remove dtype from args if present | |
args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] | |
if "dtype" in kwargs: | |
kwargs.pop("dtype") | |
return self.original_to(tensor, *args, **kwargs) | |
torch.Tensor.to = modified_to | |
def __exit__(self, *args, **kwargs): | |
torch.Tensor.to = self.original_to | |