CannaTech commited on
Commit
03435e9
·
1 Parent(s): 5382d52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -169
app.py CHANGED
@@ -1,181 +1,15 @@
1
- # Import the necessary libraries
2
- from __future__ import annotations
3
  import os
4
  import openai
5
  import gradio as gr
6
  import csv
7
  import json
8
- import io
9
- import uuid
10
- import datetime
11
- from abc import ABC, abstractmethod
12
- from pathlib import Path
13
- from typing import TYPE_CHECKING, Any, List
14
- from gradio.documentation import document, set_documentation_group
15
- from gradio.components import IOComponent
16
-
17
 
18
  # Set the OpenAI API key
19
  openai.api_key = os.getenv("OPENAI_API_KEY")
20
 
21
- ##################
22
-
23
- class FlaggingCallback(ABC):
24
- """
25
- An abstract class for defining the methods that any FlaggingCallback should have.
26
- """
27
-
28
- @abstractmethod
29
- def setup(self, components: List[IOComponent], flagging_dir: str):
30
- """
31
- This method should be overridden and ensure that everything is set up correctly for flag().
32
- This method gets called once at the beginning of the Interface.launch() method.
33
- Parameters:
34
- components: Set of components that will provide flagged data.
35
- flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
36
- """
37
- pass
38
-
39
- @abstractmethod
40
- def flag(
41
- self,
42
- flag_data: List[Any],
43
- flag_option: str | None = None,
44
- flag_index: int | None = None,
45
- username: str | None = None,
46
- ) -> int:
47
- """
48
- This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
49
- This gets called every time the <flag> button is pressed.
50
- Parameters:
51
- interface: The Interface object that is being used to launch the flagging interface.
52
- flag_data: The data to be flagged.
53
- flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
54
- flag_index (optional): The index of the sample that is being flagged.
55
- username (optional): The username of the user that is flagging the data, if logged in.
56
- Returns:
57
- (int) The total number of samples that have been flagged.
58
- """
59
- pass
60
-
61
-
62
- class HuggingFaceDatasetSaver(FlaggingCallback):
63
- """
64
- A callback that saves each flagged sample (both the input and output data)
65
- to a HuggingFace dataset.
66
- Example:
67
- import gradio as gr
68
- hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
69
- def image_classifier(inp):
70
- return {'cat': 0.3, 'dog': 0.7}
71
- demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
72
- allow_flagging="manual", flagging_callback=hf_writer)
73
- Guides: using_flagging
74
- """
75
-
76
- def __init__(
77
- self,
78
- hf_token: str,
79
- dataset_name: str,
80
- organization: str | None = None,
81
- private: bool = False,
82
- ):
83
- """
84
- Parameters:
85
- hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
86
- dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
87
- organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
88
- private: Whether the dataset should be private (defaults to False).
89
- """
90
- self.hf_token = hf_token
91
- self.dataset_name = dataset_name
92
- self.organization_name = organization
93
- self.dataset_private = private
94
-
95
- def setup(self, components: List[IOComponent], flagging_dir: str):
96
- """
97
- Params:
98
- flagging_dir (str): local directory where the dataset is cloned,
99
- updated, and pushed from.
100
- """
101
- try:
102
- import huggingface_hub
103
- except (ImportError, ModuleNotFoundError):
104
- raise ImportError(
105
- "Package `huggingface_hub` not found is needed "
106
- "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
107
- )
108
- path_to_dataset_repo = huggingface_hub.create_repo(
109
- name=self.dataset_name,
110
- token=self.hf_token,
111
- private=self.dataset_private,
112
- repo_type="dataset",
113
- exist_ok=True,
114
- )
115
- self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
116
- self.components = components
117
- self.flagging_dir = flagging_dir
118
- self.dataset_dir = Path(flagging_dir) / self.dataset_name
119
- self.repo = huggingface_hub.Repository(
120
- local_dir=str(self.dataset_dir),
121
- clone_from=path_to_dataset_repo,
122
- use_auth_token=self.hf_token,
123
- )
124
- self.repo.git_pull(lfs=True)
125
-
126
- # Should filename be user-specified?
127
- self.log_file = Path(self.dataset_dir) / "data.csv"
128
- self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
129
-
130
- def flag(
131
- self,
132
- flag_data: List[Any],
133
- flag_option: str | None = None,
134
- flag_index: int | None = None,
135
- username: str | None = None,
136
- ) -> int:
137
- self.repo.git_pull(lfs=True)
138
-
139
- is_new = not Path(self.log_file).exists()
140
-
141
- with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
142
- writer = csv.writer(csvfile)
143
-
144
- # File previews for certain input and output types
145
- infos, file_preview_types, headers = _get_dataset_features_info(
146
- is_new, self.components
147
- )
148
-
149
- # Generate the headers and dataset_infos
150
- if is_new:
151
- writer.writerow(utils.sanitize_list_for_csv(headers))
152
-
153
- # Generate the row corresponding to the flagged sample
154
- csv_data = []
155
- for component, sample in zip(self.components, flag_data):
156
- save_dir = Path(
157
- self.dataset_dir
158
- ) / utils.strip_invalid_filename_characters(component.label or "")
159
- filepath = component.deserialize(sample, save_dir, None)
160
- csv_data.append(filepath)
161
- if isinstance(component, tuple(file_preview_types)):
162
- csv_data.append(
163
- "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
164
- )
165
- csv_data.append(flag_option if flag_option is not None else "")
166
- writer.writerow(utils.sanitize_list_for_csv(csv_data))
167
-
168
- if is_new:
169
- json.dump(infos, open(self.infos_file, "w"))
170
-
171
- with open(self.log_file, "r", encoding="utf-8") as csvfile:
172
- line_count = len([None for row in csv.reader(csvfile)]) - 1
173
-
174
- self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
175
-
176
- return line_count
177
-
178
- ##################
179
 
180
  # Set up flagging callback function
181
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
1
  import os
2
  import openai
3
  import gradio as gr
4
  import csv
5
  import json
 
 
 
 
 
 
 
 
 
6
 
7
  # Set the OpenAI API key
8
  openai.api_key = os.getenv("OPENAI_API_KEY")
9
 
10
+ # Set up flagging callback function
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "CannaTech/Flagged")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Set up flagging callback function
15
  HF_TOKEN = os.getenv("HF_TOKEN")