lhallee commited on
Commit
2e03489
·
verified ·
1 Parent(s): e93e025

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +5 -8
modeling_esm_plusplus.py CHANGED
@@ -316,15 +316,12 @@ class MultiHeadAttention(nn.Module):
316
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
317
 
318
  if output_attentions: # Manual attention computation
319
- L, S = query_BLD.size(-2), key_BLD.size(-2)
320
- scale = 1 / math.sqrt(query_BLD.size(-1))
321
- attn_bias = torch.zeros(L, S, dtype=query_BLD.dtype, device=query_BLD.device)
322
  if attention_mask is not None:
323
- if attention_mask.dtype == torch.bool:
324
- attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
325
- else:
326
- attn_bias += attention_mask
327
-
328
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
329
  attn_weights += attn_bias
330
  attn_weights = F.softmax(attn_weights, dim=-1)
 
316
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
317
 
318
  if output_attentions: # Manual attention computation
319
+ b, L, d = x.shape
320
+ scale = 1 / math.sqrt(d)
321
+ attn_bias = torch.zeros(b, 1, L, L, dtype=query_BLD.dtype, device=query_BLD.device)
322
  if attention_mask is not None:
323
+ attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
324
+
 
 
 
325
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
326
  attn_weights += attn_bias
327
  attn_weights = F.softmax(attn_weights, dim=-1)