ZeqiangLai commited on
Commit
8c7926b
·
verified ·
1 Parent(s): 8839214

Update hy3dgen/shapegen/models/conditioner.py

Browse files
hy3dgen/shapegen/models/conditioner.py CHANGED
@@ -103,7 +103,7 @@ class ImageEncoder(nn.Module):
103
 
104
  return last_hidden_state
105
 
106
- def unconditional_embedding(self, batch_size):
107
  device = next(self.model.parameters()).device
108
  dtype = next(self.model.parameters()).dtype
109
  zero = torch.zeros(
@@ -159,9 +159,6 @@ class DinoImageEncoderMV(DinoImageEncoder):
159
  image = image.to(self.model.device, dtype=self.model.dtype)
160
 
161
  bs, num_views, c, h, w = image.shape
162
- # TODO: find a better place to set view_num?
163
- self.view_num = num_views
164
-
165
  image = image.view(bs * num_views, c, h, w)
166
 
167
  inputs = self.transform(image)
@@ -190,12 +187,12 @@ class DinoImageEncoderMV(DinoImageEncoder):
190
  last_hidden_state.shape[-1])
191
  return last_hidden_state
192
 
193
- def unconditional_embedding(self, batch_size):
194
  device = next(self.model.parameters()).device
195
  dtype = next(self.model.parameters()).dtype
196
  zero = torch.zeros(
197
  batch_size,
198
- self.num_patches * self.view_num,
199
  self.model.config.hidden_size,
200
  device=device,
201
  dtype=dtype,
@@ -224,17 +221,17 @@ class DualImageEncoder(nn.Module):
224
  self.main_image_encoder = build_image_encoder(main_image_encoder)
225
  self.additional_image_encoder = build_image_encoder(additional_image_encoder)
226
 
227
- def forward(self, image, mask=None):
228
  outputs = {
229
- 'main': self.main_image_encoder(image, mask=mask),
230
- 'additional': self.additional_image_encoder(image, mask=mask),
231
  }
232
  return outputs
233
 
234
- def unconditional_embedding(self, batch_size):
235
  outputs = {
236
- 'main': self.main_image_encoder.unconditional_embedding(batch_size),
237
- 'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
238
  }
239
  return outputs
240
 
@@ -253,8 +250,8 @@ class SingleImageEncoder(nn.Module):
253
  }
254
  return outputs
255
 
256
- def unconditional_embedding(self, batch_size):
257
  outputs = {
258
- 'main': self.main_image_encoder.unconditional_embedding(batch_size),
259
  }
260
  return outputs
 
103
 
104
  return last_hidden_state
105
 
106
+ def unconditional_embedding(self, batch_size, **kwargs):
107
  device = next(self.model.parameters()).device
108
  dtype = next(self.model.parameters()).dtype
109
  zero = torch.zeros(
 
159
  image = image.to(self.model.device, dtype=self.model.dtype)
160
 
161
  bs, num_views, c, h, w = image.shape
 
 
 
162
  image = image.view(bs * num_views, c, h, w)
163
 
164
  inputs = self.transform(image)
 
187
  last_hidden_state.shape[-1])
188
  return last_hidden_state
189
 
190
+ def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
191
  device = next(self.model.parameters()).device
192
  dtype = next(self.model.parameters()).dtype
193
  zero = torch.zeros(
194
  batch_size,
195
+ self.num_patches * len(view_idxs[0]),
196
  self.model.config.hidden_size,
197
  device=device,
198
  dtype=dtype,
 
221
  self.main_image_encoder = build_image_encoder(main_image_encoder)
222
  self.additional_image_encoder = build_image_encoder(additional_image_encoder)
223
 
224
+ def forward(self, image, mask=None, **kwargs):
225
  outputs = {
226
+ 'main': self.main_image_encoder(image, mask=mask, **kwargs),
227
+ 'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
228
  }
229
  return outputs
230
 
231
+ def unconditional_embedding(self, batch_size, **kwargs):
232
  outputs = {
233
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
234
+ 'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
235
  }
236
  return outputs
237
 
 
250
  }
251
  return outputs
252
 
253
+ def unconditional_embedding(self, batch_size, **kwargs):
254
  outputs = {
255
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
256
  }
257
  return outputs