LTMeyer commited on
Commit
d8efd6b
·
0 Parent(s):

Add visualization application

Browse files
Files changed (4) hide show
  1. README.md +14 -0
  2. app.py +239 -0
  3. assets +1 -0
  4. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TheWell
3
+ emoji: 🌍
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.40.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bsd-3-clause-clear
11
+ short_description: Visualization of data from the Well
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script based on streamlit to visualize data from the Well that are hosted on Hugging Face hub.
2
+
3
+ Any time the state change (due to UI interaction and callbacks),
4
+ the script is evaluated again.
5
+ Based on the state attributes some UI component are rendered
6
+ (e.g. slider for field time step).
7
+
8
+ """
9
+
10
+ import pathlib
11
+
12
+ import fsspec
13
+ import h5py
14
+ import numpy as np
15
+ import pyvista as pv
16
+ import streamlit as st
17
+ from stpyvista.trame_backend import stpyvista
18
+
19
+ # Dataset whose data will be visualized
20
+ DATASET_NAMES = [
21
+ "acoustic_scattering_inclusions",
22
+ "active_matter",
23
+ "helmholtz_staircase",
24
+ "MHD_64",
25
+ "shear_flow",
26
+ ]
27
+
28
+ DIM_SUFFIXES = ["x", "y", "z"]
29
+
30
+ # Options for HDF5 cloud optimized reads
31
+ IO_PARAMS = {
32
+ "fsspec_params": {
33
+ # "skip_instance_cache": True
34
+ "cache_type": "blockcache", # or "first" with enough space
35
+ "block_size": 2 * 1024 * 1024, # could be bigger
36
+ },
37
+ "h5py_params": {
38
+ "driver_kwds": { # only recent versions of xarray and h5netcdf allow this correctly
39
+ "page_buf_size": 2 * 1024 * 1024, # this one only works in repacked files
40
+ "rdcc_nbytes": 2 * 1024 * 1024, # this one is to read the chunks
41
+ }
42
+ },
43
+ }
44
+
45
+
46
+ # Instantiate streamlit state attributes
47
+ for key in ["file", "files", "field_names", "spatial_dim", "data"]:
48
+ if key not in st.session_state:
49
+ st.session_state[key] = None
50
+
51
+
52
+ def reset_state(key: str):
53
+ if key in st.session_state:
54
+ st.session_state[key] = None
55
+ del st.session_state[key]
56
+
57
+
58
+ @st.cache_data
59
+ def get_dataset_path(dataset_name: str) -> str:
60
+ """Compose the path to the dataset on HF hub."""
61
+ repo_id = "polymathic-ai"
62
+ dataset_path = f"hf://datasets/{repo_id}/{dataset_name}"
63
+ return dataset_path
64
+
65
+
66
+ @st.cache_data
67
+ def get_dataset_files(dataset_name: str):
68
+ """Get the list of files in the dataset."""
69
+ dataset_path = get_dataset_path(dataset_name)
70
+ fs, _ = fsspec.url_to_fs(dataset_path)
71
+ dataset_files = fs.glob(f"{dataset_path}/**/*.hdf5")
72
+ return dataset_files
73
+
74
+
75
+ @st.cache_data
76
+ def get_dataset_info(file_path: str) -> tuple([int, list[str]]):
77
+ """Retrive spatial dimension and field names from the dataset."""
78
+ file_path = f"hf://{file_path}"
79
+ with fsspec.open(file_path, "rb") as f, h5py.File(f, "r") as file:
80
+ spatial_dim = file.attrs["n_spatial_dims"]
81
+ field_names = []
82
+ for field in file["t0_fields"].keys():
83
+ field_names.append((field, "t0_fields"))
84
+ for field in file["t1_fields"].keys():
85
+ for _, dim_suffix in zip(range(spatial_dim), DIM_SUFFIXES):
86
+ field_names.append((f"{field}_{dim_suffix}", "t1_fields"))
87
+
88
+ return spatial_dim, field_names
89
+
90
+
91
+ def dataset_info_callback():
92
+ dataset_name = st.session_state.name
93
+ dataset_files = get_dataset_files(dataset_name)
94
+ st.session_state.files = dataset_files
95
+ spatial_dim, field_names = get_dataset_info(dataset_files[0])
96
+ st.session_state.spatial_dim = spatial_dim
97
+ st.session_state.field_names = field_names
98
+ # Field data for previous dataset must be cleared
99
+ reset_state(key="data")
100
+
101
+
102
+ @st.cache_data
103
+ def get_field(file_path: str, field: tuple[str, str], spatial_dim: int) -> np.ndarray:
104
+ """Load the first trajectory of a field in a given file."""
105
+ file_path = f"hf://{file_path}"
106
+ field_name, field_tensor_order = field
107
+ if field_tensor_order == "t1_fields":
108
+ field_name_splits = field_name.split("_")
109
+ dim_suffix = field_name_splits[-1]
110
+ dim_index = DIM_SUFFIXES.index(dim_suffix)
111
+ field_name = "_".join(field_name_splits[:-1])
112
+ else:
113
+ dim_index = None
114
+ with (
115
+ fsspec.open(file_path, "rb", **IO_PARAMS["fsspec_params"]) as f,
116
+ h5py.File(f, "r", **IO_PARAMS["h5py_params"]) as file,
117
+ ):
118
+ # Get the first trajectory of the file
119
+ # For tensor of order 1 take the relevant spatial dimension
120
+ if dim_index is not None:
121
+ take_indices = (0, ..., dim_index)
122
+ else:
123
+ take_indices = 0
124
+ field_data = np.array(file[field_tensor_order][field_name][take_indices])
125
+
126
+ return field_data
127
+
128
+
129
+ def field_callback():
130
+ """Callback to retrieve field data given file and field name state."""
131
+ file = st.session_state.get("file", None)
132
+ if file:
133
+ field = st.session_state.field
134
+ spatial_dim = st.session_state.spatial_dim
135
+ field_data = get_field(file, field, spatial_dim)
136
+ st.session_state.data = field_data
137
+ # The field is constant
138
+ if st.session_state.data.ndim <= 2:
139
+ reset_state(key="time_step")
140
+
141
+
142
+ def create_plotter() -> pv.Plotter:
143
+ """Create a pyvista.Plotter of the field in state."""
144
+ # Check wether the field is dynamic
145
+ # to account for time in spatial dimension retrieval
146
+ time_step = st.session_state.get("time_step", None)
147
+ position_offset = 0 if time_step is None else 1
148
+ # Create 2D or 3D grid
149
+ spatial_dim = st.session_state.spatial_dim
150
+ if spatial_dim == 2:
151
+ nx, ny = st.session_state.data.shape[position_offset:]
152
+ xrng = np.arange(0, nx)
153
+ yrng = np.arange(0, ny)
154
+ grid = pv.RectilinearGrid(xrng, yrng)
155
+ elif spatial_dim == 3:
156
+ nx, ny, nz = st.session_state.data.shape[position_offset:]
157
+ xrng = np.arange(0, nx)
158
+ yrng = np.arange(0, ny)
159
+ zrng = np.arange(0, nz)
160
+ grid = pv.RectilinearGrid(xrng, yrng, zrng)
161
+ # Set the grid scalar field
162
+ # If no time step is set the field is assumed to be constant
163
+ field_name = st.session_state.field[0]
164
+ if time_step is None:
165
+ grid[field_name] = st.session_state.data.ravel()
166
+ else:
167
+ grid[field_name] = st.session_state.data[time_step].ravel()
168
+
169
+ plotter = pv.Plotter(window_size=[400, 400])
170
+ plotter.add_mesh(grid, scalars=field_name)
171
+ if spatial_dim == 2:
172
+ plotter.view_xy()
173
+ elif spatial_dim == 3:
174
+ plotter.view_isometric()
175
+ plotter.background_color = "white"
176
+ return plotter
177
+
178
+
179
+ st.set_page_config(
180
+ page_title="Tap into the Well", page_icon="assets/the_well_color_icon.svg"
181
+ )
182
+ st.image("assets/the_well_logo.png")
183
+ st.markdown("""
184
+ [The Well](https://openreview.net/pdf?id=00Sx577BT3) is a collection of 15TB datasets of physics simulations.
185
+
186
+ This space allows you to tap into the Well by visualizing different datasets hosted on the [Hugging Face Hub](https://huggingface.co/polymathic-ai).
187
+ - Select a dataset
188
+ - Select a field
189
+ - Select a file
190
+ - Visualize different time steps
191
+
192
+ For field corresponding of higher tensor order (e.g. velocity) loading the data may be slow.
193
+ For this reason, we recommend downloading the data to work on the Well.
194
+ Check the [documentation](the-well.polymathic-ai.org) for more information.
195
+
196
+ """)
197
+ # The order of the following widget matters
198
+ # Field data is updated whenever a file or a field is selected
199
+
200
+ # Dataset selection
201
+ dataset = st.selectbox(
202
+ "Select a Dataset",
203
+ options=DATASET_NAMES,
204
+ index=None,
205
+ key="name",
206
+ on_change=dataset_info_callback,
207
+ )
208
+
209
+ # File selection
210
+ if st.session_state.name:
211
+ field_selector = st.selectbox(
212
+ "Select a field",
213
+ key="field",
214
+ options=st.session_state.field_names,
215
+ format_func=lambda option: option[0], # Fields are (name, tensor_order)
216
+ on_change=field_callback,
217
+ )
218
+ file_selector = st.selectbox(
219
+ "Select a file",
220
+ options=st.session_state.files,
221
+ key="file",
222
+ index=None,
223
+ format_func=lambda option: pathlib.Path(option).name,
224
+ on_change=field_callback,
225
+ )
226
+ if st.session_state.data is not None:
227
+ # Add a time step slider for dynamic fields
228
+ if st.session_state.data.ndim > 2:
229
+ time_step_slider = st.slider(
230
+ "Time step",
231
+ min_value=0,
232
+ value=0,
233
+ max_value=st.session_state.data.shape[0] - 1,
234
+ key="time_step",
235
+ )
236
+
237
+ if st.session_state.data is not None:
238
+ plotter = create_plotter()
239
+ stpyvista(plotter)
assets ADDED
@@ -0,0 +1 @@
 
 
1
+ ../../assets
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fsspec
2
+ h5py >= 3.11.0
3
+ numpy >= 1.26.4
4
+ pyvista >= 0.44.1
5
+ stpyvista >= 0.1.4