Spaces:
Running
Running
add partition helpers
Browse files- dalle_mini/partitions.py +69 -0
dalle_mini/partitions.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from flax.core.frozen_dict import freeze
|
4 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
5 |
+
from jax.experimental import PartitionSpec as P
|
6 |
+
|
7 |
+
|
8 |
+
# utils adapted from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
9 |
+
# Sentinels
|
10 |
+
_unmatched = object()
|
11 |
+
|
12 |
+
# For specifying empty leaf dict `{}`
|
13 |
+
empty_dict = object()
|
14 |
+
|
15 |
+
|
16 |
+
def _match(qs, ks):
|
17 |
+
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
18 |
+
# compile regexes and force complete match
|
19 |
+
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
20 |
+
for i in range(len(ks) - len(qs) + 1):
|
21 |
+
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
22 |
+
if matches and all(matches):
|
23 |
+
return True
|
24 |
+
return False
|
25 |
+
|
26 |
+
|
27 |
+
def _replacement_rules(rules):
|
28 |
+
def replace(key, val):
|
29 |
+
for rule, replacement in rules:
|
30 |
+
if _match(rule, key):
|
31 |
+
return replacement
|
32 |
+
return val
|
33 |
+
|
34 |
+
return replace
|
35 |
+
|
36 |
+
|
37 |
+
def _get_partition_rules():
|
38 |
+
return [
|
39 |
+
# embeddings
|
40 |
+
((r"embed_positions", "embedding"), P("mp", None)),
|
41 |
+
((r"embed_tokens", "embedding"), P("mp", None)),
|
42 |
+
# self-attention
|
43 |
+
((r"self_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
44 |
+
((r"self_attn", "out_proj", "kernel"), P("mp", None)),
|
45 |
+
# enc-dec attention
|
46 |
+
((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
47 |
+
((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
|
48 |
+
# FFN
|
49 |
+
((r"fc1", "kernel"), P(None, "mp")),
|
50 |
+
((r"fc2", "kernel"), P("mp", None)),
|
51 |
+
# layer norms
|
52 |
+
((r"layernorm_embedding", "(bias|scale)"), None),
|
53 |
+
((r"self_attn_layer_norm", "(bias|scale)"), None),
|
54 |
+
((r"encoder_attn_layer_norm", "(bias|scale)"), None),
|
55 |
+
((r"final_layer_norm", "(bias|scale)"), None),
|
56 |
+
((r"lm_head", "kernel"), P(None, "mp")),
|
57 |
+
]
|
58 |
+
|
59 |
+
|
60 |
+
def set_partitions(in_dict):
|
61 |
+
rules = _get_partition_rules()
|
62 |
+
replace = _replacement_rules(rules)
|
63 |
+
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
64 |
+
result = {k: replace(k, v) for k, v in initd.items()}
|
65 |
+
for k, v in result.items():
|
66 |
+
if v == _unmatched:
|
67 |
+
print(k)
|
68 |
+
assert _unmatched not in result.values(), "Incomplete partition spec."
|
69 |
+
return freeze(unflatten_dict(result))
|