Spaces:
Running
on
Zero
Running
on
Zero
Update hy3dgen/shapegen/models/conditioner.py
Browse files
hy3dgen/shapegen/models/conditioner.py
CHANGED
@@ -103,7 +103,7 @@ class ImageEncoder(nn.Module):
|
|
103 |
|
104 |
return last_hidden_state
|
105 |
|
106 |
-
def unconditional_embedding(self, batch_size):
|
107 |
device = next(self.model.parameters()).device
|
108 |
dtype = next(self.model.parameters()).dtype
|
109 |
zero = torch.zeros(
|
@@ -159,9 +159,6 @@ class DinoImageEncoderMV(DinoImageEncoder):
|
|
159 |
image = image.to(self.model.device, dtype=self.model.dtype)
|
160 |
|
161 |
bs, num_views, c, h, w = image.shape
|
162 |
-
# TODO: find a better place to set view_num?
|
163 |
-
self.view_num = num_views
|
164 |
-
|
165 |
image = image.view(bs * num_views, c, h, w)
|
166 |
|
167 |
inputs = self.transform(image)
|
@@ -190,12 +187,12 @@ class DinoImageEncoderMV(DinoImageEncoder):
|
|
190 |
last_hidden_state.shape[-1])
|
191 |
return last_hidden_state
|
192 |
|
193 |
-
def unconditional_embedding(self, batch_size):
|
194 |
device = next(self.model.parameters()).device
|
195 |
dtype = next(self.model.parameters()).dtype
|
196 |
zero = torch.zeros(
|
197 |
batch_size,
|
198 |
-
self.num_patches *
|
199 |
self.model.config.hidden_size,
|
200 |
device=device,
|
201 |
dtype=dtype,
|
@@ -224,17 +221,17 @@ class DualImageEncoder(nn.Module):
|
|
224 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
225 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
226 |
|
227 |
-
def forward(self, image, mask=None):
|
228 |
outputs = {
|
229 |
-
'main': self.main_image_encoder(image, mask=mask),
|
230 |
-
'additional': self.additional_image_encoder(image, mask=mask),
|
231 |
}
|
232 |
return outputs
|
233 |
|
234 |
-
def unconditional_embedding(self, batch_size):
|
235 |
outputs = {
|
236 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
237 |
-
'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
|
238 |
}
|
239 |
return outputs
|
240 |
|
@@ -253,8 +250,8 @@ class SingleImageEncoder(nn.Module):
|
|
253 |
}
|
254 |
return outputs
|
255 |
|
256 |
-
def unconditional_embedding(self, batch_size):
|
257 |
outputs = {
|
258 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
259 |
}
|
260 |
return outputs
|
|
|
103 |
|
104 |
return last_hidden_state
|
105 |
|
106 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
107 |
device = next(self.model.parameters()).device
|
108 |
dtype = next(self.model.parameters()).dtype
|
109 |
zero = torch.zeros(
|
|
|
159 |
image = image.to(self.model.device, dtype=self.model.dtype)
|
160 |
|
161 |
bs, num_views, c, h, w = image.shape
|
|
|
|
|
|
|
162 |
image = image.view(bs * num_views, c, h, w)
|
163 |
|
164 |
inputs = self.transform(image)
|
|
|
187 |
last_hidden_state.shape[-1])
|
188 |
return last_hidden_state
|
189 |
|
190 |
+
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
191 |
device = next(self.model.parameters()).device
|
192 |
dtype = next(self.model.parameters()).dtype
|
193 |
zero = torch.zeros(
|
194 |
batch_size,
|
195 |
+
self.num_patches * len(view_idxs[0]),
|
196 |
self.model.config.hidden_size,
|
197 |
device=device,
|
198 |
dtype=dtype,
|
|
|
221 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
222 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
223 |
|
224 |
+
def forward(self, image, mask=None, **kwargs):
|
225 |
outputs = {
|
226 |
+
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
227 |
+
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
|
228 |
}
|
229 |
return outputs
|
230 |
|
231 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
232 |
outputs = {
|
233 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
234 |
+
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
235 |
}
|
236 |
return outputs
|
237 |
|
|
|
250 |
}
|
251 |
return outputs
|
252 |
|
253 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
254 |
outputs = {
|
255 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
256 |
}
|
257 |
return outputs
|