Spaces:
Running
on
Zero
Running
on
Zero
File size: 531 Bytes
04b20ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import os
import torch
import torch.nn.functional as F
scaled_dot_product_attention = F.scaled_dot_product_attention
if os.environ.get('CA_USE_SAGEATTN', '0') == '1':
try:
from sageattention import sageattn
except ImportError:
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.')
scaled_dot_product_attention = sageattn
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = scaled_dot_product_attention(q, k, v)
return out
|