Yuan (Cyrus) Chiang commited on
Commit
00b56e2
·
unverified ·
1 Parent(s): e59bc30

convert `ResultRecord` to `result` if fetched from cache (#39)

Browse files

* convert `ResultRecord` to `result` if fetched from cache

* format

* hotfix chgnet

* move warnings filter before chgnet imports

* use simplefilter instead

* replace `raise_deadline_warning` as hotfix

* exclude chgnet for now

* install pymatgen from repo

* disable setup-python pip cache

* cleanup

.github/workflows/test.yaml CHANGED
@@ -16,7 +16,7 @@ jobs:
16
  strategy:
17
  # max-parallel: 2
18
  matrix:
19
- python-version: ["3.10", "3.11"]
20
 
21
 
22
  steps:
@@ -33,8 +33,6 @@ jobs:
33
  uses: actions/setup-python@v5
34
  with:
35
  python-version: ${{ matrix.python-version }}
36
- python-version-file: "pyproject.toml"
37
- cache: 'pip'
38
 
39
  - name: Install dependencies
40
  run: |
 
16
  strategy:
17
  # max-parallel: 2
18
  matrix:
19
+ python-version: ["3.10", "3.11", "3.12"]
20
 
21
 
22
  steps:
 
33
  uses: actions/setup-python@v5
34
  with:
35
  python-version: ${{ matrix.python-version }}
 
 
36
 
37
  - name: Install dependencies
38
  run: |
mlip_arena/tasks/__init__.py CHANGED
@@ -6,10 +6,6 @@ from huggingface_hub import HfApi, HfFileSystem, hf_hub_download
6
  from mlip_arena.models import MLIP
7
  from mlip_arena.models import REGISTRY as MODEL_REGISTRY
8
 
9
- # from .run import md as MD
10
-
11
- # __all__ = ["MD"]
12
-
13
  with open(Path(__file__).parent / "registry.yaml") as f:
14
  REGISTRY = yaml.safe_load(f)
15
 
 
6
  from mlip_arena.models import MLIP
7
  from mlip_arena.models import REGISTRY as MODEL_REGISTRY
8
 
 
 
 
 
9
  with open(Path(__file__).parent / "registry.yaml") as f:
10
  REGISTRY = yaml.safe_load(f)
11
 
mlip_arena/tasks/eos.py CHANGED
@@ -12,6 +12,7 @@ import numpy as np
12
  from prefect import task
13
  from prefect.cache_policies import INPUTS, TASK_SOURCE
14
  from prefect.futures import wait
 
15
  from prefect.runtime import task_run
16
  from prefect.states import State
17
 
@@ -38,10 +39,7 @@ def _generate_task_run_name():
38
 
39
 
40
  @task(
41
- name="EOS",
42
- task_run_name=_generate_task_run_name,
43
- cache_policy=TASK_SOURCE + INPUTS
44
- # cache_key_fn=task_input_hash,
45
  )
46
  def run(
47
  atoms: Atoms,
@@ -75,6 +73,7 @@ def run(
75
  max_abs_strain: The maximum absolute strain to use.
76
  npoints: The number of points to sample.
77
  concurrent: Whether to relax multiple structures concurrently.
 
78
  cache_opt: Whether to cache the intermediate optimization results.
79
 
80
  Returns:
@@ -83,7 +82,7 @@ def run(
83
 
84
  OPT_ = OPT.with_options(
85
  refresh_cache=not cache_opt,
86
- persist_result=persist_opt,
87
  )
88
 
89
  state = OPT_(
@@ -101,13 +100,13 @@ def run(
101
 
102
  if state.is_failed():
103
  return state
104
- elif state.is_completed() and state.name in ["Completed", "Cached"]:
105
- first_relax = state.result(raise_on_failure=False)
106
- elif state.is_completed() and state.name in ["Rollback"]:
107
- first_relax = state.result(raise_on_failure=False)
108
-
109
- assert isinstance(first_relax, dict)
110
- relaxed = first_relax["atoms"]
111
 
112
  # p0 = relaxed.get_positions()
113
  c0 = relaxed.get_cell()
@@ -159,12 +158,13 @@ def run(
159
  return_state=True,
160
  )
161
  states.append(state)
162
- results = [
163
- s.result(raise_on_failure=False) for s in states if s.is_completed()
164
- ]
165
 
166
- volumes = [f["atoms"].get_volume() for f in results]
167
- energies = [f["atoms"].get_potential_energy() for f in results]
 
 
 
 
168
 
169
  volumes, energies = map(
170
  list,
 
12
  from prefect import task
13
  from prefect.cache_policies import INPUTS, TASK_SOURCE
14
  from prefect.futures import wait
15
+ from prefect.results import ResultRecord
16
  from prefect.runtime import task_run
17
  from prefect.states import State
18
 
 
39
 
40
 
41
  @task(
42
+ name="EOS", task_run_name=_generate_task_run_name, cache_policy=TASK_SOURCE + INPUTS
 
 
 
43
  )
44
  def run(
45
  atoms: Atoms,
 
73
  max_abs_strain: The maximum absolute strain to use.
74
  npoints: The number of points to sample.
75
  concurrent: Whether to relax multiple structures concurrently.
76
+ persist_opt: Whether to persist the optimization results.
77
  cache_opt: Whether to cache the intermediate optimization results.
78
 
79
  Returns:
 
82
 
83
  OPT_ = OPT.with_options(
84
  refresh_cache=not cache_opt,
85
+ persist_result=persist_opt,
86
  )
87
 
88
  state = OPT_(
 
100
 
101
  if state.is_failed():
102
  return state
103
+
104
+ first_relax = state.result(raise_on_failure=False)
105
+
106
+ if isinstance(first_relax, ResultRecord):
107
+ relaxed = first_relax.result["atoms"]
108
+ else:
109
+ relaxed = first_relax["atoms"]
110
 
111
  # p0 = relaxed.get_positions()
112
  c0 = relaxed.get_cell()
 
158
  return_state=True,
159
  )
160
  states.append(state)
 
 
 
161
 
162
+ results = [s.result(raise_on_failure=False) for s in states if s.is_completed()]
163
+
164
+ results = [r.result if isinstance(r, ResultRecord) else r for r in results]
165
+
166
+ volumes = [r["atoms"].get_volume() for r in results]
167
+ energies = [r["atoms"].get_potential_energy() for r in results]
168
 
169
  volumes, energies = map(
170
  list,
pyproject.toml CHANGED
@@ -66,6 +66,7 @@ test = [
66
  "pytest",
67
  "pytest-xdist",
68
  "prefect>=3.0.4,<3.1.0",
 
69
  ]
70
  mace = [
71
  "mace-torch==0.3.4",
 
66
  "pytest",
67
  "pytest-xdist",
68
  "prefect>=3.0.4,<3.1.0",
69
+ "pymatgen@git+https://github.com/chiang-yuan/pymatgen" # hotfix for deprecated pymatgen property https://github.com/materialsproject/pymatgen/pull/4242
70
  ]
71
  mace = [
72
  "mace-torch==0.3.4",