balthou commited on
Commit
86d104b
·
1 Parent(s): 08480a6

disable numba imports

Browse files
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  interactive-pipe>=0.7.8
2
  torch>=2.0.0
3
  tqdm
 
4
 
5
 
 
1
  interactive-pipe>=0.7.8
2
  torch>=2.0.0
3
  tqdm
4
+ numba
5
 
6
 
scripts/save_deadleaves.py CHANGED
@@ -9,12 +9,16 @@ from pathlib import Path
9
  from time import perf_counter
10
  import matplotlib.pyplot as plt
11
  from typing import Tuple
12
-
13
  import numpy as np
14
  import torch
15
  import torch.nn.functional as F
16
  from torch.utils.data import Dataset
17
- from numba import cuda
 
 
 
 
18
  from tqdm import tqdm
19
  import argparse
20
  from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
 
9
  from time import perf_counter
10
  import matplotlib.pyplot as plt
11
  from typing import Tuple
12
+ import logging
13
  import numpy as np
14
  import torch
15
  import torch.nn.functional as F
16
  from torch.utils.data import Dataset
17
+ try:
18
+ from numba import cuda
19
+ except ImportError:
20
+ logging.warning("Numba not installed, GPU acceleration will not be available")
21
+ cuda = None
22
  from tqdm import tqdm
23
  import argparse
24
  from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
src/rstor/synthetic_data/dead_leaves_gpu.py CHANGED
@@ -3,7 +3,11 @@ from rstor.properties import SAMPLER_UNIFORM
3
  from typing import Tuple, Optional
4
  from rstor.synthetic_data.dead_leaves_cpu import define_dead_leaves_chart
5
  import numpy as np
6
- from numba import cuda
 
 
 
 
7
  import math
8
 
9
 
 
3
  from typing import Tuple, Optional
4
  from rstor.synthetic_data.dead_leaves_cpu import define_dead_leaves_chart
5
  import numpy as np
6
+ try:
7
+ from numba import cuda
8
+ except ImportError:
9
+ logging.warning("Numba not installed, GPU acceleration will not be available")
10
+ cuda = None
11
  import math
12
 
13
 
src/rstor/utils.py CHANGED
@@ -1,5 +1,8 @@
1
  import numpy as np
2
- import numba
 
 
 
3
  import torch
4
 
5
  THREADS_PER_BLOCK = 32 # 32 or 16
 
1
  import numpy as np
2
+ try:
3
+ import numba
4
+ except ImportError:
5
+ numba = None
6
  import torch
7
 
8
  THREADS_PER_BLOCK = 32 # 32 or 16