Update modeling_nort5.py
Browse files- modeling_nort5.py +2 -2
modeling_nort5.py
CHANGED
@@ -221,7 +221,7 @@ class Attention(nn.Module):
|
|
221 |
- torch.arange(512, dtype=torch.long).unsqueeze(0)
|
222 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
|
223 |
position_indices = config.position_bucket_size - 1 + position_indices
|
224 |
-
self.register_buffer("position_indices", position_indices, persistent=
|
225 |
|
226 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
227 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
@@ -271,7 +271,7 @@ class Attention(nn.Module):
|
|
271 |
- torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
|
272 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
273 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
274 |
-
self.register_buffer("position_indices", position_indices.to(q.device), persistent=
|
275 |
|
276 |
q = self.pre_layer_norm(q)
|
277 |
query = self.in_proj_q(q) # shape: [T, B, D]
|
|
|
221 |
- torch.arange(512, dtype=torch.long).unsqueeze(0)
|
222 |
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
|
223 |
position_indices = config.position_bucket_size - 1 + position_indices
|
224 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
225 |
|
226 |
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
227 |
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
|
|
271 |
- torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
|
272 |
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
273 |
position_indices = self.config.position_bucket_size - 1 + position_indices
|
274 |
+
self.register_buffer("position_indices", position_indices.to(q.device), persistent=False)
|
275 |
|
276 |
q = self.pre_layer_norm(q)
|
277 |
query = self.in_proj_q(q) # shape: [T, B, D]
|