Update utils.py
Browse files
utils.py
CHANGED
@@ -277,7 +277,8 @@ class MolecularGenerationModel():
|
|
277 |
del batch['property_names']
|
278 |
del batch['non_normalized_properties']
|
279 |
|
280 |
-
batch =
|
|
|
281 |
|
282 |
input_length = batch['input_ids'].shape[1]
|
283 |
steps = 1024 - input_length
|
|
|
277 |
del batch['property_names']
|
278 |
del batch['non_normalized_properties']
|
279 |
|
280 |
+
batch['input_ids'] = batch['input_ids'].to(self.model.device)
|
281 |
+
#batch = {k: v.to(self.model.device) for k, v in batch.items()}
|
282 |
|
283 |
input_length = batch['input_ids'].shape[1]
|
284 |
steps = 1024 - input_length
|