Spaces:
Sleeping
Sleeping
Create discriminatorModel.py
Browse files- discriminatorModel.py +37 -0
discriminatorModel.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
# The Discriminator model
|
6 |
+
class Discriminator(nn.Module):
|
7 |
+
def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
|
8 |
+
super(Discriminator, self).__init__()
|
9 |
+
self.channels = channels
|
10 |
+
self.embed_dim = embed_dim
|
11 |
+
self.embed_out_dim = embed_out_dim
|
12 |
+
|
13 |
+
# Discriminator architecture
|
14 |
+
self.model = nn.Sequential(
|
15 |
+
*self._create_layer(self.channels, 32, 4, 2, 1, normalize=False),
|
16 |
+
*self._create_layer(32, 64, 4, 2, 1),
|
17 |
+
*self._create_layer(64, 128, 4, 2, 1),
|
18 |
+
*self._create_layer(128, 256, 4, 2, 1),
|
19 |
+
*self._create_layer(256, 512, 4, 2, 1)
|
20 |
+
)
|
21 |
+
self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim) # Text embedding module
|
22 |
+
self.output = nn.Sequential(
|
23 |
+
nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
|
24 |
+
)
|
25 |
+
|
26 |
+
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
|
27 |
+
layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
|
28 |
+
if normalize:
|
29 |
+
layers.append(nn.BatchNorm2d(size_out))
|
30 |
+
layers.append(nn.LeakyReLU(0.2, inplace=True))
|
31 |
+
return layers
|
32 |
+
|
33 |
+
def forward(self, x, text):
|
34 |
+
x_out = self.model(x) # Extract features from the input using the discriminator architecture
|
35 |
+
out = self.text_embedding(x_out, text) # Apply text embedding and concatenate with the input features
|
36 |
+
out = self.output(out) # Final discriminator output
|
37 |
+
return out.squeeze(), x_out
|