bryandts commited on
Commit
e8c19db
·
verified ·
1 Parent(s): d9b1896

Create discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +37 -0
discriminatorModel.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # The Discriminator model
6
+ class Discriminator(nn.Module):
7
+ def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
8
+ super(Discriminator, self).__init__()
9
+ self.channels = channels
10
+ self.embed_dim = embed_dim
11
+ self.embed_out_dim = embed_out_dim
12
+
13
+ # Discriminator architecture
14
+ self.model = nn.Sequential(
15
+ *self._create_layer(self.channels, 32, 4, 2, 1, normalize=False),
16
+ *self._create_layer(32, 64, 4, 2, 1),
17
+ *self._create_layer(64, 128, 4, 2, 1),
18
+ *self._create_layer(128, 256, 4, 2, 1),
19
+ *self._create_layer(256, 512, 4, 2, 1)
20
+ )
21
+ self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim) # Text embedding module
22
+ self.output = nn.Sequential(
23
+ nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
24
+ )
25
+
26
+ def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
27
+ layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
28
+ if normalize:
29
+ layers.append(nn.BatchNorm2d(size_out))
30
+ layers.append(nn.LeakyReLU(0.2, inplace=True))
31
+ return layers
32
+
33
+ def forward(self, x, text):
34
+ x_out = self.model(x) # Extract features from the input using the discriminator architecture
35
+ out = self.text_embedding(x_out, text) # Apply text embedding and concatenate with the input features
36
+ out = self.output(out) # Final discriminator output
37
+ return out.squeeze(), x_out