Spaces:
Sleeping
Sleeping
File size: 4,963 Bytes
465c443 63711ea 465c443 63711ea 465c443 63711ea 465c443 63711ea 465c443 63711ea d3015aa 63711ea 465c443 63711ea 465c443 63711ea d3015aa 63711ea 465c443 63711ea 465c443 4f33c08 63711ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json
# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)
# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
if os.path.exists(LABELED_DATA_FILE):
with open(LABELED_DATA_FILE, "r") as f:
return json.load(f)
else:
return []
labeled_data = load_labeled_data()
# --- Get Next Sample for Labeling ---
def get_next_sample():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, metadata, len(labeled_data)
except StopIteration:
return None, None, None
# --- Save Labeled Data ---
def save_labeled_data(image, metadata, label):
global labeled_data
labeled_data.append({
"image": image.convert("RGB").tobytes(), # Convert to RGB and then serialize
"metadata": metadata,
"label": label
})
with open(LABELED_DATA_FILE, "w") as f:
json.dump(labeled_data, f)
image, metadata, count = get_next_sample()
if image is None:
return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."
return "", image, str(metadata["bounds"]), f"Labeled {count} samples."
# --- Gradio Interface ---
# --- Labeling UI ---
def labeling_ui():
with gr.Row():
with gr.Column():
image_component = gr.Image(label="Satellite Image", type="pil")
metadata_text = gr.Textbox(label="Metadata (Bounds)")
label_count_text = gr.Textbox(label="Label Count")
with gr.Row():
cool_button = gr.Button("Cool")
not_cool_button = gr.Button("Not Cool")
# Handle button clicks
cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
not_cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
# Initialize with the first sample
image, metadata, count = get_next_sample()
if image is not None:
image_component.value = image
metadata_text.value = str(metadata["bounds"])
label_count_text.value = f"Labeled {count} samples."
# --- Display UI ---
def display_ui():
def get_random_cool_images(n):
cool_samples = [d for d in labeled_data if d["label"] == "cool"]
if len(cool_samples) < n:
return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
selected_cool = random.sample(cool_samples, n)
return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]
def get_new_unlabeled_image():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, str(metadata["bounds"])
except StopIteration:
return None, None
def refresh_display():
new_image, new_metadata = get_new_unlabeled_image()
if new_image is None:
return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
cool_images = get_random_cool_images(DISPLAY_N_COOL)
return "", new_image, cool_images
with gr.Row():
new_image_component = gr.Image(label="New Image", type="pil")
metadata_display = gr.Textbox(label="Metadata (Bounds)")
with gr.Row():
cool_images_gallery = gr.Gallery(label="Cool Examples", value=[])
cool_images_gallery.style(grid=DISPLAY_N_COOL)
with gr.Row():
refresh_button = gr.Button("Refresh")
refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])
# Initialize
debug, image, gallery = refresh_display()
new_image_component.value = image
cool_images_gallery.value = gallery
# --- Main Interface ---
with gr.Blocks() as demo:
gr.Markdown("# TerraNomaly")
with gr.Tabs():
with gr.TabItem("Labeling"):
labeling_ui()
with gr.TabItem("Display"):
display_ui()
demo.launch(debug=True) |