Switch to streamlit with markdown, add T5X pre-trained models
Browse files- README.md +1 -1
- REMARKS.md +0 -2
- app.py +167 -67
- data/pretrained.sqlite +0 -0
- adafactor_vs_adam_pretrain.png β img/adafactor_vs_adam_pretrain.png +0 -0
- bfloat16_loss.png β img/bfloat16_loss.png +0 -0
- eval_summ_rouge1_202302.png β img/eval_summ_rouge1_202302.png +0 -0
- eval_t5_dutch_english.png β img/eval_t5_dutch_english.png +0 -0
- eval_transl_bleu_202302.png β img/eval_transl_bleu_202302.png +0 -0
- evaluation_t5_dutch_english.png β img/evaluation_t5_dutch_english.png +0 -0
- optim_lr_summarization.png β img/optim_lr_summarization.png +0 -0
- t5v1_1eval_loss_and_accuracy.png β img/t5v1_1eval_loss_and_accuracy.png +0 -0
- train_loss_eval_summarization.png β img/train_loss_eval_summarization.png +0 -0
- train_loss_eval_t5_translation.png β img/train_loss_eval_t5_translation.png +0 -0
- training_base_36l_losses.png β img/training_base_36l_losses.png +0 -0
- training_losses_summarization_sweep.png β img/training_losses_summarization_sweep.png +0 -0
- requirements.txt +2 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: Pre-training Dutch T5 Models
|
3 |
emoji: π
|
4 |
colorFrom: blue
|
5 |
colorTo: pink
|
|
|
1 |
---
|
2 |
+
title: Pre-training Dutch T5 Models, evaluation and model lists
|
3 |
emoji: π
|
4 |
colorFrom: blue
|
5 |
colorTo: pink
|
REMARKS.md
CHANGED
@@ -6,8 +6,6 @@
|
|
6 |
intend to run many epochs on the same data, its worth to try a training run without dropout.
|
7 |
If you want to compare losses, be sure to set the dropout rate equal.
|
8 |
The smaller models can probably always be trained without.
|
9 |
-
* For the translation task, I am not sure that a 'deep-narrow' model (e.g. base-nl36) is better than a normal model
|
10 |
-
or even a 'wide-deep' model.
|
11 |
* Training with more layers is much slower than you'd expect from the increased model size.
|
12 |
It is also more difficult to get batch size and learning rate right. Below is a section
|
13 |
about finding the right hyperparameters for the base-36L training.
|
|
|
6 |
intend to run many epochs on the same data, its worth to try a training run without dropout.
|
7 |
If you want to compare losses, be sure to set the dropout rate equal.
|
8 |
The smaller models can probably always be trained without.
|
|
|
|
|
9 |
* Training with more layers is much slower than you'd expect from the increased model size.
|
10 |
It is also more difficult to get batch size and learning rate right. Below is a section
|
11 |
about finding the right hyperparameters for the base-36L training.
|
app.py
CHANGED
@@ -1,9 +1,43 @@
|
|
|
|
1 |
import time
|
2 |
|
|
|
3 |
import psutil
|
4 |
import streamlit as st
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def main():
|
@@ -13,6 +47,7 @@ def main():
|
|
13 |
initial_sidebar_state="collapsed", # Can be "auto", "expanded", "collapsed"
|
14 |
page_icon="π", # String, anything supported by st.image, or None.
|
15 |
)
|
|
|
16 |
|
17 |
with open("style.css") as f:
|
18 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
@@ -23,9 +58,10 @@ def main():
|
|
23 |
with open("PRETRAINING.md", "r") as f:
|
24 |
st.markdown(f.read())
|
25 |
|
26 |
-
st.markdown(
|
|
|
27 |
|
28 |
-
###
|
29 |
|
30 |
Each pre-trained model was evaluated by fine-tuning on summarization and translation. The learning-rate was set to
|
31 |
a constant schedule after a small warmup of 32 steps.
|
@@ -33,7 +69,7 @@ Fine-tuning for evaluation was done on a limited set of 50K examples from the fi
|
|
33 |
|
34 |
| | Summarization | Translation |
|
35 |
|-----------------:|------------------|-------------------|
|
36 |
-
| Dataset | CNN Dailymail
|
37 |
| #train samples | 50K | 50K |
|
38 |
| Optimizer | AdamW | AdamW |
|
39 |
| learning rate | 0.001 | 0.0005 |
|
@@ -42,78 +78,127 @@ Fine-tuning for evaluation was done on a limited set of 50K examples from the fi
|
|
42 |
| #eval samples | 1000 | 1000 |
|
43 |
| wandb link | [eval_summ](https://wandb.ai/yepster/eval_dutch_cnndaily_202302_flax)|[eval_transl](https://wandb.ai/yepster/eval_dutch_ccmatrix_202302_flax) |
|
44 |
|
45 |
-
|
46 |
|
47 |
-
The
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
showing performance comparable to Dutch T5 base models.
|
53 |
-
*
|
|
|
54 |
* I am surprised by the consistent bad scores for the `long-t5` runs. I've retried the fine-tuning of these models with
|
55 |
`float32` instead of `bfloat16`, but the results were the same. Maybe this is normal behaviour for these models
|
56 |
targeted at dealing with longer sequence lengths.
|
57 |
-
|
58 |
-
#### Translation
|
59 |
-
|
60 |
-
The graph below ([WandB link](https://wandb.ai/yepster/eval_dutch_ccmatrix_202302_flax/reports/eval_ds_0-score-23-02-11-17-32-48---VmlldzozNTM0NTIy) shows the Bleu score for the translation runs, evaluated at step 25K and
|
61 |
-
50K on the [CCMatrix](https://huggingface.co/datasets/yhavinga/ccmatrix_en_nl) dataset, from
|
62 |
-
English to Dutch:
|
63 |
-
""")
|
64 |
-
st.image("eval_transl_bleu_202302.png", width=IMAGE_WIDTHS)
|
65 |
-
st.markdown("""* For the translation task from English to Dutch, the Dutch+English pre-trained models perform well. Also
|
66 |
`UL2 Dutch` pre-trained Dutch models are consistently better than their `Flan`, `T5 Dutch` and
|
67 |
`mT5` counterparts of the comparable size.
|
68 |
-
*
|
69 |
-
|
|
|
|
|
70 |
tokens, the sliding attention window with radius length 127 of the `long-t5` models should be able to handle this.
|
71 |
-
|
72 |
-
The figure below shows the evaluation scores for most models, with summarization Rouge1 on the x-axis (higher is better),
|
73 |
-
and translation English to Dutch Bleu score on the y-axis (higher is better).
|
74 |
-
The point size is proportional to the model size. UL2 models are blue, Flan models
|
75 |
-
red, mT5 green and the other models black.
|
76 |
""")
|
77 |
-
st.image("eval_t5_dutch_english.png", width=IMAGE_WIDTHS)
|
78 |
-
st.markdown("""* For clarity, not all models are shown.
|
79 |
-
Among the omitted are `t5-base-36L-dutch-english-cased` with scores comparable to `ul2-large-dutch-english`, but slower inference.
|
80 |
-
The `long-t5` models had such a bad performance that they could not be graphed without cluttering the other models together.
|
81 |
-
Also fine-tuning `t5-v1.1-large-dutch-cased` with the fixed settings for learning rate and batch size diverged.
|
82 |
-
* Across the board, for translation the models pre-trained with Dutch+English or Dutch converge faster than other models.
|
83 |
-
I was surprised to see `t5-xl-4l` among the best models on translation, as it has only 4 layers, and previous tests
|
84 |
-
showed that it had a very bad performance (In those tests I had forgot to force set the dropout rate to 0.0, and
|
85 |
-
apparently this model was very sensitive to dropout).
|
86 |
-
""")
|
87 |
|
88 |
with open("REMARKS.md", "r") as f:
|
89 |
st.markdown(f.read())
|
90 |
|
91 |
-
st.markdown(
|
|
|
92 |
|
93 |
When training models with `bfloat16` and without loss regularization (default), the training losses would plateau or
|
94 |
diverge. The graph below displays the results of different attempts
|
95 |
to train [t5-small-24L-dutch-english](https://huggingface.co/yhavinga/t5-small-24L-dutch-english).
|
96 |
The legend indicates the optimizer, data type, learning rate, total batch size, and learning rate schedule used.
|
97 |
As you can see, all attempts to train with `bfloat16` failed.
|
98 |
-
"""
|
99 |
-
|
100 |
-
st.
|
101 |
-
|
|
|
|
|
102 |
function, with the purpose to pull the weights towards zero.
|
103 |
I experimented with adding this regularization term in the HF pre-training script,
|
104 |
and the `bfloat16` training runs did not exhibit the problems illustrated above anymore.
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
to the regularization term added in T5X's cross entropy loss function.
|
109 |
The Optax optimizer, used in the HuggingFace script, mentions weight decay for AdaFactor (and AdamW)
|
110 |
but also mentions that L2 regularization does not work as expected with adaptive gradient
|
111 |
algorithms. It might be the case that setting a non-zero `weight_decay_rate` in the Optax Adafactor call
|
112 |
in the HuggingFace pre-training script is an alternative to adding the `z_loss` term, to solve the bfloat16 issues, but
|
113 |
I haven't tested this yet.
|
114 |
-
"""
|
|
|
115 |
|
116 |
-
st.markdown(
|
|
|
117 |
|
118 |
During the Flax/Jax Community week in '21, our team quickly decided on using Adafactor with learning rate 5e-3.
|
119 |
I believed that a more optimal setting could be found with more time.
|
@@ -125,50 +210,64 @@ because the initial version of the training script had the optimizer as a boolea
|
|
125 |
changed to a string with the optimizer name.) --
|
126 |
All runs in the graph below that achieve a loss below 4 use **Adafactor**.
|
127 |
Peach-sweep-6 is represented by a dashed orange line and had a learning rate of **5e-3**.
|
128 |
-
"""
|
|
|
129 |
|
130 |
-
st.image("adafactor_vs_adam_pretrain.png", width=IMAGE_WIDTHS)
|
131 |
-
st.markdown(
|
|
|
132 |
to find it. In a recent tweet Lucas Nestler had more success with Shampoo (https://twitter.com/_clashluke/status/1535994026876252160)
|
133 |
so maybe I need to revisit the attempt with the latest upstream code bases.
|
134 |
|
135 |
Later, when pre-training with T5X, I found that its custom Adafactor implementation with the default settings of the T5X gin configs,
|
136 |
a learning rate of 0.001 and inverse square root learning rate decay, worked well.
|
137 |
-
"""
|
|
|
138 |
|
139 |
-
st.markdown(
|
|
|
140 |
|
141 |
Finetuning summarization requires more memory than translation due to the longer sequence lengths involved.
|
142 |
I wondered if I could use Adafactor instead of Adam and ran
|
143 |
a sweep to test this. The sweep was configured with Hyperband, so not all training runs completed to the end.
|
144 |
-
"""
|
145 |
-
|
146 |
-
st.
|
147 |
-
|
|
|
|
|
|
|
148 |
|
149 |
-
st.image("training_losses_summarization_sweep.png", width=IMAGE_WIDTHS)
|
150 |
-
st.markdown(
|
|
|
151 |
While the Adafactor run with learning rate 7e-4 came close to the Adam runs, the consistent stability of training with Adam
|
152 |
made me stick with Adam as optimizer for evaluation runs on the several models. For translation the results were similar, though in the end I needed to configure a lower learning rate for all
|
153 |
models to converge during fine-tuning.
|
154 |
-
"""
|
|
|
155 |
|
156 |
-
st.markdown(
|
|
|
157 |
|
158 |
The models `t5-v1_1-base-dutch-english-cased` and `t5-v1_1-base-dutch-english-cased-1024` have the same model dimensions,
|
159 |
but are pre-trained on different sequence lenghts, 512 and 1024 respectively.
|
160 |
The evaluation loss and accuracy of the models do not look too different. Since training of the 1024 sequence length model was
|
161 |
very slow and didn't converge a was was very slow, I stopped it early. The figure below shows the evaluation
|
162 |
loss and accuracy.
|
163 |
-
"""
|
164 |
-
|
165 |
-
st.
|
|
|
|
|
166 |
sequence length model about 2 epochs of the `large` nl+en config (100B tokens total). While I expected both models to
|
167 |
perform similarly on downstream tasks, the 1024 sequence length model has better scores for both
|
168 |
summarization and translation.
|
169 |
-
"""
|
|
|
170 |
|
171 |
-
st.markdown(
|
|
|
172 |
|
173 |
### t5_1_1
|
174 |
|
@@ -302,7 +401,8 @@ of the training. Weights & Biases made it possible to keep track of many trainin
|
|
302 |
and orchestrate hyperparameter sweeps with insightful visualizations.
|
303 |
|
304 |
Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/)
|
305 |
-
"""
|
|
|
306 |
|
307 |
st.write(
|
308 |
f"""
|
|
|
1 |
+
from functools import partial
|
2 |
import time
|
3 |
|
4 |
+
import sqlite3
|
5 |
import psutil
|
6 |
import streamlit as st
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import seaborn as sns
|
11 |
+
|
12 |
+
IMAGE_WIDTHS = 900
|
13 |
+
PRE_TRAINED_DB = "data/pretrained.sqlite"
|
14 |
+
|
15 |
+
|
16 |
+
@st.cache
|
17 |
+
def load_eval_data():
|
18 |
+
conn = sqlite3.connect(PRE_TRAINED_DB)
|
19 |
+
conn.row_factory = lambda c, r: {
|
20 |
+
col[0]: r[idx] for idx, col in enumerate(c.description)
|
21 |
+
}
|
22 |
+
df = pd.read_sql_query("SELECT * FROM pretrained", conn)
|
23 |
+
df.replace("None", np.nan, inplace=True)
|
24 |
+
df.rename(columns={"model": "name"}, inplace=True)
|
25 |
+
df = df.infer_objects()
|
26 |
+
int_columns = ["train_batch_size", "num_parameters"]
|
27 |
+
df[int_columns] = df[int_columns].astype("Int32")
|
28 |
+
plot_df = df[["name", "num_parameters", "summ_rouge1", "trans_en_nl_score"]]
|
29 |
+
plot_df[["num_parameters", "summ_rouge1", "trans_en_nl_score"]] = plot_df[
|
30 |
+
["num_parameters", "summ_rouge1", "trans_en_nl_score"]
|
31 |
+
].apply(pd.to_numeric)
|
32 |
+
plot_df["num params (M)"] = plot_df["num_parameters"].map(
|
33 |
+
lambda x: int(x / 10**6)
|
34 |
+
)
|
35 |
+
plot_df.dropna(subset=["summ_rouge1"], inplace=True)
|
36 |
+
plot_df.rename(
|
37 |
+
columns={"summ_rouge1": "summ Rouge1", "trans_en_nl_score": "en->nl Bleu"},
|
38 |
+
inplace=True,
|
39 |
+
)
|
40 |
+
return plot_df
|
41 |
|
42 |
|
43 |
def main():
|
|
|
47 |
initial_sidebar_state="collapsed", # Can be "auto", "expanded", "collapsed"
|
48 |
page_icon="π", # String, anything supported by st.image, or None.
|
49 |
)
|
50 |
+
plot_df = load_eval_data()
|
51 |
|
52 |
with open("style.css") as f:
|
53 |
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
|
|
58 |
with open("PRETRAINING.md", "r") as f:
|
59 |
st.markdown(f.read())
|
60 |
|
61 |
+
st.markdown(
|
62 |
+
"""## Evaluation
|
63 |
|
64 |
+
### Evaluation setup
|
65 |
|
66 |
Each pre-trained model was evaluated by fine-tuning on summarization and translation. The learning-rate was set to
|
67 |
a constant schedule after a small warmup of 32 steps.
|
|
|
69 |
|
70 |
| | Summarization | Translation |
|
71 |
|-----------------:|------------------|-------------------|
|
72 |
+
| Dataset | [CNN Dailymail Dutch](https://huggingface.co/datasets/yhavinga/cnn_dailymail_dutch) | [CCMatrix En->NL](https://huggingface.co/datasets/yhavinga/ccmatrix_en_nl) |
|
73 |
| #train samples | 50K | 50K |
|
74 |
| Optimizer | AdamW | AdamW |
|
75 |
| learning rate | 0.001 | 0.0005 |
|
|
|
78 |
| #eval samples | 1000 | 1000 |
|
79 |
| wandb link | [eval_summ](https://wandb.ai/yepster/eval_dutch_cnndaily_202302_flax)|[eval_transl](https://wandb.ai/yepster/eval_dutch_ccmatrix_202302_flax) |
|
80 |
|
81 |
+
### Evaluation results
|
82 |
|
83 |
+
The figure below shows the evaluation scores for most models, with summarization Rouge1 on the x-axis (higher is better),
|
84 |
+
and translation English to Dutch Bleu score on the y-axis (higher is better).
|
85 |
+
The point size is proportional to the model size. UL2 models are blue, Flan models
|
86 |
+
red, mT5 green and the other models black.
|
87 |
+
"""
|
88 |
+
)
|
89 |
+
col1, col2 = st.columns(2)
|
90 |
+
with col1:
|
91 |
+
ul2_enabled = st.checkbox("UL2 Dutch (and English) (trained with T5X)", value=True)
|
92 |
+
t5_1_1_enabled = st.checkbox("t5_1_1 Dutch (trained with T5X)", value=True)
|
93 |
+
flan_enabled = st.checkbox("Flan T5 (google/flan-t5-*)", value=True)
|
94 |
+
mt5_enabled = st.checkbox("mt5 (google/mt5-*)", value=True)
|
95 |
+
long_t5_enabled = st.checkbox("Long T5 Dutch+English (trained with HuggingFace script)")
|
96 |
+
t5_v1_1_enabled = st.checkbox("T5 Dutch (and English) (trained with HuggingFace script)")
|
97 |
+
with col2:
|
98 |
+
small_enabled = st.checkbox("small model sizes")
|
99 |
+
base_enabled = st.checkbox("base model sizes")
|
100 |
+
large_enabled = st.checkbox("large model sizes")
|
101 |
+
_24_enabled = st.checkbox("small nl24 deep narrow sizes")
|
102 |
+
_36_enabled = st.checkbox("base nl36 deep narrow sizes")
|
103 |
+
_8l_enabled = st.checkbox("large nl8 deep wide sizes")
|
104 |
+
_4xl_enabled = st.checkbox("xlarge nl4 deep wide sizes")
|
105 |
+
|
106 |
+
plot_df = plot_df[
|
107 |
+
(plot_df["name"].str.contains("ul2") & ul2_enabled)
|
108 |
+
| (plot_df["name"].str.contains("flan") & flan_enabled)
|
109 |
+
| (plot_df["name"].str.contains("mt5") & mt5_enabled)
|
110 |
+
| (plot_df["name"].str.contains("long-t5") & long_t5_enabled)
|
111 |
+
| (plot_df["name"].str.contains("t5_1_1") & t5_1_1_enabled)
|
112 |
+
| ((plot_df["name"].str.startswith("t5") & ~plot_df["name"].str.startswith("t5_1_1")) & t5_v1_1_enabled)
|
113 |
+
| (plot_df["name"].str.contains("base") & base_enabled & ~plot_df["name"].str.contains("36"))
|
114 |
+
| (plot_df["name"].str.contains("small") & small_enabled & ~plot_df["name"].str.contains("24"))
|
115 |
+
| (plot_df["name"].str.contains("large") & large_enabled & ~plot_df["name"].str.contains("8"))
|
116 |
+
| ((plot_df["name"].str.contains("-36L") | plot_df["name"].str.contains("nl36")) & _36_enabled)
|
117 |
+
| ((plot_df["name"].str.contains("-24L") | plot_df["name"].str.contains("nl24")) & _24_enabled)
|
118 |
+
| ((plot_df["name"].str.contains("-8l") | plot_df["name"].str.contains("nl8")) & _8l_enabled)
|
119 |
+
| ((plot_df["name"].str.contains("-4L") | plot_df["name"].str.contains("nl4")) & _4xl_enabled)
|
120 |
+
]
|
121 |
+
|
122 |
+
color_dict = {"flan": "red", "ul2": "blue", "mt5": "green", "t5_1_1": "orange"}
|
123 |
+
colors = [
|
124 |
+
color_dict[name.split("-")[0].lower()]
|
125 |
+
if name.split("-")[0].lower() in color_dict.keys()
|
126 |
+
else "black"
|
127 |
+
for name in plot_df["name"]
|
128 |
+
]
|
129 |
+
fig = plt.figure(figsize=(15, 8))
|
130 |
+
sns.set_style("darkgrid")
|
131 |
+
ax = sns.scatterplot(
|
132 |
+
data=plot_df,
|
133 |
+
y="en->nl Bleu",
|
134 |
+
x="summ Rouge1",
|
135 |
+
size="num params (M)",
|
136 |
+
color=colors,
|
137 |
+
linewidth=0.7,
|
138 |
+
)
|
139 |
+
for i, row in plot_df.iterrows():
|
140 |
+
ax.annotate(
|
141 |
+
row["name"],
|
142 |
+
(row["summ Rouge1"], row["en->nl Bleu"]),
|
143 |
+
xytext=(0, 7),
|
144 |
+
textcoords="offset points",
|
145 |
+
ha="center",
|
146 |
+
va="center",
|
147 |
+
rotation=0,
|
148 |
+
)
|
149 |
+
plt.tight_layout()
|
150 |
+
st.pyplot(fig)
|
151 |
+
st.markdown("""* The `UL2` pre-trained Dutch(English) models consistently outperform the `T5-*` Dutch(English) models.
|
152 |
+
* Flan models perform almost instantly well on the summarization task, with `flan-t5-small`
|
153 |
showing performance comparable to Dutch T5 base models.
|
154 |
+
* Fine-tuning of `t5-v1.1-large-dutch-cased` failed with the fixed hyperparameters across all models.
|
155 |
+
Since the `UL2` models are better across the board, I've disabled this model on the hub.
|
156 |
* I am surprised by the consistent bad scores for the `long-t5` runs. I've retried the fine-tuning of these models with
|
157 |
`float32` instead of `bfloat16`, but the results were the same. Maybe this is normal behaviour for these models
|
158 |
targeted at dealing with longer sequence lengths.
|
159 |
+
* For the translation task from English to Dutch, the Dutch+English pre-trained models perform well. Also
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
`UL2 Dutch` pre-trained Dutch models are consistently better than their `Flan`, `T5 Dutch` and
|
161 |
`mT5` counterparts of the comparable size.
|
162 |
+
* For the translation task, I am not sure that a 'deep-narrow' model (e.g. base-nl36) is better than a normal model
|
163 |
+
or even a 'wide-deep' model.
|
164 |
+
* The `long-t5` models show bad performance on both tasks.
|
165 |
+
I cannot explain this the translation task. With a sequence length of 128 input and output
|
166 |
tokens, the sliding attention window with radius length 127 of the `long-t5` models should be able to handle this.
|
|
|
|
|
|
|
|
|
|
|
167 |
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
with open("REMARKS.md", "r") as f:
|
170 |
st.markdown(f.read())
|
171 |
|
172 |
+
st.markdown(
|
173 |
+
"""### Bfloat16 datatype requires loss regularization
|
174 |
|
175 |
When training models with `bfloat16` and without loss regularization (default), the training losses would plateau or
|
176 |
diverge. The graph below displays the results of different attempts
|
177 |
to train [t5-small-24L-dutch-english](https://huggingface.co/yhavinga/t5-small-24L-dutch-english).
|
178 |
The legend indicates the optimizer, data type, learning rate, total batch size, and learning rate schedule used.
|
179 |
As you can see, all attempts to train with `bfloat16` failed.
|
180 |
+
"""
|
181 |
+
)
|
182 |
+
st.image("img/bfloat16_loss.png", width=IMAGE_WIDTHS)
|
183 |
+
st.markdown(
|
184 |
+
"""The solution was found when peeking at T5X and the T5 gin configs, where I noticed a `z_loss` parameter,
|
185 |
+
always set to 1e-4. This factor is used in the T5X [cross entropy loss](https://github.com/google-research/t5x/blob/a319e559b4f72bffab91821487382ef4c25dfcf4/t5x/losses.py#L26)
|
186 |
function, with the purpose to pull the weights towards zero.
|
187 |
I experimented with adding this regularization term in the HF pre-training script,
|
188 |
and the `bfloat16` training runs did not exhibit the problems illustrated above anymore.
|
189 |
|
190 |
+
The `z_loss` regularization term in the T5X loss function looks like L2 regularization.
|
191 |
+
(See e.g. Andrej Karpathy [explaining regularization loss](https://youtu.be/PaCmpygFfXo?t=6720)).
|
|
|
192 |
The Optax optimizer, used in the HuggingFace script, mentions weight decay for AdaFactor (and AdamW)
|
193 |
but also mentions that L2 regularization does not work as expected with adaptive gradient
|
194 |
algorithms. It might be the case that setting a non-zero `weight_decay_rate` in the Optax Adafactor call
|
195 |
in the HuggingFace pre-training script is an alternative to adding the `z_loss` term, to solve the bfloat16 issues, but
|
196 |
I haven't tested this yet.
|
197 |
+
"""
|
198 |
+
)
|
199 |
|
200 |
+
st.markdown(
|
201 |
+
"""### Which optimizer and lr to use
|
202 |
|
203 |
During the Flax/Jax Community week in '21, our team quickly decided on using Adafactor with learning rate 5e-3.
|
204 |
I believed that a more optimal setting could be found with more time.
|
|
|
210 |
changed to a string with the optimizer name.) --
|
211 |
All runs in the graph below that achieve a loss below 4 use **Adafactor**.
|
212 |
Peach-sweep-6 is represented by a dashed orange line and had a learning rate of **5e-3**.
|
213 |
+
"""
|
214 |
+
)
|
215 |
|
216 |
+
st.image("img/adafactor_vs_adam_pretrain.png", width=IMAGE_WIDTHS)
|
217 |
+
st.markdown(
|
218 |
+
"""While there probably is a setting that will allow Adam and Shampoo to also converge fast below loss 4.0, I was unable
|
219 |
to find it. In a recent tweet Lucas Nestler had more success with Shampoo (https://twitter.com/_clashluke/status/1535994026876252160)
|
220 |
so maybe I need to revisit the attempt with the latest upstream code bases.
|
221 |
|
222 |
Later, when pre-training with T5X, I found that its custom Adafactor implementation with the default settings of the T5X gin configs,
|
223 |
a learning rate of 0.001 and inverse square root learning rate decay, worked well.
|
224 |
+
"""
|
225 |
+
)
|
226 |
|
227 |
+
st.markdown(
|
228 |
+
"""### Optimizer and learning rate used for summarization
|
229 |
|
230 |
Finetuning summarization requires more memory than translation due to the longer sequence lengths involved.
|
231 |
I wondered if I could use Adafactor instead of Adam and ran
|
232 |
a sweep to test this. The sweep was configured with Hyperband, so not all training runs completed to the end.
|
233 |
+
"""
|
234 |
+
)
|
235 |
+
st.image("img/optim_lr_summarization.png", width=IMAGE_WIDTHS)
|
236 |
+
st.markdown(
|
237 |
+
"""The training losses are graphed below:
|
238 |
+
"""
|
239 |
+
)
|
240 |
|
241 |
+
st.image("img/training_losses_summarization_sweep.png", width=IMAGE_WIDTHS)
|
242 |
+
st.markdown(
|
243 |
+
"""
|
244 |
While the Adafactor run with learning rate 7e-4 came close to the Adam runs, the consistent stability of training with Adam
|
245 |
made me stick with Adam as optimizer for evaluation runs on the several models. For translation the results were similar, though in the end I needed to configure a lower learning rate for all
|
246 |
models to converge during fine-tuning.
|
247 |
+
"""
|
248 |
+
)
|
249 |
|
250 |
+
st.markdown(
|
251 |
+
"""### Sequence length 512 or 1024
|
252 |
|
253 |
The models `t5-v1_1-base-dutch-english-cased` and `t5-v1_1-base-dutch-english-cased-1024` have the same model dimensions,
|
254 |
but are pre-trained on different sequence lenghts, 512 and 1024 respectively.
|
255 |
The evaluation loss and accuracy of the models do not look too different. Since training of the 1024 sequence length model was
|
256 |
very slow and didn't converge a was was very slow, I stopped it early. The figure below shows the evaluation
|
257 |
loss and accuracy.
|
258 |
+
"""
|
259 |
+
)
|
260 |
+
st.image("img/t5v1_1eval_loss_and_accuracy.png", width=IMAGE_WIDTHS)
|
261 |
+
st.markdown(
|
262 |
+
"""The 512 sequence length model was trained for 10 epochs of the `small` nl+en config (186B tokens total) and the 1024
|
263 |
sequence length model about 2 epochs of the `large` nl+en config (100B tokens total). While I expected both models to
|
264 |
perform similarly on downstream tasks, the 1024 sequence length model has better scores for both
|
265 |
summarization and translation.
|
266 |
+
"""
|
267 |
+
)
|
268 |
|
269 |
+
st.markdown(
|
270 |
+
"""## Model lists
|
271 |
|
272 |
### t5_1_1
|
273 |
|
|
|
401 |
and orchestrate hyperparameter sweeps with insightful visualizations.
|
402 |
|
403 |
Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/)
|
404 |
+
"""
|
405 |
+
)
|
406 |
|
407 |
st.write(
|
408 |
f"""
|
data/pretrained.sqlite
ADDED
Binary file (24.6 kB). View file
|
|
adafactor_vs_adam_pretrain.png β img/adafactor_vs_adam_pretrain.png
RENAMED
File without changes
|
bfloat16_loss.png β img/bfloat16_loss.png
RENAMED
File without changes
|
eval_summ_rouge1_202302.png β img/eval_summ_rouge1_202302.png
RENAMED
File without changes
|
eval_t5_dutch_english.png β img/eval_t5_dutch_english.png
RENAMED
File without changes
|
eval_transl_bleu_202302.png β img/eval_transl_bleu_202302.png
RENAMED
File without changes
|
evaluation_t5_dutch_english.png β img/evaluation_t5_dutch_english.png
RENAMED
File without changes
|
optim_lr_summarization.png β img/optim_lr_summarization.png
RENAMED
File without changes
|
t5v1_1eval_loss_and_accuracy.png β img/t5v1_1eval_loss_and_accuracy.png
RENAMED
File without changes
|
train_loss_eval_summarization.png β img/train_loss_eval_summarization.png
RENAMED
File without changes
|
train_loss_eval_t5_translation.png β img/train_loss_eval_t5_translation.png
RENAMED
File without changes
|
training_base_36l_losses.png β img/training_base_36l_losses.png
RENAMED
File without changes
|
training_losses_summarization_sweep.png β img/training_losses_summarization_sweep.png
RENAMED
File without changes
|
requirements.txt
CHANGED
@@ -12,3 +12,5 @@ chex>=0.1.4
|
|
12 |
##jaxlib==0.1.67
|
13 |
flax>=0.5.3
|
14 |
sentencepiece
|
|
|
|
|
|
12 |
##jaxlib==0.1.67
|
13 |
flax>=0.5.3
|
14 |
sentencepiece
|
15 |
+
matplotlib
|
16 |
+
seaborn
|