lhallee commited on
Commit
688ced4
·
verified ·
1 Parent(s): 2e03489

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +2 -3
modeling_esm_plusplus.py CHANGED
@@ -316,12 +316,11 @@ class MultiHeadAttention(nn.Module):
316
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
317
 
318
  if output_attentions: # Manual attention computation
319
- b, L, d = x.shape
320
  scale = 1 / math.sqrt(d)
321
- attn_bias = torch.zeros(b, 1, L, L, dtype=query_BLD.dtype, device=query_BLD.device)
322
  if attention_mask is not None:
323
  attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
324
-
325
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
326
  attn_weights += attn_bias
327
  attn_weights = F.softmax(attn_weights, dim=-1)
 
316
  query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
317
 
318
  if output_attentions: # Manual attention computation
319
+ b, h, l, d = query_BHLD.shape
320
  scale = 1 / math.sqrt(d)
321
+ attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
322
  if attention_mask is not None:
323
  attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
 
324
  attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
325
  attn_weights += attn_bias
326
  attn_weights = F.softmax(attn_weights, dim=-1)