Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from datasets import load_dataset, get_dataset_config_names | |
from functools import partial | |
from pandas import DataFrame | |
from PIL import Image | |
import gradio as gr | |
import numpy as np | |
import tqdm | |
import json | |
import os | |
DATASET = "satellogic/EarthView" | |
DEBUG = False | |
sets = { | |
"satellogic": { | |
"shards" : 3676, | |
}, | |
"sentinel_1": { | |
"shards" : 1763, | |
}, | |
"neon": { | |
"config" : "default", | |
"shards" : 607, | |
"path" : "data", | |
} | |
} | |
def open_dataset(dataset, set_name, split, batch_size, state, shard = -1): | |
if shard == -1: | |
# Trick to open the whole dataset | |
data_files = None | |
shards = 100 | |
else: | |
config = sets[set_name].get("config", set_name) | |
shards = sets[set_name]["shards"] | |
path = sets[set_name].get("path", set_name) | |
data_files = {"train":[f"{path}/{split}-{shard:05d}-of-{shards:05d}.parquet"]} | |
if DEBUG: | |
ds = lambda:None | |
ds.n_shards = 1234 | |
dsi = range(100) | |
else: | |
ds = load_dataset( | |
dataset, | |
config, | |
split=split, | |
cache_dir="dataset", | |
data_files=data_files, | |
streaming=True, | |
token=os.environ.get("HF_TOKEN", None)) | |
dsi = iter(ds) | |
state["config"] = config | |
state["dsi"] = dsi | |
return ( | |
gr.update(label=f"Shards (max {shards})", value=shard, maximum=shards), | |
*get_images(batch_size, state), | |
state | |
) | |
def item_to_images(config, item): | |
metadata = item["metadata"] | |
if type(metadata) == str: | |
metadata = json.loads(metadata) | |
item = { | |
k: np.asarray(v).astype("uint8") | |
for k,v in item.items() | |
if k != "metadata" | |
} | |
item["metadata"] = metadata | |
if config == "satellogic": | |
item["rgb"] = [ | |
Image.fromarray(image.transpose(1,2,0)) | |
for image in item["rgb"] | |
] | |
item["1m"] = [ | |
Image.fromarray(image[0,:,:]) | |
for image in item["1m"] | |
] | |
elif config == "sentinel_1": | |
# Mapping of V and H to RGB. May not be correct | |
# https://gis.stackexchange.com/questions/400726/creating-composite-rgb-images-from-sentinel-1-channels | |
i10m = item["10m"] | |
i10m = np.concatenate( | |
( i10m, | |
np.expand_dims( | |
i10m[:,0,:,:]/(i10m[:,1,:,:]+0.01)*256, | |
1 | |
).astype("uint8") | |
), | |
1 | |
) | |
item["10m"] = [ | |
Image.fromarray(image.transpose(1,2,0)) | |
for image in i10m | |
] | |
elif config == "default": | |
item["rgb"] = [ | |
Image.fromarray(image.transpose(1,2,0)) | |
for image in item["rgb"] | |
] | |
item["chm"] = [ | |
Image.fromarray(image[0]) | |
for image in item["chm"] | |
] | |
# The next is a very arbitrary conversion from the 369 hyperspectral data to RGB | |
# It just averages each 1/3 of the bads and assigns it to a channel | |
item["1m"] = [ | |
Image.fromarray( | |
np.concatenate(( | |
np.expand_dims(np.average(image[:124],0),2), | |
np.expand_dims(np.average(image[124:247],0),2), | |
np.expand_dims(np.average(image[247:],0),2)) | |
,2).astype("uint8")) | |
for image in item["1m"] | |
] | |
return item | |
def get_images(batch_size, state): | |
config = state["config"] | |
images = [] | |
metadatas = [] | |
for i in tqdm.trange(batch_size, desc=f"Getting images"): | |
if DEBUG: | |
image = np.random.randint(0,255,(384,384,3)) | |
metadata = {"bounds":[[1,1,4,4]], } | |
else: | |
try: | |
item = next(state["dsi"]) | |
except StopIteration: | |
break | |
metadata = item["metadata"] | |
item = item_to_images(config, item) | |
if config == "satellogic": | |
images.extend(item["rgb"]) | |
images.extend(item["1m"]) | |
if config == "sentinel_1": | |
images.extend(item["10m"]) | |
if config == "default": | |
images.extend(item["rgb"]) | |
images.extend(item["chm"]) | |
images.extend(item["1m"]) | |
metadatas.append(item["metadata"]) | |
return images, DataFrame(metadatas) | |
def update_shape(rows, columns): | |
return gr.update(rows=rows, columns=columns) | |
def new_state(): | |
return gr.State({}) | |
if __name__ == "__main__": | |
with gr.Blocks(title="Dataset Explorer", fill_height = True) as demo: | |
state = new_state() | |
gr.Markdown(f"# Viewer for [{DATASET}](https://huggingface.co/datasets/satellogic/EarthView) Dataset") | |
batch_size = gr.Number(10, label = "Batch Size", render=False) | |
shard = gr.Slider(label="Shard", minimum=0, maximum=10000, step=1, render=False) | |
table = gr.DataFrame(render = False) | |
# headers=["Index","TimeStamp","Bounds","CRS"], | |
gallery = gr.Gallery( | |
label=DATASET, | |
interactive=False, | |
columns=5, rows=2, render=False) | |
with gr.Row(): | |
dataset = gr.Textbox(label="Dataset", value=DATASET, interactive=False) | |
config = gr.Dropdown(choices=sets.keys(), label="Config", value="satellogic", ) | |
split = gr.Textbox(label="Split", value="train") | |
initial_shard = gr.Number(label = "Initial shard", value=0, info="-1 for whole dataset") | |
gr.Button("Load (minutes)").click( | |
open_dataset, | |
inputs=[dataset, config, split, batch_size, state, initial_shard], | |
outputs=[shard, gallery, table, state]) | |
gallery.render() | |
with gr.Row(): | |
batch_size.render() | |
rows = gr.Number(2, label="Rows") | |
columns = gr.Number(5, label="Coluns") | |
rows.change(update_shape, [rows, columns], [gallery]) | |
columns.change(update_shape, [rows, columns], [gallery]) | |
with gr.Row(): | |
shard.render() | |
shard.release( | |
open_dataset, | |
inputs=[dataset, config, split, batch_size, state, shard], | |
outputs=[shard, gallery, table, state]) | |
btn = gr.Button("Next Batch (same shard)", scale=0) | |
btn.click(get_images, [batch_size, state], [gallery, table]) | |
btn.click() | |
table.render() | |
demo.launch(show_api=False) | |