giskard-evaluator / app_text_classification.py
ZeroCommand's picture
clean up and change run btn|
4434857
raw
history blame
9.66 kB
import gradio as gr
import datasets
import os
import time
import subprocess
import logging
import json
from transformers.pipelines import TextClassificationPipeline
from text_classification import get_labels_and_features_from_dataset, check_model, get_example_prediction, check_column_mapping_keys_validity, text_classification_fix_column_mapping
from utils import read_scanners, write_scanners, read_inference_type, read_column_mapping, write_column_mapping, write_inference_type, convert_column_mapping_to_json
from wordings import CONFIRM_MAPPING_DETAILS_MD, CONFIRM_MAPPING_DETAILS_FAIL_MD, CONFIRM_MAPPING_DETAILS_FAIL_RAW
HF_REPO_ID = 'HF_REPO_ID'
HF_SPACE_ID = 'SPACE_ID'
HF_WRITE_TOKEN = 'HF_WRITE_TOKEN'
MAX_LABELS = 20
MAX_FEATURES = 20
EXAMPLE_MODEL_ID = 'cardiffnlp/twitter-roberta-base-sentiment-latest'
EXAMPLE_DATA_ID = 'tweet_eval'
CONFIG_PATH='./config.yaml'
def try_submit(m_id, d_id, config, split, local):
all_mappings = read_column_mapping(CONFIG_PATH)
if "labels" not in all_mappings.keys():
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return gr.update(interactive=True)
label_mapping = all_mappings["labels"]
if "features" not in all_mappings.keys():
gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
return gr.update(interactive=True)
feature_mapping = all_mappings["features"]
# TODO: Set column mapping for some dataset such as `amazon_polarity`
if local:
command = [
"python",
"cli.py",
"--loader", "huggingface",
"--model", m_id,
"--dataset", d_id,
"--dataset_config", config,
"--dataset_split", split,
"--hf_token", os.environ.get(HF_WRITE_TOKEN),
"--discussion_repo", os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
"--output_format", "markdown",
"--output_portal", "huggingface",
"--feature_mapping", json.dumps(feature_mapping),
"--label_mapping", json.dumps(label_mapping),
"--scan_config", "../config.yaml",
]
eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
start = time.time()
logging.info(f"Start local evaluation on {eval_str}")
evaluator = subprocess.Popen(
command,
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "cicd"),
stderr=subprocess.STDOUT,
)
result = evaluator.wait()
logging.info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
gr.Info(f"Finished local evaluation exit code {result} on {eval_str}: {time.time() - start:.2f}s")
else:
gr.Info("TODO: Submit task to an endpoint")
return gr.update(interactive=True) # Submit button
def check_dataset_and_get_config(dataset_id):
try:
configs = datasets.get_dataset_config_names(dataset_id)
return gr.Dropdown(configs, value=configs[0], visible=True)
except Exception:
# Dataset may not exist
pass
def check_dataset_and_get_split(dataset_id, dataset_config):
try:
splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
return gr.Dropdown(splits, value=splits[0], visible=True)
except Exception:
# Dataset may not exist
# gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
pass
def get_demo():
with gr.Row():
gr.Markdown(CONFIRM_MAPPING_DETAILS_MD)
with gr.Row():
model_id_input = gr.Textbox(
label="Hugging Face model id",
placeholder=EXAMPLE_MODEL_ID + " (press enter to confirm)",
)
dataset_id_input = gr.Textbox(
label="Hugging Face Dataset id",
placeholder=EXAMPLE_DATA_ID + " (press enter to confirm)",
)
with gr.Row():
dataset_config_input = gr.Dropdown(label='Dataset Config', visible=False)
dataset_split_input = gr.Dropdown(label='Dataset Split', visible=False)
with gr.Row():
example_input = gr.Markdown('Example Input', visible=False)
with gr.Row():
example_prediction = gr.Label(label='Model Prediction Sample', visible=False)
with gr.Row():
column_mappings = []
with gr.Column():
for _ in range(MAX_LABELS):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Column():
for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
column_mappings.append(gr.Dropdown(visible=False))
with gr.Accordion(label='Model Wrap Advance Config (optional)', open=False):
run_local = gr.Checkbox(value=True, label="Run in this Space")
use_inference = read_inference_type('./config.yaml') == 'hf_inference_api'
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")
with gr.Accordion(label='Scanner Advance Config (optional)', open=False):
selected = read_scanners('./config.yaml')
scan_config = selected + ['data_leakage']
scanners = gr.CheckboxGroup(choices=scan_config, value=selected, label='Scan Settings', visible=True)
with gr.Row():
run_btn = gr.Button(
"Get Evaluation Result",
variant="primary",
interactive=True,
size="lg",
)
@gr.on(triggers=[label.change for label in column_mappings],
inputs=[dataset_id_input, dataset_config_input, dataset_split_input, *column_mappings])
def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels):
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
if labels is None:
return
labels = [*labels]
all_mappings = read_column_mapping(CONFIG_PATH)
if "labels" not in all_mappings.keys():
all_mappings["labels"] = dict()
for i, label in enumerate(labels[:MAX_LABELS]):
if label:
all_mappings["labels"][label] = ds_labels[i]
if "features" not in all_mappings.keys():
all_mappings["features"] = dict()
for i, feat in enumerate(labels[MAX_LABELS:(MAX_LABELS + MAX_FEATURES)]):
if feat:
all_mappings["features"][feat] = ds_features[i]
write_column_mapping(all_mappings)
def list_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split, model_id2label, model_features):
ds_labels, ds_features = get_labels_and_features_from_dataset(dataset_id, dataset_config, dataset_split)
if ds_labels is None or ds_features is None:
return [gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
model_labels = list(model_id2label.values())
lables = [gr.Dropdown(label=f"{label}", choices=model_labels, value=model_id2label[i], interactive=True, visible=True) for i, label in enumerate(ds_labels[:MAX_LABELS])]
lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
features = [gr.Dropdown(label=f"{feature}", choices=ds_features, value=ds_features[0], interactive=True, visible=True) for feature in model_features]
features += [gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))]
return lables + features
@gr.on(triggers=[model_id_input.change, dataset_config_input.change])
def clear_column_mapping_config():
write_column_mapping(None)
@gr.on(triggers=[model_id_input.change, dataset_config_input.change, dataset_split_input.change],
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
outputs=[example_input, example_prediction, *column_mappings])
def check_model_and_show_prediction(model_id, dataset_id, dataset_config, dataset_split):
ppl = check_model(model_id)
if ppl is None or not isinstance(ppl, TextClassificationPipeline):
gr.Warning("Please check your model.")
return (
gr.update(visible=False),
gr.update(visible=False),
*[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)]
)
model_id2label = ppl.model.config.id2label
model_features = ['text']
column_mappings = list_labels_and_features_from_dataset(
dataset_id,
dataset_config,
dataset_split,
model_id2label,
model_features
)
if ppl is None:
gr.Warning("Model not found")
return (
gr.update(visible=False),
gr.update(visible=False),
*column_mappings
)
prediction_input, prediction_output = get_example_prediction(ppl, dataset_id, dataset_config, dataset_split)
return (
gr.update(value=prediction_input, visible=True),
gr.update(value=prediction_output, visible=True),
*column_mappings
)
dataset_id_input.blur(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
dataset_config_input.change(
check_dataset_and_get_split,
inputs=[dataset_id_input, dataset_config_input],
outputs=[dataset_split_input])
gr.on(
triggers=[
run_btn.click,
],
fn=try_submit,
inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, run_local],
outputs=[run_btn])