from datasets import load_dataset, get_dataset_config_names from functools import partial from pandas import DataFrame from PIL import Image import gradio as gr import numpy as np import tqdm import json import os DATASET = "satellogic/EarthView" DEBUG = False sets = { "satellogic": { "shards" : 3676, }, "sentinel_1": { "shards" : 1763, }, "neon": { "config" : "default", "shards" : 607, "path" : "data", } } def open_dataset(dataset, set_name, split, batch_size, state, shard = -1): if shard == -1: # Trick to open the whole dataset data_files = None shards = 100 else: config = sets[set_name].get("config", set_name) shards = sets[set_name]["shards"] path = sets[set_name].get("path", set_name) data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]} if DEBUG: ds = lambda:None ds.n_shards = 1234 dsi = range(100) else: ds = load_dataset( dataset, config, split=split, cache_dir="dataset", data_files=data_files, streaming=True, token=os.environ.get("HF_TOKEN", None)) dsi = iter(ds) state["config"] = config state["dsi"] = dsi return ( gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards), *get_images(batch_size, state), state ) def item_to_images(config, item): metadata = item["metadata"] if type(metadata) == str: metadata = json.loads(metadata) item = { k: np.asarray(v).astype("uint8") for k,v in item.items() if k != "metadata" } item["metadata"] = metadata if config == "satellogic": item["rgb"] = [ Image.fromarray(image.transpose(1,2,0)) for image in item["rgb"] ] item["1m"] = [ Image.fromarray(image[0,:,:]) for image in item["1m"] ] elif config == "sentinel_1": # Mapping of V and H to RGB. May not be correct # https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels i10m = item["10m"] i10m = np.concatenate( ( i10m, np.expand_dims( i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256, 1 ).astype("uint8") ), 1 ) item["10m"] = [ Image.fromarray(image.transpose(1,2,0)) for image in i10m ] elif config == "default": item["rgb"] = [ Image.fromarray(image.transpose(1,2,0)) for image in item["rgb"] ] item["chm"] = [ Image.fromarray(image[0]) for image in item["chm"] ] # The next is a very arbitrary conversion from the 369 hyperspectral data to RGB # It just averages each 1/3 of the bads and assigns it to a channel item["1m"] = [ Image.fromarray( np.concatenate(( np.expand_dims(np.average(image[:124],0),2), np.expand_dims(np.average(image[124:247],0),2), np.expand_dims(np.average(image[247:],0),2)) ,2).astype("uint8")) for image in item["1m"] ] return item def get_images(batch_size, state): config = state["config"] images = [] metadatas = [] for i in tqdm.trange(batch_size, desc=f"Getting images"): if DEBUG: image = np.random.randint(0,255,(384,384,3)) metadata = {"bounds":[[1,1,4,4]], } else: try: item = next(state["dsi"]) except StopIteration: break metadata = item["metadata"] item = item_to_images(config, item) if config == "satellogic": images.extend(item["rgb"]) images.extend(item["1m"]) if config == "sentinel_1": images.extend(item["10m"]) if config == "default": images.extend(item["rgb"]) images.extend(item["chm"]) images.extend(item["1m"]) metadatas.append(item["metadata"]) return images, DataFrame(metadatas) def update_shape(rows, columns): return gr.update(rows=rows, columns=columns) def new_state(): return gr.State({}) if __name__ == "__main__": with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo: state = new_state() gr.Markdown(f"# Viewer for [{DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset") batch_size = gr.Number(10, label = "Batch Size", render=False) shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) table = gr.DataFrame(render = False) # headers=["Index","TimeStamp","Bounds","CRS"], gallery = gr.Gallery( label=DATASET, interactive=False, columns=5, rows=2, render=False) with gr.Row(): dataset = gr.Textbox(label="Dataset", value=DATASET, interactive=False) config = gr.Dropdown(choices=sets.keys(), label="Config", value="satellogic", ) split = gr.Textbox(label="Split", value="train") initial_shard = gr.Number(label = "Initial shard", value=0, info="-1 for whole dataset") gr.Button("Load (minutes)").click( open_dataset, inputs=[dataset, config, split, batch_size, state, initial_shard], outputs=[shard, gallery, table, state]) gallery.render() with gr.Row(): batch_size.render() rows = gr.Number(2, label="Rows") columns = gr.Number(5, label="Coluns") rows.change(update_shape, [rows, columns], [gallery]) columns.change(update_shape, [rows, columns], [gallery]) with gr.Row(): shard.render() shard.release( open_dataset, inputs=[dataset, config, split, batch_size, state, shard], outputs=[shard, gallery, table, state]) btn = gr.Button("Next Batch (same shard)", scale=0) btn.click(get_images, [batch_size, state], [gallery, table]) btn.click() table.render() demo.launch(show_api=False)