msiron commited on
Commit
33df548
·
1 Parent(s): fa131a7

faster filtering

Browse files
Files changed (1) hide show
  1. app.py +20 -19
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
 
3
  import gradio as gr
 
4
  import plotly.graph_objs as go
5
  from datasets import load_dataset
6
  from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
@@ -39,41 +40,41 @@ def create_phase_diagram(
39
  # Split elements and remove any whitespace
40
  element_list = [el.strip() for el in elements.split("-")]
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  # Fetch all entries from the Materials Project database
43
  entries = [
44
  ComputedStructureEntry(
45
  Structure(
46
- [x.tolist() for x in train_df.iloc[0]["lattice_vectors"].tolist()],
47
  row["species_at_sites"],
48
  row["cartesian_site_positions"],
49
  coords_are_cartesian=True,
50
  ),
51
  energy=row["energy"],
52
- correction=row["energy_corrected"] - row["energy"],
 
 
53
  entry_id=row["immutable_id"],
54
  parameters={"run_type": row["functional"]},
55
  )
56
- for n, row in train_df.iterrows()
57
- if len(set(row["elements"]).intersection(element_list)) > 0
58
- and set(row["elements"]).issubset(element_list)
59
  ]
60
  # Fetch elemental entries (they are usually GGA calculations)
61
  elemental_entries = [e for e in entries if e.composition.is_element]
62
 
63
- # Filter entries based on functional
64
- if functional == "PBE":
65
- entries = [e for e in entries if e.parameters.get("run_type", "") == "pbe"]
66
- entries.extend([e for e in elemental_entries if e not in entries])
67
- elif functional == "PBESol":
68
- entries = [e for e in entries if e.parameters.get("run_type", "") == "pbesol"]
69
- # Add elemental entries to ensure they are included
70
- entries.extend([e for e in elemental_entries if e not in entries])
71
-
72
- elif functional == "SCAN":
73
- entries = [e for e in entries if e.parameters.get("run_type", "") == "scan"]
74
- # Add elemental entries to ensure they are included
75
- entries.extend([e for e in elemental_entries if e not in entries])
76
-
77
  if finite_temp:
78
  entries = GibbsComputedStructureEntry.from_entries(entries)
79
 
 
1
  import os
2
 
3
  import gradio as gr
4
+ import numpy as np
5
  import plotly.graph_objs as go
6
  from datasets import load_dataset
7
  from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
 
40
  # Split elements and remove any whitespace
41
  element_list = [el.strip() for el in elements.split("-")]
42
 
43
+ # Filter entries based on functional
44
+ if functional == "PBE":
45
+ entries_df = entries_df[train_df["functional"] == "pbe"]
46
+ elif functional == "PBESol":
47
+ entries_df = entries_df[train_df["functional"] == "pbe"]
48
+ elif functional == "SCAN":
49
+ entries_df = entries_df[train_df["functional"] == "pbe"]
50
+
51
+ isubset = lambda x: set(x).issubset(element_list)
52
+ isintersection = lambda x: len(set(x).intersection(element_list)) > 0
53
+ entries_df = entries_df[
54
+ [isintersection(l) and isubset(l) for l in entries_df.elements.values.tolist()]
55
+ ]
56
+
57
  # Fetch all entries from the Materials Project database
58
  entries = [
59
  ComputedStructureEntry(
60
  Structure(
61
+ [x.tolist() for x in row["lattice_vectors"].tolist()],
62
  row["species_at_sites"],
63
  row["cartesian_site_positions"],
64
  coords_are_cartesian=True,
65
  ),
66
  energy=row["energy"],
67
+ correction=row["energy_corrected"] - row["energy"]
68
+ if not np.isnan(row["energy_corrected"])
69
+ else 0,
70
  entry_id=row["immutable_id"],
71
  parameters={"run_type": row["functional"]},
72
  )
73
+ for n, row in entries_df.iterrows()
 
 
74
  ]
75
  # Fetch elemental entries (they are usually GGA calculations)
76
  elemental_entries = [e for e in entries if e.composition.is_element]
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if finite_temp:
79
  entries = GibbsComputedStructureEntry.from_entries(entries)
80