Commit
·
3a4d722
1
Parent(s):
30ae784
Add parameter grid config
Browse files
app.py
CHANGED
@@ -34,15 +34,6 @@ CATEGORIES = [
|
|
34 |
]
|
35 |
|
36 |
|
37 |
-
PARAMETER_GRID = {
|
38 |
-
"vect__max_df": (0.2, 0.4, 0.6, 0.8, 1.0),
|
39 |
-
"vect__min_df": (1, 3, 5, 10),
|
40 |
-
"vect__ngram_range": ((1, 1), (1, 2)), # unigrams or bigrams
|
41 |
-
"vect__norm": ("l1", "l2"),
|
42 |
-
"clf__alpha": np.logspace(-6, 6, 13),
|
43 |
-
}
|
44 |
-
|
45 |
-
|
46 |
def shorten_param(param_name):
|
47 |
"""Remove components' prefixes in param_name."""
|
48 |
if "__" in param_name:
|
@@ -50,7 +41,7 @@ def shorten_param(param_name):
|
|
50 |
return param_name
|
51 |
|
52 |
|
53 |
-
def train_model(categories):
|
54 |
pipeline = Pipeline(
|
55 |
[
|
56 |
("vect", TfidfVectorizer()),
|
@@ -58,6 +49,16 @@ def train_model(categories):
|
|
58 |
]
|
59 |
)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
data_train = fetch_20newsgroups(
|
62 |
subset="train",
|
63 |
categories=categories,
|
@@ -83,7 +84,7 @@ def train_model(categories):
|
|
83 |
|
84 |
random_search = RandomizedSearchCV(
|
85 |
estimator=pipeline,
|
86 |
-
param_distributions=
|
87 |
n_iter=40,
|
88 |
random_state=0,
|
89 |
n_jobs=2,
|
@@ -103,7 +104,7 @@ def train_model(categories):
|
|
103 |
cv_results = pd.DataFrame(random_search.cv_results_)
|
104 |
cv_results = cv_results.rename(shorten_param, axis=1)
|
105 |
|
106 |
-
param_names = [shorten_param(name) for name in
|
107 |
labels = {
|
108 |
"mean_score_time": "CV Score time (s)",
|
109 |
"mean_test_score": "CV score (accuracy)",
|
@@ -156,28 +157,10 @@ def train_model(categories):
|
|
156 |
return fig, fig2, best_parameters, test_accuracy
|
157 |
|
158 |
|
159 |
-
|
160 |
-
"
|
161 |
-
|
162 |
-
"which will be automatically downloaded, cached and reused for the document classification example.",
|
163 |
-
]
|
164 |
|
165 |
-
DESCRIPTION_PART2 = [
|
166 |
-
"In this example, we tune the hyperparameters of",
|
167 |
-
"a particular classifier using a",
|
168 |
-
"[RandomizedSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html#sklearn.model_selection.RandomizedSearchCV).",
|
169 |
-
"For a demo on the performance of some other classifiers, see the",
|
170 |
-
"[Classification of text documents using sparse features](https://scikit-learn.org/stable/auto_examples/text/plot_document_classification_20newsgroups.html#sphx-glr-auto-examples-text-plot-document-classification-20newsgroups-py) notebook.",
|
171 |
-
]
|
172 |
-
|
173 |
-
CATEGORY_SELECTION_DESCRIPTION = [
|
174 |
-
"The task of text classification is easier when there is little overlap between the characteristic terms ",
|
175 |
-
"of different topics. This is because the presence of common terms can make it difficult to distinguish between ",
|
176 |
-
"different topics. On the other hand, when there is little overlap between the characteristic terms of different ",
|
177 |
-
"topics, the task of text classification becomes easier, as the unique terms of each topic provide a solid basis ",
|
178 |
-
"for accurately classifying the document into its respective category. Therefore, careful selection of characteristic",
|
179 |
-
" terms for each topic is crucial to ensure accuracy in text classification."
|
180 |
-
]
|
181 |
|
182 |
AUTHOR = """
|
183 |
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
|
@@ -188,14 +171,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
188 |
with gr.Row():
|
189 |
with gr.Column():
|
190 |
gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
|
191 |
-
gr.Markdown("
|
192 |
-
gr.Markdown("
|
193 |
gr.Markdown(AUTHOR)
|
194 |
|
195 |
with gr.Row():
|
196 |
with gr.Column():
|
197 |
gr.Markdown("""## CATEGORY SELECTION""")
|
198 |
-
gr.Markdown(""
|
199 |
drop_categories = gr.Dropdown(
|
200 |
CATEGORIES,
|
201 |
value=["alt.atheism", "talk.religion.misc"],
|
@@ -207,20 +190,70 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
207 |
)
|
208 |
with gr.Row():
|
209 |
with gr.Column():
|
210 |
-
gr.Markdown(
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
## MODEL PIPELINE
|
225 |
```python
|
226 |
pipeline = Pipeline(
|
@@ -231,7 +264,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
231 |
)
|
232 |
```
|
233 |
"""
|
234 |
-
|
235 |
with gr.Row():
|
236 |
with gr.Column():
|
237 |
gr.Markdown("""## TRAINING""")
|
@@ -248,7 +281,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
|
|
248 |
|
249 |
brn_train.click(
|
250 |
train_model,
|
251 |
-
[drop_categories],
|
252 |
[plot_trade, plot_coordinates, best_parameters, test_accuracy],
|
253 |
)
|
254 |
|
|
|
34 |
]
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def shorten_param(param_name):
|
38 |
"""Remove components' prefixes in param_name."""
|
39 |
if "__" in param_name:
|
|
|
41 |
return param_name
|
42 |
|
43 |
|
44 |
+
def train_model(categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm):
|
45 |
pipeline = Pipeline(
|
46 |
[
|
47 |
("vect", TfidfVectorizer()),
|
|
|
49 |
]
|
50 |
)
|
51 |
|
52 |
+
parameters_grid = {
|
53 |
+
"vect__max_df": [eval(value) for value in vect__max_df.split(",")],
|
54 |
+
"vect__min_df": [eval(value) for value in vect__min_df.split(",")],
|
55 |
+
"vect__ngram_range": eval(vect__ngram_range), # unigrams or bigrams
|
56 |
+
"vect__norm": [value.strip() for value in vect__norm.split(",")],
|
57 |
+
"clf__alpha": np.logspace(-6, 6, 13),
|
58 |
+
}
|
59 |
+
|
60 |
+
print(parameters_grid)
|
61 |
+
|
62 |
data_train = fetch_20newsgroups(
|
63 |
subset="train",
|
64 |
categories=categories,
|
|
|
84 |
|
85 |
random_search = RandomizedSearchCV(
|
86 |
estimator=pipeline,
|
87 |
+
param_distributions=parameters_grid,
|
88 |
n_iter=40,
|
89 |
random_state=0,
|
90 |
n_jobs=2,
|
|
|
104 |
cv_results = pd.DataFrame(random_search.cv_results_)
|
105 |
cv_results = cv_results.rename(shorten_param, axis=1)
|
106 |
|
107 |
+
param_names = [shorten_param(name) for name in parameters_grid.keys()]
|
108 |
labels = {
|
109 |
"mean_score_time": "CV Score time (s)",
|
110 |
"mean_test_score": "CV score (accuracy)",
|
|
|
157 |
return fig, fig2, best_parameters, test_accuracy
|
158 |
|
159 |
|
160 |
+
def load_description(name):
|
161 |
+
with open(f"./descriptions/{name}.md", "r") as f:
|
162 |
+
return f.read()
|
|
|
|
|
163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
AUTHOR = """
|
166 |
Created by [@dominguesm](https://huggingface.co/dominguesm) based on [scikit-learn docs](https://scikit-learn.org/stable/auto_examples/model_selection/plot_grid_search_text_feature_extraction.html)
|
|
|
171 |
with gr.Row():
|
172 |
with gr.Column():
|
173 |
gr.Markdown("# Sample pipeline for text feature extraction and evaluation")
|
174 |
+
gr.Markdown(load_description("description_part1"))
|
175 |
+
gr.Markdown(load_description("description_part2"))
|
176 |
gr.Markdown(AUTHOR)
|
177 |
|
178 |
with gr.Row():
|
179 |
with gr.Column():
|
180 |
gr.Markdown("""## CATEGORY SELECTION""")
|
181 |
+
gr.Markdown(load_description("description_category_selection"))
|
182 |
drop_categories = gr.Dropdown(
|
183 |
CATEGORIES,
|
184 |
value=["alt.atheism", "talk.religion.misc"],
|
|
|
190 |
)
|
191 |
with gr.Row():
|
192 |
with gr.Column():
|
193 |
+
gr.Markdown("""## PARAMETERS GRID""")
|
194 |
+
gr.Markdown(load_description("description_parameter_grid"))
|
195 |
+
with gr.Column():
|
196 |
+
gr.Markdown("""### Classifier Alpha""")
|
197 |
+
gr.Markdown(load_description("parameter_grid/alpha"))
|
198 |
+
|
199 |
+
clf__alpha = gr.Textbox(
|
200 |
+
label="clf__alpha",
|
201 |
+
value="1.e-06, 1.e-05, 1.e-04",
|
202 |
+
info="Due to practical considerations, this parameter was kept constant.",
|
203 |
+
interactive=False,
|
204 |
+
)
|
205 |
+
|
206 |
+
with gr.Column():
|
207 |
+
gr.Markdown("""### Vectorizer max_df""")
|
208 |
+
gr.Markdown(load_description("parameter_grid/max_df"))
|
209 |
+
|
210 |
+
vect__max_df = gr.Textbox(
|
211 |
+
label="vect__max_df",
|
212 |
+
value="0.2, 0.4, 0.6, 0.8, 1.0",
|
213 |
+
info="Values ranging from 0 to 1.0, separated by a comma.",
|
214 |
+
interactive=True,
|
215 |
+
)
|
216 |
+
|
217 |
+
with gr.Column():
|
218 |
+
gr.Markdown("""### Vectorizer min_df""")
|
219 |
+
gr.Markdown(load_description("parameter_grid/min_df"))
|
220 |
+
|
221 |
+
vect__min_df = gr.Textbox(
|
222 |
+
label="vect__min_df",
|
223 |
+
value="1, 3, 5, 10",
|
224 |
+
info="Values ranging from 0 to 1.0, separated by a comma, or integers separated by a comma. If float, the parameter represents a proportion of documents, integer absolute counts.",
|
225 |
+
interactive=True,
|
226 |
+
)
|
227 |
+
with gr.Column():
|
228 |
+
gr.Markdown("""### Vectorizer ngram_range""")
|
229 |
+
gr.Markdown(load_description("parameter_grid/ngram_range"))
|
230 |
+
|
231 |
+
vect__ngram_range = gr.Textbox(
|
232 |
+
label="vect__ngram_range",
|
233 |
+
value="(1, 1), (1, 2)",
|
234 |
+
info="""Tuples of integer values separated by a comma. For example an ``ngram_range`` of ``(1, 1)`` means only unigrams, ``(1, 2)`` means unigrams and bigrams, and ``(2, 2)`` means only bigrams.""",
|
235 |
+
interactive=True,
|
236 |
+
)
|
237 |
+
with gr.Column():
|
238 |
+
gr.Markdown("""### Vectorizer norm""")
|
239 |
+
gr.Markdown(load_description("parameter_grid/norm"))
|
240 |
+
gr.Markdown(
|
241 |
+
"""- 'l2': Sum of squares of vector elements is 1. The cosine
|
242 |
+
similarity between two vectors is their dot product when l2 norm has
|
243 |
+
been applied.
|
244 |
+
- 'l1': Sum of absolute values of vector elements is 1."""
|
245 |
+
)
|
246 |
+
|
247 |
+
vect__norm = gr.Textbox(
|
248 |
+
label="vect__norm",
|
249 |
+
value="l1, l2",
|
250 |
+
info="'l1' or 'l2', separated by a comma",
|
251 |
+
interactive=True,
|
252 |
+
)
|
253 |
+
|
254 |
+
with gr.Row():
|
255 |
+
gr.Markdown(
|
256 |
+
"""
|
257 |
## MODEL PIPELINE
|
258 |
```python
|
259 |
pipeline = Pipeline(
|
|
|
264 |
)
|
265 |
```
|
266 |
"""
|
267 |
+
)
|
268 |
with gr.Row():
|
269 |
with gr.Column():
|
270 |
gr.Markdown("""## TRAINING""")
|
|
|
281 |
|
282 |
brn_train.click(
|
283 |
train_model,
|
284 |
+
[drop_categories, vect__max_df, vect__min_df, vect__ngram_range, vect__norm],
|
285 |
[plot_trade, plot_coordinates, best_parameters, test_accuracy],
|
286 |
)
|
287 |
|