Spaces:
Sleeping
Sleeping
import gradio as gr | |
import earthview as ev | |
from PIL import Image | |
import numpy as np | |
import random | |
import os | |
import json | |
import utils | |
from pandas import DataFrame | |
# --- Configuration --- | |
DATASET_SUBSET = "satellogic" | |
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this | |
LABELED_DATA_FILE = "labeled_data.json" | |
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image. | |
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo. | |
BATCH_SIZE = 10 | |
# --- Load Dataset --- | |
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED]) | |
data_iter = iter(dataset) | |
# --- Load Labeled Data (if it exists) --- | |
def load_labeled_data(): | |
if os.path.exists(LABELED_DATA_FILE): | |
with open(LABELED_DATA_FILE, "r") as f: | |
return json.load(f) | |
else: | |
return [] | |
labeled_data = load_labeled_data() | |
# --- Get Next Sample for Labeling --- | |
def get_next_sample(): | |
global data_iter | |
try: | |
sample = next(data_iter) | |
sample = ev.item_to_images(DATASET_SUBSET, sample) | |
return sample | |
except StopIteration: | |
print("No more samples in the dataset.") | |
return None | |
def get_images(batch_size, state): | |
subset = state["subset"] | |
images = [] | |
metadatas = [] | |
for _ in range(batch_size): | |
sample = get_next_sample() | |
if sample is None: | |
break | |
image = sample["rgb"][0] | |
metadata = sample["metadata"] | |
metadata["map"] = f'<a href="{utils.get_google_map_link(sample, subset)}" target="_blank">🧭</a>' | |
images.append(image) | |
metadatas.append(metadata) | |
return images, DataFrame(metadatas) | |
# --- Save Labeled Data --- | |
def save_labeled_data(label, state): | |
global labeled_data | |
sample = state["sample"] | |
if sample is None: | |
return "No image to label", None, DataFrame() | |
image = sample["rgb"][0] # Get the PIL Image object | |
image_bytes = image.convert("RGB").tobytes() | |
labeled_data.append({ | |
"image": image_bytes, | |
"metadata": sample["metadata"], | |
"label": label | |
}) | |
with open(LABELED_DATA_FILE, "w") as f: | |
json.dump(labeled_data, f) | |
new_sample = get_next_sample() | |
if new_sample is None: | |
state["sample"] = None | |
return "Dataset exhausted.", None, DataFrame() | |
state["sample"] = new_sample | |
new_image = new_sample["rgb"][0] | |
new_metadata = new_sample["metadata"] | |
new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>' | |
return "", new_image, DataFrame([new_metadata]) | |
# --- Gradio Interface --- | |
# --- Labeling UI --- | |
def labeling_ui(): | |
state = gr.State({"sample": None, "subset": DATASET_SUBSET}) | |
with gr.Row(): | |
with gr.Column(): | |
gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down") | |
with gr.Row(): | |
cool_button = gr.Button("Cool") | |
not_cool_button = gr.Button("Not Cool") | |
table = gr.DataFrame(datatype="html") | |
def initialize_labeling_ui(): | |
sample = get_next_sample() | |
image, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET}) | |
return sample, image, metadata | |
initial_sample, initial_image, initial_metadata = initialize_labeling_ui() | |
gallery.value = initial_image | |
table.value = initial_metadata | |
state.value["sample"] = initial_sample | |
# Handle button clicks | |
cool_button.click( | |
fn=lambda label, state: save_labeled_data(label, state), | |
inputs=[gr.Textbox(visible=False, value="cool"), state], | |
outputs=[gr.Textbox(label="Debug"), gallery, table] | |
) | |
not_cool_button.click( | |
fn=lambda label, state: save_labeled_data(label, state), | |
inputs=[gr.Textbox(visible=False, value="not cool"), state], | |
outputs=[gr.Textbox(label="Debug"), gallery, table] | |
) | |
# --- Display UI --- | |
def display_ui(): | |
def get_random_cool_images(n): | |
cool_samples = [d for d in labeled_data if d["label"] == "cool"] | |
return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else [] | |
def get_new_unlabeled_image(): | |
global data_iter | |
try: | |
sample = next(data_iter) | |
sample = ev.item_to_images(DATASET_SUBSET, sample) | |
return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"]) | |
except StopIteration: | |
print("No more samples in the dataset.") | |
return None, None | |
def refresh_display(): | |
new_image, new_metadata = get_new_unlabeled_image() | |
cool_images = get_random_cool_images(DISPLAY_N_COOL) | |
if new_image is None: | |
return "No more samples", None, [] | |
return "", new_image, cool_images | |
with gr.Row(): | |
new_image_component = gr.Image(label="New Image", type="pil") | |
metadata_display = gr.Textbox(label="Metadata (Bounds)") | |
with gr.Row(): | |
cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL) | |
refresh_button = gr.Button("Refresh") | |
refresh_button.click( | |
fn=refresh_display, | |
inputs=[], | |
outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery] | |
) | |
def initialize_display_ui(): | |
debug, image, gallery = refresh_display() | |
return debug, image, gallery | |
debug, initial_image, initial_gallery = initialize_display_ui() | |
new_image_component.value = initial_image | |
cool_images_gallery.value = initial_gallery | |
# --- Main Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# TerraNomaly") | |
with gr.Tabs(): | |
with gr.TabItem("Labeling"): | |
labeling_ui() | |
with gr.TabItem("Display"): | |
display_ui() | |
demo.launch(debug=True) |