PyTorch
ssl-aasist
custom_code
ash56 commited on
Commit
90b7baa
·
verified ·
1 Parent(s): db4e88d

Update model_hf.py

Browse files
Files changed (1) hide show
  1. model_hf.py +2 -2
model_hf.py CHANGED
@@ -433,13 +433,13 @@ class Model(PreTrainedModel,nn.Module):
433
  config_class = SSLConfig
434
  def __init__(self,config):
435
  super().__init__(config)
436
- # self.model_device ='cuda' if torch.cuda.is_available() else 'cpu'
437
  # AASIST parameters
438
  filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
439
  gat_dims = [64, 32]
440
  pool_ratios = [0.5, 0.5, 0.5, 0.5]
441
  temperatures = [2.0, 2.0, 100.0, 100.0]
442
- self.model_device = device
443
 
444
 
445
  ####
 
433
  config_class = SSLConfig
434
  def __init__(self,config):
435
  super().__init__(config)
436
+ self.model_device ='cuda' if torch.cuda.is_available() else 'cpu'
437
  # AASIST parameters
438
  filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
439
  gat_dims = [64, 32]
440
  pool_ratios = [0.5, 0.5, 0.5, 0.5]
441
  temperatures = [2.0, 2.0, 100.0, 100.0]
442
+ # self.model_device = device
443
 
444
 
445
  ####