CannaTech commited on
Commit
0900f62
·
1 Parent(s): 49e6487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py CHANGED
@@ -8,6 +8,130 @@ import json
8
  # Set the OpenAI API key
9
  openai.api_key = os.getenv("OPENAI_API_KEY")
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Set up flagging callback function
12
  HF_TOKEN = os.getenv("HF_TOKEN")
13
  hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "CannaTech/Flagged")
 
8
  # Set the OpenAI API key
9
  openai.api_key = os.getenv("OPENAI_API_KEY")
10
 
11
+ ##################
12
+
13
+ class HuggingFaceDatasetSaver(FlaggingCallback):
14
+ """
15
+ A callback that saves each flagged sample (both the input and output data)
16
+ to a HuggingFace dataset.
17
+ Example:
18
+ import gradio as gr
19
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
20
+ def image_classifier(inp):
21
+ return {'cat': 0.3, 'dog': 0.7}
22
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
23
+ allow_flagging="manual", flagging_callback=hf_writer)
24
+ Guides: using_flagging
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ hf_token: str,
30
+ dataset_name: str,
31
+ organization: str | None = None,
32
+ private: bool = False,
33
+ ):
34
+ """
35
+ Parameters:
36
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
37
+ dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
38
+ organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
39
+ private: Whether the dataset should be private (defaults to False).
40
+ """
41
+ self.hf_token = hf_token
42
+ self.dataset_name = dataset_name
43
+ self.organization_name = organization
44
+ self.dataset_private = private
45
+
46
+ def setup(self, components: List[IOComponent], flagging_dir: str):
47
+ """
48
+ Params:
49
+ flagging_dir (str): local directory where the dataset is cloned,
50
+ updated, and pushed from.
51
+ """
52
+ try:
53
+ import huggingface_hub
54
+ except (ImportError, ModuleNotFoundError):
55
+ raise ImportError(
56
+ "Package `huggingface_hub` not found is needed "
57
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
58
+ )
59
+ path_to_dataset_repo = huggingface_hub.create_repo(
60
+ name=self.dataset_name,
61
+ token=self.hf_token,
62
+ private=self.dataset_private,
63
+ repo_type="dataset",
64
+ exist_ok=True,
65
+ )
66
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
67
+ self.components = components
68
+ self.flagging_dir = flagging_dir
69
+ self.dataset_dir = Path(flagging_dir) / self.dataset_name
70
+ self.repo = huggingface_hub.Repository(
71
+ local_dir=str(self.dataset_dir),
72
+ clone_from=path_to_dataset_repo,
73
+ use_auth_token=self.hf_token,
74
+ )
75
+ self.repo.git_pull(lfs=True)
76
+
77
+ # Should filename be user-specified?
78
+ self.log_file = Path(self.dataset_dir) / "data.csv"
79
+ self.infos_file = Path(self.dataset_dir) / "dataset_infos.json"
80
+
81
+ def flag(
82
+ self,
83
+ flag_data: List[Any],
84
+ flag_option: str | None = None,
85
+ flag_index: int | None = None,
86
+ username: str | None = None,
87
+ ) -> int:
88
+ self.repo.git_pull(lfs=True)
89
+
90
+ is_new = not Path(self.log_file).exists()
91
+
92
+ with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
93
+ writer = csv.writer(csvfile)
94
+
95
+ # File previews for certain input and output types
96
+ infos, file_preview_types, headers = _get_dataset_features_info(
97
+ is_new, self.components
98
+ )
99
+
100
+ # Generate the headers and dataset_infos
101
+ if is_new:
102
+ writer.writerow(utils.sanitize_list_for_csv(headers))
103
+
104
+ # Generate the row corresponding to the flagged sample
105
+ csv_data = []
106
+ for component, sample in zip(self.components, flag_data):
107
+ save_dir = Path(
108
+ self.dataset_dir
109
+ ) / utils.strip_invalid_filename_characters(component.label or "")
110
+ filepath = component.deserialize(sample, save_dir, None)
111
+ csv_data.append(filepath)
112
+ if isinstance(component, tuple(file_preview_types)):
113
+ csv_data.append(
114
+ "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
115
+ )
116
+
117
+ # Truncate flagged output to first 100 characters
118
+ flagged_output = csv_data[-1][:100] if csv_data else ""
119
+ csv_data.append(flagged_output)
120
+ csv_data.append(flag_option if flag_option is not None else "")
121
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
122
+
123
+ if is_new:
124
+ json.dump(infos, open(self.infos_file, "w"))
125
+
126
+ with open(self.log_file, "r", encoding="utf-8") as csvfile:
127
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
128
+
129
+ self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
130
+
131
+ return line_count
132
+
133
+ ##################
134
+
135
  # Set up flagging callback function
136
  HF_TOKEN = os.getenv("HF_TOKEN")
137
  hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "CannaTech/Flagged")