Summary
Distilled with Distily library using teacher model gpt2 on dataset wikimedia/wikipedia.
Model Architecture:
- Architecture:
GPT2LMHeadModel
- Total Parameters: 81,912,576
- Data Type (dtype): torch.bfloat16
- Model Size: 0.16 GB
Evaluation Metrics Comparison
step | epoch | enwikippl | frwikippl | loss | runtime | samples_per_second | steps_per_second | tinystoriesppl | zhwikippl |
---|---|---|---|---|---|---|---|---|---|
teacher eval | 43.25 | 61.25 | 11.6875 | 19.125 | |||||
0 | 0 | 2018634629120.0 | 122045790683136.0 | 21.0022 | 102.1494 | 97.896 | 12.237 | 9999220736.0 | 43705587204096.0 |
2500 | 0.0101 | 299008.0 | 6422528.0 | 5.8065 | 101.9861 | 98.053 | 12.257 | 45824.0 | 14483456.0 |
5000 | 0.0202 | 6880.0 | 96256.0 | 3.3113 | 102.9516 | 97.133 | 12.142 | 4160.0 | 493568.0 |
7500 | 0.0303 | 1216.0 | 8096.0 | 2.1560 | 103.0236 | 97.065 | 12.133 | 692.0 | 42752.0 |
10000 | 0.0404 | 608.0 | 3664.0 | 1.7825 | 102.3752 | 97.68 | 12.21 | 388.0 | 888.0 |
12500 | 0.0505 | 358.0 | 1632.0 | 1.4664 | 102.1871 | 97.86 | 12.232 | 272.0 | 308.0 |
15000 | 0.0606 | 288.0 | 1176.0 | 1.3488 | 102.6007 | 97.465 | 12.183 | 228.0 | 260.0 |
17500 | 0.0707 | 255.0 | 1040.0 | 1.2932 | 102.1542 | 97.891 | 12.236 | 199.0 | 215.0 |
20000 | 0.0808 | 216.0 | 892.0 | 1.1570 | 102.1073 | 97.936 | 12.242 | 173.0 | 149.0 |
22500 | 0.0909 | 178.0 | 740.0 | 1.0350 | 102.0765 | 97.966 | 12.246 | 146.0 | 141.0 |
25000 | 0.1010 | 155.0 | 524.0 | 0.9676 | 102.1019 | 97.941 | 12.243 | 122.5 | 139.0 |
27500 | 0.1111 | 142.0 | 560.0 | 0.9230 | 102.0256 | 98.015 | 12.252 | 114.0 | 130.0 |
30000 | 0.1212 | 137.0 | 470.0 | 0.8998 | 102.3365 | 97.717 | 12.215 | 108.5 | 138.0 |
32500 | 0.1313 | 134.0 | 476.0 | 0.8740 | 102.3911 | 97.665 | 12.208 | 104.0 | 140.0 |
35000 | 0.1414 | 129.0 | 496.0 | 0.8657 | 102.2153 | 97.833 | 12.229 | 102.5 | 141.0 |
37500 | 0.1515 | 127.0 | 464.0 | 0.8513 | 102.0489 | 97.992 | 12.249 | 97.0 | 117.0 |
40000 | 0.1616 | 108.0 | 446.0 | 0.7522 | 102.9331 | 97.15 | 12.144 | 93.0 | 104.0 |
42500 | 0.1717 | 99.5 | 374.0 | 0.6850 | 103.1088 | 96.985 | 12.123 | 82.0 | 116.0 |
45000 | 0.1818 | 90.5 | 346.0 | 0.6316 | 102.7903 | 97.285 | 12.161 | 73.5 | 113.0 |
47500 | 0.1919 | 82.5 | 320.0 | 0.5960 | 102.5988 | 97.467 | 12.183 | 71.0 | 101.0 |
50000 | 0.2020 | 78.5 | 306.0 | 0.5676 | 102.5936 | 97.472 | 12.184 | 72.5 | 106.0 |
52500 | 0.2121 | 79.5 | 290.0 | 0.5424 | 102.5863 | 97.479 | 12.185 | 64.5 | 92.0 |
55000 | 0.2222 | 76.0 | 270.0 | 0.5280 | 102.6307 | 97.437 | 12.18 | 65.0 | 87.0 |
57500 | 0.2323 | 76.5 | 272.0 | 0.5278 | 101.9639 | 98.074 | 12.259 | 64.5 | 102.0 |
60000 | 0.2424 | 77.5 | 268.0 | 0.5286 | 102.0921 | 97.951 | 12.244 | 62.75 | 99.5 |
62500 | 0.2525 | 75.5 | 264.0 | 0.5204 | 102.0679 | 97.974 | 12.247 | 63.25 | 83.0 |
65000 | 0.2626 | 76.0 | 260.0 | 0.5176 | 102.1795 | 97.867 | 12.233 | 61.5 | 90.5 |
67500 | 0.2727 | 74.5 | 256.0 | 0.5112 | 102.5764 | 97.488 | 12.186 | 62.25 | 93.5 |
70000 | 0.2828 | 73.5 | 258.0 | 0.5128 | 101.9569 | 98.081 | 12.26 | 62.0 | 79.0 |
72500 | 0.2929 | 75.0 | 250.0 | 0.5053 | 101.9382 | 98.099 | 12.262 | 64.0 | 96.0 |
75000 | 0.3030 | 72.5 | 238.0 | 0.5068 | 102.0407 | 98.0 | 12.25 | 61.5 | 88.5 |
77500 | 0.3131 | 73.5 | 256.0 | 0.5085 | 102.0542 | 97.987 | 12.248 | 64.5 | 86.5 |
80000 | 0.3232 | 70.5 | 238.0 | 0.4699 | 102.4042 | 97.652 | 12.207 | 54.75 | 98.5 |
82500 | 0.3333 | 68.0 | 242.0 | 0.4574 | 102.2684 | 97.782 | 12.223 | 55.5 | 160.0 |
85000 | 0.3434 | 64.5 | 218.0 | 0.4490 | 102.3277 | 97.725 | 12.216 | 52.0 | 77.5 |
87500 | 0.3535 | 66.5 | 203.0 | 0.4394 | 102.1134 | 97.93 | 12.241 | 51.25 | 67.5 |
90000 | 0.3636 | 63.75 | 212.0 | 0.4310 | 102.0438 | 97.997 | 12.25 | 51.25 | 88.5 |
92500 | 0.3737 | 65.5 | 209.0 | 0.4262 | 101.9984 | 98.041 | 12.255 | 49.75 | 103.5 |
95000 | 0.3838 | 65.0 | 204.0 | 0.4274 | 102.0781 | 97.964 | 12.246 | 46.25 | 83.0 |
97500 | 0.3939 | 64.5 | 201.0 | 0.4192 | 102.0692 | 97.973 | 12.247 | 50.5 | 94.5 |
100000 | 0.4040 | 64.5 | 203.0 | 0.4207 | 102.1283 | 97.916 | 12.24 | 49.0 | 88.0 |
102500 | 0.4141 | 63.0 | 209.0 | 0.4184 | 102.224 | 97.824 | 12.228 | 48.0 | 125.0 |
105000 | 0.4242 | 62.75 | 193.0 | 0.4166 | 102.1918 | 97.855 | 12.232 | 46.0 | 76.0 |
107500 | 0.4343 | 62.75 | 197.0 | 0.4128 | 102.1719 | 97.874 | 12.234 | 47.0 | 113.0 |
110000 | 0.4444 | 64.5 | 191.0 | 0.4118 | 103.0992 | 96.994 | 12.124 | 49.0 | 82.0 |
112500 | 0.4545 | 65.0 | 213.0 | 0.4128 | 102.7296 | 97.343 | 12.168 | 47.0 | 111.5 |
115000 | 0.4646 | 68.5 | 207.0 | 0.4301 | 102.178 | 97.868 | 12.234 | 49.0 | 108.0 |
117500 | 0.4747 | 65.0 | 217.0 | 0.4372 | 102.2302 | 97.818 | 12.227 | 50.25 | 124.0 |
120000 | 0.4848 | 65.5 | 210.0 | 0.4351 | 102.2952 | 97.756 | 12.22 | 51.0 | 139.0 |
122500 | 0.4949 | 66.0 | 272.0 | 0.4352 | 102.1941 | 97.853 | 12.232 | 50.5 | 226.0 |
125000 | 0.5051 | 67.0 | 240.0 | 0.4387 | 101.978 | 98.06 | 12.258 | 49.0 | 71.0 |
127500 | 0.5152 | 66.5 | 224.0 | 0.4396 | 101.9014 | 98.134 | 12.267 | 49.75 | 100.0 |
130000 | 0.5253 | 65.5 | 227.0 | 0.4354 | 102.1244 | 97.92 | 12.24 | 50.75 | 146.0 |
132500 | 0.5354 | 66.0 | 209.0 | 0.4286 | 102.0218 | 98.018 | 12.252 | 52.25 | 101.5 |
135000 | 0.5455 | 64.5 | 220.0 | 0.4361 | 101.9074 | 98.128 | 12.266 | 51.25 | 181.0 |
137500 | 0.5556 | 66.5 | 223.0 | 0.4288 | 102.0744 | 97.968 | 12.246 | 49.0 | 103.0 |
140000 | 0.5657 | 66.5 | 232.0 | 0.4287 | 102.1162 | 97.928 | 12.241 | 49.25 | 127.5 |
142500 | 0.5758 | 66.5 | 220.0 | 0.4299 | 101.9461 | 98.091 | 12.261 | 49.5 | 88.5 |
145000 | 0.5859 | 65.5 | 217.0 | 0.4238 | 101.9572 | 98.08 | 12.26 | 48.75 | 177.0 |
147500 | 0.5960 | 64.0 | 205.0 | 0.4109 | 101.9497 | 98.088 | 12.261 | 48.75 | 128.0 |
150000 | 0.6061 | 63.5 | 224.0 | 0.4051 | 102.0205 | 98.02 | 12.252 | 48.5 | 117.5 |
152500 | 0.6162 | 63.25 | 202.0 | 0.4000 | 101.9318 | 98.105 | 12.263 | 47.5 | 160.0 |
155000 | 0.6263 | 63.75 | 195.0 | 0.4052 | 102.0203 | 98.02 | 12.252 | 48.75 | 100.0 |
157500 | 0.6364 | 63.75 | 212.0 | 0.4014 | 101.8935 | 98.142 | 12.268 | 49.25 | 113.0 |
160000 | 0.6465 | 62.75 | 198.0 | 0.3988 | 101.9178 | 98.118 | 12.265 | 44.5 | 132.0 |
162500 | 0.6566 | 64.5 | 192.0 | 0.3918 | 102.0303 | 98.01 | 12.251 | 45.5 | 100.0 |
165000 | 0.6667 | 62.5 | 202.0 | 0.3958 | 102.3627 | 97.692 | 12.211 | 47.75 | 88.5 |
167500 | 0.6768 | 62.5 | 191.0 | 0.3883 | 102.1537 | 97.892 | 12.236 | 44.75 | 80.5 |
170000 | 0.6869 | 63.5 | 195.0 | 0.3880 | 102.0728 | 97.969 | 12.246 | 51.0 | 91.5 |
172500 | 0.6970 | 60.75 | 201.0 | 0.3863 | 101.9235 | 98.113 | 12.264 | 47.5 | 90.5 |
175000 | 0.7071 | 61.5 | 189.0 | 0.3806 | 101.9376 | 98.099 | 12.262 | 46.5 | 82.5 |
177500 | 0.7172 | 58.75 | 171.0 | 0.3512 | 101.9844 | 98.054 | 12.257 | 42.75 | 66.0 |
180000 | 0.7273 | 55.5 | 161.0 | 0.3218 | 101.881 | 98.154 | 12.269 | 39.25 | 54.0 |
182500 | 0.7374 | 54.25 | 149.0 | 0.3148 | 101.9839 | 98.055 | 12.257 | 38.75 | 47.75 |
185000 | 0.7475 | 53.5 | 160.0 | 0.3133 | 101.9875 | 98.051 | 12.256 | 38.75 | 45.0 |
187500 | 0.7576 | 54.75 | 160.0 | 0.3114 | 101.9762 | 98.062 | 12.258 | 38.0 | 43.75 |
190000 | 0.7677 | 53.75 | 147.0 | 0.3075 | 101.9972 | 98.042 | 12.255 | 38.0 | 38.25 |
192500 | 0.7778 | 54.0 | 157.0 | 0.3057 | 101.9431 | 98.094 | 12.262 | 38.0 | 48.0 |
195000 | 0.7879 | 53.25 | 149.0 | 0.3058 | 101.9778 | 98.061 | 12.258 | 37.0 | 41.0 |
197500 | 0.7980 | 54.0 | 152.0 | 0.3032 | 102.0059 | 98.034 | 12.254 | 37.25 | 40.0 |
200000 | 0.8081 | 53.75 | 151.0 | 0.3033 | 102.0615 | 97.98 | 12.248 | 37.25 | 47.25 |
202500 | 0.8182 | 53.0 | 146.0 | 0.2957 | 102.0116 | 98.028 | 12.254 | 36.75 | 39.0 |
205000 | 0.8283 | 52.5 | 139.0 | 0.2903 | 102.1449 | 97.9 | 12.238 | 36.5 | 35.75 |
207500 | 0.8384 | 52.0 | 142.0 | 0.2894 | 102.0126 | 98.027 | 12.253 | 36.25 | 38.25 |
210000 | 0.8485 | 52.25 | 142.0 | 0.2883 | 102.0938 | 97.949 | 12.244 | 36.0 | 37.25 |
212500 | 0.8586 | 52.5 | 141.0 | 0.2874 | 101.9515 | 98.086 | 12.261 | 36.0 | 37.0 |
215000 | 0.8687 | 52.25 | 140.0 | 0.2873 | 101.9427 | 98.094 | 12.262 | 36.0 | 36.0 |
217500 | 0.8788 | 51.75 | 141.0 | 0.2863 | 102.0114 | 98.028 | 12.254 | 36.0 | 35.5 |
220000 | 0.8889 | 52.0 | 141.0 | 0.2854 | 102.0424 | 97.999 | 12.25 | 36.0 | 35.75 |
222500 | 0.8990 | 52.5 | 143.0 | 0.2853 | 102.0368 | 98.004 | 12.25 | 36.0 | 35.25 |
225000 | 0.9091 | 52.0 | 142.0 | 0.2849 | 102.115 | 97.929 | 12.241 | 35.75 | 35.0 |
227500 | 0.9192 | 52.0 | 141.0 | 0.2851 | 102.0455 | 97.996 | 12.249 | 36.0 | 35.25 |
230000 | 0.9293 | 52.0 | 141.0 | 0.2846 | 102.0273 | 98.013 | 12.252 | 35.75 | 35.25 |
232500 | 0.9394 | 52.0 | 141.0 | 0.2843 | 101.961 | 98.077 | 12.26 | 35.75 | 35.0 |
235000 | 0.9495 | 52.0 | 141.0 | 0.2844 | 102.0188 | 98.021 | 12.253 | 35.75 | 35.25 |
237500 | 0.9596 | 52.0 | 141.0 | 0.2845 | 102.0714 | 97.971 | 12.246 | 35.75 | 35.25 |
240000 | 0.9697 | 52.0 | 141.0 | 0.2844 | 102.0371 | 98.004 | 12.25 | 35.75 | 35.25 |
242500 | 0.9798 | 52.0 | 141.0 | 0.2844 | 102.0363 | 98.004 | 12.251 | 35.75 | 35.25 |
245000 | 0.9899 | 52.0 | 141.0 | 0.2844 | 102.0254 | 98.015 | 12.252 | 35.75 | 35.25 |
247500 | 1.0 | 52.0 | 141.0 | 0.2846 | 102.5728 | 97.492 | 12.186 | 35.75 | 35.25 |
Resource Usage Comparison
- VRAM Use: 7.2012 GB
`# Distillation (Teacher -> Student) Architecture Difference:
- Architecture:
GPT2LMHeadModel
->GPT2LMHeadModel
- Total Parameters: 124,439,808 -> 81,912,576
- Data Type (dtype): 124439808 -> torch.bfloat16
- Model Size: 0.24 GB -> 0.16 GB
Module Diff Details
--- teacher model modules
+++ student model modules
@@ -4,7 +4,7 @@
(wpe): Embedding(1024, 768)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
- (0-11): 12 x GPT2Block(
+ (0-5): 6 x GPT2Block(
(ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): GPT2FlashAttention2(
(c_attn): Conv1D()
Train Dataset
Trained on 521,350,663 tokens from the wikimedia/wikipedia dataset.
- Num Samples:
990,000
- Subset:
20231101.en
- Split:
train
Training Objective
DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl))
Hyperparameters
The following hyperparameters were used during training:
Expand
- learning_rate:
0.0001
- train_batch_size:
4
- eval_batch_size:
8
- seed:
42
- optimizer:
Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type:
cosine
- lr_scheduler_warmup_ratio:
0.5
- num_epochs:
1.0
- distillation_objective:
DistillationObjective(logits_loss_component=LossComponent(label=logits, weight=1, loss_fn=kl))
- train_embeddings:
True
- lr_scheduler:
<torch.optim.lr_scheduler.LambdaLR object at 0x7fd9b01df220>
- student_model_name_or_path:
None
- student_config_name_or_path:
distilbert/distilgpt2
- student_model_config:
None
- reinitialize_weights:
None
- copy_teacher_modules:
[('lm_head', False)]
- student_model_as_bitnet:
False
- student_model_compile:
False
- dropout:
None
- teacher_model_name_or_path:
gpt2
- teacher_load_in_8bit:
False
- teacher_load_in_4bit:
False
- teacher_model_compile:
False
- dataset_uri:
wikimedia/wikipedia
- dataset_subset:
20231101.en
- dataset_split:
train
- dataset_column_name:
text
- dataset_sample_size:
1000000
- dataset_test_size:
0.01
- gradient_accumulation_steps:
1
- weight_decay:
0.0
- max_grad_norm:
1.0
- warmup_ratio:
0.5
- warmup_steps:
0
- gradient_checkpointing:
True
Framework Versions
- Distily 0.2.0
- Transformers 4.44.0
- Pytorch 2.3.0
- Datasets 2.21.0
- Downloads last month
- 20
Model tree for distily/short_gpt2
Base model
distilbert/distilgpt2