kevinconka commited on
Commit
7091ddc
·
1 Parent(s): 94be7bc

updated flagger and gradio

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +13 -15
  3. flagging.py +291 -14
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 💬
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
4
  colorFrom: blue
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.0.0b3
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py CHANGED
@@ -3,8 +3,7 @@ import gradio as gr
3
  from huggingface_hub import get_token
4
 
5
  from chatbot import get_retrieval_qa
6
- from flagging import myHuggingFaceDatasetSaver as HuggingFaceDatasetSaver
7
- #from gradio.flagging import HuggingFaceDatasetSaver
8
 
9
 
10
  # get the html data and save it to a file
@@ -27,6 +26,7 @@ hf_writer = HuggingFaceDatasetSaver(get_token(), dataset_name)
27
 
28
 
29
  def answer_question(message, history, system):
 
30
  # concatenate the history, message and system
31
  query = " ".join([message, system])
32
  retrieval_qa = qa.invoke(query)
@@ -34,7 +34,7 @@ def answer_question(message, history, system):
34
  result = result.replace('"', "").strip() # clean up the result
35
 
36
  # save the query and result to the dataset
37
- hf_writer.flag(flag_data=[query, result])
38
  return result
39
 
40
 
@@ -62,15 +62,15 @@ theme = gr.themes.Default(primary_hue=gr.themes.colors.indigo)
62
 
63
  chatbot = gr.Chatbot(
64
  value=[
65
- [
66
- None,
67
- "I have memorized the entire SEA.AI FAQ page. Ask me anything about it! 🧠",
68
- ],
69
  ],
70
  label="SEA Dog",
 
71
  show_label=False,
72
  show_copy_button=True,
73
- likeable=True,
74
  )
75
 
76
 
@@ -80,19 +80,17 @@ def on_like(evt: gr.LikeData):
80
 
81
  with gr.ChatInterface(
82
  answer_question,
 
83
  chatbot=chatbot,
84
  title=title,
85
  description=description,
86
  additional_inputs=[gr.Textbox("", label="SYSTEM")],
87
- examples=[
88
- ["Can SEA.AI see at night?", "You are a helpful assistant."],
89
- ["Can SEA.AI see at night?", "Reply with sailor slang."],
90
- ],
91
  cache_examples=False,
92
  submit_btn=None,
93
- retry_btn=None,
94
- undo_btn=None,
95
- clear_btn=None,
96
  css=css,
97
  theme=theme,
98
  ) as demo:
 
3
  from huggingface_hub import get_token
4
 
5
  from chatbot import get_retrieval_qa
6
+ from flagging import HuggingFaceDatasetSaver
 
7
 
8
 
9
  # get the html data and save it to a file
 
26
 
27
 
28
  def answer_question(message, history, system):
29
+ print(f"{message=}, {history=}, {system=}")
30
  # concatenate the history, message and system
31
  query = " ".join([message, system])
32
  retrieval_qa = qa.invoke(query)
 
34
  result = result.replace('"', "").strip() # clean up the result
35
 
36
  # save the query and result to the dataset
37
+ hf_writer.flag(flag_data=[query, [dict(role="assistant", content=result)]])
38
  return result
39
 
40
 
 
62
 
63
  chatbot = gr.Chatbot(
64
  value=[
65
+ gr.ChatMessage(
66
+ role="assistant",
67
+ content="I have memorized the entire SEA.AI FAQ page. Ask me anything about it! 🧠",
68
+ ),
69
  ],
70
  label="SEA Dog",
71
+ type="messages",
72
  show_label=False,
73
  show_copy_button=True,
 
74
  )
75
 
76
 
 
80
 
81
  with gr.ChatInterface(
82
  answer_question,
83
+ type=chatbot.type,
84
  chatbot=chatbot,
85
  title=title,
86
  description=description,
87
  additional_inputs=[gr.Textbox("", label="SYSTEM")],
88
+ # examples=[
89
+ # ["Can SEA.AI see at night?", "You are a helpful assistant."],
90
+ # ["Can SEA.AI see at night?", "Reply with sailor slang."],
91
+ # ],
92
  cache_examples=False,
93
  submit_btn=None,
 
 
 
94
  css=css,
95
  theme=theme,
96
  ) as demo:
flagging.py CHANGED
@@ -1,20 +1,226 @@
 
 
 
