File size: 1,488 Bytes
132e594 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import os
import random
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
from typing import Callable, ParamSpec, TypeVar
import numpy as np
import torch
IS_ROCM = torch.version.hip is not None
class Platform(ABC):
@classmethod
def seed_everything(cls, seed: int) -> None:
"""
Set the seed of each random module.
`torch.manual_seed` will set seed on all devices.
Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@abstractmethod
def get_device_name(self, device_id: int = 0) -> str: ...
@abstractmethod
def is_cuda(self) -> bool: ...
@abstractmethod
def is_rocm(self) -> bool: ...
class CudaPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(0)
def is_cuda(self) -> bool:
return True
def is_rocm(self) -> bool:
return False
class RocmPlatform(Platform):
@classmethod
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
def is_cuda(self) -> bool:
return False
def is_rocm(self) -> bool:
return True
current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
|