inoki-giskard commited on
Commit
d1e5b15
·
1 Parent(s): 029ed97

Add an HTML widget to show the error info for model id validation

Browse files
app_text_classification.py CHANGED
@@ -18,20 +18,21 @@ from text_classification_ui_helpers import (
18
 
19
  import logging
20
  from wordings import (
21
- CONFIRM_MAPPING_DETAILS_MD,
22
- INTRODUCTION_MD,
23
- USE_INFERENCE_API_TIP,
24
- CHECK_LOG_SECTION_RAW,
25
- HF_TOKEN_INVALID_STYLED
 
26
  )
27
 
28
  MAX_LABELS = 40
29
  MAX_FEATURES = 20
30
 
31
- EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
32
  CONFIG_PATH = "./config.yaml"
33
  logger = logging.getLogger(__name__)
34
 
 
35
  def get_demo():
36
  with gr.Row():
37
  gr.Markdown(INTRODUCTION_MD)
@@ -39,10 +40,14 @@ def get_demo():
39
  label="Evaluation ID:", value=uuid.uuid4, visible=False, interactive=False
40
  )
41
  with gr.Row():
42
- model_id_input = gr.Textbox(
43
- label="Hugging Face Model id",
44
- placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)",
45
- )
 
 
 
 
46
 
47
  with gr.Column():
48
  dataset_id_input = gr.Dropdown(
@@ -53,8 +58,12 @@ def get_demo():
53
  )
54
 
55
  with gr.Row():
56
- dataset_config_input = gr.Dropdown(label="Dataset Config", visible=False, allow_custom_value=True)
57
- dataset_split_input = gr.Dropdown(label="Dataset Split", visible=False, allow_custom_value=True)
 
 
 
 
58
 
59
  with gr.Row():
60
  first_line_ds = gr.DataFrame(label="Dataset Preview", visible=False)
@@ -72,7 +81,9 @@ def get_demo():
72
  with gr.Row():
73
  validation_result = gr.HTML(visible=False)
74
  with gr.Row():
75
- example_input = gr.Textbox(label="Example Input", visible=False, interactive=False)
 
 
76
  example_prediction = gr.Label(label="Model Sample Prediction", visible=False)
77
 
78
  with gr.Row():
@@ -119,15 +130,15 @@ def get_demo():
119
  # Reason: data_leakage barely raises any issues and takes too many requests
120
  # when using inference API, causing rate limit error
121
  scan_config = [
122
- "ethical_bias",
123
- "text_perturbation",
124
  "robustness",
125
  "performance",
126
  "underconfidence",
127
  "overconfidence",
128
  "spurious_correlation",
129
  "data_leakage",
130
- ]
131
  return gr.update(
132
  choices=scan_config, value=selected, label="Scan Settings", visible=True
133
  )
@@ -148,7 +159,6 @@ def get_demo():
148
  every=0.5,
149
  )
150
 
151
-
152
  scanners.change(write_scanners, inputs=[scanners, uid_label])
153
 
154
  gr.on(
@@ -161,20 +171,28 @@ def get_demo():
161
  inputs=[dataset_id_input],
162
  outputs=[dataset_config_input, dataset_split_input, loading_dataset_info],
163
  )
164
-
165
  gr.on(
166
  triggers=[dataset_id_input.input, dataset_id_input.select],
167
  fn=check_dataset,
168
  inputs=[dataset_id_input],
169
- outputs=[dataset_config_input, dataset_split_input, loading_dataset_info]
170
  )
171
 
172
- dataset_config_input.change(fn=get_dataset_splits, inputs=[dataset_id_input, dataset_config_input], outputs=[dataset_split_input])
 
 
 
 
173
 
174
  gr.on(
175
- triggers=[model_id_input.change, dataset_id_input.change, dataset_config_input.change],
 
 
 
 
176
  fn=empty_column_mapping,
177
- inputs=[uid_label]
178
  )
179
 
180
  gr.on(
@@ -199,7 +217,6 @@ def get_demo():
199
  gr.on(
200
  triggers=[
201
  model_id_input.change,
202
- model_id_input.input,
203
  dataset_id_input.change,
204
  dataset_config_input.change,
205
  dataset_split_input.change,
@@ -212,12 +229,14 @@ def get_demo():
212
  dataset_split_input,
213
  ],
214
  outputs=[
215
- example_btn,
216
  first_line_ds,
217
  validation_result,
218
  example_input,
219
  example_prediction,
220
- column_mapping_accordion,],
 
 
221
  )
222
 
223
  gr.on(
@@ -258,14 +277,14 @@ def get_demo():
258
  uid_label,
259
  ],
260
  outputs=[
261
- run_btn,
262
- logs,
263
- uid_label,
264
  validation_result,
265
  example_input,
266
  example_prediction,
267
  column_mapping_accordion,
268
- ],
269
  )
