Qwen
/

logicwong commited on
Commit
f157e4e
·
1 Parent(s): da187eb

Update modeling_qwen.py, fix logn bug

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +6 -5
modeling_qwen.py CHANGED
@@ -177,7 +177,8 @@ class QWenAttention(nn.Module):
177
  config.hidden_size, self.projection_size, bias=not config.no_bias
178
  )
179
 
180
- if self.use_flash_attn and flash_attn_unpadded_func is not None:
 
181
  self.core_attention_flash = FlashSelfAttention(
182
  causal=True, attention_dropout=config.attn_pdrop
183
  )
@@ -371,12 +372,12 @@ class QWenAttention(nn.Module):
371
  if self.use_logn_attn and not self.training:
372
  if self.logn_tensor.device != query.device:
373
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
374
- seq_start = key.size(0) - query.size(0)
375
- seq_end = key.size(0)
376
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
377
  query = query * logn_tensor.expand_as(query)
378
 
379
- if self.use_flash_attn and flash_attn_unpadded_func is not None:
380
  q, k, v = query, key, value
381
  context_layer = self.core_attention_flash(q, k, v)
382
 
@@ -397,7 +398,7 @@ class QWenAttention(nn.Module):
397
  attn_output = self.c_proj(context_layer)
398
  outputs = (attn_output, present)
399
  if output_attentions:
400
- if self.use_flash_attn and flash_attn_unpadded_func is not None:
401
  raise ValueError("Cannot output attentions while using flash-attn")
402
  else:
403
  outputs += (attn_weight,)
 
177
  config.hidden_size, self.projection_size, bias=not config.no_bias
178
  )
179
 
180
+ self.is_fp32 = not(config.bf16 or config.fp16)
181
+ if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
182
  self.core_attention_flash = FlashSelfAttention(
183
  causal=True, attention_dropout=config.attn_pdrop
184
  )
 
372
  if self.use_logn_attn and not self.training:
373
  if self.logn_tensor.device != query.device:
374
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
375
+ seq_start = key.size(1) - query.size(1)
376
+ seq_end = key.size(1)
377
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
378
  query = query * logn_tensor.expand_as(query)
379
 
380
+ if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
381
  q, k, v = query, key, value
382
  context_layer = self.core_attention_flash(q, k, v)
383
 
 
398
  attn_output = self.c_proj(context_layer)
399
  outputs = (attn_output, present)
400
  if output_attentions:
401
+ if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
402
  raise ValueError("Cannot output attentions while using flash-attn")
403
  else:
404
  outputs += (attn_weight,)