Spaces:
Running
Running
import json | |
from collections import OrderedDict | |
from pathlib import Path | |
from typing import Any | |
import gradio as gr | |
from gradio.flagging import HuggingFaceDatasetSaver, client_utils | |
import huggingface_hub | |
class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver): | |
""" | |
Custom HuggingFaceDatasetSaver to save images/audio to disk. | |
Gradio's implementation seems to have a bug. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def _deserialize_components( | |
self, | |
data_dir: Path, | |
flag_data: list[Any], | |
flag_option: str = "", | |
username: str = "", | |
) -> tuple[dict[Any, Any], list[Any]]: | |
"""Deserialize components and return the corresponding row for the flagged sample. | |
Images/audio are saved to disk as individual files. | |
""" | |
# Components that can have a preview on dataset repos | |
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
# Generate the row corresponding to the flagged sample | |
features = OrderedDict() | |
row = [] | |
for component, sample in zip(self.components, flag_data): | |
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-) | |
label = component.label or "" | |
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) | |
save_dir.mkdir(exist_ok=True, parents=True) | |
deserialized = component.flag(sample, save_dir) | |
if isinstance(component, gr.Image) and isinstance(sample, dict): | |
deserialized = json.loads(deserialized)['path'] # dirty hack | |
# Add deserialized object to row | |
features[label] = {"dtype": "string", "_type": "Value"} | |
try: | |
assert Path(deserialized).exists() | |
row.append(str(Path(deserialized).relative_to(self.dataset_dir))) | |
except (AssertionError, TypeError, ValueError): | |
deserialized = "" if deserialized is None else str(deserialized) | |
row.append(deserialized) | |
# If component is eligible for a preview, add the URL of the file | |
# Be mindful that images and audio can be None | |
if isinstance(component, tuple(file_preview_types)): # type: ignore | |
for _component, _type in file_preview_types.items(): | |
if isinstance(component, _component): | |
features[label + " file"] = {"_type": _type} | |
break | |
if deserialized: | |
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL | |
Path(deserialized).relative_to(self.dataset_dir) | |
).replace("\\", "/") | |
row.append( | |
huggingface_hub.hf_hub_url( | |
repo_id=self.dataset_id, | |
filename=path_in_repo, | |
repo_type="dataset", | |
) | |
) | |
else: | |
row.append("") | |
features["flag"] = {"dtype": "string", "_type": "Value"} | |
features["username"] = {"dtype": "string", "_type": "Value"} | |
row.append(flag_option) | |
row.append(username) | |
return features, row |