Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
f97fe8b
1
Parent(s):
49021fb
update device
Browse files
src/linker_size_lightning.py
CHANGED
@@ -38,8 +38,8 @@ class SizeClassifier(pl.LightningModule):
|
|
38 |
self.linker_id2size = linker_id2size
|
39 |
self.batch_size = batch_size
|
40 |
self.lr = lr
|
41 |
-
self.torch_device =
|
42 |
-
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=torch_device)
|
43 |
self.gnn = SizeGNN(
|
44 |
in_node_nf=in_node_nf,
|
45 |
hidden_nf=hidden_nf,
|
|
|
38 |
self.linker_id2size = linker_id2size
|
39 |
self.batch_size = batch_size
|
40 |
self.lr = lr
|
41 |
+
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
|
43 |
self.gnn = SizeGNN(
|
44 |
in_node_nf=in_node_nf,
|
45 |
hidden_nf=hidden_nf,
|