from gradio.flagging import FlaggingCallback, _get_dataset_features_info from gradio.components import IOComponent from gradio import utils from typing import Any, List, Optional from dotenv import load_dotenv from datetime import datetime import csv, os, pytz # --- Load environments vars --- load_dotenv() # --- Classes declaration --- class DateLogs: def __init__( self, zone: str="America/Argentina/Cordoba" ) -> None: self.time_zone = pytz.timezone(zone) def full( self ) -> str: now = datetime.now(self.time_zone) return now.strftime("%H:%M:%S %d-%m-%Y") def day( self ) -> str: now = datetime.now(self.time_zone) return now.strftime("%d-%m-%Y") class HuggingFaceDatasetSaver(FlaggingCallback): """ A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset. Example: import gradio as gr hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes") def image_classifier(inp): return {'cat': 0.3, 'dog': 0.7} demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label", allow_flagging="manual", flagging_callback=hf_writer) Guides: using_flagging """ def __init__( self, dataset_name: str=None, hf_token: str=os.getenv('HF_TOKEN'), organization: Optional[str]=os.getenv('ORG_NAME'), private: bool=True, available_logs: bool=False ) -> None: """ Parameters: hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset. dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1" organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token. private: Whether the dataset should be private (defaults to False). """ assert(dataset_name is not None), "Error: Parameter 'dataset_name' cannot be empty!." self.hf_token = hf_token self.dataset_name = dataset_name self.organization_name = organization self.dataset_private = private self.datetime = DateLogs() self.available_logs = available_logs if not available_logs: print("Push: logs DISABLED!...") def setup( self, components: List[IOComponent], flagging_dir: str ) -> None: """ Params: flagging_dir (str): local directory where the dataset is cloned, updated, and pushed from. """ if self.available_logs: try: import huggingface_hub except (ImportError, ModuleNotFoundError): raise ImportError( "Package `huggingface_hub` not found is needed " "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'." ) path_to_dataset_repo = huggingface_hub.create_repo( repo_id=os.path.join(self.organization_name, self.dataset_name), token=self.hf_token, private=self.dataset_private, repo_type="dataset", exist_ok=True, ) self.path_to_dataset_repo = path_to_dataset_repo self.components = components self.flagging_dir = flagging_dir self.dataset_dir = self.dataset_name self.repo = huggingface_hub.Repository( local_dir=self.dataset_dir, clone_from=path_to_dataset_repo, use_auth_token=self.hf_token, ) self.repo.git_pull(lfs=True) # Should filename be user-specified? # log_file_name = self.datetime.day()+"_"+self.flagging_dir+".csv" self.log_file = os.path.join(self.dataset_dir, self.flagging_dir+".csv") def flag( self, flag_data: List[Any], flag_option: Optional[str]=None, flag_index: Optional[int]=None, username: Optional[str]=None, ) -> int: if self.available_logs: self.repo.git_pull(lfs=True) is_new = not os.path.exists(self.log_file) with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile: writer = csv.writer(csvfile) # File previews for certain input and output types infos, file_preview_types, headers = _get_dataset_features_info( is_new, self.components ) # Generate the headers and dataset_infos if is_new: headers = [ component.label or f"component {idx}" for idx, component in enumerate(self.components) ] + [ "flag", "username", "timestamp", ] writer.writerow(utils.sanitize_list_for_csv(headers)) # Generate the row corresponding to the flagged sample csv_data = [] for component, sample in zip(self.components, flag_data): save_dir = os.path.join( self.dataset_dir, utils.strip_invalid_filename_characters(component.label), ) filepath = component.deserialize(sample, save_dir, None) csv_data.append(filepath) if isinstance(component, tuple(file_preview_types)): csv_data.append( "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath) ) csv_data.append(flag_option if flag_option is not None else "") csv_data.append(username if username is not None else "") csv_data.append(self.datetime.full()) writer.writerow(utils.sanitize_list_for_csv(csv_data)) with open(self.log_file, "r", encoding="utf-8") as csvfile: line_count = len([None for row in csv.reader(csvfile)]) - 1 self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count)) else: line_count = 0 print("Logs: Virtual push...") return line_count