|
import os |
|
import random |
|
from abc import ABC, abstractmethod |
|
from functools import lru_cache, wraps |
|
from typing import Callable, ParamSpec, TypeVar |
|
|
|
import numpy as np |
|
import torch |
|
|
|
IS_ROCM = torch.version.hip is not None |
|
|
|
|
|
class Platform(ABC): |
|
@classmethod |
|
def seed_everything(cls, seed: int) -> None: |
|
""" |
|
Set the seed of each random module. |
|
`torch.manual_seed` will set seed on all devices. |
|
|
|
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 |
|
""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
@abstractmethod |
|
def get_device_name(self, device_id: int = 0) -> str: ... |
|
|
|
@abstractmethod |
|
def is_cuda(self) -> bool: ... |
|
|
|
@abstractmethod |
|
def is_rocm(self) -> bool: ... |
|
|
|
|
|
class CudaPlatform(Platform): |
|
@classmethod |
|
@lru_cache(maxsize=8) |
|
def get_device_name(cls, device_id: int = 0) -> str: |
|
return torch.cuda.get_device_name(0) |
|
|
|
def is_cuda(self) -> bool: |
|
return True |
|
|
|
def is_rocm(self) -> bool: |
|
return False |
|
|
|
|
|
class RocmPlatform(Platform): |
|
@classmethod |
|
@lru_cache(maxsize=8) |
|
def get_device_name(cls, device_id: int = 0) -> str: |
|
return torch.cuda.get_device_name(device_id) |
|
|
|
def is_cuda(self) -> bool: |
|
return False |
|
|
|
def is_rocm(self) -> bool: |
|
return True |
|
|
|
|
|
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() |
|
|