270
 
271
  gr.on(
@@ -276,11 +295,11 @@ def get_demo():
276
  fn=enable_run_btn,
277
  inputs=[
278
  uid_label,
279
- inference_token,
280
- model_id_input,
281
- dataset_id_input,
282
- dataset_config_input,
283
- dataset_split_input
284
  ],
285
  outputs=[run_btn],
286
  )
@@ -290,11 +309,11 @@ def get_demo():
290
  fn=enable_run_btn,
291
  inputs=[
292
  uid_label,
293
- inference_token,
294
- model_id_input,
295
- dataset_id_input,
296
- dataset_config_input,
297
- dataset_split_input
298
  ], # FIXME
299
  outputs=[run_btn],
300
  )
 
18
 
19
  import logging
20
  from wordings import (
21
+ EXAMPLE_MODEL_ID,
22
+ CONFIRM_MAPPING_DETAILS_MD,
23
+ INTRODUCTION_MD,
24
+ USE_INFERENCE_API_TIP,
25
+ CHECK_LOG_SECTION_RAW,
26
+ HF_TOKEN_INVALID_STYLED,
27
  )
28
 
29
  MAX_LABELS = 40
30
  MAX_FEATURES = 20
31
 
 
32
  CONFIG_PATH = "./config.yaml"
33
  logger = logging.getLogger(__name__)
34
 
35
+
36
  def get_demo():
37
  with gr.Row():
38
  gr.Markdown(INTRODUCTION_MD)
 
40
  label="Evaluation ID:", value=uuid.uuid4, visible=False, interactive=False
41
  )
42
  with gr.Row():
43
+ with gr.Column():
44
+ with gr.Row():
45
+ model_id_input = gr.Textbox(
46
+ label="Hugging Face Model id",
47
+ placeholder=f"e.g. {EXAMPLE_MODEL_ID}",
48
+ )
49
+ with gr.Row():
50
+ model_id_error_info = gr.HTML(visible=False)
51
 
52
  with gr.Column():
53
  dataset_id_input = gr.Dropdown(
 
58
  )
59
 
60
  with gr.Row():
61
+ dataset_config_input = gr.Dropdown(
62
+ label="Dataset Config", visible=False, allow_custom_value=True
63
+ )
64
+ dataset_split_input = gr.Dropdown(
65
+ label="Dataset Split", visible=False, allow_custom_value=True
66
+ )
67
 
68
  with gr.Row():
69
  first_line_ds = gr.DataFrame(label="Dataset Preview", visible=False)
 
81
  with gr.Row():
82
  validation_result = gr.HTML(visible=False)
83
  with gr.Row():
84
+ example_input = gr.Textbox(
85
+ label="Example Input", visible=False, interactive=False
86
+ )
87
  example_prediction = gr.Label(label="Model Sample Prediction", visible=False)
88
 
89
  with gr.Row():
 
130
  # Reason: data_leakage barely raises any issues and takes too many requests
131
  # when using inference API, causing rate limit error
132
  scan_config = [
133
+ "ethical_bias",
134
+ "text_perturbation",
135
  "robustness",
136
  "performance",
137
  "underconfidence",
138
  "overconfidence",
139
  "spurious_correlation",
140
  "data_leakage",
141
+ ]
142
  return gr.update(
143
  choices=scan_config, value=selected, label="Scan Settings", visible=True
144
  )
 
159
  every=0.5,
160
  )
161
 
 
162
  scanners.change(write_scanners, inputs=[scanners, uid_label])
163
 
164
  gr.on(
 
171
  inputs=[dataset_id_input],
172
  outputs=[dataset_config_input, dataset_split_input, loading_dataset_info],
173
  )
174
+
175
  gr.on(
176
  triggers=[dataset_id_input.input, dataset_id_input.select],
177
  fn=check_dataset,
178
  inputs=[dataset_id_input],
179
+ outputs=[dataset_config_input, dataset_split_input, loading_dataset_info],
180
  )
181
 
