Update chatNT.py
Browse files
chatNT.py
CHANGED
@@ -1330,6 +1330,9 @@ class MultiHeadAttention(nn.Module):
|
|
1330 |
)
|
1331 |
else:
|
1332 |
attention_weights = F.softmax(attention_weights, dim=-1)
|
|
|
|
|
|
|
1333 |
value_out = torch.einsum(
|
1334 |
"...htT, ...Thd->...thd", attention_weights, value_heads
|
1335 |
)
|
|
|
1330 |
)
|
1331 |
else:
|
1332 |
attention_weights = F.softmax(attention_weights, dim=-1)
|
1333 |
+
|
1334 |
+
print(f"Attention weights : {attention_weights.dtype}")
|
1335 |
+
print(f"Value heads : {value_heads.dtype}")
|
1336 |
value_out = torch.einsum(
|
1337 |
"...htT, ...Thd->...thd", attention_weights, value_heads
|
1338 |
)
|