wondervictor commited on
Commit
1b32236
·
1 Parent(s): fc47e93

add requirements

Browse files
Files changed (1) hide show
  1. condition/midas/midas/vit.py +116 -82
condition/midas/midas/vit.py CHANGED
@@ -7,15 +7,17 @@ import torch.nn.functional as F
7
 
8
 
9
  class Slice(nn.Module):
 
10
  def __init__(self, start_index=1):
11
  super(Slice, self).__init__()
12
  self.start_index = start_index
13
 
14
  def forward(self, x):
15
- return x[:, self.start_index :]
16
 
17
 
18
  class AddReadout(nn.Module):
 
19
  def __init__(self, start_index=1):
20
  super(AddReadout, self).__init__()
21
  self.start_index = start_index
@@ -25,24 +27,27 @@ class AddReadout(nn.Module):
25
  readout = (x[:, 0] + x[:, 1]) / 2
26
  else:
27
  readout = x[:, 0]
28
- return x[:, self.start_index :] + readout.unsqueeze(1)
29
 
30
 
31
  class ProjectReadout(nn.Module):
 
32
  def __init__(self, in_features, start_index=1):
33
  super(ProjectReadout, self).__init__()
34
  self.start_index = start_index
35
 
36
- self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
 
37
 
38
  def forward(self, x):
39
- readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
40
- features = torch.cat((x[:, self.start_index :], readout), -1)
41
 
42
  return self.project(features)
43
 
44
 
45
  class Transpose(nn.Module):
 
46
  def __init__(self, dim0, dim1):
47
  super(Transpose, self).__init__()
48
  self.dim0 = dim0
@@ -58,10 +63,14 @@ def forward_vit(pretrained, x):
58
 
59
  glob = pretrained.model.forward_flex(x)
60
 
61
- layer_1 = pretrained.activations["1"]
62
- layer_2 = pretrained.activations["2"]
63
- layer_3 = pretrained.activations["3"]
64
- layer_4 = pretrained.activations["4"]
 
 
 
 
65
 
66
  layer_1 = pretrained.act_postprocess1[0:2](layer_1)
67
  layer_2 = pretrained.act_postprocess2[0:2](layer_2)
@@ -71,14 +80,11 @@ def forward_vit(pretrained, x):
71
  unflatten = nn.Sequential(
72
  nn.Unflatten(
73
  2,
74
- torch.Size(
75
- [
76
- h // pretrained.model.patch_size[1],
77
- w // pretrained.model.patch_size[0],
78
- ]
79
- ),
80
- )
81
- )
82
 
83
  if layer_1.ndim == 3:
84
  layer_1 = unflatten(layer_1)
@@ -89,24 +95,31 @@ def forward_vit(pretrained, x):
89
  if layer_4.ndim == 3:
90
  layer_4 = unflatten(layer_4)
91
 
92
- layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
93
- layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
94
- layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
95
- layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
 
 
 
 
96
 
97
  return layer_1, layer_2, layer_3, layer_4
98
 
99
 
100
  def _resize_pos_embed(self, posemb, gs_h, gs_w):
101
  posemb_tok, posemb_grid = (
102
- posemb[:, : self.start_index],
103
- posemb[0, self.start_index :],
104
  )
105
 
106
  gs_old = int(math.sqrt(len(posemb_grid)))
107
 
108
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
109
- posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
 
 
 
110
  posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
111
 
112
  posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
@@ -117,29 +130,27 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w):
117
  def forward_flex(self, x):
118
  b, c, h, w = x.shape
119
 
120
- pos_embed = self._resize_pos_embed(
121
- self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
122
- )
123
 
124
  B = x.shape[0]
125
 
126
  if hasattr(self.patch_embed, "backbone"):
127
  x = self.patch_embed.backbone(x)
128
  if isinstance(x, (list, tuple)):
129
- x = x[-1] # last feature if backbone outputs list/tuple of features
 
130
 
131
  x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
132
 
133
  if getattr(self, "dist_token", None) is not None:
134
  cls_tokens = self.cls_token.expand(
135
- B, -1, -1
136
- ) # stole cls_tokens impl from Phil Wang, thanks
137
  dist_token = self.dist_token.expand(B, -1, -1)
138
  x = torch.cat((cls_tokens, dist_token, x), dim=1)
139
  else:
140
  cls_tokens = self.cls_token.expand(
141
- B, -1, -1
142
- ) # stole cls_tokens impl from Phil Wang, thanks
143
  x = torch.cat((cls_tokens, x), dim=1)
