Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import torch.nn.functional as F | |
scaled_dot_product_attention = F.scaled_dot_product_attention | |
if os.environ.get('CA_USE_SAGEATTN', '0') == '1': | |
try: | |
from sageattention import sageattn | |
except ImportError: | |
raise ImportError('Please install the package "sageattention" to use this USE_SAGEATTN.') | |
scaled_dot_product_attention = sageattn | |
class CrossAttentionProcessor: | |
def __call__(self, attn, q, k, v): | |
out = scaled_dot_product_attention(q, k, v) | |
return out | |