PatrickHaller commited on
Commit
ddd178f
1 Parent(s): b95230d

Upload modeling_xlstm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_xlstm.py +210 -0
modeling_xlstm.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
7
+ from xlstm.components.init import small_init_init_
8
+ from xlstm.utils import WeightDecayOptimGroupMixin
9
+ from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack
10
+
11
+ from .configuration_xlstm import xLSTMConfig
12
+
13
+
14
+ class xLSTMPreTrainedModel(PreTrainedModel):
15
+ """Base class for all models."""
16
+
17
+ config_class = xLSTMConfig
18
+
19
+
20
+ class xLSTMBlockStack(_xLSTMBlockStack):
21
+ """Small wrapper to expose hidden states"""
22
+
23
+ def forward(
24
+ self, x: torch.Tensor, **kwargs
25
+ ) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]:
26
+ hidden_states = ()
27
+ for block in self.blocks:
28
+ x = block(x, **kwargs)
29
+ hidden_states += (x,)
30
+ return x, hidden_states
31
+
32
+
33
+ class xLSTMModel(xLSTMPreTrainedModel):
34
+ def __init__(self, config: xLSTMConfig):
35
+ super().__init__(config)
36
+ self.config = config
37
+
38
+ self.token_embedding = nn.Embedding(
39
+ num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim
40
+ )
41
+ _config = config.to_xlstm_config()
42
+
43
+ self.emb_dropout = (
44
+ nn.Dropout(_config.dropout)
45
+ if _config.add_embedding_dropout
46
+ else nn.Identity()
47
+ )
48
+
49
+ self.xlstm_block_stack = xLSTMBlockStack(config=_config)
50
+
51
+ def forward(
52
+ self,
53
+ input_ids: torch.LongTensor,
54
+ output_hidden_states: Optional[bool] = None,
55
+ return_dict=Optional[bool],
56
+ ) -> Union[Tuple, BaseModelOutput]:
57
+ token_embedding = self.token_embedding(input_ids)
58
+ x = self.emb_dropout(token_embedding)
59
+ x, hidden_states = self.xlstm_block_stack(x)
60
+
61
+ if output_hidden_states:
62
+ hidden_states = (token_embedding,) + hidden_states
63
+
64
+ if not return_dict:
65
+ return x, hidden_states
66
+
67
+ return BaseModelOutput(
68
+ last_hidden_state=x,
69
+ hidden_states=hidden_states if output_hidden_states else None,
70
+ )
71
+
72
+
73
+ class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin):
74
+ _tied_weights_keys = ["lm_head.weight"]
75
+
76
+ def __init__(self, config: xLSTMConfig, **kwargs):
77
+ super().__init__(config)
78
+ self.config = config
79
+ self.vocab_size = config.vocab_size
80
+
81
+ self.model = xLSTMModel(config)
82
+
83
+ self.lm_head = nn.Linear(
84
+ in_features=config.embedding_dim,
85
+ out_features=config.vocab_size,
86
+ bias=False,
87
+ )
88
+
89
+ self.post_init()
90
+ # TODO: Add option for up-projection
91
+
92
+ def get_input_embeddings(self):
93
+ return self.model.token_embedding
94
+
95
+ def set_input_embeddings(self, value: nn.Module):
96
+ self.model.token_embedding = value
97
+
98
+ def get_output_embeddings(self):
99
+ return self.lm_head
100
+
101
+ def set_output_embeddings(self, value):
102
+ self.lm_head = value
103
+
104
+ def reset_parameters(self):
105
+ self.model.xlstm_block_stack.reset_parameters()
106
+
107
+ small_init_init_(
108
+ self.get_input_embeddings().weight, dim=self.config.embedding_dim
109
+ )
110
+
111
+ if not self.config.tie_word_embeddings:
112
+ small_init_init_(
113
+ self.get_output_embeddings().weight, dim=self.config.embedding_dim
114
+ )
115
+
116
+ def forward(
117
+ self,
118
+ input_ids: torch.Tensor,
119
+ labels: Optional[torch.LongTensor] = None,
120
+ output_hidden_states: Optional[bool] = None,
121
+ return_dict: Optional[bool] = None,
122
+ ):
123
+ output = self.model(
124
+ input_ids,
125
+ output_hidden_states=output_hidden_states,
126
+ )
127
+
128
+ hidden_state = output[0]
129
+
130
+ logits = self.lm_head(hidden_state)
131
+ logits = logits.float()
132
+
133
+ loss = None
134
+
135
+ if labels is not None:
136
+ shift_logits = logits[..., :-1, :].contiguous()
137
+ shift_labels = labels[..., 1:].contiguous()
138
+
139
+ loss_fct = nn.CrossEntropyLoss()
140
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
141
+ shift_labels = shift_labels.view(-1)
142
+
143
+ shift_labels = shift_labels.to(shift_logits.device)
144
+ loss = loss_fct(shift_logits, shift_labels)
145
+
146
+ if not return_dict:
147
+ output = (logits,) + output[1:]
148
+ return ((loss,) + output) if loss is not None else output
149
+
150
+ return CausalLMOutputWithPast(
151
+ loss=loss,
152
+ logits=logits,
153
+ hidden_states=output.hidden_states,
154
+ )
155
+
156
+ def step(
157
+ self,
158
+ idx: torch.Tensor,
159
+ state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None,
160
+ **kwargs,
161
+ ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]:
162
+ x = self.token_embedding(idx)
163
+ x = self.emb_dropout(x)
164
+ x, state = self.xlstm_block_stack.step(x, state=state, **kwargs)
165
+ logits = self.lm_head(x)
166
+ return logits, state
167
+
168
+ def _create_weight_decay_optim_groups(
169
+ self, **kwargs
170
+ ) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]:
171
+ weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups(
172
+ **kwargs
173
+ )
174
+ # remove token embedding and add it to the correct group, accrording to the config
175
+ weight_decay = list(weight_decay)
176
+ removed = 0
177
+ for idx in range(len(weight_decay)):
178
+ if weight_decay[idx - removed] is self.get_input_embeddings().weight:
179
+ weight_decay.pop(idx - removed)
180
+ removed += 1
181
+ weight_decay = tuple(weight_decay)
182
+
183
+ # TODO: Fix this
184
+ # if self.config.weight_decay_on_embedding:
185
+ if True:
186
+ weight_decay += (self.get_input_embeddings().weight,)
187
+ else:
188
+ no_weight_decay += (self.get_input_embeddings().weight,)
189
+
190
+ return weight_decay, no_weight_decay
191
+
192
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
193
+ new_embeddings = nn.Embedding(
194
+ new_num_tokens, self.token_embedding.embedding_dim
195
+ )
196
+ self.token_embedding = new_embeddings.to(self.device)
197
+ return new_embeddings
198
+
199
+ def tie_weights(self):
200
+ self.get_output_embeddings().weight = self.get_input_embeddings().weight
201
+
202
+ def prepare_inputs_for_generation(
203
+ self,
204
+ input_ids,
205
+ **kwargs,
206
+ ):
207
+ model_inputs = {
208
+ "input_ids": input_ids.to(self.device),
209
+ }
210
+ return model_inputs