Spaces:
Sleeping
Sleeping
import gradio as gr | |
import earthview as ev | |
from PIL import Image | |
import numpy as np | |
import random | |
import os | |
import json | |
# --- Configuration --- | |
DATASET_SUBSET = "satellogic" | |
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this | |
LABELED_DATA_FILE = "labeled_data.json" | |
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image. | |
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo. | |
# --- Load Dataset --- | |
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED]) | |
data_iter = iter(dataset) | |
# --- Load Labeled Data (if it exists) --- | |
def load_labeled_data(): | |
if os.path.exists(LABELED_DATA_FILE): | |
with open(LABELED_DATA_FILE, "r") as f: | |
return json.load(f) | |
else: | |
return [] | |
labeled_data = load_labeled_data() | |
# --- Get Next Sample for Labeling --- | |
def get_next_sample(): | |
global data_iter | |
try: | |
sample = next(data_iter) | |
sample = ev.item_to_images(DATASET_SUBSET, sample) | |
image = sample["rgb"][0] | |
metadata = sample["metadata"] | |
return image, metadata, len(labeled_data) | |
except StopIteration: | |
return None, None, None | |
# --- Save Labeled Data --- | |
def save_labeled_data(image, metadata, label): | |
global labeled_data | |
labeled_data.append({ | |
"image": image.convert("RGB").tobytes(), # Convert to RGB and then serialize | |
"metadata": metadata, | |
"label": label | |
}) | |
with open(LABELED_DATA_FILE, "w") as f: | |
json.dump(labeled_data, f) | |
image, metadata, count = get_next_sample() | |
if image is None: | |
return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples." | |
return "", image, str(metadata["bounds"]), f"Labeled {count} samples." | |
# --- Gradio Interface --- | |
# --- Labeling UI --- | |
def labeling_ui(): | |
with gr.Row(): | |
with gr.Column(): | |
image_component = gr.Image(label="Satellite Image", type="pil") | |
metadata_text = gr.Textbox(label="Metadata (Bounds)") | |
label_count_text = gr.Textbox(label="Label Count") | |
with gr.Row(): | |
cool_button = gr.Button("Cool") | |
not_cool_button = gr.Button("Not Cool") | |
# Handle button clicks | |
cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]) | |
not_cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text]) | |
# Initialize with the first sample | |
image, metadata, count = get_next_sample() | |
if image is not None: | |
image_component.value = image | |
metadata_text.value = str(metadata["bounds"]) | |
label_count_text.value = f"Labeled {count} samples." | |
# --- Display UI --- | |
def display_ui(): | |
def get_random_cool_images(n): | |
cool_samples = [d for d in labeled_data if d["label"] == "cool"] | |
if len(cool_samples) < n: | |
return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples] | |
selected_cool = random.sample(cool_samples, n) | |
return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool] | |
def get_new_unlabeled_image(): | |
global data_iter | |
try: | |
sample = next(data_iter) | |
sample = ev.item_to_images(DATASET_SUBSET, sample) | |
image = sample["rgb"][0] | |
metadata = sample["metadata"] | |
return image, str(metadata["bounds"]) | |
except StopIteration: | |
return None, None | |
def refresh_display(): | |
new_image, new_metadata = get_new_unlabeled_image() | |
if new_image is None: | |
return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[]) | |
cool_images = get_random_cool_images(DISPLAY_N_COOL) | |
return "", new_image, cool_images | |
with gr.Row(): | |
new_image_component = gr.Image(label="New Image", type="pil") | |
metadata_display = gr.Textbox(label="Metadata (Bounds)") | |
with gr.Row(): | |
cool_images_gallery = gr.Gallery(label="Cool Examples", value=[]) | |
cool_images_gallery.style(grid=DISPLAY_N_COOL) | |
with gr.Row(): | |
refresh_button = gr.Button("Refresh") | |
refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery]) | |
# Initialize | |
debug, image, gallery = refresh_display() | |
new_image_component.value = image | |
cool_images_gallery.value = gallery | |
# --- Main Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# TerraNomaly") | |
with gr.Tabs(): | |
with gr.TabItem("Labeling"): | |
labeling_ui() | |
with gr.TabItem("Display"): | |
display_ui() | |
demo.launch(debug=True) |