File size: 251 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 |
# mypy: allow-untyped-defs
import torch
from typing import TypeVar
T = TypeVar('T')
# returns if all are the same mode
def all_same_mode(modes):
return all(tuple(mode == modes[0] for mode in modes))
no_dispatch = torch._C._DisableTorchDispatch
|