182
+ dataset_config_input.change(
183
+ fn=get_dataset_splits,
184
+ inputs=[dataset_id_input, dataset_config_input],
185
+ outputs=[dataset_split_input],
186
+ )
187
 
188
  gr.on(
189
+ triggers=[
190
+ model_id_input.change,
191
+ dataset_id_input.change,
192
+ dataset_config_input.change,
193
+ ],
194
  fn=empty_column_mapping,
195
+ inputs=[uid_label],
196
  )
197
 
198
  gr.on(
 
217
  gr.on(
218
  triggers=[
219
  model_id_input.change,
 
220
  dataset_id_input.change,
221
  dataset_config_input.change,
222
  dataset_split_input.change,
 
229
  dataset_split_input,
230
  ],
231
  outputs=[
232
+ example_btn,
233
  first_line_ds,
234
  validation_result,
235
  example_input,
236
  example_prediction,
237
+ column_mapping_accordion,
238
+ model_id_error_info,
239
+ ],
240
  )
241
 
242
  gr.on(
 
277
  uid_label,
278
  ],
279
  outputs=[
280
+ run_btn,
281
+ logs,
282
+ uid_label,
283
  validation_result,
284
  example_input,
285
  example_prediction,
286
  column_mapping_accordion,
287
+ ],
288
  )
289
 
290
  gr.on(
 
295
  fn=enable_run_btn,
296
  inputs=[
297
  uid_label,
298
+ inference_token,
299
+ model_id_input,
300
+ dataset_id_input,
301
+ dataset_config_input,
302
+ dataset_split_input,
303
  ],
304
  outputs=[run_btn],
305
  )
 
309
  fn=enable_run_btn,
310
  inputs=[
311
  uid_label,
312
+ inference_token,
313
+ model_id_input,
314
+ dataset_id_input,
315
+ dataset_config_input,
316
+ dataset_split_input,
317
  ], # FIXME
318
  outputs=[run_btn],
319
  )
text_classification_ui_helpers.py CHANGED
@@ -9,10 +9,10 @@ import pandas as pd
9
 
10
  import leaderboard
11
  from io_utils import (
12
- read_column_mapping,
13
- write_column_mapping,
14
- read_scanners,
15
- write_scanners,
16
  )
17
  from run_jobs import save_job_to_pipe
