File size: 7,684 Bytes
1a9eacc
 
baa7f8d
33df548
8689fa0
739c614
baa7f8d
0bf9e36
fbd13ac
baa7f8d
0bf9e36
baa7f8d
2af7535
 
 
 
70b543a
a0bb336
 
193b388
 
 
 
 
fbd13ac
9ec01d1
 
 
 
 
 
7f3ef59
9ec01d1
 
 
 
 
 
 
 
2200e20
9ec01d1
 
 
 
 
 
 
 
 
 
 
 
 
739c614
70b543a
739c614
1a9eacc
7f3ef59
739c614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f3ef59
aba1804
baa7f8d
2af7535
f6a3fec
2af7535
 
 
 
baa7f8d
1a9eacc
baa7f8d
1a9eacc
739c614
33df548
739c614
 
 
7f3ef59
739c614
 
 
9ec01d1
 
739c614
 
 
 
 
 
 
 
 
 
 
 
193b388
739c614
33df548
70b543a
f6a3fec
 
 
 
 
 
 
 
 
 
baa7f8d
0641fac
 
33df548
0641fac
 
 
 
baa7f8d
f6a3fec
baa7f8d
137bc98
baa7f8d
33df548
baa7f8d
bddaa7b
b88325e
 
1a9eacc
0641fac
 
 
1a9eacc
 
 
 
739c614
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
 
 
1a9eacc
 
baa7f8d
 
 
1a9eacc
 
 
 
 
 
 
 
 
baa7f8d
 
 
 
 
bddaa7b
 
 
f6a3fec
739c614
 
 
 
f6a3fec
 
1a9eacc
137bc98
1a9eacc
 
2af7535
 
 
 
 
 
 
 
 
1517af1
 
 
fbd13ac
 
 
739c614
8a53814
1a9eacc
 
 
 
 
bddaa7b
f6a3fec
1a9eacc
 
baa7f8d
1a9eacc
c61d5b5
4610d52
bd292ac
1a9eacc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import os

import gradio as gr
import numpy as np
import pandas as pd
import periodictable
import plotly.graph_objs as go
import polars as pl
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
from pymatgen.core import Composition, Element, Structure
from pymatgen.core.composition import Composition
from pymatgen.entries.computed_entries import (
    ComputedStructureEntry,
    GibbsComputedStructureEntry,
)

HF_TOKEN = os.environ.get("HF_TOKEN")

subsets = [
    "compatible_pbe",
    "compatible_pbesol",
    "compatible_scan",
]

# polars_dfs = {
#     subset: pl.read_parquet(
#         "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
#         storage_options={
#             "token": HF_TOKEN,
#         },
#     )
#     for subset in subsets
# }

# # Load only the train split of the dataset

subsets_ds = {}
for subset in subsets:
    dataset = load_dataset(
        "LeMaterial/leMat1",
        subset,
        token=HF_TOKEN,
        columns=[
            "lattice_vectors",
            "species_at_sites",
            "cartesian_site_positions",
            "energy",
            "energy_corrected",
            "immutable_id",
            "elements",
            "functional",
        ],
    )
    subsets_ds[subset] = dataset["train"].to_pandas()

elements_df = {k: subset["elements"] for k, subset in subsets_ds.items()}


all_elements = {str(el): i for i, el in enumerate(periodictable.elements)}
elements_indices = {}
for subset, df in elements_df.items():
    print("Processing subset: ", subset)
    elements_indices[subset] = np.zeros((len(df), len(all_elements)))

    def map_elements(row):
        index, xs = row["index"], row["elements"]
        for x in xs:
            elements_indices[subset][index, all_elements[x]] = 1

    df = df.reset_index().apply(map_elements, axis=1)

map_functional = {
    "PBE": "compatible_pbe",
    "PBESol": "compatible_pbesol",
    "SCAN": "compatible_scan",
}


