FixedF1 / app.py
John Graham Reynolds
add radio button for selecting averaging method
f75e383
raw
history blame
2.14 kB
import sys
import gradio as gr
import pandas as pd
import evaluate
from evaluate.utils import infer_gradio_input_types, json_to_string_type, parse_readme, parse_test_cases
# from evaluate.utils import launch_gradio_widget # using this directly is erroneous - lets fix this
from fixed_f1 import FixedF1
from pathlib import Path
metric = FixedF1()
if isinstance(metric.features, list):
(feature_names, feature_types) = zip(*metric.features[0].items())
else:
(feature_names, feature_types) = zip(*metric.features.items())
gradio_input_types = infer_gradio_input_types(feature_types)
local_path = Path(sys.path[0])
test_cases = [ {"predictions":[1,2,3,4,5], "references":[1,2,5,4,3]} ] # configure this randomly using randint generator and feature names?
# configure this based on the input type, etc. for launch_gradio_widget
def compute(input_df: pd.DataFrame, method: str):
metric = FixedF1(average=method if method != "None" else None)
cols = [col for col in input_df.columns]
predicted = [int(num) for num in input_df[cols[0]].to_list()]
references = [int(num) for num in input_df[cols[1]].to_list()]
metric.add_batch(predictions=predicted, references=references)
outputs = metric.compute()
f"Your metrics are as follows: \n {outputs}"
space = gr.Interface(
fn=compute,
inputs=[
gr.Dataframe(
headers=feature_names,
col_count=len(feature_names),
row_count=5,
datatype=json_to_string_type(gradio_input_types),
),
gr.Radio(
["weighted", "micro", "macro", "None", "binary"],
label="Averaging Method",
info="Method for averaging the F1 score across labels. `Binary` only works if you are evaluating a binary classification model."
)
],
outputs=gr.Textbox(label=metric.name),
description=metric.info.description,
title=f"Metric: {metric.name}",
article=parse_readme(local_path / "README.md"),
examples=[
[pd.DataFrame(parse_test_cases(test_cases, feature_names, gradio_input_types)[0]), "weighted"],
],
cache_examples=False
)
space.launch()