Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +2 -3
modeling_esm_plusplus.py
CHANGED
@@ -316,12 +316,11 @@ class MultiHeadAttention(nn.Module):
|
|
316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
317 |
|
318 |
if output_attentions: # Manual attention computation
|
319 |
-
b,
|
320 |
scale = 1 / math.sqrt(d)
|
321 |
-
attn_bias = torch.zeros(b,
|
322 |
if attention_mask is not None:
|
323 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
324 |
-
|
325 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
326 |
attn_weights += attn_bias
|
327 |
attn_weights = F.softmax(attn_weights, dim=-1)
|
|
|
316 |
query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
|
317 |
|
318 |
if output_attentions: # Manual attention computation
|
319 |
+
b, h, l, d = query_BHLD.shape
|
320 |
scale = 1 / math.sqrt(d)
|
321 |
+
attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
|
322 |
if attention_mask is not None:
|
323 |
attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
|
|
|
324 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
|
325 |
attn_weights += attn_bias
|
326 |
attn_weights = F.softmax(attn_weights, dim=-1)
|