davda54 commited on
Commit
20d2498
·
verified ·
1 Parent(s): f1bcb7e

Simplified the model by always computing batch-first

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +2 -2
modeling_norbert.py CHANGED
@@ -156,8 +156,8 @@ class Attention(nn.Module):
156
  value = self.in_proj_v(hidden_states) # shape: [B, T, D]
157
 
158
  # Reshape to [B, num_heads, T, head_size]
159
- query = query.reshape(batch_size, query_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
160
- key = key.reshape(batch_size, key_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
161
  value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
162
 
163
  # Compute relative positional contributions
 
156
  value = self.in_proj_v(hidden_states) # shape: [B, T, D]
157
 
158
  # Reshape to [B, num_heads, T, head_size]
159
+ query = query.view(batch_size, query_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
160
+ key = key.view(batch_size, key_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
161
  value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
162
 
163
  # Compute relative positional contributions