TIMBOVILL commited on
Commit
7d7061d
·
verified ·
1 Parent(s): 811f42c

Upload device_detection.py

Browse files
Files changed (1) hide show
  1. src/modules/device_detection.py +37 -0
src/modules/device_detection.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Device detection module."""
2
+
3
+ import torch
4
+ import os
5
+ import tensorflow as tf
6
+
7
+ from modules.console_colors import ULTRASINGER_HEAD, red_highlighted, blue_highlighted
8
+
9
+ tensorflow_gpu_supported = False
10
+ pytorch_gpu_supported = False
11
+
12
+ def check_gpu_support() -> tuple[bool, bool]:
13
+ """Check worker device (e.g cuda or cpu) supported by tensorflow and pytorch"""
14
+
15
+ print(f"{ULTRASINGER_HEAD} Checking GPU support for {blue_highlighted('tensorflow')} and {blue_highlighted('pytorch')}.")
16
+
17
+ tensorflow_gpu_supported = False
18
+ pytorch_gpu_supported = False
19
+
20
+ gpus = tf.config.list_physical_devices('GPU')
21
+ if gpus:
22
+ tensorflow_gpu_supported = True
23
+ print(f"{ULTRASINGER_HEAD} {blue_highlighted('tensorflow')} - using {red_highlighted('cuda')} gpu.")
24
+ else:
25
+ print(f"{ULTRASINGER_HEAD} {blue_highlighted('tensorflow')} - there are no {red_highlighted('cuda')} devices available -> Using {red_highlighted('cpu')}.")
26
+ if os.name == 'nt':
27
+ print(f"{ULTRASINGER_HEAD} {blue_highlighted('tensorflow')} - versions above 2.10 dropped GPU support for Windows, refer to the readme for possible solutions.")
28
+
29
+ pytorch_gpu_supported = torch.cuda.is_available()
30
+ if not pytorch_gpu_supported:
31
+ print(
32
+ f"{ULTRASINGER_HEAD} {blue_highlighted('pytorch')} - there are no {red_highlighted('cuda')} devices available -> Using {red_highlighted('cpu')}."
33
+ )
34
+ else:
35
+ print(f"{ULTRASINGER_HEAD} {blue_highlighted('pytorch')} - using {red_highlighted('cuda')} gpu.")
36
+
37
+ return 'cuda' if tensorflow_gpu_supported else 'cpu', 'cuda' if pytorch_gpu_supported else 'cpu'