Spaces:
Running
Running
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
from ase.db import connect | |
from dask.distributed import Client | |
from dask_jobqueue import SLURMCluster | |
from prefect import flow, task | |
from prefect.runtime import task_run | |
from prefect_dask import DaskTaskRunner | |
from prefect.cache_policies import INPUTS, TASK_SOURCE | |
from mlip_arena.models import REGISTRY, MLIPEnum | |
from mlip_arena.tasks.utils import get_calculator | |
def load_wbm_structures(): | |
""" | |
Load the WBM structures from an ASE database file. | |
Reads structures from 'wbm_structures.db' and yields them as ASE Atoms objects | |
with additional metadata preserved from the database. | |
Yields: | |
ase.Atoms: Individual atomic structures from the WBM database with preserved | |
metadata in the .info dictionary. | |
""" | |
with connect("../wbm_structures.db") as db: | |
for row in db.select(): | |
yield row.toatoms(add_additional_information=True) | |
def ev_scan(atoms, model): | |
""" | |
Perform an energy-volume scan for a given model and atomic structure. | |
This function applies uniaxial strain to the structure in all three dimensions, | |
maintaining the fractional coordinates of atoms, and computes the energy at each | |
deformation point using the specified model. | |
Args: | |
atoms: ASE atoms object containing the structure to analyze. | |
model: MLIPEnum model to use for the energy calculations. | |
Returns: | |
dict: Results dictionary containing: | |
- method (str): The name of the model used | |
- id (str): The WBM ID of the structure | |
- eos (dict): Energy of state data with: | |
- volumes (list): Volume of the unit cell at each strain point | |
- energies (list): Computed potential energy at each strain point | |
Note: | |
The strain range is fixed at ±20% with 21 evenly spaced points. | |
Results are also saved as a JSON file in a directory named after the model. | |
""" | |
calculator = get_calculator( | |
model | |
) # avoid sending entire model over prefect and select freer GPU | |
wbm_id = atoms.info["key_value_pairs"]["wbm_id"] | |
c0 = atoms.get_cell() | |
max_abs_strain = 0.2 | |
npoints = 21 | |
volumes = [] | |
energies = [] | |
for uniaxial_strain in np.linspace(-max_abs_strain, max_abs_strain, npoints): | |
cloned = atoms.copy() | |
scale_factor = uniaxial_strain + 1 | |
cloned.set_cell(c0 * scale_factor, scale_atoms=True) | |
cloned.calc = calculator | |
volumes.append(cloned.get_volume()) | |
energies.append(cloned.get_potential_energy()) | |
data = { | |
"method": model.name, | |
"id": wbm_id, | |
"eos": { | |
"volumes": volumes, "energies": energies | |
} | |
} | |
fpath = Path(f"{model.name}") / f"{wbm_id}.json" | |
fpath.parent.mkdir(exist_ok=True) | |
df = pd.DataFrame([data]) | |
df.to_json(fpath) | |
return df | |
def submit_tasks(): | |
""" | |
Create and submit energy-volume scan tasks for subsampled WBM structures and applicable models. | |
This flow function: | |
1. Loads all structures from the WBM database | |
2. Iterates through available models in MLIPEnum | |
3. Filters models based on their capability to handle the 'wbm_ev' GPU task | |
4. Submits parallel ev_scan tasks for all valid (structure, model) combinations | |
5. Collects and returns results from all tasks | |
Returns: | |
list: Results from all executed tasks (successful or failed) | |
""" | |
futures = [] | |
for atoms in load_wbm_structures(): | |
for model in MLIPEnum: | |
if "wbm_ev" not in REGISTRY[model.name].get("gpu-tasks", []): | |
continue | |
try: | |
result = ev_scan.submit(atoms, model) | |
except Exception as e: | |
print(f"Failed to submit task for {model.name}: {e}") | |
continue | |
futures.append(result) | |
return [f.result(raise_on_failure=False) for f in futures] | |
if __name__ == "__main__": | |
nodes_per_alloc = 1 | |
gpus_per_alloc = 1 | |
ntasks = 1 | |
cluster_kwargs = dict( | |
cores=1, | |
memory="64 GB", | |
processes=1, | |
shebang="#!/bin/bash", | |
account="m3828", | |
walltime="00:30:00", | |
# job_mem="0", | |
job_script_prologue=[ | |
"source ~/.bashrc", | |
"module load python", | |
"source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena", | |
], | |
job_directives_skip=["-n", "--cpus-per-task", "-J"], | |
job_extra_directives=[ | |
"-J wbm_ev", | |
"-q debug", | |
f"-N {nodes_per_alloc}", | |
"-C gpu", | |
f"-G {gpus_per_alloc}", | |
"--exclusive", | |
], | |
) | |
cluster = SLURMCluster(**cluster_kwargs) | |
print(cluster.job_script()) | |
cluster.adapt(minimum_jobs=2, maximum_jobs=2) | |
client = Client(cluster) | |
submit_tasks.with_options( | |
task_runner=DaskTaskRunner(address=client.scheduler.address), | |
log_prints=True, | |
)() | |