Spaces:
Running
Running
inoki-giskard
commited on
Commit
·
583defc
1
Parent(s):
ea670d5
Attemp to match labels in model and in dataset
Browse files
app.py
CHANGED
@@ -6,6 +6,10 @@ import os
|
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
10 |
HF_REPO_ID = 'HF_REPO_ID'
|
11 |
HF_SPACE_ID = 'SPACE_ID'
|
@@ -54,15 +58,41 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
|
|
54 |
return dataset_id, dataset_config, dataset_split
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def try_validate(model_id, dataset_id, dataset_config, dataset_split):
|
58 |
# Validate model
|
59 |
m_id, ppl = check_model(model_id=model_id)
|
60 |
if m_id is None:
|
61 |
gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
|
62 |
-
return dataset_config, dataset_split, gr.update(interactive=False)
|
63 |
if isinstance(ppl, Exception):
|
64 |
gr.Warning(f'Failed to load "{model_id} model": {ppl}')
|
65 |
-
return dataset_config, dataset_split, gr.update(interactive=False)
|
66 |
|
67 |
# Validate dataset
|
68 |
d_id, config, split = check_dataset(dataset_id=dataset_id, dataset_config=dataset_config, dataset_split=dataset_split)
|
@@ -80,15 +110,42 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split):
|
|
80 |
dataset_ok = True
|
81 |
|
82 |
if not dataset_ok:
|
83 |
-
return config, split, gr.update(interactive=False)
|
84 |
|
85 |
# TODO: Validate column mapping by running once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
del ppl
|
88 |
|
89 |
gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
|
90 |
|
91 |
-
return config, split, gr.update(interactive=True)
|
92 |
|
93 |
|
94 |
def try_submit(m_id, d_id, config, split, local):
|
@@ -133,7 +190,7 @@ def try_submit(m_id, d_id, config, split, local):
|
|
133 |
with open(output_dir / "report.html", "w") as f:
|
134 |
print(f'Writing to {output_dir / "report.html"}')
|
135 |
f.write(rendered_report)
|
136 |
-
|
137 |
print(f"Finished local evaluation on {eval_str}: {time.time() - start:.2f}s")
|
138 |
|
139 |
|
@@ -155,6 +212,7 @@ with gr.Blocks(theme=theme) as iface:
|
|
155 |
value=0,
|
156 |
)
|
157 |
run_local = gr.Checkbox(value=True, label="Run in this Space")
|
|
|
158 |
|
159 |
with gr.Column():
|
160 |
dataset_id_input = gr.Textbox(
|
@@ -180,6 +238,8 @@ with gr.Blocks(theme=theme) as iface:
|
|
180 |
value="test",
|
181 |
)
|
182 |
|
|
|
|
|
183 |
with gr.Row():
|
184 |
validate_btn = gr.Button("Validate model and dataset", variant="primary")
|
185 |
run_btn = gr.Button(
|
@@ -199,6 +259,8 @@ with gr.Blocks(theme=theme) as iface:
|
|
199 |
dataset_config_input,
|
200 |
dataset_split_input,
|
201 |
run_btn,
|
|
|
|
|
202 |
],
|
203 |
)
|
204 |
run_btn.click(
|
|
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
|
9 |
+
import pandas as pd
|
10 |
+
|
11 |
+
from transformers.pipelines import TextClassificationPipeline
|
12 |
+
|
13 |
|
14 |
HF_REPO_ID = 'HF_REPO_ID'
|
15 |
HF_SPACE_ID = 'SPACE_ID'
|
|
|
58 |
return dataset_id, dataset_config, dataset_split
|
59 |
|
60 |
|
61 |
+
def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
|
62 |
+
for model_label in id2label_mapping.keys():
|
63 |
+
if model_label.upper() == label.upper():
|
64 |
+
return model_label, label
|
65 |
+
|
66 |
+
|
67 |
+
def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
|
68 |
+
id2label_mapping = {id2label[k]: None for k in id2label.keys()}
|
69 |
+
for feature in dataset_features.values():
|
70 |
+
if not isinstance(feature, datasets.ClassLabel):
|
71 |
+
continue
|
72 |
+
if len(feature.names) != len(id2label_mapping.keys()):
|
73 |
+
continue
|
74 |
+
|
75 |
+
# Try to match labels
|
76 |
+
for label in feature.names:
|
77 |
+
if label in id2label_mapping.keys():
|
78 |
+
model_label = label
|
79 |
+
else:
|
80 |
+
# Try to find case unsensative
|
81 |
+
model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
|
82 |
+
id2label_mapping[model_label] = label
|
83 |
+
|
84 |
+
return id2label_mapping
|
85 |
+
|
86 |
+
|
87 |
def try_validate(model_id, dataset_id, dataset_config, dataset_split):
|
88 |
# Validate model
|
89 |
m_id, ppl = check_model(model_id=model_id)
|
90 |
if m_id is None:
|
91 |
gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
|
92 |
+
return dataset_config, dataset_split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
|
93 |
if isinstance(ppl, Exception):
|
94 |
gr.Warning(f'Failed to load "{model_id} model": {ppl}')
|
95 |
+
return dataset_config, dataset_split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
|
96 |
|
97 |
# Validate dataset
|
98 |
d_id, config, split = check_dataset(dataset_id=dataset_id, dataset_config=dataset_config, dataset_split=dataset_split)
|
|
|
110 |
dataset_ok = True
|
111 |
|
112 |
if not dataset_ok:
|
113 |
+
return config, split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
|
114 |
|
115 |
# TODO: Validate column mapping by running once
|
116 |
+
prediction_result = {}
|
117 |
+
id2label_df = None
|
118 |
+
if isinstance(ppl, TextClassificationPipeline):
|
119 |
+
# Retrieve all labels
|
120 |
+
id2label_mapping = {}
|
121 |
+
try:
|
122 |
+
results = ppl({"text": "Test"}, top_k=None)
|
123 |
+
prediction_result = {
|
124 |
+
result["label"]: result["score"] for result in results
|
125 |
+
}
|
126 |
+
except Exception as e:
|
127 |
+
# Pipeline is not executable
|
128 |
+
pass
|
129 |
+
|
130 |
+
# We assume dataset is ok here
|
131 |
+
ds = datasets.load_dataset(d_id, config)[split]
|
132 |
+
try:
|
133 |
+
id2label = ppl.model.config.id2label
|
134 |
+
id2label_mapping = text_classification_map_model_and_dataset_labels(ppl.model.config.id2label, ds.features)
|
135 |
+
id2label_df = pd.DataFrame({
|
136 |
+
"ID": [i for i in id2label.keys()],
|
137 |
+
"Model labels": [id2label[label] for label in id2label.keys()],
|
138 |
+
"Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
|
139 |
+
})
|
140 |
+
except AttributeError:
|
141 |
+
# Dataset does not have features
|
142 |
+
pass
|
143 |
|
144 |
del ppl
|
145 |
|
146 |
gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
|
147 |
|
148 |
+
return config, split, gr.update(interactive=True), gr.update(value=prediction_result, visible=True), gr.update(value=id2label_df, visible=True)
|
149 |
|
150 |
|
151 |
def try_submit(m_id, d_id, config, split, local):
|
|
|
190 |
with open(output_dir / "report.html", "w") as f:
|
191 |
print(f'Writing to {output_dir / "report.html"}')
|
192 |
f.write(rendered_report)
|
193 |
+
|
194 |
print(f"Finished local evaluation on {eval_str}: {time.time() - start:.2f}s")
|
195 |
|
196 |
|
|
|
212 |
value=0,
|
213 |
)
|
214 |
run_local = gr.Checkbox(value=True, label="Run in this Space")
|
215 |
+
example_labels = gr.Label(label='Model pipeline test prediction result', visible=False)
|
216 |
|
217 |
with gr.Column():
|
218 |
dataset_id_input = gr.Textbox(
|
|
|
238 |
value="test",
|
239 |
)
|
240 |
|
241 |
+
id2label_mapping_dataframe = gr.DataFrame(visible=False)
|
242 |
+
|
243 |
with gr.Row():
|
244 |
validate_btn = gr.Button("Validate model and dataset", variant="primary")
|
245 |
run_btn = gr.Button(
|
|
|
259 |
dataset_config_input,
|
260 |
dataset_split_input,
|
261 |
run_btn,
|
262 |
+
example_labels,
|
263 |
+
id2label_mapping_dataframe,
|
264 |
],
|
265 |
)
|
266 |
run_btn.click(
|