File size: 2,118 Bytes
d17b8f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import random
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
from typing import Callable, ParamSpec, TypeVar

import numpy as np
import torch

IS_ROCM = torch.version.hip is not None
IS_MPS = torch.backends.mps.is_available()


class Platform(ABC):
    @classmethod
    def seed_everything(cls, seed: int) -> None:
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    @abstractmethod
    def get_device_name(self, device_id: int = 0) -> str: ...

    @abstractmethod
    def is_cuda(self) -> bool: ...

    @abstractmethod
    def is_rocm(self) -> bool: ...

    @abstractmethod
    def is_mps(self) -> bool: ...


class CudaPlatform(Platform):
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(0)

    def is_cuda(self) -> bool:
        return True

    def is_rocm(self) -> bool:
        return False

    def is_mps(self) -> bool:
        return False


class RocmPlatform(Platform):
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    def is_cuda(self) -> bool:
        return False

    def is_rocm(self) -> bool:
        return True

    def is_mps(self) -> bool:
        return False


class MpsPlatform(Platform):
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    def is_cuda(self) -> bool:
        return False

    def is_rocm(self) -> bool:
        return False

    def is_mps(self) -> bool:
        return True

current_platform = (
    RocmPlatform() if IS_ROCM else
    MpsPlatform() if IS_MPS else
    CudaPlatform() if torch.cuda.is_available() else
    None
)