144
 
145
  x = x + pos_embed
@@ -157,11 +168,15 @@ activations = {}
157
 
158
 
159
  def get_activation(name):
 
160
  def hook(model, input, output):
161
  activations[name] = output
162
 
163
  return hook
164
 
 
 
 
165
 
166
  def get_readout_oper(vit_features, features, use_readout, start_index=1):
167
  if use_readout == "ignore":
@@ -191,15 +206,26 @@ def _make_vit_b16_backbone(
191
  ):
192
  pretrained = nn.Module()
193
 
 
194
  pretrained.model = model
195
- # pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
196
- # pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
197
- # pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
198
- # pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
199
-
200
- # pretrained.activations = activations
201
-
202
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
 
 
 
 
 
 
 
 
 
 
203
 
204
  # 32, 48, 136, 384
205
  pretrained.act_postprocess1 = nn.Sequential(
@@ -286,10 +312,10 @@ def _make_vit_b16_backbone(
286
 
287
  # We inject this function into the VisionTransformer instances so that
288
  # we can use it with interpolated position embeddings without modifying the library source.
289
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
 
290
  pretrained.model._resize_pos_embed = types.MethodType(
291
- _resize_pos_embed, pretrained.model
292
- )
293
 
294
  return pretrained
295
 
@@ -311,24 +337,28 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
311
  model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
312
 
313
  hooks = [2, 5, 8, 11] if hooks == None else hooks
314
- return _make_vit_b16_backbone(
315
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
316
- )
 
317
 
318
 
319
  def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
320
- model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
 
321
 
322
  hooks = [2, 5, 8, 11] if hooks == None else hooks
323
- return _make_vit_b16_backbone(
324
- model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
325
- )
 
326
 
327
 
328
- def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
329
- model = timm.create_model(
330
- "vit_deit_base_distilled_patch16_384", pretrained=pretrained
331
- )
 
332
 
333
  hooks = [2, 5, 8, 11] if hooks == None else hooks
334
  return _make_vit_b16_backbone(
@@ -354,23 +384,26 @@ def _make_vit_b_rn50_backbone(
354
 
355
  pretrained.model = model
356
 
357
- # if use_vit_only == True:
358
- # pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
359
- # pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
360
- # else:
361
- # pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
362
- # get_activation("1")
363
- # )
364
- # pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
365
- # get_activation("2")
366
- # )
367
 
368
- # pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
369
- # pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
 
 
370
 
371
- # pretrained.activations = activations
372
 
373
- readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
 
374
 
375
  if use_vit_only == True:
376
  pretrained.act_postprocess1 = nn.Sequential(
@@ -419,12 +452,12 @@ def _make_vit_b_rn50_backbone(
419
  ),
420
  )
421
  else:
422
- pretrained.act_postprocess1 = nn.Sequential(
423
- nn.Identity(), nn.Identity(), nn.Identity()
424
- )
425
- pretrained.act_postprocess2 = nn.Sequential(
426
- nn.Identity(), nn.Identity(), nn.Identity()
427
- )
428
 
429
  pretrained.act_postprocess3 = nn.Sequential(
430
  readout_oper[2],
@@ -464,20 +497,21 @@ def _make_vit_b_rn50_backbone(
464
 
465
  # We inject this function into the VisionTransformer instances so that
466
  # we can use it with interpolated position embeddings without modifying the library source.
467
- pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
 
468
 
469
  # We inject this function into the VisionTransformer instances so that
470
  # we can use it with interpolated position embeddings without modifying the library source.
471
  pretrained.model._resize_pos_embed = types.MethodType(
472
- _resize_pos_embed, pretrained.model
473
- )
474
 
475
  return pretrained
476
 
477
 
478
- def _make_pretrained_vitb_rn50_384(
479
- pretrained, use_readout="ignore", hooks=None, use_vit_only=False
480
- ):
 
481
  model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
482
 
483
  hooks = [0, 1, 8, 11] if hooks == None else hooks
@@ -488,4 +522,4 @@ def _make_pretrained_vitb_rn50_384(
488
  hooks=hooks,
489
  use_vit_only=use_vit_only,
490
  use_readout=use_readout,
491
- )
 
7
 
8
 
9
  class Slice(nn.Module):
10
+
11
  def __init__(self, start_index=1):
12
  super(Slice, self).__init__()
13
  self.start_index = start_index
14
 
15
  def forward(self, x):
16
+ return x[:, self.start_index:]
17
 
18
 
19
  class AddReadout(nn.Module):
20
+
21
  def __init__(self, start_index=1):
22
  super(AddReadout, self).__init__()
23
  self.start_index = start_index
 
27
  readout = (x[:, 0] + x[:, 1]) / 2
28
  else:
29
  readout = x[:, 0]
30
+ return x[:, self.start_index:] + readout.unsqueeze(1)
31
 
32
 
33
  class ProjectReadout(nn.Module):
34
+
35
  def __init__(self, in_features, start_index=1):
36
  super(ProjectReadout, self).__init__()
37
  self.start_index = start_index
38
 
39
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
40
+ nn.GELU())
41
 
42
  def forward(self, x):
43
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
44
+ features = torch.cat((x[:, self.start_index:], readout), -1)
45
 
46
  return self.project(features)
47
 
48
 
49
  class Transpose(nn.Module):
50
+
51
  def __init__(self, dim0, dim1):
52
  super(Transpose, self).__init__()
53
  self.dim0 = dim0
 
63
 
64
  glob = pretrained.model.forward_flex(x)
65
 
66
+ # layer_1 = pretrained.activations["1"]
67
+ # layer_2 = pretrained.activations["2"]
68
+ # layer_3 = pretrained.activations["3"]
69
+ # layer_4 = pretrained.activations["4"]
70
+ layer_1 = pretrained.activations[0]
71
+ layer_2 = pretrained.activations[1]
72
+ layer_3 = pretrained.activations[2]
73
+ layer_4 = pretrained.activations[3]
74
 
75
  layer_1 = pretrained.act_postprocess1[0:2](layer_1)
76
  layer_2 = pretrained.act_postprocess2[0:2](layer_2)
 
80
  unflatten = nn.Sequential(
81
  nn.Unflatten(
82
  2,
83
+ torch.Size([
84
+ h // pretrained.model.patch_size[1],
85
+ w // pretrained.model.patch_size[0],
86
+ ]),
87
+ ))
 
 
 
88
 
89
  if layer_1.ndim == 3:
90
  layer_1 = unflatten(layer_1)
 
95
  if layer_4.ndim == 3:
96
  layer_4 = unflatten(layer_4)
97
 
98
+ layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
99
+ layer_1)
100
+ layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
101
+ layer_2)
102
+ layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
103
+ layer_3)
104
+ layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
105
+ layer_4)
106
 
107
  return layer_1, layer_2, layer_3, layer_4
108
 
109
 
110
  def _resize_pos_embed(self, posemb, gs_h, gs_w):
111
  posemb_tok, posemb_grid = (
112
+ posemb[:, :self.start_index],
113
+ posemb[0, self.start_index:],
114
  )
115
 
116
  gs_old = int(math.sqrt(len(posemb_grid)))
117
 
118
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
119
+ -1).permute(0, 3, 1, 2)
120
+ posemb_grid = F.interpolate(posemb_grid,
121
+ size=(gs_h, gs_w),
122
+ mode="bilinear")
123
  posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
