File size: 5,251 Bytes
465c443
63711ea
7d5835f
80acacc
7d5835f
63711ea
5ec1d6a
7c425bb
09f8048
 
ef924e1
09f8048
 
 
 
 
 
 
fb2c0c4
ef924e1
 
 
 
 
09f8048
 
ef924e1
09f8048
 
 
 
 
 
 
fb2c0c4
09f8048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb2c0c4
 
 
09f8048
 
 
 
 
 
ef924e1
 
09f8048
ef924e1
09f8048
ef924e1
 
 
 
09f8048
 
ef924e1
 
09f8048
 
 
7d5835f
ef924e1
 
 
 
 
 
13eef3f
 
 
fb2c0c4
 
 
 
 
 
 
 
 
 
 
13eef3f
2a6e97e
ef924e1
 
 
 
 
2a6e97e
ef924e1
fb2c0c4
ef924e1
 
2a6e97e
 
ef924e1
 
 
2a6e97e
 
 
ef924e1
 
 
2a6e97e
 
fb2c0c4
 
 
 
 
 
ef924e1
2a6e97e
 
ef924e1
 
 
 
 
2a6e97e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import earthview as ev
import utils
import random
import pandas as pd
import os
from itertools import islice

# Configuration
chunk_size = 100  # Size of the chunks to shuffle
label_file = os.path.join(os.path.dirname(__file__), "labels.csv")  # Save CSV in the same directory as the script

# Load the Satellogic dataset (streaming)
dataset = ev.load_dataset("satellogic", streaming=True)
data_iter = iter(dataset)
shuffled_chunk = []  # Initialize an empty list to hold the current chunk
chunk_iter = None  # Initialize the chunk iterator

# Initialize or load labels DataFrame
labels_df = None
if os.path.exists(label_file):
    labels_df = pd.read_csv(label_file)
else:
    labels_df = pd.DataFrame(columns=["image_id", "bounds", "rating", "google_maps_link"])

def get_next_image():
    global data_iter, labels_df, shuffled_chunk, chunk_iter

    while True:
        # If we don't have a current chunk or it's exhausted, get a new one
        if not shuffled_chunk or chunk_iter is None:
            chunk = list(islice(data_iter, chunk_size))
            if not chunk:  # If the dataset is exhausted, reset the iterator
                print("Dataset exhausted, resetting iterator.")
                reset_dataset_iterator() # Use the reset function
                chunk = list(islice(data_iter, chunk_size))
                if not chunk:
                    print("Still no data after reset.")
                    return None, "Dataset exhausted", None, None

            random.shuffle(chunk)
            shuffled_chunk = chunk
            chunk_iter = iter(shuffled_chunk)

        try:
            sample = next(chunk_iter)
            sample = ev.item_to_images("satellogic", sample)
            image = sample["rgb"][0]
            metadata = sample["metadata"]
            bounds = metadata["bounds"]
            google_maps_link = utils.get_google_map_link(sample, "satellogic")
            image_id = str(bounds)

            if labels_df is not None and image_id not in labels_df["image_id"].values:
                return image, image_id, bounds, google_maps_link
            elif labels_df is None: # Handle case where labels_df is not initialized yet
                return image, image_id, bounds, google_maps_link
        except StopIteration:
            # Current chunk is exhausted, reset chunk variables to get a new one in the next iteration
            shuffled_chunk = []
            chunk_iter = None

def rate_image(image_id, bounds, rating):
    global labels_df

    new_row = pd.DataFrame(
        {
            "image_id": [image_id],
            "bounds": [bounds],
            "rating": [rating],
            "google_maps_link": [""], # this isn't necessary to pass to the function since we aren't updating it here.
        }
    )
    labels_df = pd.concat([labels_df, new_row], ignore_index=True)
    labels_df.to_csv(label_file, index=False)

    next_image, next_image_id, next_bounds, next_google_maps_link = get_next_image()
    return next_image, next_image_id, next_bounds, next_google_maps_link

def save_labels_parquet():
    global labels_df
    if labels_df is not None and not labels_df.empty:
        table = pa.Table.from_pandas(labels_df)
        pq.write_table(table, 'labeled_data.parquet')
        return 'labeled_data.parquet'
    else:
        return None

def reset_dataset_iterator():
    global data_iter, shuffled_chunk, chunk_iter
    data_iter = iter(ev.load_dataset("satellogic", streaming=True))
    shuffled_chunk = []
    chunk_iter = None

def load_different_batch():
    print("Loading a different batch of images...")
    reset_dataset_iterator()
    return get_next_image() # Return the first image from the new batch

# Gradio interface
with gr.Blocks() as iface:
    image_out = gr.Image(label="Satellite Image")
    image_id_out = gr.Textbox(label="Image ID", visible=False)
    bounds_out = gr.Textbox(label="Bounds", visible=False)
    google_maps_link_out = gr.Textbox(label="Google Maps Link", visible=True)
    rating_radio = gr.Radio(["Cool", "Not Cool"], label="Rating")
    with gr.Row():
        submit_button = gr.Button("Submit Rating")
        different_batch_button = gr.Button("Load Different Batch") # New button
        download_button = gr.Button("Download Labels (Parquet)")
    download_output = gr.File(label="Download Labeled Data")

    submit_button.click(
        fn=rate_image,
        inputs=[image_id_out, bounds_out, rating_radio],
        outputs=[image_out, image_id_out, bounds_out, google_maps_link_out],
    )

    download_button.click(
        fn=save_labels_parquet,
        inputs=[],
        outputs=[download_output],
    )

    different_batch_button.click(
        fn=load_different_batch,
        inputs=[],
        outputs=[image_out, image_id_out, bounds_out, google_maps_link_out]
    )

     # Load the first image
    initial_image, initial_image_id, initial_bounds, initial_google_maps_link = get_next_image()

    # Set initial values
    if initial_image:
        iface.load(lambda: (initial_image, initial_image_id, initial_bounds, initial_google_maps_link),
                   inputs=None,
                   outputs=[image_out, image_id_out, bounds_out, google_maps_link_out])

iface.launch(share=True)