Spaces:
Running
Running
File size: 1,873 Bytes
1fe249a e517f23 1fe249a e517f23 1fe249a e517f23 1fe249a e517f23 1fe249a e517f23 1fe249a e517f23 1fe249a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.colors as pcolors
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from mlip_arena.models import REGISTRY
DATA_DIR = Path("mlip_arena/tasks/stability")
st.markdown("# Stability")
st.markdown("### Methods")
container = st.container(border=True)
models = container.multiselect("MLIPs", REGISTRY.keys(), ['MACE-MP(M)', "CHGNet", "EquiformerV2(OC22)"])
st.markdown("### Settings")
vis = st.container(border=True)
# Get all attributes from pcolors.qualitative
all_attributes = dir(pcolors.qualitative)
color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)}
color_palettes.pop("__all__", None)
palette_names = list(color_palettes.keys())
palette_colors = list(color_palettes.values())
palette_name = vis.selectbox(
"Color sequence",
options=palette_names, index=22
)
color_sequence = color_palettes[palette_name]
if not models:
st.stop()
families = [REGISTRY[str(model)]['family'] for model in models]
dfs = [pd.read_json(DATA_DIR / family.lower() / "chloride-salts.json") for family in families]
df = pd.concat(dfs, ignore_index=True)
df.drop_duplicates(inplace=True, subset=["material_id", "formula", "method"])
method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}
# fig = px.scatter(df, x="natoms", y="seconds_per_step", trendline="ols", trendline_options=dict(log_y=True), log_y=True)
fig = px.scatter(
df, x="natoms", y="steps_per_second",
color="method",
color_discrete_map=method_color_mapping,
trendline="ols", trendline_options=dict(log_x=True), log_x=True
)
event = st.plotly_chart(
fig,
key="stability",
on_select="rerun"
)
event
|