Charbel Malo
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -38,7 +38,34 @@ parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=F
|
|
38 |
parser.add_argument(
|
39 |
"--colab", action="store_true", help="Enable colab mode", default=False
|
40 |
)
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
## ------------------------------ DEFAULTS ------------------------------
|
44 |
|
|
|
38 |
parser.add_argument(
|
39 |
"--colab", action="store_true", help="Enable colab mode", default=False
|
40 |
)
|
41 |
+
|
42 |
+
parser.add_argument("--device", default="cuda:0", type=str)
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
@spaces.GPU
|
46 |
+
def find_cuda():
|
47 |
+
# Check if CUDA_HOME or CUDA_PATH environment variables are set
|
48 |
+
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
49 |
+
|
50 |
+
if cuda_home and os.path.exists(cuda_home):
|
51 |
+
return cuda_home
|
52 |
+
|
53 |
+
# Search for the nvcc executable in the system's PATH
|
54 |
+
nvcc_path = shutil.which('nvcc')
|
55 |
+
|
56 |
+
if nvcc_path:
|
57 |
+
# Remove the 'bin/nvcc' part to get the CUDA installation path
|
58 |
+
cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
|
59 |
+
return cuda_path
|
60 |
+
|
61 |
+
return None
|
62 |
+
|
63 |
+
cuda_path = find_cuda()
|
64 |
+
|
65 |
+
if cuda_path:
|
66 |
+
print(f"CUDA installation found at: {cuda_path}")
|
67 |
+
else:
|
68 |
+
print("CUDA installation not found")
|
69 |
|
70 |
## ------------------------------ DEFAULTS ------------------------------
|
71 |
|