Pavithiran commited on
Commit
7698a6d
·
verified ·
1 Parent(s): c98cb1d

Create sagan_model.py

Browse files
Files changed (1) hide show
  1. sagan_model.py +98 -0
sagan_model.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sagan_model.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.utils import spectral_norm
6
+
7
+ # -------------------------
8
+ # Self-Attention Module
9
+ # -------------------------
10
+ class Self_Attn(nn.Module):
11
+ def __init__(self, in_dim):
12
+ super().__init__()
13
+ self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
14
+ self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
15
+ self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
16
+ self.gamma = nn.Parameter(torch.zeros(1))
17
+ self.softmax = nn.Softmax(dim=-1)
18
+
19
+ def forward(self, x):
20
+ B, C, W, H = x.size()
21
+ proj_q = self.query_conv(x).view(B, -1, W*H).permute(0,2,1)
22
+ proj_k = self.key_conv(x).view(B, -1, W*H)
23
+ energy = torch.bmm(proj_q, proj_k) # B×(WH)×(WH)
24
+ attention = self.softmax(energy)
25
+ proj_v = self.value_conv(x).view(B, -1, W*H)
26
+
27
+ out = torch.bmm(proj_v, attention.permute(0,2,1))
28
+ out = out.view(B, C, W, H)
29
+ return self.gamma * out + x
30
+
31
+ # -------------------------
32
+ # Generator & Discriminator
33
+ # -------------------------
34
+ class Generator(nn.Module):
35
+ def __init__(self, z_dim=128, img_channels=3, base_channels=64):
36
+ super().__init__()
37
+ self.net = nn.Sequential(
38
+ spectral_norm(nn.ConvTranspose2d(z_dim, base_channels*8, 4, 1, 0)),
39
+ nn.BatchNorm2d(base_channels*8),
40
+ nn.ReLU(True),
41
+
42
+ spectral_norm(nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, 2, 1)),
43
+ nn.BatchNorm2d(base_channels*4),
44
+ nn.ReLU(True),
45
+
46
+ # insert self‐attention at 32×32
47
+ Self_Attn(base_channels*4),
48
+
49
+ spectral_norm(nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, 2, 1)),
50
+ nn.BatchNorm2d(base_channels*2),
51
+ nn.ReLU(True),
52
+
53
+ spectral_norm(nn.ConvTranspose2d(base_channels*2, base_channels, 4, 2, 1)),
54
+ nn.BatchNorm2d(base_channels),
55
+ nn.ReLU(True),
56
+
57
+ spectral_norm(nn.ConvTranspose2d(base_channels, img_channels, 4, 2, 1)),
58
+ nn.Tanh()
59
+ )
60
+
61
+ def forward(self, z):
62
+ # Expect z shape: (B, z_dim, 1, 1)
63
+ return self.net(z)
64
+
65
+ class Discriminator(nn.Module):
66
+ def __init__(self, img_channels=3, base_channels=64):
67
+ super().__init__()
68
+ self.net = nn.Sequential(
69
+ spectral_norm(nn.Conv2d(img_channels, base_channels, 4, 2, 1)),
70
+ nn.LeakyReLU(0.1, True),
71
+
72
+ spectral_norm(nn.Conv2d(base_channels, base_channels*2, 4, 2, 1)),
73
+ nn.LeakyReLU(0.1, True),
74
+
75
+ # self‐attention at 32×32
76
+ Self_Attn(base_channels*2),
77
+
78
+ spectral_norm(nn.Conv2d(base_channels*2, base_channels*4, 4, 2, 1)),
79
+ nn.LeakyReLU(0.1, True),
80
+
81
+ spectral_norm(nn.Conv2d(base_channels*4, 1, 4, 1, 0))
82
+ )
83
+
84
+ def forward(self, x):
85
+ return self.net(x).view(-1)
86
+
87
+ # -------------------------
88
+ # High-Level Wrapper
89
+ # -------------------------
90
+ class SAGANModel(nn.Module):
91
+ def __init__(self, z_dim=128, img_channels=3, base_channels=64):
92
+ super().__init__()
93
+ self.gen = Generator(z_dim, img_channels, base_channels)
94
+ self.dis = Discriminator(img_channels, base_channels)
95
+
96
+ def forward(self, z):
97
+ # Only generator’s forward is typically used during inference
98
+ return self.gen(z)