twodgirl commited on
Commit
d213e90
1 Parent(s): 3890141

Create model.

Browse files
Files changed (1) hide show
  1. tea_model.py +63 -0
tea_model.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ # VAE Decoder but with half-size output.
5
+ # The last upsample is not there.
6
+
7
+ ###
8
+ # Code from madebyollin/taesd
9
+
10
+ class Recon(nn.Module):
11
+ def __init__(self, ch_in, ch_out):
12
+ super().__init__()
13
+ self.long = nn.Sequential(
14
+ nn.Conv2d(ch_in, ch_out, 3, padding=1),
15
+ nn.ReLU(),
16
+ nn.Conv2d(ch_out, ch_out, 3, padding=1),
17
+ nn.ReLU(),
18
+ nn.Conv2d(ch_out, ch_out, 3, padding=1)
19
+ )
20
+ if ch_in != ch_out:
21
+ self.short = nn.Conv2d(ch_in, ch_out, 1, bias=False)
22
+ else:
23
+ # The one without identity, a placeholder.
24
+ self.short = nn.Identity()
25
+ self.fuse = nn.ReLU()
26
+
27
+ def forward(self, x):
28
+ return self.fuse(self.long(x) + self.short(x))
29
+
30
+ class TeaEncoder(nn.Module):
31
+ def __init__(self, ch_in):
32
+ super().__init__()
33
+ self.block_in = nn.Sequential(
34
+ nn.Conv2d(ch_in, 64, 3, padding=1),
35
+ nn.ReLU()
36
+ )
37
+ self.middle = nn.Sequential(
38
+ *[Recon(64, 64) for _ in range(3)],
39
+ # Opposite of stride=2
40
+ nn.Upsample(scale_factor=2),
41
+ # It leads to a simpler model with fewer parameters.
42
+ # The output of the previous layers matches the number of channels specified in this line.
43
+ # The input to this layer is already well-represented by the feature maps from the previous layers,
44
+ # the bias may not add significant value.
45
+ nn.Conv2d(64, 64, 3, padding=1, bias=False),
46
+ # Final upscale to 1/2 size of the image.
47
+ *[Recon(64, 64) for _ in range(3)],
48
+ nn.Upsample(scale_factor=2),
49
+ nn.Conv2d(64, 64, 3, padding=1, bias=False),
50
+ )
51
+ self.block_out = nn.Sequential(
52
+ Recon(64, 64),
53
+ # Convert to RGB, regardless of the latent channels.
54
+ nn.Conv2d(64, 3, 3, padding=1),
55
+ )
56
+
57
+
58
+ def forward(self, x):
59
+ # Clamp the input values to a specific range, between -3 and 3.
60
+ clamped = torch.tanh(x / 1)
61
+ cooked = self.middle(self.block_in(clamped))
62
+
63
+ return self.block_out(cooked)