Willie Maddox
commited on
Update general.py (#823)
Browse filesFixes #822
`init_seeds` from `torch_utils` import is being overwritten by function `init_seeds` in `general.py`
- utils/general.py +3 -2
utils/general.py
CHANGED
@@ -23,7 +23,8 @@ from scipy.cluster.vq import kmeans
|
|
23 |
from scipy.signal import butter, filtfilt
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
-
from utils.torch_utils import init_seeds
|
|
|
27 |
|
28 |
# Set printoptions
|
29 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
@@ -55,7 +56,7 @@ def set_logging(rank=-1):
|
|
55 |
def init_seeds(seed=0):
|
56 |
random.seed(seed)
|
57 |
np.random.seed(seed)
|
58 |
-
|
59 |
|
60 |
|
61 |
def get_latest_run(search_dir='./runs'):
|
|
|
23 |
from scipy.signal import butter, filtfilt
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from utils.torch_utils import init_seeds as init_torch_seeds
|
27 |
+
from utils.torch_utils import is_parallel
|
28 |
|
29 |
# Set printoptions
|
30 |
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
|
|
56 |
def init_seeds(seed=0):
|
57 |
random.seed(seed)
|
58 |
np.random.seed(seed)
|
59 |
+
init_torch_seeds(seed=seed)
|
60 |
|
61 |
|
62 |
def get_latest_run(search_dir='./runs'):
|