Spaces:
Running
on
Zero
Running
on
Zero
guangkaixu
commited on
Commit
·
562c833
1
Parent(s):
1722ece
upload
Browse files- gradio_patches/examples.py +13 -0
- gradio_patches/flagging.py +164 -0
gradio_patches/examples.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import gradio
|
4 |
+
from gradio.utils import get_cache_folder
|
5 |
+
|
6 |
+
|
7 |
+
class Examples(gradio.helpers.Examples):
|
8 |
+
def __init__(self, *args, directory_name=None, **kwargs):
|
9 |
+
super().__init__(*args, **kwargs, _initiated_directly=False)
|
10 |
+
if directory_name is not None:
|
11 |
+
self.cached_folder = get_cache_folder() / directory_name
|
12 |
+
self.cached_file = Path(self.cached_folder) / "log.csv"
|
13 |
+
self.create()
|
gradio_patches/flagging.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
import uuid
|
7 |
+
from collections import OrderedDict
|
8 |
+
from datetime import datetime, timezone
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any
|
11 |
+
|
12 |
+
import gradio
|
13 |
+
import gradio as gr
|
14 |
+
import huggingface_hub
|
15 |
+
from gradio import FlaggingCallback
|
16 |
+
from gradio_client import utils as client_utils
|
17 |
+
|
18 |
+
|
19 |
+
class HuggingFaceDatasetSaver(gradio.HuggingFaceDatasetSaver):
|
20 |
+
def flag(
|
21 |
+
self,
|
22 |
+
flag_data: list[Any],
|
23 |
+
flag_option: str = "",
|
24 |
+
username: str | None = None,
|
25 |
+
) -> int:
|
26 |
+
if self.separate_dirs:
|
27 |
+
# JSONL files to support dataset preview on the Hub
|
28 |
+
current_utc_time = datetime.now(timezone.utc)
|
29 |
+
iso_format_without_microseconds = current_utc_time.strftime(
|
30 |
+
"%Y-%m-%dT%H:%M:%S"
|
31 |
+
)
|
32 |
+
milliseconds = int(current_utc_time.microsecond / 1000)
|
33 |
+
unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
|
34 |
+
if username not in (None, ""):
|
35 |
+
unique_id += f"_U_{username}"
|
36 |
+
else:
|
37 |
+
unique_id += f"_{str(uuid.uuid4())[:8]}"
|
38 |
+
components_dir = self.dataset_dir / unique_id
|
39 |
+
data_file = components_dir / "metadata.jsonl"
|
40 |
+
path_in_repo = unique_id # upload in sub folder (safer for concurrency)
|
41 |
+
else:
|
42 |
+
# Unique CSV file
|
43 |
+
components_dir = self.dataset_dir
|
44 |
+
data_file = components_dir / "data.csv"
|
45 |
+
path_in_repo = None # upload at root level
|
46 |
+
|
47 |
+
return self._flag_in_dir(
|
48 |
+
data_file=data_file,
|
49 |
+
components_dir=components_dir,
|
50 |
+
path_in_repo=path_in_repo,
|
51 |
+
flag_data=flag_data,
|
52 |
+
flag_option=flag_option,
|
53 |
+
username=username or "",
|
54 |
+
)
|
55 |
+
|
56 |
+
def _deserialize_components(
|
57 |
+
self,
|
58 |
+
data_dir: Path,
|
59 |
+
flag_data: list[Any],
|
60 |
+
flag_option: str = "",
|
61 |
+
username: str = "",
|
62 |
+
) -> tuple[dict[Any, Any], list[Any]]:
|
63 |
+
"""Deserialize components and return the corresponding row for the flagged sample.
|
64 |
+
Images/audio are saved to disk as individual files.
|
65 |
+
"""
|
66 |
+
# Components that can have a preview on dataset repos
|
67 |
+
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
|
68 |
+
|
69 |
+
# Generate the row corresponding to the flagged sample
|
70 |
+
features = OrderedDict()
|
71 |
+
row = []
|
72 |
+
for component, sample in zip(self.components, flag_data):
|
73 |
+
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
|
74 |
+
label = component.label or ""
|
75 |
+
save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
|
76 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
77 |
+
deserialized = component.flag(sample, save_dir)
|
78 |
+
|
79 |
+
# Base component .flag method returns JSON; extract path from it when it is FileData
|
80 |
+
if component.data_model:
|
81 |
+
data = component.data_model.from_json(json.loads(deserialized))
|
82 |
+
if component.data_model == gr.data_classes.FileData:
|
83 |
+
deserialized = data.path
|
84 |
+
|
85 |
+
# Add deserialized object to row
|
86 |
+
features[label] = {"dtype": "string", "_type": "Value"}
|
87 |
+
try:
|
88 |
+
deserialized_path = Path(deserialized)
|
89 |
+
if not deserialized_path.exists():
|
90 |
+
raise FileNotFoundError(f"File {deserialized} not found")
|
91 |
+
row.append(str(deserialized_path.relative_to(self.dataset_dir)))
|
92 |
+
except (FileNotFoundError, TypeError, ValueError):
|
93 |
+
deserialized = "" if deserialized is None else str(deserialized)
|
94 |
+
row.append(deserialized)
|
95 |
+
|
96 |
+
# If component is eligible for a preview, add the URL of the file
|
97 |
+
# Be mindful that images and audio can be None
|
98 |
+
if isinstance(component, tuple(file_preview_types)): # type: ignore
|
99 |
+
for _component, _type in file_preview_types.items():
|
100 |
+
if isinstance(component, _component):
|
101 |
+
features[label + " file"] = {"_type": _type}
|
102 |
+
break
|
103 |
+
if deserialized:
|
104 |
+
path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
|
105 |
+
Path(deserialized).relative_to(self.dataset_dir)
|
106 |
+
).replace(
|
107 |
+
"\\", "/"
|
108 |
+
)
|
109 |
+
row.append(
|
110 |
+
huggingface_hub.hf_hub_url(
|
111 |
+
repo_id=self.dataset_id,
|
112 |
+
filename=path_in_repo,
|
113 |
+
repo_type="dataset",
|
114 |
+
)
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
row.append("")
|
118 |
+
features["flag"] = {"dtype": "string", "_type": "Value"}
|
119 |
+
features["username"] = {"dtype": "string", "_type": "Value"}
|
120 |
+
row.append(flag_option)
|
121 |
+
row.append(username)
|
122 |
+
return features, row
|
123 |
+
|
124 |
+
|
125 |
+
class FlagMethod:
|
126 |
+
"""
|
127 |
+
Helper class that contains the flagging options and calls the flagging method. Also
|
128 |
+
provides visual feedback to the user when flag is clicked.
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
flagging_callback: FlaggingCallback,
|
134 |
+
label: str,
|
135 |
+
value: str,
|
136 |
+
visual_feedback: bool = True,
|
137 |
+
):
|
138 |
+
self.flagging_callback = flagging_callback
|
139 |
+
self.label = label
|
140 |
+
self.value = value
|
141 |
+
self.__name__ = "Flag"
|
142 |
+
self.visual_feedback = visual_feedback
|
143 |
+
|
144 |
+
def __call__(
|
145 |
+
self,
|
146 |
+
request: gr.Request,
|
147 |
+
profile: gr.OAuthProfile | None,
|
148 |
+
*flag_data,
|
149 |
+
):
|
150 |
+
username = None
|
151 |
+
if profile is not None:
|
152 |
+
username = profile.username
|
153 |
+
try:
|
154 |
+
self.flagging_callback.flag(
|
155 |
+
list(flag_data), flag_option=self.value, username=username
|
156 |
+
)
|
157 |
+
except Exception as e:
|
158 |
+
print(f"Error while sharing: {e}")
|
159 |
+
if self.visual_feedback:
|
160 |
+
return gr.Button(value="Sharing error", interactive=False)
|
161 |
+
if not self.visual_feedback:
|
162 |
+
return
|
163 |
+
time.sleep(0.8) # to provide enough time for the user to observe button change
|
164 |
+
return gr.Button(value="Sharing complete", interactive=False)
|