124
 
125
  posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
 
130
  def forward_flex(self, x):
131
  b, c, h, w = x.shape
132
 
133
+ pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
134
+ w // self.patch_size[0])
 
135
 
136
  B = x.shape[0]
137
 
138
  if hasattr(self.patch_embed, "backbone"):
139
  x = self.patch_embed.backbone(x)
140
  if isinstance(x, (list, tuple)):
141
+ x = x[
142
+ -1] # last feature if backbone outputs list/tuple of features
143
 
144
  x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
145
 
146
  if getattr(self, "dist_token", None) is not None:
147
  cls_tokens = self.cls_token.expand(
148
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
 
149
  dist_token = self.dist_token.expand(B, -1, -1)
150
  x = torch.cat((cls_tokens, dist_token, x), dim=1)
151
  else:
152
  cls_tokens = self.cls_token.expand(
153
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
 
154
  x = torch.cat((cls_tokens, x), dim=1)
155
 
156
  x = x + pos_embed
 
168
 
169
 
170
  def get_activation(name):
171
+
172
  def hook(model, input, output):
173
  activations[name] = output
174
 
175
  return hook
176
 
177
+ def hook_act(module, input, output):
178
+ activations.append(output)
179
+
180
 
181
  def get_readout_oper(vit_features, features, use_readout, start_index=1):
182
  if use_readout == "ignore":
 
206
  ):
207
  pretrained = nn.Module()
208
 
209
+ activations = []
210
  pretrained.model = model
211
+ pretrained.model.blocks[hooks[0]].register_forward_hook(hook_act)
212
+ pretrained.model.blocks[hooks[1]].register_forward_hook(hook_act)
213
+ pretrained.model.blocks[hooks[2]].register_forward_hook(hook_act)
214
+ pretrained.model.blocks[hooks[3]].register_forward_hook(hook_act)
215
+
216
+ # pretrained.model.blocks[hooks[0]].register_forward_hook(
217
+ # get_activation("1"))
218
+ # pretrained.model.blocks[hooks[1]].register_forward_hook(
219
+ # get_activation("2"))
220
+ # pretrained.model.blocks[hooks[2]].register_forward_hook(
221
+ # get_activation("3"))
222
+ # pretrained.model.blocks[hooks[3]].register_forward_hook(
223
+ # get_activation("4"))
224
+
225
+ pretrained.activations = activations
226
+
227
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
228
+ start_index)
229
 
