lgcharpe commited on
Commit
afa86ca
·
verified ·
1 Parent(s): e8b9b83

Update modeling_nort5.py

Browse files
Files changed (1) hide show
  1. modeling_nort5.py +2 -2
modeling_nort5.py CHANGED
@@ -221,7 +221,7 @@ class Attention(nn.Module):
221
  - torch.arange(512, dtype=torch.long).unsqueeze(0)
222
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
223
  position_indices = config.position_bucket_size - 1 + position_indices
224
- self.register_buffer("position_indices", position_indices, persistent=True)
225
 
226
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
@@ -271,7 +271,7 @@ class Attention(nn.Module):
271
  - torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
272
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
273
  position_indices = self.config.position_bucket_size - 1 + position_indices
274
- self.register_buffer("position_indices", position_indices.to(q.device), persistent=True)
275
 
276
  q = self.pre_layer_norm(q)
277
  query = self.in_proj_q(q) # shape: [T, B, D]
 
221
  - torch.arange(512, dtype=torch.long).unsqueeze(0)
222
  position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
223
  position_indices = config.position_bucket_size - 1 + position_indices
224
+ self.register_buffer("position_indices", position_indices, persistent=False)
225
 
226
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
 
271
  - torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
272
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
273
  position_indices = self.config.position_bucket_size - 1 + position_indices
274
+ self.register_buffer("position_indices", position_indices.to(q.device), persistent=False)
275
 
276
  q = self.pre_layer_norm(q)
277
  query = self.in_proj_q(q) # shape: [T, B, D]