Spaces:
Running
Running
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
import plotly.colors as pcolors | |
import plotly.express as px | |
import plotly.graph_objects as go | |
import streamlit as st | |
from mlip_arena.models import REGISTRY | |
DATA_DIR = Path("mlip_arena/tasks/stability") | |
st.markdown("# Stability") | |
st.markdown("### Methods") | |
container = st.container(border=True) | |
models = container.multiselect("MLIPs", REGISTRY.keys(), ['MACE-MP(M)', "CHGNet", "EquiformerV2(OC22)"]) | |
st.markdown("### Settings") | |
vis = st.container(border=True) | |
# Get all attributes from pcolors.qualitative | |
all_attributes = dir(pcolors.qualitative) | |
color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)} | |
color_palettes.pop("__all__", None) | |
palette_names = list(color_palettes.keys()) | |
palette_colors = list(color_palettes.values()) | |
palette_name = vis.selectbox( | |
"Color sequence", | |
options=palette_names, index=22 | |
) | |
color_sequence = color_palettes[palette_name] | |
if not models: | |
st.stop() | |
families = [REGISTRY[str(model)]['family'] for model in models] | |
dfs = [pd.read_json(DATA_DIR / family.lower() / "chloride-salts.json") for family in families] | |
df = pd.concat(dfs, ignore_index=True) | |
df.drop_duplicates(inplace=True, subset=["material_id", "formula", "method"]) | |
method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())} | |
# fig = px.scatter(df, x="natoms", y="seconds_per_step", trendline="ols", trendline_options=dict(log_y=True), log_y=True) | |
fig = px.scatter( | |
df, x="natoms", y="steps_per_second", | |
color="method", | |
color_discrete_map=method_color_mapping, | |
trendline="ols", trendline_options=dict(log_x=True), log_x=True | |
) | |
event = st.plotly_chart( | |
fig, | |
key="stability", | |
on_select="rerun" | |
) | |
event | |