Yanisadel commited on
Commit
112bf64
·
verified ·
1 Parent(s): 6d6a20f

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +4 -2
chatNT.py CHANGED
@@ -925,7 +925,11 @@ class TorchGptGroupedQueryAttention(nn.Module):
925
  )
926
 
927
  attention_weights = nn.functional.softmax(attention_logits, dim=-1)
 
928
 
 
 
 
929
  values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
930
  values = values.contiguous().view(batch_size, seq_len, -1)
931
 
@@ -1334,8 +1338,6 @@ class MultiHeadAttention(nn.Module):
1334
  else:
1335
  attention_weights = F.softmax(attention_weights, dim=-1)
1336
 
1337
- print(f"Attention weights : {attention_weights.dtype}")
1338
- print(f"Value heads : {value_heads.dtype}")
1339
  value_out = torch.einsum(
1340
  "...htT, ...Thd->...thd", attention_weights, value_heads
1341
  )
 
925
  )
926
 
927
  attention_weights = nn.functional.softmax(attention_logits, dim=-1)
928
+ attention_weights = attention_weights.to(values.dtype)
929
 
930
+ print(f"Attention weights type : ", attention_weights.dtype)
931
+ print(f"Values type : ", values.dtype)
932
+
933
  values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
934
  values = values.contiguous().view(batch_size, seq_len, -1)
935
 
 
1338
  else:
1339
  attention_weights = F.softmax(attention_weights, dim=-1)
1340
 
 
 
1341
  value_out = torch.einsum(
1342
  "...htT, ...Thd->...thd", attention_weights, value_heads
1343
  )