data
Browse files
app.py
CHANGED
@@ -5,6 +5,33 @@ import torch.nn as nn
|
|
5 |
from einops import rearrange
|
6 |
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class RangeModel(nn.Module):
|
9 |
def __init__(self):
|
10 |
super(RangeModel, self).__init__()
|
|
|
5 |
from einops import rearrange
|
6 |
|
7 |
|
8 |
+
class Attn(nn.Module):
|
9 |
+
def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
|
10 |
+
super().__init__()
|
11 |
+
self.scale = dim_head ** -0.5
|
12 |
+
self.heads = heads
|
13 |
+
hidden_dim = dim_head * heads
|
14 |
+
self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
|
15 |
+
self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
|
16 |
+
#self.norm = nn.LayerNorm(dim)
|
17 |
+
self.to_out = nn.Linear(hidden_dim, dim)
|
18 |
+
|
19 |
+
def forward(self, x, text):
|
20 |
+
b, c, h, w = x.shape
|
21 |
+
kv = self.to_kv(text).chunk(2, dim = -1)
|
22 |
+
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
|
23 |
+
q = self.to_q(x)
|
24 |
+
q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
|
25 |
+
|
26 |
+
#attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
|
27 |
+
attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
28 |
+
attn = attn.softmax(dim=-1)
|
29 |
+
#print(attn.shape)
|
30 |
+
out = torch.matmul(attn, v)
|
31 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
32 |
+
#print(out.shape)
|
33 |
+
return self.to_out(out)
|
34 |
+
|
35 |
class RangeModel(nn.Module):
|
36 |
def __init__(self):
|
37 |
super(RangeModel, self).__init__()
|