Alex Birch
commited on
prefer NamedTuple
Browse files- attention.py +1 -1
attention.py
CHANGED
@@ -121,7 +121,7 @@ def scaled_multihead_dot_product_attention(
|
|
121 |
out = attn_weight.matmul(v)
|
122 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
123 |
if needs_weights:
|
124 |
-
return (out, attn_weight)
|
125 |
return AttnFnOutput(out, None)
|
126 |
|
127 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|
|
|
121 |
out = attn_weight.matmul(v)
|
122 |
out = rearrange(out, 'b h s d -> b s (h d)')
|
123 |
if needs_weights:
|
124 |
+
return AttnFnOutput(out, attn_weight)
|
125 |
return AttnFnOutput(out, None)
|
126 |
|
127 |
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
|