debisoft commited on
Commit
a43d8a8
·
1 Parent(s): defbb61

Added output_attentions: bool=False to GroupedQueryAttention.forward() as a temporary fix for AWQ

Browse files
Files changed (1) hide show
  1. attention.py +1 -1
attention.py CHANGED
@@ -260,7 +260,7 @@ class GroupedQueryAttention(nn.Module):
260
  self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
261
  self.out_proj._is_residual = True
262
 
263
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
264
  qkv = self.Wqkv(x)
265
  if self.clip_qkv:
266
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
 
260
  self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
261
  self.out_proj._is_residual = True
262
 
263
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, output_attentions: bool=False, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
264
  qkv = self.Wqkv(x)
265
  if self.clip_qkv:
266
  qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)