File size: 1,488 Bytes
132e594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import random
from abc import ABC, abstractmethod
from functools import lru_cache, wraps
from typing import Callable, ParamSpec, TypeVar

import numpy as np
import torch

IS_ROCM = torch.version.hip is not None


class Platform(ABC):
    @classmethod
    def seed_everything(cls, seed: int) -> None:
        """
        Set the seed of each random module.
        `torch.manual_seed` will set seed on all devices.

        Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
        """
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    @abstractmethod
    def get_device_name(self, device_id: int = 0) -> str: ...

    @abstractmethod
    def is_cuda(self) -> bool: ...

    @abstractmethod
    def is_rocm(self) -> bool: ...


class CudaPlatform(Platform):
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(0)

    def is_cuda(self) -> bool:
        return True

    def is_rocm(self) -> bool:
        return False


class RocmPlatform(Platform):
    @classmethod
    @lru_cache(maxsize=8)
    def get_device_name(cls, device_id: int = 0) -> str:
        return torch.cuda.get_device_name(device_id)

    def is_cuda(self) -> bool:
        return False

    def is_rocm(self) -> bool:
        return True


current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()