TerraNomaly / app.py
dropbop's picture
Update app.py
63711ea verified
raw
history blame
4.96 kB
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json
# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)
# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
if os.path.exists(LABELED_DATA_FILE):
with open(LABELED_DATA_FILE, "r") as f:
return json.load(f)
else:
return []
labeled_data = load_labeled_data()
# --- Get Next Sample for Labeling ---
def get_next_sample():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, metadata, len(labeled_data)
except StopIteration:
return None, None, None
# --- Save Labeled Data ---
def save_labeled_data(image, metadata, label):
global labeled_data
labeled_data.append({
"image": image.convert("RGB").tobytes(), # Convert to RGB and then serialize
"metadata": metadata,
"label": label
})
with open(LABELED_DATA_FILE, "w") as f:
json.dump(labeled_data, f)
image, metadata, count = get_next_sample()
if image is None:
return "No more samples", gr.Image.update(value=None), "", f"Labeled {count} samples."
return "", image, str(metadata["bounds"]), f"Labeled {count} samples."
# --- Gradio Interface ---
# --- Labeling UI ---
def labeling_ui():
with gr.Row():
with gr.Column():
image_component = gr.Image(label="Satellite Image", type="pil")
metadata_text = gr.Textbox(label="Metadata (Bounds)")
label_count_text = gr.Textbox(label="Label Count")
with gr.Row():
cool_button = gr.Button("Cool")
not_cool_button = gr.Button("Not Cool")
# Handle button clicks
cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
not_cool_button.click(fn=lambda image, metadata: save_labeled_data(image, metadata, "not cool"), inputs=[image_component, metadata_text], outputs=[gr.Textbox(label="Debug"), image_component, metadata_text, label_count_text])
# Initialize with the first sample
image, metadata, count = get_next_sample()
if image is not None:
image_component.value = image
metadata_text.value = str(metadata["bounds"])
label_count_text.value = f"Labeled {count} samples."
# --- Display UI ---
def display_ui():
def get_random_cool_images(n):
cool_samples = [d for d in labeled_data if d["label"] == "cool"]
if len(cool_samples) < n:
return [Image.frombytes("RGB", (384,384), s["image"]) for s in cool_samples]
selected_cool = random.sample(cool_samples, n)
return [Image.frombytes("RGB", (384,384), s["image"]) for s in selected_cool]
def get_new_unlabeled_image():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
return image, str(metadata["bounds"])
except StopIteration:
return None, None
def refresh_display():
new_image, new_metadata = get_new_unlabeled_image()
if new_image is None:
return "No more samples", gr.Image.update(value=None), gr.Gallery.update(value=[])
cool_images = get_random_cool_images(DISPLAY_N_COOL)
return "", new_image, cool_images
with gr.Row():
new_image_component = gr.Image(label="New Image", type="pil")
metadata_display = gr.Textbox(label="Metadata (Bounds)")
with gr.Row():
cool_images_gallery = gr.Gallery(label="Cool Examples", value=[])
cool_images_gallery.style(grid=DISPLAY_N_COOL)
with gr.Row():
refresh_button = gr.Button("Refresh")
refresh_button.click(fn=refresh_display, inputs=[], outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery])
# Initialize
debug, image, gallery = refresh_display()
new_image_component.value = image
cool_images_gallery.value = gallery
# --- Main Interface ---
with gr.Blocks() as demo:
gr.Markdown("# TerraNomaly")
with gr.Tabs():
with gr.TabItem("Labeling"):
labeling_ui()
with gr.TabItem("Display"):
display_ui()
demo.launch(debug=True)