Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
import datetime | |
import json | |
import time | |
import uuid | |
from collections import OrderedDict | |
from datetime import datetime, timezone | |
from pathlib import Path | |
from typing import Any | |
import gradio | |
import gradio as gr | |
import huggingface_hub | |
from gradio import FlaggingCallback | |
from gradio_client import utils as client_utils | |
class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver): | |
def flag( | |
self, | |
flag_data: list[Any], | |
flag_option: str = "", | |
username: str | None = None, | |
) -> int: | |
if self.separate_dirs: | |
# JSONL files to support dataset preview on the Hub | |
current_utc_time = datetime.now(timezone.utc) | |
iso_format_without_microseconds = current_utc_time.strftime( | |
"%Y-%m-%dT%H:%M:%S" | |
) | |
milliseconds = int(current_utc_time.microsecond / 1000) | |
unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z" | |
if username not in (None, ""): | |
unique_id += f"_U_{username}" | |
else: | |
unique_id += f"_{str(uuid.uuid4())[:8]}" | |
components_dir = self.dataset_dir / unique_id | |
data_file = components_dir / "metadata.jsonl" | |
path_in_repo = unique_id # upload in sub folder (safer for concurrency) | |
else: | |
# Unique CSV file | |
components_dir = self.dataset_dir | |
data_file = components_dir / "data.csv" | |
path_in_repo = None # upload at root level | |
return self._flag_in_dir( | |
data_file=data_file, | |
components_dir=components_dir, | |
path_in_repo=path_in_repo, | |
flag_data=flag_data, | |
flag_option=flag_option, | |
username=username or "", | |
) | |
def _deserialize_components( | |
self, | |
data_dir: Path, | |
flag_data: list[Any], | |
flag_option: str = "", | |
username: str = "", | |
) -> tuple[dict[Any, Any], list[Any]]: | |
"""Deserialize components and return the corresponding row for the flagged sample. | |
Images/audio are saved to disk as individual files. | |
""" | |
# Components that can have a preview on dataset repos | |
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} | |
# Generate the row corresponding to the flagged sample | |
features = OrderedDict() | |
row = [] | |
for component, sample in zip(self.components, flag_data): | |
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-) | |
label = component.label or "" | |
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) | |
save_dir.mkdir(exist_ok=True, parents=True) | |
deserialized = component.flag(sample, save_dir) | |
# Base component .flag method returns JSON; extract path from it when it is FileData | |
if component.data_model: | |
data = component.data_model.from_json(json.loads(deserialized)) | |
if component.data_model == gr.data_classes.FileData: | |
deserialized = data.path | |
# Add deserialized object to row | |
features[label] = {"dtype": "string", "_type": "Value"} | |
try: | |
deserialized_path = Path(deserialized) | |
if not deserialized_path.exists(): | |
raise FileNotFoundError(f"File {deserialized} not found") | |
row.append(str(deserialized_path.relative_to(self.dataset_dir))) | |
except (FileNotFoundError, TypeError, ValueError): | |
deserialized = "" if deserialized is None else str(deserialized) | |
row.append(deserialized) | |
# If component is eligible for a preview, add the URL of the file | |
# Be mindful that images and audio can be None | |
if isinstance(component, tuple(file_preview_types)): # type: ignore | |
for _component, _type in file_preview_types.items(): | |
if isinstance(component, _component): | |
features[label + " file"] = {"_type": _type} | |
break | |
if deserialized: | |
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL | |
Path(deserialized).relative_to(self.dataset_dir) | |
).replace( | |
"\\", "/" | |
) | |
row.append( | |
huggingface_hub.hf_hub_url( | |
repo_id=self.dataset_id, | |
filename=path_in_repo, | |
repo_type="dataset", | |
) | |
) | |
else: | |
row.append("") | |
features["flag"] = {"dtype": "string", "_type": "Value"} | |
features["username"] = {"dtype": "string", "_type": "Value"} | |
row.append(flag_option) | |
row.append(username) | |
return features, row | |
class FlagMethod: | |
""" | |
Helper class that contains the flagging options and calls the flagging method. Also | |
provides visual feedback to the user when flag is clicked. | |
""" | |
def __init__( | |
self, | |
flagging_callback: FlaggingCallback, | |
label: str, | |
value: str, | |
visual_feedback: bool = True, | |
): | |
self.flagging_callback = flagging_callback | |
self.label = label | |
self.value = value | |
self.__name__ = "Flag" | |
self.visual_feedback = visual_feedback | |
def __call__( | |
self, | |
request: gr.Request, | |
profile: gr.OAuthProfile | None, | |
*flag_data, | |
): | |
username = None | |
if profile is not None: | |
username = profile.username | |
try: | |
self.flagging_callback.flag( | |
list(flag_data), flag_option=self.value, username=username | |
) | |
except Exception as e: | |
print(f"Error while sharing: {e}") | |
if self.visual_feedback: | |
return gr.Button(value="Sharing error", interactive=False) | |
if not self.visual_feedback: | |
return | |
time.sleep(0.8) # to provide enough time for the user to observe button change | |
return gr.Button(value="Sharing complete", interactive=False) |