TerraNomaly / app.py
dropbop's picture
Update app.py
c501c19 verified
raw
history blame
6 kB
import gradio as gr
import earthview as ev
from PIL import Image
import numpy as np
import random
import os
import json
import utils
from pandas import DataFrame
# --- Configuration ---
DATASET_SUBSET = "satellogic"
NUM_SAMPLES_TO_LABEL = 100 # You can adjust this
LABELED_DATA_FILE = "labeled_data.json"
DISPLAY_N_COOL = 5 # How many "cool" examples to display alongside each new image.
SAMPLE_SEED = 10 # The seed to use when sampling the dataset for the demo.
BATCH_SIZE = 10
# --- Load Dataset ---
dataset = ev.load_dataset(DATASET_SUBSET, shards=[SAMPLE_SEED])
data_iter = iter(dataset)
# --- Load Labeled Data (if it exists) ---
def load_labeled_data():
if os.path.exists(LABELED_DATA_FILE):
with open(LABELED_DATA_FILE, "r") as f:
return json.load(f)
else:
return []
labeled_data = load_labeled_data()
# --- Get Next Sample for Labeling ---
def get_next_sample():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
return sample
except StopIteration:
print("No more samples in the dataset.")
return None
def get_images(batch_size, state):
subset = state["subset"]
images = []
metadatas = []
for _ in range(batch_size):
sample = get_next_sample()
if sample is None:
break
image = sample["rgb"][0]
metadata = sample["metadata"]
metadata["map"] = f'<a href="{utils.get_google_map_link(sample, subset)}" target="_blank">🧭</a>'
images.append(image)
metadatas.append(metadata)
return images, DataFrame(metadatas)
# --- Save Labeled Data ---
def save_labeled_data(label, state):
global labeled_data
sample = state["sample"]
if sample is None:
return "No image to label", None, DataFrame()
image = sample["rgb"][0] # Get the PIL Image object
image_bytes = image.convert("RGB").tobytes()
labeled_data.append({
"image": image_bytes,
"metadata": sample["metadata"],
"label": label
})
with open(LABELED_DATA_FILE, "w") as f:
json.dump(labeled_data, f)
new_sample = get_next_sample()
if new_sample is None:
state["sample"] = None
return "Dataset exhausted.", None, DataFrame()
state["sample"] = new_sample
new_image = new_sample["rgb"][0]
new_metadata = new_sample["metadata"]
new_metadata["map"] = f'<a href="{utils.get_google_map_link(new_sample, DATASET_SUBSET)}" target="_blank">🧭</a>'
return "", new_image, DataFrame([new_metadata])
# --- Gradio Interface ---
# --- Labeling UI ---
def labeling_ui():
state = gr.State({"sample": None, "subset": DATASET_SUBSET})
with gr.Row():
with gr.Column():
gallery = gr.Gallery(label="Satellite Image", interactive=False, columns=1, object_fit="scale-down")
with gr.Row():
cool_button = gr.Button("Cool")
not_cool_button = gr.Button("Not Cool")
table = gr.DataFrame(datatype="html")
def initialize_labeling_ui():
sample = get_next_sample()
image, metadata = get_images(1, {"sample": None, "subset": DATASET_SUBSET})
return sample, image, metadata
initial_sample, initial_image, initial_metadata = initialize_labeling_ui()
gallery.value = initial_image
table.value = initial_metadata
state.value["sample"] = initial_sample
# Handle button clicks
cool_button.click(
fn=lambda label, state: save_labeled_data(label, state),
inputs=[gr.Textbox(visible=False, value="cool"), state],
outputs=[gr.Textbox(label="Debug"), gallery, table]
)
not_cool_button.click(
fn=lambda label, state: save_labeled_data(label, state),
inputs=[gr.Textbox(visible=False, value="not cool"), state],
outputs=[gr.Textbox(label="Debug"), gallery, table]
)
# --- Display UI ---
def display_ui():
def get_random_cool_images(n):
cool_samples = [d for d in labeled_data if d["label"] == "cool"]
return [Image.frombytes("RGB", (384, 384), s["image"]) for s in cool_samples] if len(cool_samples) >= n else []
def get_new_unlabeled_image():
global data_iter
try:
sample = next(data_iter)
sample = ev.item_to_images(DATASET_SUBSET, sample)
return sample["rgb"][0], json.dumps(sample["metadata"]["bounds"])
except StopIteration:
print("No more samples in the dataset.")
return None, None
def refresh_display():
new_image, new_metadata = get_new_unlabeled_image()
cool_images = get_random_cool_images(DISPLAY_N_COOL)
if new_image is None:
return "No more samples", None, []
return "", new_image, cool_images
with gr.Row():
new_image_component = gr.Image(label="New Image", type="pil")
metadata_display = gr.Textbox(label="Metadata (Bounds)")
with gr.Row():
cool_images_gallery = gr.Gallery(label="Cool Examples", value=[], columns=DISPLAY_N_COOL)
refresh_button = gr.Button("Refresh")
refresh_button.click(
fn=refresh_display,
inputs=[],
outputs=[gr.Textbox(label="Debug"), new_image_component, cool_images_gallery]
)
def initialize_display_ui():
debug, image, gallery = refresh_display()
return debug, image, gallery
debug, initial_image, initial_gallery = initialize_display_ui()
new_image_component.value = initial_image
cool_images_gallery.value = initial_gallery
# --- Main Interface ---
with gr.Blocks() as demo:
gr.Markdown("# TerraNomaly")
with gr.Tabs():
with gr.TabItem("Labeling"):
labeling_ui()
with gr.TabItem("Display"):
display_ui()
demo.launch(debug=True)