Vishu26 commited on
Commit
2bebf28
·
1 Parent(s): 44c57c9
Files changed (1) hide show
  1. app.py +27 -0
app.py CHANGED
@@ -5,6 +5,33 @@ import torch.nn as nn
5
  from einops import rearrange
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class RangeModel(nn.Module):
9
  def __init__(self):
10
  super(RangeModel, self).__init__()
 
5
  from einops import rearrange
6
 
7
 
8
+ class Attn(nn.Module):
9
+ def __init__(self, dim, dim_text, heads = 16, dim_head = 64):
10
+ super().__init__()
11
+ self.scale = dim_head ** -0.5
12
+ self.heads = heads
13
+ hidden_dim = dim_head * heads
14
+ self.to_q = nn.Conv2d(dim, hidden_dim, 4, bias = False, stride=4)
15
+ self.to_kv = nn.Linear(dim_text, hidden_dim * 2, bias=False)
16
+ #self.norm = nn.LayerNorm(dim)
17
+ self.to_out = nn.Linear(hidden_dim, dim)
18
+
19
+ def forward(self, x, text):
20
+ b, c, h, w = x.shape
21
+ kv = self.to_kv(text).chunk(2, dim = -1)
22
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv)
23
+ q = self.to_q(x)
24
+ q = rearrange(q, 'b (h c) x y -> b h (x y) c', h=self.heads)
25
+
26
+ #attn = torch.einsum('bhnd,bhed->bhne',q,k) * self.scale
27
+ attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
28
+ attn = attn.softmax(dim=-1)
29
+ #print(attn.shape)
30
+ out = torch.matmul(attn, v)
31
+ out = rearrange(out, 'b h n d -> b n (h d)')
32
+ #print(out.shape)
33
+ return self.to_out(out)
34
+
35
  class RangeModel(nn.Module):
36
  def __init__(self):
37
  super(RangeModel, self).__init__()