from pathlib import Path import numpy as np import pandas as pd import plotly.colors as pcolors import plotly.graph_objects as go import streamlit as st from ase.db import connect from scipy import stats from mlip_arena.models import REGISTRY as MODELS DATA_DIR = Path("benchmarks/wbm_ev") st.markdown(""" # Energy-volume scans """) # Control panels at the top st.markdown("### Methods") methods_container = st.container(border=True) # Get valid models that support wbm_ev valid_models = [ model for model, metadata in MODELS.items() if Path(__file__).stem in metadata.get("gpu-tasks", []) ] # Model selection selected_models = methods_container.multiselect( "Select Models", options=valid_models, default=valid_models ) # Visualization settings st.markdown("### Visualization Settings") vis = st.container(border=True) # Column settings ncols = vis.select_slider("Number of columns", options=[1, 2, 3, 4], value=2) # Color palette selection all_attributes = dir(pcolors.qualitative) color_palettes = { attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list) } color_palettes.pop("__all__", None) palette_names = list(color_palettes.keys()) palette_name = vis.selectbox("Color sequence", options=palette_names, index=22) color_sequence = color_palettes[palette_name] # Stop execution if no models selected if not selected_models: st.warning("Please select at least one model to visualize.") st.stop() def load_wbm_structures(): """ Load the WBM structures from a ASE DB file. """ with connect(DATA_DIR.parent / "wbm_structures.db") as db: for row in db.select(): yield row.toatoms(add_additional_information=True) @st.cache_data def generate_dataframe(model_name): fpath = DATA_DIR / f"{model_name}.parquet" if not fpath.exists(): return pd.DataFrame() # Return empty dataframe instead of using continue df_raw_results = pd.read_parquet(fpath) df_analyzed = pd.DataFrame( columns=[ "model", "structure", "formula", "volume-ratio", "energy-delta-per-atom", "energy-diff-flip-times", "tortuosity", "spearman-compression-energy", "spearman-compression-derivative", "spearman-tension-energy", "missing", ] ) for wbm_struct in load_wbm_structures(): structure_id = wbm_struct.info["key_value_pairs"]["wbm_id"] try: results = df_raw_results.loc[df_raw_results["id"] == structure_id] results = results["eos"].values[0] es = np.array(results["energies"]) vols = np.array(results["volumes"]) vol0 = wbm_struct.get_volume() indices = np.argsort(vols) vols = vols[indices] es = es[indices] imine = len(es) // 2 # min_center_val = np.min(es[imid - 1 : imid + 2]) # imine = np.where(es == min_center_val)[0][0] emin = es[imine] interpolated_volumes = [ (vols[i] + vols[i + 1]) / 2 for i in range(len(vols) - 1) ] ediff = np.diff(es) ediff_sign = np.sign(ediff) mask = ediff_sign != 0 ediff = ediff[mask] ediff_sign = ediff_sign[mask] ediff_flip = np.diff(ediff_sign) != 0 etv = np.sum(np.abs(np.diff(es))) data = { "model": model_name, "structure": structure_id, "formula": wbm_struct.get_chemical_formula(), "missing": False, "volume-ratio": vols / vol0, "energy-delta-per-atom": (es - emin) / len(wbm_struct), "energy-diff-flip-times": np.sum(ediff_flip).astype(int), "tortuosity": etv / (abs(es[0] - emin) + abs(es[-1] - emin)), "spearman-compression-energy": stats.spearmanr( vols[:imine], es[:imine] ).statistic, "spearman-compression-derivative": stats.spearmanr( interpolated_volumes[:imine], ediff[:imine] ).statistic, "spearman-tension-energy": stats.spearmanr( vols[imine:], es[imine:] ).statistic, } except Exception: data = { "model": model_name, "structure": structure_id, "formula": wbm_struct.get_chemical_formula(), "missing": True, "volume-ratio": None, "energy-delta-per-atom": None, "energy-diff-flip-times": None, "tortuosity": None, "spearman-compression-energy": None, "spearman-compression-derivative": None, "spearman-tension-energy": None, } df_analyzed = pd.concat([df_analyzed, pd.DataFrame([data])], ignore_index=True) return df_analyzed @st.cache_data def get_plots(selected_models): """Generate one plot per model with all structures (legend disabled for each structure).""" figs = [] for model_name in selected_models: fpath = DATA_DIR / f"{model_name}_processed.parquet" if not fpath.exists(): df = generate_dataframe(model_name) else: df = pd.read_parquet(fpath) if len(df) == 0: continue fig = go.Figure() valid_structures = [] for i, (_, row) in enumerate(df.iterrows()): structure_id = row["structure"] formula = row.get("formula", "") if isinstance(row["volume-ratio"], list | np.ndarray) and isinstance( row["energy-delta-per-atom"], list | np.ndarray ): vol_strain = row["volume-ratio"] energy_delta = row["energy-delta-per-atom"] color = color_sequence[i % len(color_sequence)] fig.add_trace( go.Scatter( x=vol_strain, y=energy_delta, mode="lines", name=f"{structure_id}", showlegend=False, line=dict(color=color), hoverlabel=dict(bgcolor=color, font=dict(color="black")), hovertemplate=( structure_id + "
" "Formula: " + str(formula) + "
" "Volume ratio V/V₀: %{x:.3f}
" "ΔEnergy: %{y:.3f} eV/atom
" "" ), ) ) valid_structures.append(structure_id) # if valid_structures: fig.update_layout( title=f"{model_name} ({len(valid_structures)} / {len(df)} structures)", xaxis_title="Volume ratio V/V₀", yaxis_title="Relative energy E - E₀ (eV/atom)", height=500, showlegend=False, # Disable legend for the whole plot yaxis=dict(range=[-1, 15]), # Set y-axis limits ) fig.add_vline(x=1, line_dash="dash", line_color="gray", opacity=0.7) figs.append((model_name, fig, valid_structures)) return figs # Generate all plots all_plots = get_plots(selected_models) # Display plots in the specified column layout if all_plots: for i, (model_name, fig, structures) in enumerate(all_plots): if i % ncols == 0: cols = st.columns(ncols) cols[i % ncols].plotly_chart(fig, use_container_width=True) # Display number of structures in this plot # cols[i % ncols].caption(f"{len(structures)} / 1000 structures") else: st.warning("No data available for the selected models.")