Sapir's picture
working.
86b1a7e
raw
history blame
220 Bytes
from enum import Enum
class AccelerationType(Enum):
CPU = "cpu"
GPU = "gpu"
TPU = "tpu"
MPS = "mps"
def execute_graph() -> None:
if _acceleration_type == AccelerationType.TPU:
xm.mark_step()