Spaces:
Sleeping
Sleeping
File size: 5,251 Bytes
465c443 63711ea 7d5835f 80acacc 7d5835f 63711ea 5ec1d6a 7c425bb 09f8048 ef924e1 09f8048 fb2c0c4 ef924e1 09f8048 ef924e1 09f8048 fb2c0c4 09f8048 fb2c0c4 09f8048 ef924e1 09f8048 ef924e1 09f8048 ef924e1 09f8048 ef924e1 09f8048 7d5835f ef924e1 13eef3f fb2c0c4 13eef3f 2a6e97e ef924e1 2a6e97e ef924e1 fb2c0c4 ef924e1 2a6e97e ef924e1 2a6e97e ef924e1 2a6e97e fb2c0c4 ef924e1 2a6e97e ef924e1 2a6e97e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import earthview as ev
import utils
import random
import pandas as pd
import os
from itertools import islice
# Configuration
chunk_size = 100 # Size of the chunks to shuffle
label_file = os.path.join(os.path.dirname(__file__), "labels.csv") # Save CSV in the same directory as the script
# Load the Satellogic dataset (streaming)
dataset = ev.load_dataset("satellogic", streaming=True)
data_iter = iter(dataset)
shuffled_chunk = [] # Initialize an empty list to hold the current chunk
chunk_iter = None # Initialize the chunk iterator
# Initialize or load labels DataFrame
labels_df = None
if os.path.exists(label_file):
labels_df = pd.read_csv(label_file)
else:
labels_df = pd.DataFrame(columns=["image_id", "bounds", "rating", "google_maps_link"])
def get_next_image():
global data_iter, labels_df, shuffled_chunk, chunk_iter
while True:
# If we don't have a current chunk or it's exhausted, get a new one
if not shuffled_chunk or chunk_iter is None:
chunk = list(islice(data_iter, chunk_size))
if not chunk: # If the dataset is exhausted, reset the iterator
print("Dataset exhausted, resetting iterator.")
reset_dataset_iterator() # Use the reset function
chunk = list(islice(data_iter, chunk_size))
if not chunk:
print("Still no data after reset.")
return None, "Dataset exhausted", None, None
random.shuffle(chunk)
shuffled_chunk = chunk
chunk_iter = iter(shuffled_chunk)
try:
sample = next(chunk_iter)
sample = ev.item_to_images("satellogic", sample)
image = sample["rgb"][0]
metadata = sample["metadata"]
bounds = metadata["bounds"]
google_maps_link = utils.get_google_map_link(sample, "satellogic")
image_id = str(bounds)
if labels_df is not None and image_id not in labels_df["image_id"].values:
return image, image_id, bounds, google_maps_link
elif labels_df is None: # Handle case where labels_df is not initialized yet
return image, image_id, bounds, google_maps_link
except StopIteration:
# Current chunk is exhausted, reset chunk variables to get a new one in the next iteration
shuffled_chunk = []
chunk_iter = None
def rate_image(image_id, bounds, rating):
global labels_df
new_row = pd.DataFrame(
{
"image_id": [image_id],
"bounds": [bounds],
"rating": [rating],
"google_maps_link": [""], # this isn't necessary to pass to the function since we aren't updating it here.
}
)
labels_df = pd.concat([labels_df, new_row], ignore_index=True)
labels_df.to_csv(label_file, index=False)
next_image, next_image_id, next_bounds, next_google_maps_link = get_next_image()
return next_image, next_image_id, next_bounds, next_google_maps_link
def save_labels_parquet():
global labels_df
if labels_df is not None and not labels_df.empty:
table = pa.Table.from_pandas(labels_df)
pq.write_table(table, 'labeled_data.parquet')
return 'labeled_data.parquet'
else:
return None
def reset_dataset_iterator():
global data_iter, shuffled_chunk, chunk_iter
data_iter = iter(ev.load_dataset("satellogic", streaming=True))
shuffled_chunk = []
chunk_iter = None
def load_different_batch():
print("Loading a different batch of images...")
reset_dataset_iterator()
return get_next_image() # Return the first image from the new batch
# Gradio interface
with gr.Blocks() as iface:
image_out = gr.Image(label="Satellite Image")
image_id_out = gr.Textbox(label="Image ID", visible=False)
bounds_out = gr.Textbox(label="Bounds", visible=False)
google_maps_link_out = gr.Textbox(label="Google Maps Link", visible=True)
rating_radio = gr.Radio(["Cool", "Not Cool"], label="Rating")
with gr.Row():
submit_button = gr.Button("Submit Rating")
different_batch_button = gr.Button("Load Different Batch") # New button
download_button = gr.Button("Download Labels (Parquet)")
download_output = gr.File(label="Download Labeled Data")
submit_button.click(
fn=rate_image,
inputs=[image_id_out, bounds_out, rating_radio],
outputs=[image_out, image_id_out, bounds_out, google_maps_link_out],
)
download_button.click(
fn=save_labels_parquet,
inputs=[],
outputs=[download_output],
)
different_batch_button.click(
fn=load_different_batch,
inputs=[],
outputs=[image_out, image_id_out, bounds_out, google_maps_link_out]
)
# Load the first image
initial_image, initial_image_id, initial_bounds, initial_google_maps_link = get_next_image()
# Set initial values
if initial_image:
iface.load(lambda: (initial_image, initial_image_id, initial_bounds, initial_google_maps_link),
inputs=None,
outputs=[image_out, image_id_out, bounds_out, google_maps_link_out])
iface.launch(share=True) |