def create_phase_diagram(
    elements,
    energy_correction,
    plot_style,
    functional,
    finite_temp,
    **kwargs,
):
    # Split elements and remove any whitespace
    element_list = [el.strip() for el in elements.split("-")]

    subset_name = map_functional[functional]

    element_list_vector = np.zeros(len(all_elements))
    for el in element_list:
        element_list_vector[all_elements[el]] = 1

    n_elements = elements_indices[subset_name].sum(axis=1)
    n_elements_query = elements_indices[subset_name][
        :, element_list_vector.astype(bool)
    ]

    if n_elements_query.shape[1] == 0:
        indices_with_only_elements = []
    else:
        indices_with_only_elements = np.where(
            n_elements_query.sum(axis=1) == n_elements
        )[0]

    print(indices_with_only_elements)

    entries_df = subsets_ds[subset_name].loc[indices_with_only_elements]

    entries_df = entries_df[~entries_df["immutable_id"].isna()]

    print(entries_df)

    # Fetch all entries from the Materials Project database
    def get_energy_correction(energy_correction, row):
        if energy_correction == "Database specific, or MP2020":
            return (
                row["energy_corrected"] - row["energy"]
                if not np.isnan(row["energy_corrected"])
                else 0
            )
        elif energy_correction == "The 110 PBE Method":
            return row["energy"] * 1.1

    entries = [
        ComputedStructureEntry(
            Structure(
                [x.tolist() for x in row["lattice_vectors"].tolist()],
                row["species_at_sites"],
                row["cartesian_site_positions"],
                coords_are_cartesian=True,
            ),
            energy=row["energy"],
            correction=get_energy_correction(energy_correction, row),
            entry_id=row["immutable_id"],
            parameters={"run_type": row["functional"]},
        )
        for n, row in entries_df.iterrows()
    ]

    # TODO: Fetch elemental entries (they are usually GGA calculations)
    # entries.extend([e for e in entries if e.composition.is_element])

    if finite_temp:
        entries = GibbsComputedStructureEntry.from_entries(entries)

    # Build the phase diagram
    try:
        phase_diagram = PhaseDiagram(entries)
    except ValueError as e:
        print(e)
        return go.Figure().add_annotation(text=str(e))

    # Generate plotly figure
    if plot_style == "2D":
        plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly")
        fig = plotter.get_plot()
    else:
        # For 3D plots, limit to ternary systems
        if len(element_list) == 3:
            plotter = PDPlotter(
                phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d"
            )
            fig = plotter.get_plot()
        else:
            return go.Figure().add_annotation(
                text="3D plots are only available for ternary systems."
            )

    # Adjust the maximum energy above hull
    # (This is a placeholder as PDPlotter does not support direct filtering)

    # Return the figure
    return fig


# Define Gradio interface components
elements_input = gr.Textbox(
    label="Elements (e.g., 'Li-Fe-O')",
    placeholder="Enter elements separated by '-'",
    value="Li-Fe-O",
)
# max_e_above_hull_slider = gr.Slider(
#     minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)"
# )
energy_correction_dropdown = gr.Dropdown(
    choices=[
        "The 110 PBE Method",
        "Database specific, or MP2020",
    ],
    label="Energy correction",
)
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style")
functional_dropdown = gr.Dropdown(choices=["PBE", "PBESol", "SCAN"], label="Functional")
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation")

warning_message = "This application uses energy correction schemes directly"
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP"
warning_message += " Compatibility scheme applied to OQMD. However, because we did"
warning_message += " not directly apply the compatibility schemes to Alexandria, MP"
warning_message += " we have noticed discrepencies in the data. While the correction"
warning_message += " scheme will be standardized in a soon to be released update, for"
warning_message += " now please take caution when analyzing the results of this"
warning_message += " application."

warning_message += "<br> Additionally, we have provided the 110 PBE correction method"
warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/67252d617be152b1d0b2c1ef/original/a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf' target='_blank'>Rohr et al (2024)</a>.<br>"

message = '<div class="alert"><span class="closebtn" onclick="this.parentElement.style.display="none";">&times;</span>{}</div>Generate a phase diagram for a set of elements using LeMat-Bulk data.'.format(
    warning_message
)
message += "<br>Built with <a href='https://pymatgen.org/' target='_blank'>Pymatgen</a> and  <a href='https://docs.crystaltoolkit.org/' target='_blank'>Crystal Toolkit</a>.<br>"

# Create Gradio interface
iface = gr.Interface(
    fn=create_phase_diagram,
    inputs=[
        elements_input,
        # max_e_above_hull_slider,
        energy_correction_dropdown,
        plot_style_dropdown,
        functional_dropdown,
        finite_temp_toggle,
    ],
    outputs=gr.Plot(label="Phase Diagram"),
    title="LeMaterial - Phase Diagram Viewer",
    description=message,
)

# Launch the app
iface.launch()