1
  from collections import OrderedDict
2
  from pathlib import Path
3
- from typing import Any
 
 
 
 
4
  import gradio as gr
5
- from gradio.flagging import HuggingFaceDatasetSaver, client_utils
6
  from gradio import utils
7
- import huggingface_hub
 
 
8
 
9
 
10
- class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
 
11
  """
12
- Custom HuggingFaceDatasetSaver to save images/audio to disk.
13
- Gradio's implementation seems to have a bug.
 
 
 
 
 
 
 
 
14
  """
15
 
16
- def __init__(self, *args, **kwargs):
17
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def _deserialize_components(
20
  self,
@@ -38,12 +244,12 @@ class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
38
  label = component.label or ""
39
  save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
40
  save_dir.mkdir(exist_ok=True, parents=True)
 
 
 
41
  if isinstance(component, gr.Chatbot):
42
- deserialized = sample # dirty fix
43
- else:
44
- deserialized = utils.simplify_file_data_in_str(
45
- component.flag(sample, save_dir)
46
- )
47
 
48
  # Add deserialized object to row
49
  features[label] = {"dtype": "string", "_type": "Value"}
@@ -52,7 +258,78 @@ class myHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
52
  if not deserialized_path.exists():
53
  raise FileNotFoundError(f"File {deserialized} not found")
54
  row.append(str(deserialized_path.relative_to(self.dataset_dir)))
55
- except (AssertionError, TypeError, ValueError, OSError):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  deserialized = "" if deserialized is None else str(deserialized)
57
  row.append(deserialized)
58
 
 
1
+ import csv
2
+ import json
3
+ import uuid
4
  from collections import OrderedDict
5
  from pathlib import Path
6
+ from typing import Any, Sequence
7
+
8
+ import filelock
9
+ import huggingface_hub
10
+
11
  import gradio as gr
 
12
  from gradio import utils
13
+ from gradio.flagging import client_utils, FlaggingCallback
14
+ from gradio_client.documentation import document
15
+ from gradio.components import Component
16
 
17
 
18
+ @document()
19
+ class HuggingFaceDatasetSaver(FlaggingCallback):
20
  """
21
+ A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
22
+
23
+ Example:
24
+ import gradio as gr
25
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
26
+ def image_classifier(inp):
27
+ return {'cat': 0.3, 'dog': 0.7}
28
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
29
+ allow_flagging="manual", flagging_callback=hf_writer)
30
+ Guides: using-flagging
31
  """
32
 
33
+ def __init__(
34
+ self,
35
+ hf_token: str,
36
+ dataset_name: str,
37
+ private: bool = False,
38
+ info_filename: str = "dataset_info.json",
39
+ separate_dirs: bool = False,
40
+ ):
41
+ """
42
+ Parameters:
43
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
44
+ dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
45
+ private: Whether the dataset should be private (defaults to False).
46
+ info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
47
+ separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
48
+ """
49
+ self.hf_token = hf_token
50
+ self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
51
+ self.dataset_private = private
52
+ self.info_filename = info_filename
53
+ self.separate_dirs = separate_dirs
54
+
55
+ def setup(self, components: Sequence[Component], flagging_dir: str):
56
+ """
57
+ Params:
58
+ flagging_dir (str): local directory where the dataset is cloned,
59
+ updated, and pushed from.
60
+ """
61
+ # Setup dataset on the Hub
62
+ self.dataset_id = huggingface_hub.create_repo(
63
+ repo_id=self.dataset_id,
64
+ token=self.hf_token,
65
+ private=self.dataset_private,
66
+ repo_type="dataset",
67
+ exist_ok=True,
68
+ ).repo_id
69
+ path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
70
+ huggingface_hub.metadata_update(
71
+ repo_id=self.dataset_id,
72
+ repo_type="dataset",
73
+ metadata={
74
+ "configs": [
75
+ {
76
+ "config_name": "default",
77
+ "data_files": [{"split": "train", "path": path_glob}],
78
+ }
79
+ ]
80
+ },
81
+ overwrite=True,
82
+ token=self.hf_token,
83
+ )
84
+
85
+ # Setup flagging dir
86
+ self.components = components
87
+ self.dataset_dir = (
88
+ Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
89
+ )
90
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
91
+ self.infos_file = self.dataset_dir / self.info_filename
92
+
93
+ # Download remote files to local
94
+ remote_files = [self.info_filename]
95
+ if not self.separate_dirs:
96
+ # No separate dirs => means all data is in the same CSV file => download it to get its current content
97
+ remote_files.append("data.csv")
98
+
99
+ for filename in remote_files:
100
+ try:
101
+ huggingface_hub.hf_hub_download(
102
+ repo_id=self.dataset_id,
103
+ repo_type="dataset",
104
+ filename=filename,
105
+ local_dir=self.dataset_dir,
106
+ token=self.hf_token,
107
+ )
108
+ except huggingface_hub.utils.EntryNotFoundError:
109
+ pass
110
+
111
+ def flag(
112
+ self,
113
+ flag_data: list[Any],
114
+ flag_option: str = "",
115
+ username: str | None = None,
116
+ ) -> int:
117
+ if self.separate_dirs:
118
+ # JSONL files to support dataset preview on the Hub
119
+ unique_id = str(uuid.uuid4())
120
+ components_dir = self.dataset_dir / unique_id
121
+ data_file = components_dir / "metadata.jsonl"
122
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
123
+ else:
124
+ # Unique CSV file
125
+ components_dir = self.dataset_dir
126
+ data_file = components_dir / "data.csv"
127
+ path_in_repo = None # upload at root level
128
+
129
+ return self._flag_in_dir(
130
+ data_file=data_file,
131
+ components_dir=components_dir,
132
+ path_in_repo=path_in_repo,
133
+ flag_data=flag_data,
134
+ flag_option=flag_option,
135
+ username=username or "",
136
+ )
137
+
138
+ def _flag_in_dir(
139
+ self,
140
+ data_file: Path,
141
+ components_dir: Path,
142
+ path_in_repo: str | None,
143
+ flag_data: list[Any],
144
+ flag_option: str = "",
145
+ username: str = "",
146
+ ) -> int:
147
+ # Deserialize components (write images/audio to files)
148
+ features, row = self._deserialize_components(
149
+ components_dir, flag_data, flag_option, username
150
+ )
151
+
152
+ # Write generic info to dataset_infos.json + upload
153
+ with filelock.FileLock(str(self.infos_file) + ".lock"):
154
+ if not self.infos_file.exists():
155
+ self.infos_file.write_text(
156
+ json.dumps({"flagged": {"features": features}})
157
+ )
158
+
159
+ huggingface_hub.upload_file(
160
+ repo_id=self.dataset_id,
161
+ repo_type="dataset",
162
+ token=self.hf_token,
163
+ path_in_repo=self.infos_file.name,
164
+ path_or_fileobj=self.infos_file,
165
+ )
166
+
167
+ headers = list(features.keys())
168
+
169
+ if not self.separate_dirs:
170
+ with filelock.FileLock(components_dir / ".lock"):
171
+ sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
172
+ sample_name = str(sample_nb)
173
+ huggingface_hub.upload_folder(
174
+ repo_id=self.dataset_id,
175
+ repo_type="dataset",
176
+ commit_message=f"Flagged sample #{sample_name}",
177
+ path_in_repo=path_in_repo,
178
+ ignore_patterns="*.lock",
179
+ folder_path=components_dir,
180
+ token=self.hf_token,
181
+ )
182
+ else:
183
+ sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
184
+ sample_nb = len(
185
+ [path for path in self.dataset_dir.iterdir() if path.is_dir()]
186
+ )
187
+ huggingface_hub.upload_folder(
188
+ repo_id=self.dataset_id,
189
+ repo_type="dataset",
190
+ commit_message=f"Flagged sample #{sample_name}",
191
+ path_in_repo=path_in_repo,
192
+ ignore_patterns="*.lock",
193
+ folder_path=components_dir,
194
+ token=self.hf_token,
195
+ )
196
+
197
+ return sample_nb
198
+
199
+ @staticmethod
200
+ def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
201
+ """Save data as CSV and return the sample name (row number)."""
202
+ is_new = not data_file.exists()
203
+
204
+ with data_file.open("a", newline="", encoding="utf-8") as csvfile:
205
+ writer = csv.writer(csvfile)
206
+
207
+ # Write CSV headers if new file
208
+ if is_new:
209
+ writer.writerow(utils.sanitize_list_for_csv(headers))
210
+
211
+ # Write CSV row for flagged sample
212
+ writer.writerow(utils.sanitize_list_for_csv(row))
213
+
214
+ with data_file.open(encoding="utf-8") as csvfile:
215
+ return sum(1 for _ in csv.reader(csvfile)) - 1
216
+
217
+ @staticmethod
218
+ def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
219
+ """Save data as JSONL and return the sample name (uuid)."""
220
+ Path.mkdir(data_file.parent, parents=True, exist_ok=True)
221
+ with open(data_file, "w", encoding="utf-8") as f:
222
+ json.dump(dict(zip(headers, row)), f)
223
+ return data_file.parent.name
224
 
225
  def _deserialize_components(
226
  self,
 
244
  label = component.label or ""
245
  save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
246
  save_dir.mkdir(exist_ok=True, parents=True)
247
+ deserialized = utils.simplify_file_data_in_str(
248
+ component.flag(sample, save_dir)
249
+ )
250
  if isinstance(component, gr.Chatbot):
251
+ messages = json.loads(deserialized)
252
+ deserialized = [msg.get("content") for msg in messages if msg.get("role") == "assistant"][0]
 
 
 
253
 
254
  # Add deserialized object to row
255
  features[label] = {"dtype": "string", "_type": "Value"}
 
258
  if not deserialized_path.exists():
259
  raise FileNotFoundError(f"File {deserialized} not found")
260
  row.append(str(deserialized_path.relative_to(self.dataset_dir)))
261
+ except (FileNotFoundError, TypeError, ValueError, OSError):
262
+ deserialized = "" if deserialized is None else str(deserialized)
263
+ row.append(deserialized)
264
+
265
+ # If component is eligible for a preview, add the URL of the file
266
+ # Be mindful that images and audio can be None
267
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
268
+ for _component, _type in file_preview_types.items():
269
+ if isinstance(component, _component):
270
+ features[label + " file"] = {"_type": _type}
271
+ break
272
+ if deserialized:
273
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
274
+ Path(deserialized).relative_to(self.dataset_dir)
275
+ ).replace("\\", "/")
276
+ row.append(
277
+ huggingface_hub.hf_hub_url(
278
+ repo_id=self.dataset_id,
279
+ filename=path_in_repo,
280
+ repo_type="dataset",
281
+ )
282
+ )
283
+ else:
284
+ row.append("")
285
+ features["flag"] = {"dtype": "string", "_type": "Value"}
286
+ features["username"] = {"dtype": "string", "_type": "Value"}
287
+ row.append(flag_option)
288
+ row.append(username)
289
+ return features, row
290
+
291
+
292
+ class MyHuggingFaceDatasetSaver(HuggingFaceDatasetSaver):
293
+ """
294
+ Custom HuggingFaceDatasetSaver to save images/audio to disk.
295
+ Gradio's implementation seems to have a bug.
296
+ """
297
+
298
+ def __init__(self, *args, **kwargs):
299
+ super().__init__(*args, **kwargs)
300
+
301
+ def _deserialize_components(
302
+ self,
303
+ data_dir: Path,
304
+ flag_data: list[Any],
305
+ flag_option: str = "",
306
+ username: str = "",
307
+ ) -> tuple[dict[Any, Any], list[Any]]:
308
+ """Deserialize components and return the corresponding row for the flagged sample.
309
+
310
+ Images/audio are saved to disk as individual files.
311
+ """
312
+ # Components that can have a preview on dataset repos
313
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
314
+
315
+ # Generate the row corresponding to the flagged sample
316
+ features = OrderedDict()
317
+ row = []
318
+ for component, sample in zip(self.components, flag_data):
319
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
320
+ label = component.label or ""
321
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
322
+ save_dir.mkdir(exist_ok=True, parents=True)
323
+ deserialized = component.flag(sample, save_dir)
324
+ if isinstance(component, gr.Image) and isinstance(sample, dict):
325
+ deserialized = json.loads(deserialized)["path"] # dirty hack
326
+
327
+ # Add deserialized object to row
328
+ features[label] = {"dtype": "string", "_type": "Value"}
329
+ try:
330
+ assert Path(deserialized).exists()
331
+ row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
332
+ except (AssertionError, TypeError, ValueError):
333
  deserialized = "" if deserialized is None else str(deserialized)
334
  row.append(deserialized)
335