File size: 7,174 Bytes
73d9a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
from datasets import concatenate_datasets
from huggingface_hub import HfApi
from huggingface_hub.errors import HFValidationError
from requests.exceptions import HTTPError
from transformer_ranker import Result
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
from transformer_ranker.embedder import Embedder
import math

DISABLED_BUTTON_VARIANT = "huggingface"
ENABLED_BUTTON_VARIANT = "primary"

HEADLINE = """
<h1 align="center">TransformerRanker</h1>
<p align="center" style="max-width: 560px; margin: auto;">
    A very simple library that helps you find the best-suited language model for your NLP task.
    All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub.
    TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!
</p>
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
    <a href="https://github.com/flairNLP/transformer-ranker">
        <img src="https://img.shields.io/github/stars/flairNLP/transformer-ranker?style=social&label=Repository" alt="GitHub Badge">
    </a>
    <a href="https://pypi.org/project/transformer-ranker/">
        <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="Package Badge">
    </a>
    <a href="https://github.com/flairNLP/transformer-ranker/blob/main/examples/01-walkthrough.md">
        <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="Tutorials Badge">
    </a>
    <img src="https://img.shields.io/badge/license-MIT-green?style=flat" alt="License: MIT">
</p>
<p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
"""

FOOTER = """
**Note:** This demonstration currently runs on a CPU and is suited for smaller models only.  
**Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas). 
For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions).
"""

CSS = """
.gradio-container{max-width: 800px !important}
a {color: #ff9d00;}
@media (prefers-color-scheme: dark) { a {color: #be185d;} }
"""


hf_api = HfApi()


def check_dataset_exists(dataset_name):
    """Update loading button if dataset can be found"""
    try:
        hf_api.dataset_info(dataset_name)
        return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)

    except (HTTPError, HFValidationError):
        return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT)

def check_dataset_is_loaded(dataset, text_column, label_column, task_category):
    if dataset and text_column != "-" and label_column != "-" and task_category != "-":
        return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
    else:
        return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT)


def get_dataset_info(dataset):
    """Show information for dataset settings"""
    joined_dataset = concatenate_datasets(list(dataset.values()))
    datacleaner = DatasetCleaner()

    try:
        text_column = datacleaner._find_column(joined_dataset, "text column")
    except ValueError:
        gr.Warning("Text column can not be found. Select it in the dataset settings.")
        text_column = "-"

    try:
        label_column = datacleaner._find_column(joined_dataset, "label column")
    except ValueError:
        gr.Warning("Label column can not be found. Select it in the dataset settings.")
        label_column = "-"

    task_category = "-"
    if label_column != "-":
        try:
            # Find or set the task_category
            task_category = datacleaner._find_task_category(joined_dataset, label_column)
        except ValueError:
            gr.Warning(
                "Task category could not be determined. The dataset must support classification or regression tasks.",
            )
            pass

    num_samples = len(joined_dataset)

    return (
        gr.update(
            value=task_category,
            choices=[str(t) for t in TaskCategory],
            interactive=True,
        ),
        gr.update(
            value=text_column, choices=joined_dataset.column_names, interactive=True
        ),
        gr.update(
            value="-", choices=["-", *joined_dataset.column_names], interactive=True
        ),
        gr.update(
            value=label_column, choices=joined_dataset.column_names, interactive=True
        ),
        num_samples,
    )


def compute_ratio(num_samples_to_use, num_samples):
    if num_samples > 0:
        return num_samples_to_use / num_samples
    else:
        return 0.0


def ensure_one_lm_selected(checkbox_values, previous_values):
    if not any(checkbox_values):
        return previous_values
    return checkbox_values


# Apply monkey patch to enable callbacks
_old_embed = Embedder.embed

def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
    if embedder.tracker is not None:
        embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))

    return _old_embed(embedder, sentences, batch_size=batch_size, **kw)

Embedder.embed = _new_embed

_old_embed_batch = Embedder.embed_batch

def _new_embed_batch(embedder, *args, **kw):
    r = _old_embed_batch(embedder, *args, **kw)
    if embedder.tracker is not None:
        embedder.tracker.update_batch_complete()
    return r

Embedder.embed_batch = _new_embed_batch

_old_init = Embedder.__init__

def _new_init(embedder, *args, tracker=None, **kw):
    _old_init(embedder, *args, **kw)
    embedder.tracker = tracker

Embedder.__init__ = _new_init


class EmbeddingProgressTracker:
    def __init__(self, *, progress, model_names):
        self.model_names = model_names
        self.progress_bar = progress

    @property
    def total(self):
        return len(self.model_names)

    def __enter__(self):
        self.progress_bar = gr.Progress(track_tqdm=False)
        self.current_model = -1
        self.batches_complete = 0
        self.batches_total = None
        return self

    def __exit__(self, typ, value, tb):
        if typ is None:
            self.progress_bar(1.0, desc="Done")
        else:
            self.progress_bar(1.0, desc="Error")

        # Do not suppress any errors
        return False

    def update_num_batches(self, total):
        self.current_model += 1
        self.batches_complete = 0
        self.batches_total = total
        self.update_bar()

    def update_batch_complete(self):
        self.batches_complete += 1
        self.update_bar()

    def update_bar(self):
        i = self.current_model

        description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"

        progress = i / self.total
        if self.batches_total is not None:
            progress += (self.batches_complete / self.batches_total) / self.total

        self.progress_bar(progress=progress, desc=description)