File size: 5,775 Bytes
d992c15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os

import datasets
import fuego
import gradio as gr
from datasets import load_dataset
from huggingface_hub import HfFolder, create_repo, delete_repo, login
from PIL import Image


datasets.disable_caching()

login(token=os.getenv("HUGGING_FACE_HUB_TOKEN", HfFolder.get_token()), add_to_git_credential=True)

labeled_samples_repo_id = create_repo("actlearn_labeled_samples", exist_ok=True, repo_type="dataset").repo_id
unlabled_samples_repo_id = create_repo("actlearn_unlabeled_samples", exist_ok=True, repo_type="dataset").repo_id
to_label_samples_repo_id = create_repo("actlearn_to_label_samples", exist_ok=True, repo_type="dataset").repo_id
test_dataset_repo_id = create_repo("actlearn_test_mnist", exist_ok=True, repo_type="dataset").repo_id
model_repo_id = create_repo("actlearn_mnist_model", exist_ok=True).repo_id


idx = 0
try:
    data_to_label = load_dataset(to_label_samples_repo_id)
    imgs = data_to_label["train"]["image"]
except:
    imgs = None
    data_to_label = None


def get_image():
    global idx
    if imgs is None:
        return None
    new_img = imgs[idx]
    idx += 1
    return new_img


labeled_data = []

information = """# Active Learning Demo
This demo showcases Active Learning, which is great when labeling is expensive. In this demo, you will label images by choosing a digit (0-9).
How does this work?
* There is a large pool of unlabeled images
* A model is trained with the few labeled images
* We can then use the model to pick the images with the lowest confidence or with the lowest probability of corresponding to an image. These are the images for which the model is confused, so by improving them, the quality of the model can improve much more than queries for which the model was already doing well!
* In this UI, you will be provided a couple of images to label
* Once all the provided images are labeled, the model is retrained, and a new set of images is chosen!
"""

training_info = """## Model Retraining
There are new labeled images. The model is retraining. Follow progress in the "fuego" space that was spun up for you in your profile.
"""

with gr.Blocks() as demo:
    gr.Markdown(information)

    img_to_label = gr.Image(shape=[28, 28], value=get_image(), visible=True if imgs is not None else False)
    label_dropdown = gr.Dropdown(
        choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], interactive=True, value=0, visible=True if imgs is not None else False
    )
    save_btn = gr.Button("Save label", visible=True if imgs is not None else False)
    output_box = gr.Markdown(value=training_info, visible=False)
    reload_btn = gr.Button("Reload", visible=False if imgs is not None else True)

    def save_data(img, label):
        global labeled_data
        global idx

        labeled_data.append([img, label])

        if imgs is not None and len(imgs) == idx:
            # Remove dataset of queries to label
            # datasets library does not allow pushing an empty dataset, so as a
            # workaround we just delete the repo
            delete_repo(repo_id=to_label_samples_repo_id, repo_type="dataset")
            create_repo(repo_id=to_label_samples_repo_id, repo_type="dataset")

            # Push to training dataset
            labeled_dataset = load_dataset(labeled_samples_repo_id)["train"]
            feature = datasets.Image(decode=False)
            for img, label in labeled_data:
                # Hack due to https://github.com/huggingface/datasets/issues/4796
                labeled_dataset = labeled_dataset.add_item(
                    {"image": feature.encode_example(Image.fromarray(img)), "label": label}
                )
            labeled_dataset.push_to_hub(labeled_samples_repo_id)

            # Clean up data
            labeled_data = []
            idx = 0

            fuego.run("training/run.py", "training/requirements.txt", space_id="actlearn-fuego-runner")

            # Update UI
            return {
                img_to_label: gr.update(visible=False),
                label_dropdown: gr.update(visible=False),
                save_btn: gr.update(visible=False),
                output_box: gr.update(visible=True, value=training_info),
                reload_btn: gr.update(visible=True),
            }
        else:
            return {img_to_label: gr.update(value=get_image())}

    def reload_data():
        global data_to_label
        global imgs
        try:
            # See if there is new data to be labeled
            data_to_label = load_dataset(to_label_samples_repo_id)
            imgs = data_to_label["train"]["image"]
        except Exception:
            imgs = None
            data_to_label = None
            return {
                img_to_label: gr.update(visible=False, value=None),
                label_dropdown: gr.update(visible=False),
                save_btn: gr.update(visible=False),
                output_box: gr.update(visible=True, value="No more images to label"),
                reload_btn: gr.update(visible=True),
            }

        if len(imgs) == 0:
            return
        else:
            global idx
            idx = 0
            return {
                img_to_label: gr.update(visible=True, value=get_image()),
                label_dropdown: gr.update(visible=True),
                save_btn: gr.update(visible=True),
                output_box: gr.update(visible=False),
                reload_btn: gr.update(visible=False),
            }

    save_btn.click(
        save_data,
        inputs=[img_to_label, label_dropdown],
        outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn],
    )

    reload_btn.click(reload_data, outputs=[img_to_label, label_dropdown, save_btn, output_box, reload_btn])


if __name__ == "__main__":
    demo.launch(debug=True)