File size: 1,873 Bytes
1fe249a
 
 
 
 
e517f23
1fe249a
 
 
e517f23
 
 
 
1fe249a
 
 
 
 
e517f23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fe249a
 
 
e517f23
 
 
 
 
 
 
1fe249a
 
e517f23
 
 
 
 
1fe249a
e517f23
1fe249a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from pathlib import Path

import numpy as np
import pandas as pd
import plotly.colors as pcolors
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st

from mlip_arena.models import REGISTRY

DATA_DIR = Path("mlip_arena/tasks/stability")


st.markdown("# Stability")

st.markdown("### Methods")
container = st.container(border=True)
models = container.multiselect("MLIPs", REGISTRY.keys(), ['MACE-MP(M)', "CHGNet", "EquiformerV2(OC22)"])

st.markdown("### Settings")
vis = st.container(border=True)
# Get all attributes from pcolors.qualitative
all_attributes = dir(pcolors.qualitative)
color_palettes = {attr: getattr(pcolors.qualitative, attr) for attr in all_attributes if isinstance(getattr(pcolors.qualitative, attr), list)}
color_palettes.pop("__all__", None)

palette_names = list(color_palettes.keys())
palette_colors = list(color_palettes.values())

palette_name = vis.selectbox(
    "Color sequence",
    options=palette_names, index=22
)

color_sequence = color_palettes[palette_name]

if not models:
    st.stop()

families = [REGISTRY[str(model)]['family'] for model in models]

dfs = [pd.read_json(DATA_DIR / family.lower() /  "chloride-salts.json") for family in families]
df = pd.concat(dfs, ignore_index=True)
df.drop_duplicates(inplace=True, subset=["material_id", "formula", "method"])

method_color_mapping = {method: color_sequence[i % len(color_sequence)] for i, method in enumerate(df["method"].unique())}



# fig = px.scatter(df, x="natoms", y="seconds_per_step", trendline="ols", trendline_options=dict(log_y=True), log_y=True)
fig = px.scatter(
    df, x="natoms", y="steps_per_second", 
    color="method",
    color_discrete_map=method_color_mapping,
    trendline="ols", trendline_options=dict(log_x=True), log_x=True
)


event = st.plotly_chart(
    fig, 
    key="stability", 
    on_select="rerun"
)

event