guipenedo HF staff commited on
Commit
e23c5c4
·
unverified ·
1 Parent(s): a06324b

added cumsum and %

Browse files
Files changed (1) hide show
  1. app.py +44 -20
app.py CHANGED
@@ -10,6 +10,7 @@ from pathlib import Path
10
  from typing import Literal
11
 
12
  import gradio as gr
 
13
  import plotly.express as px
14
  import plotly.graph_objects as go
15
  import tenacity
@@ -197,6 +198,8 @@ def plot_scatter(
197
  log_scale_y: bool,
198
  normalization: bool,
199
  rounding: int,
 
 
200
  progress: gr.Progress,
201
  ):
202
  fig = go.Figure()
@@ -207,6 +210,10 @@ def plot_scatter(
207
  histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
208
  x = sorted(histogram_prepared.keys())
209
  y = [histogram_prepared[k] for k in x]
 
 
 
 
210
 
211
  fig.add_trace(
212
  go.Scatter(
@@ -274,6 +281,7 @@ def plot_bars(
274
 
275
  return fig
276
 
 
277
  def get_desc(data):
278
  res = {name: list(dt.to_dict().keys()) for name, dt in data.items()}
279
  return "\n".join([
@@ -293,6 +301,8 @@ def update_graph(
293
  top_k,
294
  direction,
295
  regex,
 
 
296
  progress=gr.Progress(),
297
  ):
298
  if len(datasets) <= 0 or not metric_name or not grouping:
@@ -313,15 +323,16 @@ def update_graph(
313
 
314
  data = {path: result for path, result in zip(datasets, data)}
315
  return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x,
316
- log_scale_y, progress), data, export_data(data, metric_name), get_desc(data)
317
 
318
 
319
  def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y,
 
320
  progress=gr.Progress()):
321
  if rounding is None or top_k is None:
322
  return None
323
  graph_fc = (
324
- partial(plot_scatter, normalization=normalization, rounding=rounding)
325
  if grouping == "histogram"
326
  else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
327
  )
@@ -399,6 +410,14 @@ The data might not be 100% representative, due to the sampling and optimistic me
399
 
400
  with gr.Row():
401
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
402
  log_scale_x_checkbox = gr.Checkbox(
403
  label="Log scale x",
404
  value=False,
@@ -491,28 +510,33 @@ The data might not be 100% representative, due to the sampling and optimistic me
491
  top_select,
492
  direction_checkbox,
493
  group_regex,
 
 
494
  ],
495
  outputs=[graph_output, exported_data, export_data_json, min_max_hist_data],
496
  )
497
 
498
- for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox,
499
- log_scale_y_checkbox]:
500
- inp.change(
501
- fn=plot_data,
502
- inputs=[
503
- exported_data,
504
- metric_name_dropdown,
505
- normalization_checkbox,
506
- rounding,
507
- grouping_dropdown,
508
- top_select,
509
- direction_checkbox,
510
- group_regex,
511
- log_scale_x_checkbox,
512
- log_scale_y_checkbox,
513
- ],
514
- outputs=[graph_output],
515
- )
 
 
 
516
 
517
  datasets_selected.change(
518
  fn=fetch_groups,
 
10
  from typing import Literal
11
 
12
  import gradio as gr
13
+ import numpy as np
14
  import plotly.express as px
15
  import plotly.graph_objects as go
16
  import tenacity
 
198
  log_scale_y: bool,
199
  normalization: bool,
200
  rounding: int,
201
+ cumsum: bool,
202
+ perc: bool,
203
  progress: gr.Progress,
204
  ):
205
  fig = go.Figure()
 
210
  histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
211
  x = sorted(histogram_prepared.keys())
212
  y = [histogram_prepared[k] for k in x]
213
+ if cumsum:
214
+ y = np.cumsum(y).tolist()
215
+ if perc:
216
+ y = (np.array(y) * 100).tolist()
217
 
218
  fig.add_trace(
219
  go.Scatter(
 
281
 
282
  return fig
283
 
284
+
285
  def get_desc(data):
286
  res = {name: list(dt.to_dict().keys()) for name, dt in data.items()}
287
  return "\n".join([
 
301
  top_k,
302
  direction,
303
  regex,
304
+ cumsum,
305
+ perc,
306
  progress=gr.Progress(),
307
  ):
308
  if len(datasets) <= 0 or not metric_name or not grouping:
 
323
 
324
  data = {path: result for path, result in zip(datasets, data)}
325
  return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x,
326
+ log_scale_y, cumsum, perc, progress), data, export_data(data, metric_name), get_desc(data)
327
 
328
 
329
  def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y,
330
+ cumsum, perc,
331
  progress=gr.Progress()):
332
  if rounding is None or top_k is None:
333
  return None
334
  graph_fc = (
335
+ partial(plot_scatter, normalization=normalization, rounding=rounding, cumsum=cumsum, perc=perc)
336
  if grouping == "histogram"
337
  else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
338
  )
 
410
 
411
  with gr.Row():
412
  with gr.Column(scale=1):
413
+ cumsum_checkbox = gr.Checkbox(
414
+ label="Cumsum",
415
+ value=False,
416
+ )
417
+ perc_checkbox = gr.Checkbox(
418
+ label="%",
419
+ value=False,
420
+ )
421
  log_scale_x_checkbox = gr.Checkbox(
422
  label="Log scale x",
423
  value=False,
 
510
  top_select,
511
  direction_checkbox,
512
  group_regex,
513
+ cumsum_checkbox,
514
+ perc_checkbox
515
  ],
516
  outputs=[graph_output, exported_data, export_data_json, min_max_hist_data],
517
  )
518
 
519
+ gr.on(
520
+ triggers=[normalization_checkbox.change, rounding.change, group_regex.change, direction_checkbox.change,
521
+ top_select.change, log_scale_x_checkbox.change,
522
+ log_scale_y_checkbox.change, cumsum_checkbox.change, perc_checkbox.change],
523
+ fn=plot_data,
524
+ inputs=[
525
+ exported_data,
526
+ metric_name_dropdown,
527
+ normalization_checkbox,
528
+ rounding,
529
+ grouping_dropdown,
530
+ top_select,
531
+ direction_checkbox,
532
+ group_regex,
533
+ log_scale_x_checkbox,
534
+ log_scale_y_checkbox,
535
+ cumsum_checkbox,
536
+ perc_checkbox
537
+ ],
538
+ outputs=[graph_output],
539
+ )
540
 
541
  datasets_selected.change(
542
  fn=fetch_groups,