ohayonguy
commited on
Commit
·
c50d79a
1
Parent(s):
eafdcb2
fixing kdiff
Browse files
arch/hourglass/axial_rope.py
CHANGED
@@ -21,7 +21,6 @@ def rotate_half(x):
|
|
21 |
return x.view(*shape, d * r)
|
22 |
|
23 |
|
24 |
-
@flags.compile_wrap
|
25 |
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
|
26 |
freqs = freqs.to(t)
|
27 |
rot_dim = freqs.shape[-1]
|
|
|
21 |
return x.view(*shape, d * r)
|
22 |
|
23 |
|
|
|
24 |
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
|
25 |
freqs = freqs.to(t)
|
26 |
rot_dim = freqs.shape[-1]
|
arch/hourglass/image_transformer_v2.py
CHANGED
@@ -87,7 +87,6 @@ def filter_params(function, module):
|
|
87 |
|
88 |
# Kernels
|
89 |
|
90 |
-
@flags.compile_wrap
|
91 |
def linear_geglu(x, weight, bias=None):
|
92 |
x = x @ weight.mT
|
93 |
if bias is not None:
|
@@ -96,7 +95,6 @@ def linear_geglu(x, weight, bias=None):
|
|
96 |
return x * F.gelu(gate)
|
97 |
|
98 |
|
99 |
-
@flags.compile_wrap
|
100 |
def rms_norm(x, scale, eps):
|
101 |
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
102 |
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
@@ -104,7 +102,6 @@ def rms_norm(x, scale, eps):
|
|
104 |
return x * scale.to(x.dtype)
|
105 |
|
106 |
|
107 |
-
@flags.compile_wrap
|
108 |
def scale_for_cosine_sim(q, k, scale, eps):
|
109 |
dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
|
110 |
sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
|
@@ -115,7 +112,6 @@ def scale_for_cosine_sim(q, k, scale, eps):
|
|
115 |
return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
|
116 |
|
117 |
|
118 |
-
@flags.compile_wrap
|
119 |
def scale_for_cosine_sim_qkv(qkv, scale, eps):
|
120 |
q, k, v = qkv.unbind(2)
|
121 |
q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
|
@@ -179,7 +175,6 @@ class AdaRMSNorm(nn.Module):
|
|
179 |
|
180 |
# Rotary position embeddings
|
181 |
|
182 |
-
@flags.compile_wrap
|
183 |
def apply_rotary_emb(x, theta, conj=False):
|
184 |
out_dtype = x.dtype
|
185 |
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
@@ -195,7 +190,6 @@ def apply_rotary_emb(x, theta, conj=False):
|
|
195 |
return torch.cat((y1, y2, x3), dim=-1)
|
196 |
|
197 |
|
198 |
-
@flags.compile_wrap
|
199 |
def _apply_rotary_emb_inplace(x, theta, conj):
|
200 |
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
201 |
d = theta.shape[-1]
|
|
|
87 |
|
88 |
# Kernels
|
89 |
|
|
|
90 |
def linear_geglu(x, weight, bias=None):
|
91 |
x = x @ weight.mT
|
92 |
if bias is not None:
|
|
|
95 |
return x * F.gelu(gate)
|
96 |
|
97 |
|
|
|
98 |
def rms_norm(x, scale, eps):
|
99 |
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
100 |
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
|
|
102 |
return x * scale.to(x.dtype)
|
103 |
|
104 |
|
|
|
105 |
def scale_for_cosine_sim(q, k, scale, eps):
|
106 |
dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
|
107 |
sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
|
|
|
112 |
return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
|
113 |
|
114 |
|
|
|
115 |
def scale_for_cosine_sim_qkv(qkv, scale, eps):
|
116 |
q, k, v = qkv.unbind(2)
|
117 |
q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
|
|
|
175 |
|
176 |
# Rotary position embeddings
|
177 |
|
|
|
178 |
def apply_rotary_emb(x, theta, conj=False):
|
179 |
out_dtype = x.dtype
|
180 |
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
|
|
190 |
return torch.cat((y1, y2, x3), dim=-1)
|
191 |
|
192 |
|
|
|
193 |
def _apply_rotary_emb_inplace(x, theta, conj):
|
194 |
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
195 |
d = theta.shape[-1]
|