Update modeling_nort5.py
Browse files- modeling_nort5.py +9 -8
modeling_nort5.py
CHANGED
@@ -62,7 +62,7 @@ class Decoder(nn.Module):
|
|
62 |
self_relative_embedding = self.self_relative_embedding()
|
63 |
cross_relative_embedding = self.cross_relative_embedding()
|
64 |
|
65 |
-
if past_key_values is
|
66 |
autoreg_mask = torch.triu(
|
67 |
torch.full((x.size(0), x.size(0)), True, device=x.device),
|
68 |
diagonal=1
|
@@ -259,12 +259,12 @@ class Attention(nn.Module):
|
|
259 |
|
260 |
if past_key_value is not None:
|
261 |
if not self.is_cross_attention:
|
262 |
-
key = torch.cat([past_key_value[0], key], dim=1)
|
263 |
-
value = torch.cat([past_key_value[1], value], dim=1)
|
264 |
key_len = key.size(1)
|
265 |
elif past_key_value[0].size(1) == kv.size(0):
|
266 |
-
key = past_key_value[0]
|
267 |
-
value = past_key_value[1]
|
268 |
|
269 |
if self.position_indices.size(0) < max(query_len, key_len):
|
270 |
position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
|
@@ -306,7 +306,10 @@ class Attention(nn.Module):
|
|
306 |
context = self.post_layer_norm(context)
|
307 |
context = self.dropout(context)
|
308 |
|
309 |
-
|
|
|
|
|
|
|
310 |
|
311 |
|
312 |
class WordEmbedding(nn.Module):
|
@@ -662,9 +665,7 @@ class NorT5ForConditionalGeneration(NorT5Model):
|
|
662 |
reordered_layer_past_states = ()
|
663 |
for layer_past_state in layer_past_states:
|
664 |
# need to set correct `past` for each of the four key / value states
|
665 |
-
layer_past_state = layer_past_state.unflatten(0, (-1, self.config.num_attention_heads))
|
666 |
layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
|
667 |
-
layer_past_state = layer_past_state.flatten(0, 1)
|
668 |
reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
|
669 |
|
670 |
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
|
|
62 |
self_relative_embedding = self.self_relative_embedding()
|
63 |
cross_relative_embedding = self.cross_relative_embedding()
|
64 |
|
65 |
+
if past_key_values is None:
|
66 |
autoreg_mask = torch.triu(
|
67 |
torch.full((x.size(0), x.size(0)), True, device=x.device),
|
68 |
diagonal=1
|
|
|
259 |
|
260 |
if past_key_value is not None:
|
261 |
if not self.is_cross_attention:
|
262 |
+
key = torch.cat([past_key_value[0].flatten(0, 1), key], dim=1)
|
263 |
+
value = torch.cat([past_key_value[1].flatten(0, 1), value], dim=1)
|
264 |
key_len = key.size(1)
|
265 |
elif past_key_value[0].size(1) == kv.size(0):
|
266 |
+
key = past_key_value[0].flatten(0, 1)
|
267 |
+
value = past_key_value[1].flatten(0, 1)
|
268 |
|
269 |
if self.position_indices.size(0) < max(query_len, key_len):
|
270 |
position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
|
|
|
306 |
context = self.post_layer_norm(context)
|
307 |
context = self.dropout(context)
|
308 |
|
309 |
+
key = key.detach().unflatten(0, (-1, self.num_heads))
|
310 |
+
value = value.detach().unflatten(0, (-1, self.num_heads))
|
311 |
+
|
312 |
+
return context, attention_probs.detach(), (key, value)
|
313 |
|
314 |
|
315 |
class WordEmbedding(nn.Module):
|
|
|
665 |
reordered_layer_past_states = ()
|
666 |
for layer_past_state in layer_past_states:
|
667 |
# need to set correct `past` for each of the four key / value states
|
|
|
668 |
layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
|
|
|
669 |
reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
|
670 |
|
671 |
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|