bryandts commited on
Commit
bd223fb
·
verified ·
1 Parent(s): 8072e11

Update discriminatorModel.py

Browse files
Files changed (1) hide show
  1. discriminatorModel.py +16 -0
discriminatorModel.py CHANGED
@@ -2,6 +2,22 @@
2
  import torch
3
  import torch.nn as nn
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # The Discriminator model
6
  class Discriminator(nn.Module):
7
  def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
 
2
  import torch
3
  import torch.nn as nn
4
 
5
+ # The Embedding model
6
+ class Embedding(nn.Module):
7
+ def __init__(self, size_in, size_out):
8
+ super(Embedding, self).__init__()
9
+ self.text_embedding = nn.Sequential(
10
+ nn.Linear(size_in, size_out),
11
+ nn.BatchNorm1d(1),
12
+ nn.LeakyReLU(0.2, inplace=True)
13
+ )
14
+
15
+ def forward(self, x, text):
16
+ embed_out = self.text_embedding(text)
17
+ embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
18
+ out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
19
+ return out
20
+
21
  # The Discriminator model
22
  class Discriminator(nn.Module):
23
  def __init__(self, channels, embed_dim=1024, embed_out_dim=128):