|
|
|
|
|
from inspect import getattr_static |
|
|
|
from ..bytecode_transformation import create_call_function |
|
from ..exc import Unsupported |
|
from .base import VariableTracker |
|
|
|
|
|
class SDPAParamsVariable(VariableTracker): |
|
"""Represents the c++ params struct for scaled dot product attention. |
|
This is a read-only container.""" |
|
|
|
@staticmethod |
|
def create(tx, value, source): |
|
from torch.backends.cuda import SDPAParams |
|
from ..source import AttrSource |
|
from .builder import VariableBuilder |
|
from .torch import TorchInGraphFunctionVariable |
|
|
|
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) |
|
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) |
|
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) |
|
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( |
|
value.attn_mask |
|
) |
|
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) |
|
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( |
|
value.is_causal |
|
) |
|
param_vars = [ |
|
query_var, |
|
key_var, |
|
value_var, |
|
attn_mask_var, |
|
dropout_var, |
|
is_causal_var, |
|
] |
|
return TorchInGraphFunctionVariable(SDPAParams).call_function( |
|
tx, param_vars, {} |
|
) |
|
|
|
def __init__(self, proxy, param_vars, **kwargs): |
|
self.proxy = proxy |
|
self.param_vars = param_vars |
|
super().__init__(**kwargs) |
|
|
|
def reconstruct(self, codegen): |
|
assert self.source is None |
|
assert self.param_vars is not None |
|
codegen.load_import_from("torch._C", "_SDPAParams") |
|
codegen.foreach(self.param_vars) |
|
codegen.extend_output(create_call_function(len(self.param_vars), True)) |
|
|
|
def as_proxy(self): |
|
return self.proxy |
|
|
|
def var_getattr(self, tx, name: str) -> VariableTracker: |
|
import torch._C |
|
from ..source import AttrSource |
|
from .builder import wrap_fx_proxy |
|
from .misc import GetAttrVariable |
|
|
|
try: |
|
getattr_static(torch._C._SDPAParams, name) |
|
except AttributeError: |
|
|
|
raise Unsupported( |
|
f"Unsupported torch._C._SDPAParams attribute {name}" |
|
) from None |
|
|
|
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) |
|
if self.source is not None: |
|
return wrap_fx_proxy( |
|
tx=tx, proxy=proxy, source=AttrSource(self.source, name) |
|
) |
|
else: |
|
return wrap_fx_proxy(tx=tx, proxy=proxy) |
|
|
|
@staticmethod |
|
def is_sdpa_params(value): |
|
from torch.backends.cuda import SDPAParams |
|
|
|
return value is SDPAParams |
|
|