HugoVoxx commited on
Commit
9334a63
·
verified ·
1 Parent(s): c6eebc1

Update ag4masses/alphageometry/models.py

Browse files
Files changed (1) hide show
  1. ag4masses/alphageometry/models.py +178 -178
ag4masses/alphageometry/models.py CHANGED
@@ -1,178 +1,178 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Transformer language model generate mode."""
17
-
18
- from typing import Any, Tuple
19
- import beam_search
20
- import decoder_stack
21
- import gin
22
- import jax
23
- import jax.numpy as jnp
24
- from meliad_lib.meliad.transformer import models
25
-
26
-
27
- @gin.configurable
28
- class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29
- """Decoder only language modeling in inference mode."""
30
-
31
- decoder_factory = decoder_stack.DecoderStackGenerate
32
-
33
- num_heads: int = gin.REQUIRED
34
- head_size: int = gin.REQUIRED
35
-
36
- def get_fake_input(self) -> dict[str, Any]:
37
- fake_input_dict = super().get_fake_input()
38
- b = self.task_config.batch_size
39
- n = self.num_heads
40
- h = self.head_size
41
- fake_input_dict.update({
42
- 'dstate': tuple(
43
- [{
44
- 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45
- 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46
- 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47
- 'recurrent_kvq': None,
48
- 'relative_position_bias': jnp.zeros(
49
- (b, n, 1, 1024), dtype=jnp.bfloat16
50
- ),
51
- }]
52
- * 12
53
- ),
54
- 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55
- 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56
- 'length': 1,
57
- 'temperature': 1.0,
58
- })
59
- return fake_input_dict
60
-
61
- def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62
- # Make sure this code is not used on untested cases.
63
- if self.mode not in ['init', 'beam_search']:
64
- raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65
- if self.decoder.supports_generate():
66
- raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67
-
68
- self.decoder(
69
- input_tokens=inputs['targets'][:, 0:1],
70
- target_tokens=None,
71
- start_of_sequence=inputs['start_of_sequence'],
72
- )
73
-
74
- b = inputs['targets'].shape[0]
75
- no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76
-
77
- # This fn is used in both beam_search or topk_sampling.
78
- def tokens_to_logits_fn(
79
- input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80
- ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81
- (logits, dstate, _) = self.decoder(
82
- input_tokens=input_token,
83
- target_tokens=None,
84
- start_of_sequence=no_start_of_seq,
85
- decoder_state=dstate,
86
- )
87
- return logits[:, -1, :], dstate
88
-
89
- last_token = jax.lax.dynamic_slice_in_dim(
90
- inputs['targets'], inputs['length'] - 1, 1, axis=1
91
- )
92
-
93
- # last token is used to seed beam_search
94
- inputs['targets'] = inputs['targets'][:, 0:-1]
95
- dstate = jax.lax.cond(
96
- inputs['start_of_sequence'][0],
97
- lambda: self.generate(inputs)[0],
98
- lambda: inputs['dstate'],
99
- )
100
-
101
- # Then we run beam search, init with last_token & dstate.
102
- finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103
- last_token,
104
- dstate,
105
- tokens_to_logits_fn,
106
- max_decode_len=512,
107
- eos=inputs['eos'].reshape((1, 1, -1)),
108
- mask=inputs['mask'].reshape((1, 1, -1)),
109
- )
110
-
111
- return 0.0, {
112
- 'finished_seqs': finished_seqs,
113
- 'finished_scores': finished_scores,
114
- 'dstate': dstate,
115
- }
116
-
117
- def generate(
118
- self, inputs: ...
119
- ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120
- """Generate an output sequence.
121
-
122
- Args:
123
- inputs: the same as argument to _call_.
124
-
125
- Returns:
126
- An array of generated tokens of shape (batch_size, sequence_length).
127
- """
128
- input_tokens = inputs['targets'] # [b,seq_len]
129
- start_of_sequence = inputs['start_of_sequence'] # [b]
130
- target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131
- batch_size = target_tokens.shape[0]
132
-
133
- # Assuming all sequences start at the same time.
134
- start0 = inputs['start_of_sequence'][0]
135
- dstate = jax.lax.cond(
136
- start0,
137
- lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138
- 1024, start_of_sequence
139
- ),
140
- lambda: inputs['dstate'],
141
- )
142
-
143
- first_token = input_tokens[:, 0:1]
144
- no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145
- temperature = 1
146
- if 'temperature' in inputs:
147
- temperature = inputs['temperature']
148
-
149
- num_steps = inputs['length']
150
- if self.mode == 'beam_search':
151
- num_steps -= 1
152
-
153
- def cond_fn(scan_state) -> jnp.bool_:
154
- _, _, i, _ = scan_state
155
- return i < num_steps
156
-
157
- def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158
- (dstate, input_token, i, _) = scan_state
159
-
160
- (logits, dstate, _) = self.decoder(
161
- input_tokens=input_token,
162
- target_tokens=None,
163
- start_of_sequence=no_start_of_seq,
164
- decoder_state=dstate,
165
- )
166
-
167
- logits = logits / temperature
168
- output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169
-
170
- return (dstate, output_token, i + 1, logits)
171
-
172
- # Scan over the sequence length.
173
- dummy_logits = jnp.zeros((batch_size, 1, 1024))
174
- initial_scan_state = (dstate, first_token, 0, dummy_logits)
175
- dstate, _, _, logits = jax.lax.while_loop(
176
- cond_fn, loop_fn, initial_scan_state
177
- )
178
- return dstate, logits
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Transformer language model generate mode."""
17
+
18
+ from typing import Any, Tuple
19
+ import beam_search
20
+ import decoder_stack
21
+ import gin
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from aglib.meliad.transformer import models
25
+
26
+
27
+ @gin.configurable
28
+ class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29
+ """Decoder only language modeling in inference mode."""
30
+
31
+ decoder_factory = decoder_stack.DecoderStackGenerate
32
+
33
+ num_heads: int = gin.REQUIRED
34
+ head_size: int = gin.REQUIRED
35
+
36
+ def get_fake_input(self) -> dict[str, Any]:
37
+ fake_input_dict = super().get_fake_input()
38
+ b = self.task_config.batch_size
39
+ n = self.num_heads
40
+ h = self.head_size
41
+ fake_input_dict.update({
42
+ 'dstate': tuple(
43
+ [{
44
+ 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45
+ 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46
+ 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47
+ 'recurrent_kvq': None,
48
+ 'relative_position_bias': jnp.zeros(
49
+ (b, n, 1, 1024), dtype=jnp.bfloat16
50
+ ),
51
+ }]
52
+ * 12
53
+ ),
54
+ 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55
+ 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56
+ 'length': 1,
57
+ 'temperature': 1.0,
58
+ })
59
+ return fake_input_dict
60
+
61
+ def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62
+ # Make sure this code is not used on untested cases.
63
+ if self.mode not in ['init', 'beam_search']:
64
+ raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65
+ if self.decoder.supports_generate():
66
+ raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67
+
68
+ self.decoder(
69
+ input_tokens=inputs['targets'][:, 0:1],
70
+ target_tokens=None,
71
+ start_of_sequence=inputs['start_of_sequence'],
72
+ )
73
+
74
+ b = inputs['targets'].shape[0]
75
+ no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76
+
77
+ # This fn is used in both beam_search or topk_sampling.
78
+ def tokens_to_logits_fn(
79
+ input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80
+ ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81
+ (logits, dstate, _) = self.decoder(
82
+ input_tokens=input_token,
83
+ target_tokens=None,
84
+ start_of_sequence=no_start_of_seq,
85
+ decoder_state=dstate,
86
+ )
87
+ return logits[:, -1, :], dstate
88
+
89
+ last_token = jax.lax.dynamic_slice_in_dim(
90
+ inputs['targets'], inputs['length'] - 1, 1, axis=1
91
+ )
92
+
93
+ # last token is used to seed beam_search
94
+ inputs['targets'] = inputs['targets'][:, 0:-1]
95
+ dstate = jax.lax.cond(
96
+ inputs['start_of_sequence'][0],
97
+ lambda: self.generate(inputs)[0],
98
+ lambda: inputs['dstate'],
99
+ )
100
+
101
+ # Then we run beam search, init with last_token & dstate.
102
+ finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103
+ last_token,
104
+ dstate,
105
+ tokens_to_logits_fn,
106
+ max_decode_len=512,
107
+ eos=inputs['eos'].reshape((1, 1, -1)),
108
+ mask=inputs['mask'].reshape((1, 1, -1)),
109
+ )
110
+
111
+ return 0.0, {
112
+ 'finished_seqs': finished_seqs,
113
+ 'finished_scores': finished_scores,
114
+ 'dstate': dstate,
115
+ }
116
+
117
+ def generate(
118
+ self, inputs: ...
119
+ ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120
+ """Generate an output sequence.
121
+
122
+ Args:
123
+ inputs: the same as argument to _call_.
124
+
125
+ Returns:
126
+ An array of generated tokens of shape (batch_size, sequence_length).
127
+ """
128
+ input_tokens = inputs['targets'] # [b,seq_len]
129
+ start_of_sequence = inputs['start_of_sequence'] # [b]
130
+ target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131
+ batch_size = target_tokens.shape[0]
132
+
133
+ # Assuming all sequences start at the same time.
134
+ start0 = inputs['start_of_sequence'][0]
135
+ dstate = jax.lax.cond(
136
+ start0,
137
+ lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138
+ 1024, start_of_sequence
139
+ ),
140
+ lambda: inputs['dstate'],
141
+ )
142
+
143
+ first_token = input_tokens[:, 0:1]
144
+ no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145
+ temperature = 1
146
+ if 'temperature' in inputs:
147
+ temperature = inputs['temperature']
148
+
149
+ num_steps = inputs['length']
150
+ if self.mode == 'beam_search':
151
+ num_steps -= 1
152
+
153
+ def cond_fn(scan_state) -> jnp.bool_:
154
+ _, _, i, _ = scan_state
155
+ return i < num_steps
156
+
157
+ def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158
+ (dstate, input_token, i, _) = scan_state
159
+
160
+ (logits, dstate, _) = self.decoder(
161
+ input_tokens=input_token,
162
+ target_tokens=None,
163
+ start_of_sequence=no_start_of_seq,
164
+ decoder_state=dstate,
165
+ )
166
+
167
+ logits = logits / temperature
168
+ output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169
+
170
+ return (dstate, output_token, i + 1, logits)
171
+
172
+ # Scan over the sequence length.
173
+ dummy_logits = jnp.zeros((batch_size, 1, 1024))
174
+ initial_scan_state = (dstate, first_token, 0, dummy_logits)
175
+ dstate, _, _, logits = jax.lax.while_loop(
176
+ cond_fn, loop_fn, initial_scan_state
177
+ )
178
+ return dstate, logits