mimbres commited on
Commit
badf747
1 Parent(s): 4f37a95

Update amt/src/config/config.py

Browse files
Files changed (1) hide show
  1. amt/src/config/config.py +7 -1
amt/src/config/config.py CHANGED
@@ -1,5 +1,11 @@
1
  """config.py"""
2
  import numpy as np
 
 
 
 
 
 
3
  # yapf: disable
4
  """
5
  audio_cfg:
@@ -132,7 +138,7 @@ shared_cfg = {
132
  "train_sub": 12, #20, # sub-batch size is per CPU worker
133
  "train_local": 24, #40, # local batch size is per GPU in DDP mode
134
  "validation": 64, # validation batch size is per GPU in DDP mode
135
- "test": 64,
136
  },
137
  "AUGMENTATION": {
138
  "train_random_amp_range": [0.8, 1.1], # min and max amplitude scaling factor
 
1
  """config.py"""
2
  import numpy as np
3
+ import torch
4
+
5
+ if torch.cuda.is_available():
6
+ TEST_BSZ = 64
7
+ else:
8
+ TEST_BSZ = 16
9
  # yapf: disable
10
  """
11
  audio_cfg:
 
138
  "train_sub": 12, #20, # sub-batch size is per CPU worker
139
  "train_local": 24, #40, # local batch size is per GPU in DDP mode
140
  "validation": 64, # validation batch size is per GPU in DDP mode
141
+ "test": TEST_BSZ,
142
  },
143
  "AUGMENTATION": {
144
  "train_random_amp_range": [0.8, 1.1], # min and max amplitude scaling factor