guangkaixu commited on
Commit
562c833
·
1 Parent(s): 1722ece
gradio_patches/examples.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import gradio
4
+ from gradio.utils import get_cache_folder
5
+
6
+
7
+ class Examples(gradio.helpers.Examples):
8
+ def __init__(self, *args, directory_name=None, **kwargs):
9
+ super().__init__(*args, **kwargs, _initiated_directly=False)
10
+ if directory_name is not None:
11
+ self.cached_folder = get_cache_folder() / directory_name
12
+ self.cached_file = Path(self.cached_folder) / "log.csv"
13
+ self.create()
gradio_patches/flagging.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import json
5
+ import time
6
+ import uuid
7
+ from collections import OrderedDict
8
+ from datetime import datetime, timezone
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import gradio
13
+ import gradio as gr
14
+ import huggingface_hub
15
+ from gradio import FlaggingCallback
16
+ from gradio_client import utils as client_utils
17
+
18
+
19
+ class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
20
+ def flag(
21
+ self,
22
+ flag_data: list[Any],
23
+ flag_option: str = "",
24
+ username: str | None = None,
25
+ ) -> int:
26
+ if self.separate_dirs:
27
+ # JSONL files to support dataset preview on the Hub
28
+ current_utc_time = datetime.now(timezone.utc)
29
+ iso_format_without_microseconds = current_utc_time.strftime(
30
+ "%Y-%m-%dT%H:%M:%S"
31
+ )
32
+ milliseconds = int(current_utc_time.microsecond / 1000)
33
+ unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
34
+ if username not in (None, ""):
35
+ unique_id += f"_U_{username}"
36
+ else:
37
+ unique_id += f"_{str(uuid.uuid4())[:8]}"
38
+ components_dir = self.dataset_dir / unique_id
39
+ data_file = components_dir / "metadata.jsonl"
40
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
41
+ else:
42
+ # Unique CSV file
43
+ components_dir = self.dataset_dir
44
+ data_file = components_dir / "data.csv"
45
+ path_in_repo = None # upload at root level
46
+
47
+ return self._flag_in_dir(
48
+ data_file=data_file,
49
+ components_dir=components_dir,
50
+ path_in_repo=path_in_repo,
51
+ flag_data=flag_data,
52
+ flag_option=flag_option,
53
+ username=username or "",
54
+ )
55
+
56
+ def _deserialize_components(
57
+ self,
58
+ data_dir: Path,
59
+ flag_data: list[Any],
60
+ flag_option: str = "",
61
+ username: str = "",
62
+ ) -> tuple[dict[Any, Any], list[Any]]:
63
+ """Deserialize components and return the corresponding row for the flagged sample.
64
+ Images/audio are saved to disk as individual files.
65
+ """
66
+ # Components that can have a preview on dataset repos
67
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
68
+
69
+ # Generate the row corresponding to the flagged sample
70
+ features = OrderedDict()
71
+ row = []
72
+ for component, sample in zip(self.components, flag_data):
73
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
74
+ label = component.label or ""
75
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
76
+ save_dir.mkdir(exist_ok=True, parents=True)
77
+ deserialized = component.flag(sample, save_dir)
78
+
79
+ # Base component .flag method returns JSON; extract path from it when it is FileData
80
+ if component.data_model:
81
+ data = component.data_model.from_json(json.loads(deserialized))
82
+ if component.data_model == gr.data_classes.FileData:
83
+ deserialized = data.path
84
+
85
+ # Add deserialized object to row
86
+ features[label] = {"dtype": "string", "_type": "Value"}
87
+ try:
88
+ deserialized_path = Path(deserialized)
89
+ if not deserialized_path.exists():
90
+ raise FileNotFoundError(f"File {deserialized} not found")
91
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
92
+ except (FileNotFoundError, TypeError, ValueError):
93
+ deserialized = "" if deserialized is None else str(deserialized)
94
+ row.append(deserialized)
95
+
96
+ # If component is eligible for a preview, add the URL of the file
97
+ # Be mindful that images and audio can be None
98
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
99
+ for _component, _type in file_preview_types.items():
100
+ if isinstance(component, _component):
101
+ features[label + " file"] = {"_type": _type}
102
+ break
103
+ if deserialized:
104
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
105
+ Path(deserialized).relative_to(self.dataset_dir)
106
+ ).replace(
107
+ "\\", "/"
108
+ )
109
+ row.append(
110
+ huggingface_hub.hf_hub_url(
111
+ repo_id=self.dataset_id,
112
+ filename=path_in_repo,
113
+ repo_type="dataset",
114
+ )
115
+ )
116
+ else:
117
+ row.append("")
118
+ features["flag"] = {"dtype": "string", "_type": "Value"}
119
+ features["username"] = {"dtype": "string", "_type": "Value"}
120
+ row.append(flag_option)
121
+ row.append(username)
122
+ return features, row
123
+
124
+
125
+ class FlagMethod:
126
+ """
127
+ Helper class that contains the flagging options and calls the flagging method. Also
128
+ provides visual feedback to the user when flag is clicked.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ flagging_callback: FlaggingCallback,
134
+ label: str,
135
+ value: str,
136
+ visual_feedback: bool = True,
137
+ ):
138
+ self.flagging_callback = flagging_callback
139
+ self.label = label
140
+ self.value = value
141
+ self.__name__ = "Flag"
142
+ self.visual_feedback = visual_feedback
143
+
144
+ def __call__(
145
+ self,
146
+ request: gr.Request,
147
+ profile: gr.OAuthProfile | None,
148
+ *flag_data,
149
+ ):
150
+ username = None
151
+ if profile is not None:
152
+ username = profile.username
153
+ try:
154
+ self.flagging_callback.flag(
155
+ list(flag_data), flag_option=self.value, username=username
156
+ )
157
+ except Exception as e:
158
+ print(f"Error while sharing: {e}")
159
+ if self.visual_feedback:
160
+ return gr.Button(value="Sharing error", interactive=False)
161
+ if not self.visual_feedback:
162
+ return
163
+ time.sleep(0.8) # to provide enough time for the user to observe button change
164
+ return gr.Button(value="Sharing complete", interactive=False)