from collections import OrderedDict from pathlib import Path from typing import Any import gradio as gr from gradio.flagging import HuggingFaceDatasetSaver, client_utils from gradio import utils import huggingface_hub class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver): """ Custom HuggingFaceDatasetSaver to save images/audio to disk. Gradio's implementation seems to have a bug. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _deserialize_components( self, data_dir: Path, flag_data: list[Any], flag_option: str = "", username: str = "", ) -> tuple[dict[Any, Any], list[Any]]: """Deserialize components and return the corresponding row for the flagged sample. Images/audio are saved to disk as individual files. """ # Components that can have a preview on dataset repos file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"} # Generate the row corresponding to the flagged sample features = OrderedDict() row = [] for component, sample in zip(self.components, flag_data): # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-) label = component.label or "" save_dir = data_dir / client_utils.strip_invalid_filename_characters(label) save_dir.mkdir(exist_ok=True, parents=True) if isinstance(component, gr.Chatbot): deserialized = sample # dirty fix else: deserialized = utils.simplify_file_data_in_str( component.flag(sample, save_dir) ) # Add deserialized object to row features[label] = {"dtype": "string", "_type": "Value"} try: deserialized_path = Path(deserialized) if not deserialized_path.exists(): raise FileNotFoundError(f"File {deserialized} not found") row.append(str(deserialized_path.relative_to(self.dataset_dir))) except (AssertionError, TypeError, ValueError, OSError): deserialized = "" if deserialized is None else str(deserialized) row.append(deserialized) # If component is eligible for a preview, add the URL of the file # Be mindful that images and audio can be None if isinstance(component, tuple(file_preview_types)): # type: ignore for _component, _type in file_preview_types.items(): if isinstance(component, _component): features[label + " file"] = {"_type": _type} break if deserialized: path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL Path(deserialized).relative_to(self.dataset_dir) ).replace("\\", "/") row.append( huggingface_hub.hf_hub_url( repo_id=self.dataset_id, filename=path_in_repo, repo_type="dataset", ) ) else: row.append("") features["flag"] = {"dtype": "string", "_type": "Value"} features["username"] = {"dtype": "string", "_type": "Value"} row.append(flag_option) row.append(username) return features, row