Simplified the model by always computing batch-first
Browse files- modeling_norbert.py +2 -2
modeling_norbert.py
CHANGED
@@ -156,8 +156,8 @@ class Attention(nn.Module):
|
|
156 |
value = self.in_proj_v(hidden_states) # shape: [B, T, D]
|
157 |
|
158 |
# Reshape to [B, num_heads, T, head_size]
|
159 |
-
query = query.
|
160 |
-
key = key.
|
161 |
value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
|
162 |
|
163 |
# Compute relative positional contributions
|
|
|
156 |
value = self.in_proj_v(hidden_states) # shape: [B, T, D]
|
157 |
|
158 |
# Reshape to [B, num_heads, T, head_size]
|
159 |
+
query = query.view(batch_size, query_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_q, head_size]
|
160 |
+
key = key.view(batch_size, key_len, self.num_heads, self.head_size).permute(0, 2, 3, 1) # shape: [B, num_heads, head_size, T_k]
|
161 |
value = value.view(batch_size, key_len, self.num_heads, self.head_size).transpose(1, 2) # shape: [B, num_heads, T_k, head_size]
|
162 |
|
163 |
# Compute relative positional contributions
|