Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +1 -1
modeling_esm_plusplus.py
CHANGED
@@ -321,7 +321,7 @@ class MultiHeadAttention(nn.Module):
|
|
321 |
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
322 |
if attention_mask is not None:
|
323 |
if attention_mask.dtype == torch.bool:
|
324 |
-
|
325 |
else:
|
326 |
attn_bias += attention_mask
|
327 |
|
|
|
321 |
attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
|
322 |
if attention_mask is not None:
|
323 |
if attention_mask.dtype == torch.bool:
|
324 |
+
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
325 |
else:
|
326 |
attn_bias += attention_mask
|
327 |
|