msiron commited on
Commit
7f3ef59
·
1 Parent(s): 193b388

polars test

Browse files
Files changed (1) hide show
  1. app.py +64 -28
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
 
3
  import gradio as gr
4
  import numpy as np
@@ -6,7 +7,7 @@ import pandas as pd
6
  import plotly.graph_objs as go
7
  from datasets import concatenate_datasets, load_dataset
8
  from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
9
- from pymatgen.core import Composition, Structure
10
  from pymatgen.core.composition import Composition
11
  from pymatgen.entries.computed_entries import (
12
  ComputedStructureEntry,
@@ -21,26 +22,36 @@ subsets = [
21
  "compatible_scan",
22
  ]
23
 
 
 
 
 
 
 
 
 
 
 
24
  # Load only the train split of the dataset
25
 
26
- datasets = []
27
- for subset in subsets:
28
- dataset = load_dataset(
29
- "LeMaterial/leMat-Bulk",
30
- subset,
31
- token=HF_TOKEN,
32
- columns=[
33
- "lattice_vectors",
34
- "species_at_sites",
35
- "cartesian_site_positions",
36
- "energy",
37
- "energy_corrected",
38
- "immutable_id",
39
- "elements",
40
- "functional",
41
- ],
42
- )
43
- datasets.append(dataset["train"])
44
 
45
  # Convert the train split to a pandas DataFrame
46
  # df = pd.concat([x.to_pandas() for x in datasets])
@@ -49,6 +60,21 @@ for subset in subsets:
49
 
50
  dataset = concatenate_datasets(datasets)
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def create_phase_diagram(
54
  elements,
@@ -64,23 +90,33 @@ def create_phase_diagram(
64
 
65
  # Filter entries based on functional
66
  if functional == "PBE":
67
- ds_filter = dataset.filter(lambda example: example["functional"] == "pbe")
68
  # entries_df = train_df[train_df["functional"] == "pbe"]
69
  elif functional == "PBESol":
70
- ds_filter = dataset.filter(lambda example: example["functional"] == "pbesol")
71
  # entries_df = train_df[train_df["functional"] == "pbesol"]
72
  elif functional == "SCAN":
73
- ds_filter = dataset.filter(lambda example: example["functional"] == "scan")
74
  # entries_df = train_df[train_df["functional"] == "scan"]
75
 
76
- isubset = lambda x: set(x).issubset(element_list)
77
- isintersection = lambda x: len(set(x).intersection(element_list)) > 0
78
- ds_filter = ds_filter.filter(
79
- lambda example: isintersection(example["functional"])
80
- and isubset(example["functional"])
 
 
 
 
 
 
 
 
 
 
81
  )
82
 
83
- entries_df = ds_filter.to_pandas()
84
 
85
  # Fetch all entries from the Materials Project database
86
  entries = [
 
1
  import os
2
+ import polars as pl
3
 
4
  import gradio as gr
5
  import numpy as np
 
7
  import plotly.graph_objs as go
8
  from datasets import concatenate_datasets, load_dataset
9
  from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram
10
+ from pymatgen.core import Composition, Structure, Element
11
  from pymatgen.core.composition import Composition
12
  from pymatgen.entries.computed_entries import (
13
  ComputedStructureEntry,
 
22
  "compatible_scan",
23
  ]
24
 
25
+ polars_dfs = {
26
+ subset: pl.read_parquet(
27
+ "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset),
28
+ storage_options={
29
+ "token": HF_TOKEN,
30
+ },
31
+ )
32
+ for subset in subsets
33
+ }
34
+
35
  # Load only the train split of the dataset
36
 
37
+ # datasets = []
38
+ # for subset in subsets:
39
+ # dataset = load_dataset(
40
+ # "LeMaterial/leMat-Bulk",
41
+ # subset,
42
+ # token=HF_TOKEN,
43
+ # columns=[
44
+ # "lattice_vectors",
45
+ # "species_at_sites",
46
+ # "cartesian_site_positions",
47
+ # "energy",
48
+ # "energy_corrected",
49
+ # "immutable_id",
50
+ # "elements",
51
+ # "functional",
52
+ # ],
53
+ # )
54
+ # datasets.append(dataset["train"])
55
 
56
  # Convert the train split to a pandas DataFrame
57
  # df = pd.concat([x.to_pandas() for x in datasets])
 
60
 
61
  dataset = concatenate_datasets(datasets)
62
 
63
+ # dataset_element_combination_dict = {}
64
+
65
+ # isubset = lambda x: set(x).issubset(element_list)
66
+ # isintersection = lambda x: len(set(x).intersection(element_list)) > 0
67
+ # for element_1 in Element:
68
+ # for element_2 in Element:
69
+ # for element_3 in Element:
70
+ # if element_1 != element_2 and element_2 != element_3 and element_3 != element_1:
71
+ # print("processing {},{},{}".format(*element_list))
72
+ # element_list = [element_1.name, element_2.name, element_3.name]
73
+ # dataset_element_combination_dict(sorted(tuple(element_list))) = dataset.filter(
74
+ # lambda example: isintersection(example["elements"])
75
+ # and isubset(example["elements"])
76
+ # )
77
+
78
 
79
  def create_phase_diagram(
80
  elements,
 
90
 
91
  # Filter entries based on functional
92
  if functional == "PBE":
93
+ df = polars_dfs["compatible_pbe"].clone()
94
  # entries_df = train_df[train_df["functional"] == "pbe"]
95
  elif functional == "PBESol":
96
+ df = polars_dfs["compatible_pbesol"].clone()
97
  # entries_df = train_df[train_df["functional"] == "pbesol"]
98
  elif functional == "SCAN":
99
+ df = polars_dfs["compatible_scan"].clone()
100
  # entries_df = train_df[train_df["functional"] == "scan"]
101
 
102
+ # entries_df = df.to_pandas()
103
+
104
+ # isubset = lambda x: set(x).issubset(element_list)
105
+ # isintersection = lambda x: len(set(x).intersection(element_list)) > 0
106
+ # entries_df = entries_df[entries_df["elements"]](
107
+ # lambda example: isintersection(example["elements"])
108
+ # and isubset(example["elements"])
109
+ # )
110
+
111
+ df = df.filter((df.col("elements").list.contains(x) for x in element_list))
112
+ df = df.filter(
113
+ pl.col("elements")
114
+ .list.eval(pl.element().is_in(element_list))
115
+ .list.any()
116
+ .alias("check")
117
  )
118
 
119
+ entries_df = df.to_pandas()
120
 
121
  # Fetch all entries from the Materials Project database
122
  entries = [