Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +5 -8
modeling_esm_plusplus.py
CHANGED
@@ -316,15 +316,12 @@ class MultiHeadAttention(nn.Module):
|
|
316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
317 |
|
318 |
if output_attentions: # Manual attention computation
|
319 |
-
L,
|
320 |
-
scale = 1 / math.sqrt(
|
321 |
-
attn_bias = torch.zeros(L,
|
322 |
if attention_mask is not None:
|
323 |
-
|
324 |
-
|
325 |
-
else:
|
326 |
-
attn_bias += attention_mask
|
327 |
-
|
328 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
329 |
attn_weights += attn_bias
|
330 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
317 |
|
318 |
if output_attentions: # Manual attention computation
|
319 |
+
b, L, d = x.shape
|
320 |
+
scale = 1 / math.sqrt(d)
|
321 |
+
attn_bias = torch.zeros(b, 1, L, L, dtype=query_BLD.dtype, device=query_BLD.device)
|
322 |
if attention_mask is not None:
|
323 |
+
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
324 |
+
|
|
|
|
|
|
|
325 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
326 |
attn_weights += attn_bias
|
327 |
attn_weights = F.softmax(attn_weights, dim=-1)
|