marriola commited on
Commit
4e01d79
·
verified ·
1 Parent(s): 9368cea

Upload BD3LM

Browse files
Files changed (1) hide show
  1. modeling_bd3lm.py +12 -1
modeling_bd3lm.py CHANGED
@@ -16,6 +16,14 @@ try:
16
  FLEX_ATTN_AVAILABLE = True
17
  except:
18
  FLEX_ATTN_AVAILABLE = False
 
 
 
 
 
 
 
 
19
 
20
  from .configuration_bd3lm import BD3LMConfig
21
 
@@ -69,6 +77,7 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
69
  def fused_flex_attention(q, k, v, mask=None):
70
  return flex_attention(q, k, v, block_mask=mask)
71
 
 
72
  def bias_dropout_add_scale(
73
  x: torch.Tensor,
74
  bias: typing.Optional[torch.Tensor],
@@ -93,13 +102,13 @@ def get_bias_dropout_add_scale(training):
93
 
94
  return _bias_dropout_add
95
 
96
-
97
  # function overload
98
  def modulate(x: torch.Tensor,
99
  shift: torch.Tensor,
100
  scale: torch.Tensor) -> torch.Tensor:
101
  return x * (1 + scale) + shift
102
 
 
103
  def bias_dropout_add_scale_fused_train(
104
  x: torch.Tensor,
105
  bias: typing.Optional[torch.Tensor],
@@ -109,6 +118,7 @@ def bias_dropout_add_scale_fused_train(
109
  return bias_dropout_add_scale(
110
  x, bias, scale, residual, prob, True)
111
 
 
112
  def bias_dropout_add_scale_fused_inference(
113
  x: torch.Tensor,
114
  bias: typing.Optional[torch.Tensor],
@@ -118,6 +128,7 @@ def bias_dropout_add_scale_fused_inference(
118
  return bias_dropout_add_scale(
119
  x, bias, scale, residual, prob, False)
120
 
 
121
  def modulate_fused(x: torch.Tensor,
122
  shift: torch.Tensor,
123
  scale: torch.Tensor) -> torch.Tensor:
 
16
  FLEX_ATTN_AVAILABLE = True
17
  except:
18
  FLEX_ATTN_AVAILABLE = False
19
+ # Flags required to enable jit fusion kernels
20
+ try:
21
+ torch._C._jit_set_profiling_mode(False)
22
+ torch._C._jit_set_profiling_executor(False)
23
+ torch._C._jit_override_can_fuse_on_cpu(True)
24
+ torch._C._jit_override_fcan_fuse_on_gpu(True)
25
+ except:
26
+ pass
27
 
28
  from .configuration_bd3lm import BD3LMConfig
29
 
 
77
  def fused_flex_attention(q, k, v, mask=None):
78
  return flex_attention(q, k, v, block_mask=mask)
79
 
80
+
81
  def bias_dropout_add_scale(
82
  x: torch.Tensor,
83
  bias: typing.Optional[torch.Tensor],
 
102
 
103
  return _bias_dropout_add
104
 
 
105
  # function overload
106
  def modulate(x: torch.Tensor,
107
  shift: torch.Tensor,
108
  scale: torch.Tensor) -> torch.Tensor:
109
  return x * (1 + scale) + shift
110
 
111
+ @torch.jit.script
112
  def bias_dropout_add_scale_fused_train(
113
  x: torch.Tensor,
114
  bias: typing.Optional[torch.Tensor],
 
118
  return bias_dropout_add_scale(
119
  x, bias, scale, residual, prob, True)
120
 
121
+ @torch.jit.script
122
  def bias_dropout_add_scale_fused_inference(
123
  x: torch.Tensor,
124
  bias: typing.Optional[torch.Tensor],
 
128
  return bias_dropout_add_scale(
129
  x, bias, scale, residual, prob, False)
130
 
131
+ @torch.jit.script
132
  def modulate_fused(x: torch.Tensor,
133
  shift: torch.Tensor,
134
  scale: torch.Tensor) -> torch.Tensor: