valhalla commited on
Commit
2856356
·
1 Parent(s): 180ed1e

add partition helpers

Browse files
Files changed (1) hide show
  1. dalle_mini/partitions.py +69 -0
dalle_mini/partitions.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from flax.core.frozen_dict import freeze
4
+ from flax.traverse_util import flatten_dict, unflatten_dict
5
+ from jax.experimental import PartitionSpec as P
6
+
7
+
8
+ # utils adapted from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
9
+ # Sentinels
10
+ _unmatched = object()
11
+
12
+ # For specifying empty leaf dict `{}`
13
+ empty_dict = object()
14
+
15
+
16
+ def _match(qs, ks):
17
+ """Return True if regexes in qs match any window of strings in tuple ks."""
18
+ # compile regexes and force complete match
19
+ qts = tuple(map(lambda x: re.compile(x + "$"), qs))
20
+ for i in range(len(ks) - len(qs) + 1):
21
+ matches = [x.match(y) for x, y in zip(qts, ks[i:])]
22
+ if matches and all(matches):
23
+ return True
24
+ return False
25
+
26
+
27
+ def _replacement_rules(rules):
28
+ def replace(key, val):
29
+ for rule, replacement in rules:
30
+ if _match(rule, key):
31
+ return replacement
32
+ return val
33
+
34
+ return replace
35
+
36
+
37
+ def _get_partition_rules():
38
+ return [
39
+ # embeddings
40
+ ((r"embed_positions", "embedding"), P("mp", None)),
41
+ ((r"embed_tokens", "embedding"), P("mp", None)),
42
+ # self-attention
43
+ ((r"self_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
44
+ ((r"self_attn", "out_proj", "kernel"), P("mp", None)),
45
+ # enc-dec attention
46
+ ((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
47
+ ((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
48
+ # FFN
49
+ ((r"fc1", "kernel"), P(None, "mp")),
50
+ ((r"fc2", "kernel"), P("mp", None)),
51
+ # layer norms
52
+ ((r"layernorm_embedding", "(bias|scale)"), None),
53
+ ((r"self_attn_layer_norm", "(bias|scale)"), None),
54
+ ((r"encoder_attn_layer_norm", "(bias|scale)"), None),
55
+ ((r"final_layer_norm", "(bias|scale)"), None),
56
+ ((r"lm_head", "kernel"), P(None, "mp")),
57
+ ]
58
+
59
+
60
+ def set_partitions(in_dict):
61
+ rules = _get_partition_rules()
62
+ replace = _replacement_rules(rules)
63
+ initd = {k: _unmatched for k in flatten_dict(in_dict)}
64
+ result = {k: replace(k, v) for k, v in initd.items()}
65
+ for k, v in result.items():
66
+ if v == _unmatched:
67
+ print(k)
68
+ assert _unmatched not in result.values(), "Incomplete partition spec."
69
+ return freeze(unflatten_dict(result))