230
  # 32, 48, 136, 384
231
  pretrained.act_postprocess1 = nn.Sequential(
 
312
 
313
  # We inject this function into the VisionTransformer instances so that
314
  # we can use it with interpolated position embeddings without modifying the library source.
315
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
316
+ pretrained.model)
317
  pretrained.model._resize_pos_embed = types.MethodType(
318
+ _resize_pos_embed, pretrained.model)
 
319
 
320
  return pretrained
321
 
 
337
  model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
338
 
339
  hooks = [2, 5, 8, 11] if hooks == None else hooks
340
+ return _make_vit_b16_backbone(model,
341
+ features=[96, 192, 384, 768],
342
+ hooks=hooks,
343
+ use_readout=use_readout)
344
 
345
 
346
  def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
347
+ model = timm.create_model("vit_deit_base_patch16_384",
348
+ pretrained=pretrained)
349
 
350
  hooks = [2, 5, 8, 11] if hooks == None else hooks
351
+ return _make_vit_b16_backbone(model,
352
+ features=[96, 192, 384, 768],
353
+ hooks=hooks,
354
+ use_readout=use_readout)
355
 
356
 
357
+ def _make_pretrained_deitb16_distil_384(pretrained,
358
+ use_readout="ignore",
359
+ hooks=None):
360
+ model = timm.create_model("vit_deit_base_distilled_patch16_384",
361
+ pretrained=pretrained)
362
 
363
  hooks = [2, 5, 8, 11] if hooks == None else hooks
364
  return _make_vit_b16_backbone(
 
384
 
385
  pretrained.model = model
386
 
387
+ if use_vit_only == True:
388
+ pretrained.model.blocks[hooks[0]].register_forward_hook(
389
+ get_activation("1"))
390
+ pretrained.model.blocks[hooks[1]].register_forward_hook(
391
+ get_activation("2"))
392
+ else:
393
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
394
+ get_activation("1"))
395
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
396
+ get_activation("2"))
397
 
398
+ pretrained.model.blocks[hooks[2]].register_forward_hook(
399
+ get_activation("3"))
400
+ pretrained.model.blocks[hooks[3]].register_forward_hook(
401
+ get_activation("4"))
402
 
403
+ pretrained.activations = activations
404
 
405
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
406
+ start_index)
407
 
408
  if use_vit_only == True:
409
  pretrained.act_postprocess1 = nn.Sequential(
 
452
  ),
453
  )
454
  else:
455
+ pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
456
+ nn.Identity(),
457
+ nn.Identity())
458
+ pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
459
+ nn.Identity(),
460
+ nn.Identity())
461
 
462
  pretrained.act_postprocess3 = nn.Sequential(
463
  readout_oper[2],
 
497
 
498
  # We inject this function into the VisionTransformer instances so that
499
  # we can use it with interpolated position embeddings without modifying the library source.
500
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
501
+ pretrained.model)
502
 
503
  # We inject this function into the VisionTransformer instances so that
504
  # we can use it with interpolated position embeddings without modifying the library source.
505
  pretrained.model._resize_pos_embed = types.MethodType(
506
+ _resize_pos_embed, pretrained.model)
 
507
 
508
  return pretrained
509
 
510
 
511
+ def _make_pretrained_vitb_rn50_384(pretrained,
512
+ use_readout="ignore",
513
+ hooks=None,
514
+ use_vit_only=False):
515
  model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
516
 
517
  hooks = [0, 1, 8, 11] if hooks == None else hooks
 
522
  hooks=hooks,
523
  use_vit_only=use_vit_only,
524
  use_readout=use_readout,
525
+ )