18
  from text_classification import (
@@ -24,9 +24,11 @@ from text_classification import (
24
  HuggingFaceInferenceAPIResponse,
25
  )
26
  from wordings import (
 
27
  CHECK_CONFIG_OR_SPLIT_RAW,
28
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
29
  MAPPING_STYLED_ERROR_WARNING,
 
30
  NOT_TEXT_CLASSIFICATION_MODEL_RAW,
31
  UNMATCHED_MODEL_DATASET_STYLED_ERROR,
32
  CHECK_LOG_SECTION_RAW,
@@ -42,6 +44,7 @@ MAX_FEATURES = 20
42
  ds_dict = None
43
  ds_config = None
44
 
 
45
  def get_related_datasets_from_leaderboard(model_id, dataset_id_input):
46
  records = leaderboard.records
47
  model_records = records[records["model_id"] == model_id]
@@ -49,54 +52,56 @@ def get_related_datasets_from_leaderboard(model_id, dataset_id_input):
49
 
50
  if len(datasets_unique) == 0:
51
  return gr.update(choices=[])
52
-
53
  if dataset_id_input in datasets_unique:
54
  return gr.update(choices=datasets_unique)
55
-
56
  return gr.update(choices=datasets_unique, value="")
57
 
58
 
59
  logger = logging.getLogger(__file__)
60
 
 
61
  def get_dataset_splits(dataset_id, dataset_config):
62
  try:
63
- splits = datasets.get_dataset_split_names(dataset_id, dataset_config, trust_remote_code=True)
 
 
64
  return gr.update(choices=splits, value=splits[0], visible=True)
65
  except Exception as e:
66
- logger.warning(f"Check your dataset {dataset_id} and config {dataset_config}: {e}")
 
 
67
  return gr.update(visible=False)
68
 
 
69
  def check_dataset(dataset_id):
70
  logger.info(f"Loading {dataset_id}")
71
  try:
72
  configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True)
73
  if len(configs) == 0:
74
- return (
75
- gr.update(visible=False),
76
- gr.update(visible=False),
77
- ""
78
- )
79
- splits = datasets.get_dataset_split_names(dataset_id, configs[0], trust_remote_code=True)
80
  return (
81
  gr.update(choices=configs, value=configs[0], visible=True),
82
  gr.update(choices=splits, value=splits[0], visible=True),
83
- ""
84
  )
85
  except Exception as e:
86
  logger.warning(f"Check your dataset {dataset_id}: {e}")
87
  if "doesn't exist" in str(e):
88
  gr.Warning(get_dataset_fetch_error_raw(e))
89
- if "forbidden" in str(e).lower(): # GSK-2770
90
  gr.Warning(get_dataset_fetch_error_raw(e))
91
- return (
92
- gr.update(visible=False),
93
- gr.update(visible=False),
94
- ""
95
- )
96
 
97
  def empty_column_mapping(uid):
98
  write_column_mapping(None, uid)
99
 
 
100
  def write_column_mapping_to_config(uid, *labels):
101
  # TODO: Substitute 'text' with more features for zero-shot
102
  # we are not using ds features because we only support "text" for now
@@ -114,13 +119,14 @@ def write_column_mapping_to_config(uid, *labels):
114
 
115
  write_column_mapping(all_mappings, uid)
116
 
 
117
  def export_mappings(all_mappings, key, subkeys, values):
118
  if key not in all_mappings.keys():
119
  all_mappings[key] = dict()
120
  if subkeys is None:
121
  subkeys = list(all_mappings[key].keys())
122
 
123
- if not subkeys:
124
  logging.debug(f"subkeys is empty for {key}")
125
  return all_mappings
126
 
@@ -139,7 +145,9 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
139
  ds_labels = list(shared_labels)
140
  if len(ds_labels) > MAX_LABELS:
141
  ds_labels = ds_labels[:MAX_LABELS]
142
- gr.Warning(f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd.")
 
 
143
 
144
  # sort labels to make sure the order is consistent
145
  # prediction gives the order based on probability
@@ -183,11 +191,47 @@ def precheck_model_ds_enable_example_btn(
183
  model_id, dataset_id, dataset_config, dataset_split
184
  ):
185
  model_task = check_model_task(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  preload_hf_inference_api(model_id)
187
 
188
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
189
  return (
190
- gr.update(interactive=False),
 
191
  gr.update(visible=False),
192
  gr.update(visible=False),
193
  gr.update(visible=False),
@@ -198,41 +242,36 @@ def precheck_model_ds_enable_example_btn(
198
  try:
199
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
200
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
201
- ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds[dataset_split])
202
-
203
- if model_task is None or model_task != "text-classification":
204
- gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
205
- return (
206
- gr.update(interactive=False),
207
- gr.update(value=df, visible=True),
208
- gr.update(visible=False),
209
- gr.update(visible=False),
210
- gr.update(visible=False),
211
- gr.update(visible=False),
212
- )
213
 
214
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
215
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
216
  return (
217
- gr.update(interactive=False),
218
  gr.update(value=df, visible=True),
219
  gr.update(visible=False),
220
  gr.update(visible=False),
221
  gr.update(visible=False),
222
  gr.update(visible=False),
 
223
  )
224
 
225
  return (
226
- gr.update(interactive=True),
227
- gr.update(value=df, visible=True),
228
- gr.update(visible=False),
229
- gr.update(visible=False),
230
- gr.update(visible=False),
231
- gr.update(visible=False),
232
- )
 
233
  except Exception as e:
234
  # Config or split wrong
235
- logger.warning(f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}")
 
 
236
  return (
237
  gr.update(interactive=False),
238
  gr.update(visible=False),
@@ -240,6 +279,7 @@ def precheck_model_ds_enable_example_btn(
240
  gr.update(visible=False),
241
  gr.update(visible=False),
242
  gr.update(visible=False),
 
243
  )
244
 
245
 
@@ -266,7 +306,7 @@ def align_columns_and_show_prediction(
266
  dropdown_placement = [
267
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
268
  ]
269
-
270
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
271
 
272
  prediction_input, prediction_response = get_example_prediction(
@@ -296,8 +336,10 @@ def align_columns_and_show_prediction(
296
  )
297
 
298
  model_labels = list(prediction_response.keys())
299
-
300
- ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
 
 
301
  ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
302
 
303
  # when dataset does not have labels or features
@@ -312,7 +354,7 @@ def align_columns_and_show_prediction(
312
  "",
313
  *dropdown_placement,
314
  )
315
-
316
  if len(ds_labels) != len(model_labels):
317
  return (
318
  gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True),
@@ -339,7 +381,11 @@ def align_columns_and_show_prediction(
339
  ):
340
  return (
341
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
342
- gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
 
 
 
 
343
  gr.update(value=prediction_response, visible=True),
344
  gr.update(visible=True, open=True),
345
  gr.update(interactive=(inference_token != "")),
@@ -349,7 +395,11 @@ def align_columns_and_show_prediction(
349
 
350
  return (
351
  gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
352
- gr.update(value=prediction_input, lines=min(len(prediction_input)//225 + 1, 5), visible=True),
 
 
 
 
353
  gr.update(value=prediction_response, visible=True),
354
  gr.update(visible=True, open=False),
355
  gr.update(interactive=(inference_token != "")),
@@ -370,14 +420,22 @@ def check_column_mapping_keys_validity(all_mappings):
370
 
371
  return True
372
 
373
- def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, dataset_split):
 
 
 
374
  if inference_token == "":
375
  logger.warning("Inference API is not enabled")
376
  return gr.update(interactive=False)
377
- if model_id == "" or dataset_id == "" or dataset_config == "" or dataset_split == "":
 
 
 
 
 
378
  logger.warning("Model id or dataset id is not selected")
379
  return gr.update(interactive=False)
380
-
381
  all_mappings = read_column_mapping(uid)
382
  if not check_column_mapping_keys_validity(all_mappings):
383
  logger.warning("Column mapping is not valid")
@@ -388,17 +446,24 @@ def enable_run_btn(uid, inference_token, model_id, dataset_id, dataset_config, d
388
  return gr.update(interactive=False)
389
  return gr.update(interactive=True)
390
 
391
- def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys=None):
 
 
 
392
  label_mapping = {}
393
  if len(all_mappings["labels"].keys()) != len(ds_labels):
394
- logger.warning(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
395
- \nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
396
-
 
 
397
  if len(all_mappings["features"].keys()) != len(ds_features):
398
- logger.warning(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
399
- \nall_mappings: {all_mappings}\nds_features: {ds_features}""")
 
 
400
 
401
- for i, label in zip(range(len(ds_labels)), ds_labels):
402
  # align the saved labels with dataset labels order
403
  label_mapping.update({str(i): all_mappings["labels"][label]})
404
 
@@ -408,15 +473,17 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, la
408
 
409
  feature_mapping = all_mappings["features"]
410
  if len(label_keys) > 0:
411
- feature_mapping.update({"label": label_keys[0]})
412
  return label_mapping, feature_mapping
413
 
 
414
  def show_hf_token_info(token):
415
  valid = check_hf_token_validity(token)
416
  if not valid:
417
  return gr.update(visible=True)
418
  return gr.update(visible=False)
419
 
 
420
  def try_submit(m_id, d_id, config, split, inference_token, uid):
421
  all_mappings = read_column_mapping(uid)
422
  if not check_column_mapping_keys_validity(all_mappings):
@@ -425,7 +492,9 @@ def try_submit(m_id, d_id, config, split, inference_token, uid):
425
  # get ds labels and features again for alignment
426
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
427
  ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
428
- label_mapping, feature_mapping = construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features, label_keys)
 
 
429
 
430
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
431
  save_job_to_pipe(
@@ -451,7 +520,12 @@ def try_submit(m_id, d_id, config, split, inference_token, uid):
451
 
452
  return (
453
  gr.update(interactive=False), # Submit button
454
- gr.update(value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ", lines=5, visible=True, interactive=False),
 
 
 
 
 
455
  new_uid, # Allocate a new uuid
456
  gr.update(visible=False),
457
  gr.update(visible=False),
 
9
 
10
  import leaderboard
11
  from io_utils import (
12
+ read_column_mapping,
13
+ write_column_mapping,
14
+ read_scanners,
15
+ write_scanners,
16
  )
17
  from run_jobs import save_job_to_pipe
18
  from text_classification import (
 
24
  HuggingFaceInferenceAPIResponse,
25
  )
26
  from wordings import (
27
+ EXAMPLE_MODEL_ID,
28
  CHECK_CONFIG_OR_SPLIT_RAW,
29
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
30
  MAPPING_STYLED_ERROR_WARNING,
31
+ NOT_FOUND_MODEL_RAW,
32
  NOT_TEXT_CLASSIFICATION_MODEL_RAW,
33
  UNMATCHED_MODEL_DATASET_STYLED_ERROR,
34
  CHECK_LOG_SECTION_RAW,
 
44
  ds_dict = None
45
  ds_config = None
46
 
47
+
48
  def get_related_datasets_from_leaderboard(model_id, dataset_id_input):
49
  records = leaderboard.records
50
  model_records = records[records["model_id"] == model_id]
 
52
 
53
  if len(datasets_unique) == 0:
54
  return gr.update(choices=[])
55
+
56
  if dataset_id_input in datasets_unique:
57
  return gr.update(choices=datasets_unique)
58
+
59
  return gr.update(choices=datasets_unique, value="")
60
 
61
 
62
  logger = logging.getLogger(__file__)
63
 
64
+
65
  def get_dataset_splits(dataset_id, dataset_config):
66
  try:
67
+ splits = datasets.get_dataset_split_names(
68
+ dataset_id, dataset_config, trust_remote_code=True
69
+ )
70
  return gr.update(choices=splits, value=splits[0], visible=True)
71
  except Exception as e:
72
+ logger.warning(
73
+ f"Check your dataset {dataset_id} and config {dataset_config}: {e}"
74
+ )
75
  return gr.update(visible=False)
76
 
77
+
78
  def check_dataset(dataset_id):
79
  logger.info(f"Loading {dataset_id}")
80
  try:
81
  configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True)
82
  if len(configs) == 0:
83
+ return (gr.update(visible=False), gr.update(visible=False), "")
84
+ splits = datasets.get_dataset_split_names(
85
+ dataset_id, configs[0], trust_remote_code=True
86
+ )
 
 
87
  return (
88
  gr.update(choices=configs, value=configs[0], visible=True),
89
  gr.update(choices=splits, value=splits[0], visible=True),
90
+ "",
91
  )
92
  except Exception as e:
93
  logger.warning(f"Check your dataset {dataset_id}: {e}")
94
  if "doesn't exist" in str(e):
95
  gr.Warning(get_dataset_fetch_error_raw(e))
96
+ if "forbidden" in str(e).lower(): # GSK-2770
97
  gr.Warning(get_dataset_fetch_error_raw(e))
98
+ return (gr.update(visible=False), gr.update(visible=False), "")
99
+
 
 
 
100
 
101
  def empty_column_mapping(uid):
102
  write_column_mapping(None, uid)
103
 
104
+
105
  def write_column_mapping_to_config(uid, *labels):
106
  # TODO: Substitute 'text' with more features for zero-shot
107
  # we are not using ds features because we only support "text" for now
 
119
 
120
  write_column_mapping(all_mappings, uid)
121
 
122
+
123
  def export_mappings(all_mappings, key, subkeys, values):
124
  if key not in all_mappings.keys():
125
  all_mappings[key] = dict()
126
  if subkeys is None:
127
  subkeys = list(all_mappings[key].keys())
128
 
129
+ if not subkeys:
130
  logging.debug(f"subkeys is empty for {key}")
131
  return all_mappings
132
 
 
145
  ds_labels = list(shared_labels)
146
  if len(ds_labels) > MAX_LABELS:
147
  ds_labels = ds_labels[:MAX_LABELS]
148
+ gr.Warning(
149
+ f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd."
150
+ )
151
 
152
  # sort labels to make sure the order is consistent
153
  # prediction gives the order based on probability
 
191
  model_id, dataset_id, dataset_config, dataset_split
192
  ):
193
  model_task = check_model_task(model_id)
194
+ if not model_task:
195
+ # Model might be not found
196
+ error_msg_html = f"<p style='color: red;'>{NOT_FOUND_MODEL_RAW}</p>"
197
+ if model_id.startswith("http://") or model_id.startswith("https://"):
198
+ error_msg = f"Please input your model id, such as {EXAMPLE_MODEL_ID}, instead of URL"
199
+ gr.Warning(error_msg)
200
+ error_msg_html = f"<p style='color: red;'>{error_msg}</p>"
201
+ else:
202
+ gr.Warning(NOT_FOUND_MODEL_RAW)
203
+
204
+ return (
205
+ gr.update(interactive=False),
206
+ gr.update(visible=False),
207
+ gr.update(visible=False),
208
+ gr.update(visible=False),
209
+ gr.update(visible=False),
210
+ gr.update(visible=False),
211
+ gr.update(value=error_msg_html, visible=True),
212
+ )
213
+
214
+ if model_task != "text-classification":
215
+ gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
216
+ return (
217
+ gr.update(interactive=False),
218
+ gr.update(value=df, visible=True),
219
+ gr.update(visible=False),
220
+ gr.update(visible=False),
221
+ gr.update(visible=False),
222
+ gr.update(visible=False),
223
+ gr.update(
224
+ value=f"<p style='color: red;'>{NOT_TEXT_CLASSIFICATION_MODEL_RAW}",
225
+ visible=True,
226
+ ),
227
+ )
228
+
229
  preload_hf_inference_api(model_id)
230
 
231
  if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
232
  return (
233
+ gr.update(interactive=False),
234
+ gr.update(visible=False),
235
  gr.update(visible=False),
236
  gr.update(visible=False),
237
  gr.update(visible=False),
 
242
  try:
243
  ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
244
  df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
245
+ ds_labels, ds_features, _ = get_labels_and_features_from_dataset(
246
+ ds[dataset_split]
247
+ )
 
 
 
 
 
 
 
 
 
248
 
249
  if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
250
  gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
251
  return (
252
+ gr.update(interactive=False),
253
  gr.update(value=df, visible=True),
254
  gr.update(visible=False),
255
  gr.update(visible=False),
256
  gr.update(visible=False),
257
  gr.update(visible=False),
258
+ gr.update(visible=False),
259
  )
260
 
261
  return (
262
+ gr.update(interactive=True),
263
+ gr.update(value=df, visible=True),
264
+ gr.update(visible=False),
265
+ gr.update(visible=False),
266
+ gr.update(visible=False),
267
+ gr.update(visible=False),
268
+ gr.update(visible=False),
269
+ )
270
  except Exception as e:
271
  # Config or split wrong
272
+ logger.warning(
273
+ f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}"
274
+ )
275
  return (
276
  gr.update(interactive=False),
277
  gr.update(visible=False),
 
279
  gr.update(visible=False),
280
  gr.update(visible=False),
281
  gr.update(visible=False),
282
+ gr.update(visible=False),
283
  )
284
 
285
 
 
306
  dropdown_placement = [
307
  gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
308
  ]
309
+
310
  hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
311
 
312
  prediction_input, prediction_response = get_example_prediction(
 
336
  )
337
 
338
  model_labels = list(prediction_response.keys())
339
+
340
+ ds = datasets.load_dataset(
341
+ dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
342
+ )
343
  ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)
344
 
345
  # when dataset does not have labels or features
 
354
  "",
355
  *dropdown_placement,
356
  )
357
+
358
  if len(ds_labels) != len(model_labels):
359
  return (
360
  gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True),
 
381
  ):
382
  return (
383
  gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
384
+ gr.update(
385
+ value=prediction_input,
386
+ lines=min(len(prediction_input) // 225 + 1, 5),
387
+ visible=True,
388
+ ),
389
  gr.update(value=prediction_response, visible=True),
390
  gr.update(visible=True, open=True),
391
  gr.update(interactive=(inference_token != "")),
 
395
 
396
  return (
397
  gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
398
+ gr.update(
399
+ value=prediction_input,
400
+ lines=min(len(prediction_input) // 225 + 1, 5),
401
+ visible=True,
402
+ ),
403
  gr.update(value=prediction_response, visible=True),
404
  gr.update(visible=True, open=False),
405
  gr.update(interactive=(inference_token != "")),
 
420
 
421
  return True
422
 
423
+
424
+ def enable_run_btn(
425
+ uid, inference_token, model_id, dataset_id, dataset_config, dataset_split
426
+ ):
427
  if inference_token == "":
428
  logger.warning("Inference API is not enabled")
429
  return gr.update(interactive=False)
430
+ if (
431
+ model_id == ""
432
+ or dataset_id == ""
433
+ or dataset_config == ""
434
+ or dataset_split == ""
435
+ ):
436
  logger.warning("Model id or dataset id is not selected")
437
  return gr.update(interactive=False)
438
+
439
  all_mappings = read_column_mapping(uid)
440
  if not check_column_mapping_keys_validity(all_mappings):
441
  logger.warning("Column mapping is not valid")
 
446
  return gr.update(interactive=False)
447
  return gr.update(interactive=True)
448
 
449
+
450
+ def construct_label_and_feature_mapping(
451
+ all_mappings, ds_labels, ds_features, label_keys=None
452
+ ):
453
  label_mapping = {}
454
  if len(all_mappings["labels"].keys()) != len(ds_labels):
455
+ logger.warning(
456
+ f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
457
+ \nall_mappings: {all_mappings}\nds_labels: {ds_labels}"""
458
+ )
459
+
460
  if len(all_mappings["features"].keys()) != len(ds_features):
461
+ logger.warning(
462
+ f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
463
+ \nall_mappings: {all_mappings}\nds_features: {ds_features}"""
464
+ )
465
 
466
+ for i, label in zip(range(len(ds_labels)), ds_labels):
467
  # align the saved labels with dataset labels order
468
  label_mapping.update({str(i): all_mappings["labels"][label]})
469
 
 
473
 
474
  feature_mapping = all_mappings["features"]
475
  if len(label_keys) > 0:
476
+ feature_mapping.update({"label": label_keys[0]})
477
  return label_mapping, feature_mapping
478
 
479
+
480
  def show_hf_token_info(token):
481
  valid = check_hf_token_validity(token)
482
  if not valid:
483
  return gr.update(visible=True)
484
  return gr.update(visible=False)
485
 
486
+
487
  def try_submit(m_id, d_id, config, split, inference_token, uid):
488
  all_mappings = read_column_mapping(uid)
489
  if not check_column_mapping_keys_validity(all_mappings):
 
492
  # get ds labels and features again for alignment
493
  ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
494
  ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
495
+ label_mapping, feature_mapping = construct_label_and_feature_mapping(
496
+ all_mappings, ds_labels, ds_features, label_keys
497
+ )
498
 
499
  eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
500
  save_job_to_pipe(
 
520
 
521
  return (
522
  gr.update(interactive=False), # Submit button
523
+ gr.update(
524
+ value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ",
525
+ lines=5,
526
+ visible=True,
527
+ interactive=False,
528
+ ),
529
  new_uid, # Allocate a new uuid
530
  gr.update(visible=False),
531
  gr.update(visible=False),
wordings.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  INTRODUCTION_MD = """
2
  <div style="display: flex; justify-content: center;">
3
  <h1 style="text-align: center;">
@@ -49,6 +51,10 @@ UNMATCHED_MODEL_DATASET_STYLED_ERROR = """
49
  </h3>
50
  """
51
 
 
 
 
 
52
  NOT_TEXT_CLASSIFICATION_MODEL_RAW = """
53
  Your model does not fall under the category of text classification. This page is specifically designated for the evaluation of text classification models.
54
  """
@@ -61,7 +67,7 @@ USE_INFERENCE_API_TIP = """
61
  . Please input your <a href="https://huggingface.co/docs/hub/security-tokens#user-access-tokens">Hugging Face token</a> to do so. You can find it <a href="https://huggingface.co/settings/tokens">here</a>.
62
  """
63
 
64
- HF_TOKEN_INVALID_STYLED= """
65
  <p style="text-align: left;color: red; ">
66
  Your Hugging Face token is invalid. Please double check your token.
67
  </p>
@@ -72,5 +78,6 @@ VALIDATED_MODEL_DATASET_STYLED = """
72
  Your model and dataset have been validated!
73
  </h3>"""
74
 
 
75
  def get_dataset_fetch_error_raw(error):
76
  return f"""Sorry you cannot use this dataset because {error}. Contact HF team to support this dataset."""
 
1
+ EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
2
+
3
  INTRODUCTION_MD = """
4
  <div style="display: flex; justify-content: center;">
5
  <h1 style="text-align: center;">
 
51
  </h3>
52
  """
53
 
54
+ NOT_FOUND_MODEL_RAW = """
55
+ We cannot find your model on Hugging Face. Please ensure that the model is accessible.
56
+ """
57
+
58
  NOT_TEXT_CLASSIFICATION_MODEL_RAW = """
59
  Your model does not fall under the category of text classification. This page is specifically designated for the evaluation of text classification models.
60
  """
 
67
  . Please input your <a href="https://huggingface.co/docs/hub/security-tokens#user-access-tokens">Hugging Face token</a> to do so. You can find it <a href="https://huggingface.co/settings/tokens">here</a>.
68
  """
69
 
70
+ HF_TOKEN_INVALID_STYLED = """
71
  <p style="text-align: left;color: red; ">
72
  Your Hugging Face token is invalid. Please double check your token.
73
  </p>
 
78
  Your model and dataset have been validated!
79
  </h3>"""
80
 
81
+
82
  def get_dataset_fetch_error_raw(error):
83
  return f"""Sorry you cannot use this dataset because {error}. Contact HF team to support this dataset."""