Spaces:
Running
Running
"""Script based on streamlit to visualize data from the Well that are hosted on Hugging Face hub. | |
Any time the state change (due to UI interaction and callbacks), | |
the script is evaluated again. | |
Based on the state attributes some UI component are rendered | |
(e.g. slider for field time step). | |
""" | |
import pathlib | |
import fsspec | |
import h5py | |
import numpy as np | |
import pyvista as pv | |
import streamlit as st | |
from stpyvista.trame_backend import stpyvista | |
# Dataset whose data will be visualized | |
DATASET_NAMES = [ | |
"acoustic_scattering_inclusions", | |
"active_matter", | |
"helmholtz_staircase", | |
"MHD_64", | |
"shear_flow", | |
] | |
DIM_SUFFIXES = ["x", "y", "z"] | |
# Options for HDF5 cloud optimized reads | |
IO_PARAMS = { | |
"fsspec_params": { | |
# "skip_instance_cache": True | |
"cache_type": "blockcache", # or "first" with enough space | |
"block_size": 2 * 1024 * 1024, # could be bigger | |
"token": st.secrets["HF_TOKEN"], | |
}, | |
"h5py_params": { | |
"driver_kwds": { # only recent versions of xarray and h5netcdf allow this correctly | |
"page_buf_size": 2 * 1024 * 1024, # this one only works in repacked files | |
"rdcc_nbytes": 2 * 1024 * 1024, # this one is to read the chunks | |
} | |
}, | |
} | |
# Instantiate streamlit state attributes | |
for key in ["file", "files", "field_names", "spatial_dim", "data"]: | |
if key not in st.session_state: | |
st.session_state[key] = None | |
def reset_state(key: str): | |
if key in st.session_state: | |
st.session_state[key] = None | |
del st.session_state[key] | |
def get_dataset_path(dataset_name: str) -> str: | |
"""Compose the path to the dataset on HF hub.""" | |
repo_id = "polymathic-ai" | |
dataset_path = f"hf://datasets/{repo_id}/{dataset_name}" | |
return dataset_path | |
def get_dataset_files(dataset_name: str): | |
"""Get the list of files in the dataset.""" | |
dataset_path = get_dataset_path(dataset_name) | |
fs, _ = fsspec.url_to_fs(dataset_path) | |
dataset_files = fs.glob(f"{dataset_path}/**/*.hdf5") | |
return dataset_files | |
def get_dataset_info(file_path: str) -> tuple([int, list[str]]): | |
"""Retrive spatial dimension and field names from the dataset.""" | |
file_path = f"hf://{file_path}" | |
with fsspec.open(file_path, "rb") as f, h5py.File(f, "r") as file: | |
spatial_dim = file.attrs["n_spatial_dims"] | |
field_names = [] | |
for field in file["t0_fields"].keys(): | |
field_names.append((field, "t0_fields")) | |
for field in file["t1_fields"].keys(): | |
for _, dim_suffix in zip(range(spatial_dim), DIM_SUFFIXES): | |
field_names.append((f"{field}_{dim_suffix}", "t1_fields")) | |
return spatial_dim, field_names | |
def dataset_info_callback(): | |
dataset_name = st.session_state.name | |
dataset_files = get_dataset_files(dataset_name) | |
st.session_state.files = dataset_files | |
spatial_dim, field_names = get_dataset_info(dataset_files[0]) | |
st.session_state.spatial_dim = spatial_dim | |
st.session_state.field_names = field_names | |
# Field data for previous dataset must be cleared | |
reset_state(key="data") | |
def get_field(file_path: str, field: tuple[str, str], spatial_dim: int) -> np.ndarray: | |
"""Load the first trajectory of a field in a given file.""" | |
file_path = f"hf://{file_path}" | |
field_name, field_tensor_order = field | |
if field_tensor_order == "t1_fields": | |
field_name_splits = field_name.split("_") | |
dim_suffix = field_name_splits[-1] | |
dim_index = DIM_SUFFIXES.index(dim_suffix) | |
field_name = "_".join(field_name_splits[:-1]) | |
else: | |
dim_index = None | |
with ( | |
fsspec.open(file_path, "rb", **IO_PARAMS["fsspec_params"]) as f, | |
h5py.File(f, "r", **IO_PARAMS["h5py_params"]) as file, | |
): | |
# Get the first trajectory of the file | |
# For tensor of order 1 take the relevant spatial dimension | |
if dim_index is not None: | |
take_indices = (0, ..., dim_index) | |
else: | |
take_indices = 0 | |
field_data = np.array(file[field_tensor_order][field_name][take_indices]) | |
return field_data | |
def field_callback(): | |
"""Callback to retrieve field data given file and field name state.""" | |
file = st.session_state.get("file", None) | |
if file: | |
field = st.session_state.field | |
spatial_dim = st.session_state.spatial_dim | |
field_data = get_field(file, field, spatial_dim) | |
st.session_state.data = field_data | |
# The field is constant | |
if st.session_state.data.ndim <= 2: | |
reset_state(key="time_step") | |
def create_plotter() -> pv.Plotter: | |
"""Create a pyvista.Plotter of the field in state.""" | |
# Check wether the field is dynamic | |
# to account for time in spatial dimension retrieval | |
time_step = st.session_state.get("time_step", None) | |
position_offset = 0 if time_step is None else 1 | |
# Create 2D or 3D grid | |
spatial_dim = st.session_state.spatial_dim | |
if spatial_dim == 2: | |
nx, ny = st.session_state.data.shape[position_offset:] | |
xrng = np.arange(0, nx) | |
yrng = np.arange(0, ny) | |
grid = pv.RectilinearGrid(xrng, yrng) | |
elif spatial_dim == 3: | |
nx, ny, nz = st.session_state.data.shape[position_offset:] | |
xrng = np.arange(0, nx) | |
yrng = np.arange(0, ny) | |
zrng = np.arange(0, nz) | |
grid = pv.RectilinearGrid(xrng, yrng, zrng) | |
# Set the grid scalar field | |
# If no time step is set the field is assumed to be constant | |
field_name = st.session_state.field[0] | |
if time_step is None: | |
grid[field_name] = st.session_state.data.ravel() | |
else: | |
grid[field_name] = st.session_state.data[time_step].ravel() | |
plotter = pv.Plotter(window_size=[400, 400]) | |
plotter.add_mesh(grid, scalars=field_name) | |
if spatial_dim == 2: | |
plotter.view_xy() | |
elif spatial_dim == 3: | |
plotter.view_isometric() | |
plotter.background_color = "white" | |
return plotter | |
st.set_page_config( | |
page_title="Tap into the Well", page_icon="assets/the_well_color_icon.svg" | |
) | |
st.image("assets/the_well_logo.png") | |
st.markdown(""" | |
[The Well](https://openreview.net/pdf?id=00Sx577BT3) is a collection of 15TB datasets of physics simulations. | |
This space allows you to tap into the Well by visualizing different datasets hosted on the [Hugging Face Hub](https://huggingface.co/polymathic-ai). | |
- Select a dataset | |
- Select a field | |
- Select a file | |
- Visualize different time steps | |
For field corresponding of higher tensor order (e.g. velocity) loading the data may be slow. | |
For this reason, we recommend downloading the data to work on the Well. | |
Check the [documentation](the-well.polymathic-ai.org) for more information. | |
""") | |
# The order of the following widget matters | |
# Field data is updated whenever a file or a field is selected | |
# Dataset selection | |
dataset = st.selectbox( | |
"Select a Dataset", | |
options=DATASET_NAMES, | |
index=None, | |
key="name", | |
on_change=dataset_info_callback, | |
) | |
# File selection | |
if st.session_state.name: | |
field_selector = st.selectbox( | |
"Select a field", | |
key="field", | |
options=st.session_state.field_names, | |
format_func=lambda option: option[0], # Fields are (name, tensor_order) | |
on_change=field_callback, | |
) | |
file_selector = st.selectbox( | |
"Select a file", | |
options=st.session_state.files, | |
key="file", | |
index=None, | |
format_func=lambda option: pathlib.Path(option).name, | |
on_change=field_callback, | |
) | |
if st.session_state.data is not None: | |
# Add a time step slider for dynamic fields | |
if st.session_state.data.ndim > 2: | |
time_step_slider = st.slider( | |
"Time step", | |
min_value=0, | |
value=0, | |
max_value=st.session_state.data.shape[0] - 1, | |
key="time_step", | |
) | |
if st.session_state.data is not None: | |
plotter = create_plotter() | |
stpyvista(plotter) | |