DawnC commited on
Commit
dad3bfd
·
verified ·
1 Parent(s): 7b6cc71

Delete device_manager.py

Browse files
Files changed (1) hide show
  1. device_manager.py +0 -89
device_manager.py DELETED
@@ -1,89 +0,0 @@
1
- from functools import wraps
2
- import torch
3
- from huggingface_hub import HfApi
4
- import os
5
- import logging
6
-
7
- logging.basicConfig(level=logging.INFO)
8
- logger = logging.getLogger(__name__)
9
-
10
- class DeviceManager:
11
- _instance = None
12
-
13
- def __new__(cls):
14
- if cls._instance is None:
15
- cls._instance = super(DeviceManager, cls).__new__(cls)
16
- cls._instance._initialized = False
17
- return cls._instance
18
-
19
- def __init__(self):
20
- if self._initialized:
21
- return
22
-
23
- self._initialized = True
24
- self._current_device = None
25
- self._zero_gpu_available = None
26
-
27
- def check_zero_gpu_availability(self):
28
- try:
29
- api = HfApi()
30
- # 檢查環境變數或其他方式確認是否在 Spaces 環境
31
- if 'SPACE_ID' in os.environ:
32
- # 這裡可以添加更多具體的 ZeroGPU 可用性檢查
33
- self._zero_gpu_available = True
34
- return True
35
- except Exception as e:
36
- logger.warning(f"Error checking ZeroGPU availability: {e}")
37
-
38
- self._zero_gpu_available = False
39
- return False
40
-
41
- def get_optimal_device(self):
42
- if self._current_device is None:
43
- if self.check_zero_gpu_availability():
44
- self._current_device = torch.device('cuda')
45
- logger.info("Using ZeroGPU")
46
- else:
47
- self._current_device = torch.device('cpu')
48
- logger.info("Using CPU")
49
- return self._current_device
50
-
51
- def move_to_device(self, tensor_or_model):
52
- device = self.get_optimal_device()
53
- if hasattr(tensor_or_model, 'to'):
54
- return tensor_or_model.to(device)
55
- return tensor_or_model
56
-
57
- def device_handler(func):
58
- """Decorator for handling device placement"""
59
- @wraps(func)
60
- async def wrapper(*args, **kwargs):
61
- device_mgr = DeviceManager()
62
-
63
- # 處理輸入參數的設備轉換
64
- def process_arg(arg):
65
- if torch.is_tensor(arg) or hasattr(arg, 'to'):
66
- return device_mgr.move_to_device(arg)
67
- return arg
68
-
69
- processed_args = [process_arg(arg) for arg in args]
70
- processed_kwargs = {k: process_arg(v) for k, v in kwargs.items()}
71
-
72
- try:
73
- result = await func(*processed_args, **processed_kwargs)
74
-
75
- # 處理輸出結果的設備轉換
76
- if torch.is_tensor(result):
77
- return device_mgr.move_to_device(result)
78
- elif isinstance(result, tuple):
79
- return tuple(device_mgr.move_to_device(r) if torch.is_tensor(r) else r for r in result)
80
- return result
81
-
82
- except RuntimeError as e:
83
- if "out of memory" in str(e):
84
- logger.warning("GPU memory exceeded, falling back to CPU")
85
- device_mgr._current_device = torch.device('cpu')
86
- return await wrapper(*args, **kwargs)
87
- raise e
88
-
89
- return wrapper