wondervictor commited on
Commit
6cd385f
·
1 Parent(s): 76ecb1e

update README

Browse files
Files changed (1) hide show
  1. condition/midas/midas/vit.py +46 -10
condition/midas/midas/vit.py CHANGED
@@ -61,7 +61,8 @@ class Transpose(nn.Module):
61
  def forward_vit(pretrained, x):
62
  b, c, h, w = x.shape
63
 
64
- glob = pretrained.model.forward_flex(x)
 
65
 
66
  # layer_1 = pretrained.activations["1"]
67
  # layer_2 = pretrained.activations["2"]
@@ -127,6 +128,44 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
127
  return posemb
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def forward_flex(self, x):
131
  b, c, h, w = x.shape
132
 
@@ -174,6 +213,7 @@ def get_activation(name):
174
 
175
  return hook
176
 
 
177
  def hook_act(module, input, output):
178
  activations.append(output)
179
 
@@ -212,7 +252,7 @@ def _make_vit_b16_backbone(
212
  pretrained.model.blocks[hooks[1]].register_forward_hook(hook_act)
213
  pretrained.model.blocks[hooks[2]].register_forward_hook(hook_act)
214
  pretrained.model.blocks[hooks[3]].register_forward_hook(hook_act)
215
-
216
  # pretrained.model.blocks[hooks[0]].register_forward_hook(
217
  # get_activation("1"))
218
  # pretrained.model.blocks[hooks[1]].register_forward_hook(
@@ -386,20 +426,16 @@ def _make_vit_b_rn50_backbone(
386
  activations = []
387
 
388
  if use_vit_only == True:
389
- pretrained.model.blocks[hooks[0]].register_forward_hook(
390
- hook_act)
391
- pretrained.model.blocks[hooks[1]].register_forward_hook(
392
- hook_act)
393
  else:
394
  pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
395
  hook_act)
396
  pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
397
  hook_act)
398
 
399
- pretrained.model.blocks[hooks[2]].register_forward_hook(
400
- hook_act)
401
- pretrained.model.blocks[hooks[3]].register_forward_hook(
402
- hook_act)
403
 
404
  # if use_vit_only == True:
405
  # pretrained.model.blocks[hooks[0]].register_forward_hook(
 
61
  def forward_vit(pretrained, x):
62
  b, c, h, w = x.shape
63
 
64
+ # glob = pretrained.model.forward_flex(x)
65
+ glob = flat_forward_flex(pretrained.model, x)
66
 
67
  # layer_1 = pretrained.activations["1"]
68
  # layer_2 = pretrained.activations["2"]
 
128
  return posemb
129
 
130
 
131
+ def flat_forward_flex(model, x):
132
+ b, c, h, w = x.shape
133
+
134
+ pos_embed = model._resize_pos_embed(model.pos_embed,
135
+ h // model.patch_size[1],
136
+ w // model.patch_size[0])
137
+
138
+ B = x.shape[0]
139
+
140
+ if hasattr(model.patch_embed, "backbone"):
141
+ x = model.patch_embed.backbone(x)
142
+ if isinstance(x, (list, tuple)):
143
+ x = x[
144
+ -1] # last feature if backbone outputs list/tuple of features
145
+
146
+ x = model.patch_embed.proj(x).flatten(2).transpose(1, 2)
147
+
148
+ if getattr(model, "dist_token", None) is not None:
149
+ cls_tokens = model.cls_token.expand(
150
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
151
+ dist_token = model.dist_token.expand(B, -1, -1)
152
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
153
+ else:
154
+ cls_tokens = model.cls_token.expand(
155
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
156
+ x = torch.cat((cls_tokens, x), dim=1)
157
+
158
+ x = x + pos_embed
159
+ x = model.pos_drop(x)
160
+
161
+ for blk in model.blocks:
162
+ x = blk(x)
163
+
164
+ x = model.norm(x)
165
+
166
+ return x
167
+
168
+
169
  def forward_flex(self, x):
170
  b, c, h, w = x.shape
171
 
 
213
 
214
  return hook
215
 
216
+
217
  def hook_act(module, input, output):
218
  activations.append(output)
219
 
 
252
  pretrained.model.blocks[hooks[1]].register_forward_hook(hook_act)
253
  pretrained.model.blocks[hooks[2]].register_forward_hook(hook_act)
254
  pretrained.model.blocks[hooks[3]].register_forward_hook(hook_act)
255
+
256
  # pretrained.model.blocks[hooks[0]].register_forward_hook(
257
  # get_activation("1"))
258
  # pretrained.model.blocks[hooks[1]].register_forward_hook(
 
426
  activations = []
427
 
428
  if use_vit_only == True:
429
+ pretrained.model.blocks[hooks[0]].register_forward_hook(hook_act)
430
+ pretrained.model.blocks[hooks[1]].register_forward_hook(hook_act)
 
 
431
  else:
432
  pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
433
  hook_act)
434
  pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
435
  hook_act)
436
 
437
+ pretrained.model.blocks[hooks[2]].register_forward_hook(hook_act)
438
+ pretrained.model.blocks[hooks[3]].register_forward_hook(hook_act)
 
 
439
 
440
  # if use_vit_only == True:
441
  # pretrained.model.blocks[hooks[0]].register_forward_hook(