tincri commited on
Commit
055ba7e
·
1 Parent(s): 2c87f60

Fix #9 app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -107,7 +107,7 @@ class MappingNetwork(nn.Module):
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
110
- def __init__(self, img_size=256, style_dim=64, num_domains=3, max_conv_dim=512):
111
  super().__init__()
112
  dim_in = 64
113
  blocks = []
@@ -117,20 +117,23 @@ class StyleEncoder(nn.Module):
117
  dim_out = min(dim_in*2, max_conv_dim)
118
  blocks += [ResBlk(dim_in, dim_out, downsample=True)]
119
  dim_in = dim_out
 
120
  self.shared = nn.Sequential(*blocks)
 
121
  self.unshared = nn.ModuleList()
122
  for _ in range(num_domains):
123
- self.unshared += [nn.Linear(dim_in * (img_size // (2**repeat_num))**2, style_dim)]
124
 
125
  def forward(self, x, y):
126
  h = self.shared(x)
 
127
  h = h.view(h.size(0), -1)
128
  out = []
129
  for layer in self.unshared:
130
  out += [layer(h)]
131
- out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)
132
- idx = torch.LongTensor(range(y.size(0))).unsqueeze(1).to(y.device)
133
- s = torch.gather(out, 1, idx.unsqueeze(2).expand(-1, -1, out.size(2))).squeeze(1)
134
  return s
135
 
136
  # DEFINICIÓN DEL GENERADOR
 
107
  return s
108
 
109
  class StyleEncoder(nn.Module):
110
+ def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
111
  super().__init__()
112
  dim_in = 64
113
  blocks = []
 
117
  dim_out = min(dim_in*2, max_conv_dim)
118
  blocks += [ResBlk(dim_in, dim_out, downsample=True)]
119
  dim_in = dim_out
120
+ blocks += [nn.LeakyReLU(0.2)]
121
  self.shared = nn.Sequential(*blocks)
122
+
123
  self.unshared = nn.ModuleList()
124
  for _ in range(num_domains):
125
+ self.unshared += [nn.Linear(dim_in, style_dim)]
126
 
127
  def forward(self, x, y):
128
  h = self.shared(x)
129
+ h = F.adaptive_avg_pool2d(h, (1,1))
130
  h = h.view(h.size(0), -1)
131
  out = []
132
  for layer in self.unshared:
133
  out += [layer(h)]
134
+ out = torch.stack(out, dim=1)
135
+ idx = torch.arange(y.size(0)).to(y.device)
136
+ s = out[idx, y]
137
  return s
138
 
139
  # DEFINICIÓN DEL GENERADOR