import torch class DisableTensorToDtype: def __enter__(self): self.original_to = torch.Tensor.to def modified_to(tensor, *args, **kwargs): # remove dtype from args if present args = [arg if not isinstance(arg, torch.dtype) else None for arg in args] if "dtype" in kwargs: kwargs.pop("dtype") return self.original_to(tensor, *args, **kwargs) torch.Tensor.to = modified_to def __exit__(self, *args, **kwargs): torch.Tensor.to = self.original_to