Add pytorch model at 240k steps
Browse files- README.md +1 -1
- flax_to_pytorch.py +22 -0
README.md
CHANGED
@@ -14,7 +14,7 @@ datasets:
|
|
14 |
|
15 |
Training details:
|
16 |
|
17 |
-
* trained for
|
18 |
* block size: 512
|
19 |
* optimizer: adam, lr 8e-4, beta1 0.9, beta2 0.98
|
20 |
* warmup 5000 steps
|
|
|
14 |
|
15 |
Training details:
|
16 |
|
17 |
+
* trained for 240k steps (29 dec 2021)
|
18 |
* block size: 512
|
19 |
* optimizer: adam, lr 8e-4, beta1 0.9, beta2 0.98
|
20 |
* warmup 5000 steps
|
flax_to_pytorch.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import jax
|
4 |
+
import jax.numpy as jnp
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from transformers import FlaxGPT2LMHeadModel
|
7 |
+
from transformers import GPT2LMHeadModel
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(".")
|
9 |
+
tokenizer.pad_token = tokenizer.eos_token
|
10 |
+
model_fx = FlaxGPT2LMHeadModel.from_pretrained(".")
|
11 |
+
# def to_f32(t):
|
12 |
+
# return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
|
13 |
+
# model_fx.params = to_f32(model_fx.params)
|
14 |
+
# model_fx.save_pretrained("./fx")
|
15 |
+
model_pt = GPT2LMHeadModel.from_pretrained(".", from_flax=True)
|
16 |
+
model_pt.save_pretrained("./pt")
|
17 |
+
input_ids = np.asarray(2 * [128 * [0]], dtype=np.int32)
|
18 |
+
input_ids_pt = torch.tensor(input_ids)
|
19 |
+
logits_pt = model_pt(input_ids_pt).logits
|
20 |
+
print(logits_pt)
|
21 |
+
logits_fx = model_fx(input_ids).logits
|
22 |
+
print(logits_fx)
|