detection-demo / flagging.py
kevinconka's picture
save flagged data to HF dataset
955daea
raw
history blame
3.38 kB
import json
from collections import OrderedDict
from pathlib import Path
from typing import Any
import gradio as gr
from gradio.flagging import HuggingFaceDatasetSaver, client_utils
import huggingface_hub
class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
"""
Custom HuggingFaceDatasetSaver to save images/audio to disk.
Gradio's implementation seems to have a bug.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _deserialize_components(
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
# Generate the row corresponding to the flagged sample
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
label = component.label or ""
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
save_dir.mkdir(exist_ok=True, parents=True)
deserialized = component.flag(sample, save_dir)
if isinstance(component, gr.Image) and isinstance(sample, dict):
deserialized = json.loads(deserialized)['path'] # dirty hack
# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
try:
assert Path(deserialized).exists()
row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
except (AssertionError, TypeError, ValueError):
deserialized = "" if deserialized is None else str(deserialized)
row.append(deserialized)
# If component is eligible for a preview, add the URL of the file
# Be mindful that images and audio can be None
if isinstance(component, tuple(file_preview_types)): # type: ignore
for _component, _type in file_preview_types.items():
if isinstance(component, _component):
features[label + " file"] = {"_type": _type}
break
if deserialized:
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
Path(deserialized).relative_to(self.dataset_dir)
).replace("\\", "/")
row.append(
huggingface_hub.hf_hub_url(
repo_id=self.dataset_id,
filename=path_in_repo,
repo_type="dataset",
)
)
else:
row.append("")
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row