Spaces:
Paused
Paused
add cis_2D preprocessing
Browse files
app.py
CHANGED
@@ -53,30 +53,29 @@ class GELU(nn.Module):
|
|
53 |
else:
|
54 |
return F.gelu(self.linear(x))
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
class Rope2D(nn.Module):
|
57 |
def __init__(self, dim, max_position_embeddings=1024, base=10000):
|
58 |
super().__init__()
|
59 |
-
|
60 |
-
self.
|
61 |
-
self.
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
67 |
-
|
68 |
-
def forward(self, x, seq_len=None):
|
69 |
-
if seq_len > self.max_seq_len_cached:
|
70 |
-
self.max_seq_len_cached = seq_len
|
71 |
-
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
72 |
-
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
73 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
74 |
-
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
75 |
-
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
76 |
-
return (
|
77 |
-
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
78 |
-
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
79 |
-
)
|
80 |
|
81 |
class VisionEncoder(nn.Module):
|
82 |
def __init__(self, config):
|
@@ -92,14 +91,13 @@ class VisionEncoder(nn.Module):
|
|
92 |
x = self.embed(pixel_values)
|
93 |
b, c, h, w = x.shape
|
94 |
x = x.flatten(2).transpose(1, 2)
|
95 |
-
|
96 |
for layer in self.layers:
|
97 |
x = layer(x)
|
98 |
x = self.norm(x)
|
99 |
x = self.gelu(x)
|
100 |
return x
|
101 |
|
102 |
-
|
103 |
class PixtralModel(nn.Module):
|
104 |
def __init__(self, params):
|
105 |
super().__init__()
|
|
|
53 |
else:
|
54 |
return F.gelu(self.linear(x))
|
55 |
|
56 |
+
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
|
57 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
58 |
+
h = torch.arange(height, device=freqs.device)
|
59 |
+
w = torch.arange(width, device=freqs.device)
|
60 |
+
|
61 |
+
freqs_h = torch.outer(h, freqs[::2]).float()
|
62 |
+
freqs_w = torch.outer(w, freqs[1::2]).float()
|
63 |
+
freqs_2d = torch.cat([
|
64 |
+
freqs_h[:, None, :].repeat(1, width, 1),
|
65 |
+
freqs_w[None, :, :].repeat(height, 1, 1),
|
66 |
+
], dim=-1)
|
67 |
+
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
|
68 |
+
|
69 |
class Rope2D(nn.Module):
|
70 |
def __init__(self, dim, max_position_embeddings=1024, base=10000):
|
71 |
super().__init__()
|
72 |
+
self.dim = dim
|
73 |
+
self.max_position_embeddings = max_position_embeddings
|
74 |
+
self.base = base
|
75 |
+
|
76 |
+
def forward(self, x, height, width):
|
77 |
+
freqs_cis = precompute_freqs_cis_2d(self.dim, height, width, self.base)
|
78 |
+
return freqs_cis.to(x.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
class VisionEncoder(nn.Module):
|
81 |
def __init__(self, config):
|
|
|
91 |
x = self.embed(pixel_values)
|
92 |
b, c, h, w = x.shape
|
93 |
x = x.flatten(2).transpose(1, 2)
|
94 |
+
freqs_cis = self.rope(x, h, w)
|
95 |
for layer in self.layers:
|
96 |
x = layer(x)
|
97 |
x = self.norm(x)
|
98 |
x = self.gelu(x)
|
99 |
return x
|
100 |
|
|
|
101 |
class PixtralModel(nn.Module):
|
102 |
def __init__(self, params):
|
103 |
super().__init__()
|