HugoVoxx commited on
Commit
2c38483
·
verified ·
1 Parent(s): 4efea5f

Update ag4masses/alphageometry/decoder_stack.py

Browse files
ag4masses/alphageometry/decoder_stack.py CHANGED
@@ -1,55 +1,55 @@
1
- # Copyright 2023 DeepMind Technologies Limited
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """The decoder stack in inference mode."""
17
-
18
- from typing import Any, Tuple
19
-
20
- import gin
21
- from meliad_lib.meliad.transformer import decoder_stack
22
- import transformer_layer as tl
23
-
24
-
25
- struct = decoder_stack.struct
26
- nn_components = decoder_stack.nn_components
27
- position = decoder_stack.position
28
- jnp = decoder_stack.jnp
29
- attention = decoder_stack.attention
30
-
31
- DStackWindowState = decoder_stack.DStackWindowState
32
-
33
- Array = Any
34
-
35
- TransformerTaskConfig = decoder_stack.TransformerTaskConfig
36
-
37
- DStackDecoderState = Tuple[tl.DecoderState, ...]
38
-
39
-
40
- @gin.configurable
41
- class DecoderStackGenerate(decoder_stack.DecoderStack):
42
- """Stack of transformer decoder layers."""
43
-
44
- layer_factory = tl.TransformerLayerGenerate
45
-
46
- def init_decoder_state_vanilla(
47
- self, sequence_length: int, start_of_sequence: Array
48
- ) -> DStackDecoderState:
49
- """Return initial state for autoregressive generation."""
50
- return tuple(
51
- [
52
- layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
53
- for layer in self.transformer_layers
54
- ]
55
- )
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """The decoder stack in inference mode."""
17
+
18
+ from typing import Any, Tuple
19
+
20
+ import gin
21
+ from aglib.meliad.transformer import decoder_stack
22
+ import transformer_layer as tl
23
+
24
+
25
+ struct = decoder_stack.struct
26
+ nn_components = decoder_stack.nn_components
27
+ position = decoder_stack.position
28
+ jnp = decoder_stack.jnp
29
+ attention = decoder_stack.attention
30
+
31
+ DStackWindowState = decoder_stack.DStackWindowState
32
+
33
+ Array = Any
34
+
35
+ TransformerTaskConfig = decoder_stack.TransformerTaskConfig
36
+
37
+ DStackDecoderState = Tuple[tl.DecoderState, ...]
38
+
39
+
40
+ @gin.configurable
41
+ class DecoderStackGenerate(decoder_stack.DecoderStack):
42
+ """Stack of transformer decoder layers."""
43
+
44
+ layer_factory = tl.TransformerLayerGenerate
45
+
46
+ def init_decoder_state_vanilla(
47
+ self, sequence_length: int, start_of_sequence: Array
48
+ ) -> DStackDecoderState:
49
+ """Return initial state for autoregressive generation."""
50
+ return tuple(
51
+ [
52
+ layer.init_decoder_state_vanilla(sequence_length, start_of_sequence)
53
+ for layer in self.transformer_layers
54
+ ]
55
+ )