Spaces:
Running
Running
File size: 3,157 Bytes
75ac94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import functools
from pathlib import Path
import pandas as pd
from ase.db import connect
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect import Task, flow, task
from prefect.client.schemas.objects import TaskRun
from prefect.states import State
from prefect_dask import DaskTaskRunner
from mlip_arena.models import REGISTRY, MLIPEnum
from mlip_arena.tasks.eos import run as EOS
from mlip_arena.tasks.optimize import run as OPT
from mlip_arena.tasks.utils import get_calculator
@task
def load_wbm_structures():
"""
Load the WBM structures from a ASE DB file.
"""
with connect("../wbm_structures.db") as db:
for row in db.select():
yield row.toatoms(add_additional_information=True)
def save_result(
tsk: Task,
run: TaskRun,
state: State,
model_name: str,
id: str,
):
result = run.state.result()
assert isinstance(result, dict)
result["method"] = model_name
result["id"] = id
result.pop("atoms", None)
fpath = Path(f"{model_name}")
fpath.mkdir(exist_ok=True)
fpath = fpath / f"{result['id']}.pkl"
df = pd.DataFrame([result])
df.to_pickle(fpath)
@task
def eos_bulk(atoms, model):
calculator = get_calculator(
model
) # avoid sending entire model over prefect and select freer GPU
result = OPT.with_options(
refresh_cache=True,
)(
atoms,
calculator,
optimizer="FIRE",
criterion=dict(
fmax=0.1,
),
)
return EOS.with_options(
refresh_cache=True,
on_completion=[functools.partial(
save_result,
model_name=model.name,
id=atoms.info["key_value_pairs"]["wbm_id"],
)],
)(
atoms=result["atoms"],
calculator=calculator,
optimizer="FIRE",
npoints=21,
max_abs_strain=0.2,
concurrent=False
)
@flow
def run_all():
futures = []
for atoms in load_wbm_structures():
for model in MLIPEnum:
if "eos_bulk" not in REGISTRY[model.name].get("gpu-tasks", []):
continue
result = eos_bulk.submit(atoms, model)
futures.append(result)
return [f.result(raise_on_failure=False) for f in futures]
nodes_per_alloc = 1
gpus_per_alloc = 1
ntasks = 1
cluster_kwargs = dict(
cores=4,
memory="64 GB",
shebang="#!/bin/bash",
account="m3828",
walltime="00:50:00",
job_mem="0",
job_script_prologue=[
"source ~/.bashrc",
"module load python",
"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
],
job_directives_skip=["-n", "--cpus-per-task", "-J"],
job_extra_directives=[
"-J eos_bulk",
"-q regular",
f"-N {nodes_per_alloc}",
"-C gpu",
f"-G {gpus_per_alloc}",
"--exclusive",
],
)
cluster = SLURMCluster(**cluster_kwargs)
print(cluster.job_script())
cluster.adapt(minimum_jobs=20, maximum_jobs=40)
client = Client(cluster)
run_all.with_options(
task_runner=DaskTaskRunner(address=client.scheduler.address),
log_prints=True,
)()
|