Yanisadel commited on
Commit
3bd05b8
·
verified ·
1 Parent(s): b9b892c

Update chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +3 -0
chatNT.py CHANGED
@@ -1330,6 +1330,9 @@ class MultiHeadAttention(nn.Module):
1330
  )
1331
  else:
1332
  attention_weights = F.softmax(attention_weights, dim=-1)
 
 
 
1333
  value_out = torch.einsum(
1334
  "...htT, ...Thd->...thd", attention_weights, value_heads
1335
  )
 
1330
  )
1331
  else:
1332
  attention_weights = F.softmax(attention_weights, dim=-1)
1333
+
1334
+ print(f"Attention weights : {attention_weights.dtype}")
1335
+ print(f"Value heads : {value_heads.dtype}")
1336
  value_out = torch.einsum(
1337
  "...htT, ...Thd->...thd", attention_weights, value_heads
1338
  )