explore-label-concepts / src /label_interface.py
Xmaster6y's picture
bug
24b09f4 unverified
raw
history blame
6.96 kB
"""Interface for labeling concepts in images.
"""
from typing import Optional
import random
import gradio as gr
from src import global_variables
from src.constants import CONCEPTS, ASSETS_FOLDER, DATASET_NAME
def filter_sample(sample, concepts, username, sample_type):
has_concepts = all([sample["concepts"].get(c, False) for c in concepts])
if not has_concepts:
return False
if "votes" in sample and username in sample["votes"]:
is_labelled = all([c in sample["votes"][username] for c in CONCEPTS])
else:
is_labelled = False
if sample_type == "labelled":
return is_labelled
elif sample_type == "unlabelled":
return not is_labelled
else:
raise ValueError(f"Invalid sample type: {sample_type}")
def get_next_image(
split: str,
concepts: list,
sample_type: str,
filtered_indices: dict,
selected_concepts: list,
selected_sample_type: str,
profile: gr.OAuthProfile
):
username = profile.username
if concepts != selected_concepts or sample_type != selected_sample_type:
for key, values in global_variables.all_metadata.items():
filtered_indices[key] = [i for i in range(len(values)) if filter_sample(values[i], concepts, username, sample_type)]
selected_concepts = concepts
selected_sample_type = sample_type
try:
sample_idx = random.choice(filtered_indices[split])
sample = global_variables.all_metadata[split][sample_idx]
image_path = f"{ASSETS_FOLDER}/{DATASET_NAME}/data/{split}/{sample['file_name']}"
try:
username_votes = sample["votes"][username]
voted_concepts = [c for c in CONCEPTS if username_votes.get(c, False)]
unseen_concepts = [c for c in CONCEPTS if c not in username_votes]
except KeyError:
voted_concepts = []
unseen_concepts = []
tie_concepts = [c for c in sample["concepts"] if sample["concepts"][c] is None]
return (
image_path,
voted_concepts,
f"{split}:{sample_idx}",
sample["class"],
sample["concepts"],
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
)
except IndexError:
gr.Warning("No image found for the selected filter.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
def submit_label(
voted_concepts: list,
current_image: Optional[str],
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile: gr.OAuthProfile
):
username = profile.username
if current_image is None:
gr.Warning("No image selected.")
return None, None, None, None, None, None, None, filtered_indices, selected_concepts, selected_sample_type
current_split, idx = current_image.split(":")
idx = int(idx)
global_variables.get_metadata(current_split)
if "votes" not in global_variables.all_metadata[current_split][idx]:
global_variables.all_metadata[current_split][idx]["votes"] = {}
global_variables.all_metadata[current_split][idx]["votes"][username] = {c: c in voted_concepts for c in CONCEPTS}
vote_sum = {c: 0 for c in CONCEPTS}
new_concepts = {}
for c in CONCEPTS:
for vote in global_variables.all_metadata[current_split][idx]["votes"].values():
if c not in vote:
continue
vote_sum[c] += 2 * vote[c] - 1
new_concepts[c] = vote_sum[c] > 0 if vote_sum[c] != 0 else None
global_variables.all_metadata[current_split][idx]["concepts"] = new_concepts
global_variables.save_metadata(current_split)
gr.Info("Submit success")
return get_next_image(
split,
concepts,
sample_type,
filtered_indices,
selected_concepts,
selected_sample_type,
profile
)
with gr.Blocks() as interface:
with gr.Row():
with gr.Column():
with gr.Group():
gr.Markdown(
"## # Image Selection",
)
with gr.Row():
split = gr.Radio(
label="Split",
choices=["train", "validation", "test"],
value="train",
)
sample_type = gr.Radio(
label="Sample Type",
choices=["labelled", "unlabelled"],
value="unlabelled",
)
concepts = gr.Dropdown(
label="Concepts",
multiselect=True,
choices=CONCEPTS,
)
with gr.Group():
voted_concepts = gr.CheckboxGroup(
label="Voted Concepts",
choices=CONCEPTS,
)
unseen_concepts = gr.CheckboxGroup(
label="Previously Unseen Concepts",
choices=CONCEPTS,
)
tie_concepts = gr.CheckboxGroup(
label="Tie Concepts",
choices=CONCEPTS,
)
with gr.Row():
next_button = gr.Button(
value="Next",
)
gr.LoginButton()
submit_button = gr.Button(
value="Submit",
)
with gr.Group():
gr.Markdown(
"## # Image Info",
)
im_class = gr.Textbox(
label="Class",
)
im_concepts = gr.JSON(
label="Concepts",
)
with gr.Column():
image = gr.Image(
label="Image",
)
current_image = gr.State(None)
filtered_indices = gr.State({
split: list(range(len(global_variables.all_metadata[split])))
for split in global_variables.all_metadata
})
selected_concepts = gr.State([])
selected_sample_type = gr.State(None)
common_output = [
image,
voted_concepts,
current_image,
im_class,
im_concepts,
unseen_concepts,
tie_concepts,
filtered_indices,
selected_concepts,
selected_sample_type,
]
next_button.click(
get_next_image,
inputs=[split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)
submit_button.click(
submit_label,
inputs=[voted_concepts, current_image, split, concepts, sample_type, filtered_indices, selected_concepts, selected_sample_type],
outputs=common_output
)