Update model_hf.py
Browse files- model_hf.py +2 -2
model_hf.py
CHANGED
@@ -433,13 +433,13 @@ class Model(PreTrainedModel,nn.Module):
|
|
433 |
config_class = SSLConfig
|
434 |
def __init__(self,config):
|
435 |
super().__init__(config)
|
436 |
-
|
437 |
# AASIST parameters
|
438 |
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
|
439 |
gat_dims = [64, 32]
|
440 |
pool_ratios = [0.5, 0.5, 0.5, 0.5]
|
441 |
temperatures = [2.0, 2.0, 100.0, 100.0]
|
442 |
-
self.model_device = device
|
443 |
|
444 |
|
445 |
####
|
|
|
433 |
config_class = SSLConfig
|
434 |
def __init__(self,config):
|
435 |
super().__init__(config)
|
436 |
+
self.model_device ='cuda' if torch.cuda.is_available() else 'cpu'
|
437 |
# AASIST parameters
|
438 |
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
|
439 |
gat_dims = [64, 32]
|
440 |
pool_ratios = [0.5, 0.5, 0.5, 0.5]
|
441 |
temperatures = [2.0, 2.0, 100.0, 100.0]
|
442 |
+
# self.model_device = device
|
443 |
|
444 |
|
445 |
####
|