cyrusyc commited on
Commit
8cb1d3b
·
1 Parent(s): a909029

propagate device setting to calculator

Browse files
Files changed (1) hide show
  1. mlip_arena/tasks/utils.py +4 -2
mlip_arena/tasks/utils.py CHANGED
@@ -28,11 +28,13 @@ def get_calculator(
28
  device: str | None = None,
29
  ) -> Calculator | SumCalculator:
30
  """Get a calculator with optional dispersion correction."""
31
- device = device or str(get_freer_device())
32
 
33
- logger.info(f"Using device: {device}")
34
 
35
  calculator_kwargs = calculator_kwargs or {}
 
 
 
36
 
37
  if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
38
  calc = calculator_name.value(**calculator_kwargs)
 
28
  device: str | None = None,
29
  ) -> Calculator | SumCalculator:
30
  """Get a calculator with optional dispersion correction."""
 
31
 
32
+ device = device or str(get_freer_device())
33
 
34
  calculator_kwargs = calculator_kwargs or {}
35
+ calculator_kwargs.update({"device": device})
36
+
37
+ logger.info(f"Using device: {device}")
38
 
39
  if isinstance(calculator_name, MLIPEnum) and calculator_name in MLIPEnum:
40
  calc = calculator_name.value(**calculator_kwargs)