glenn-jocher pre-commit-ci[bot] commited on
Commit
5774a15
·
unverified ·
1 Parent(s): a9a92ae

Add `DWConvTranspose2d()` module (#7881)

Browse files

* Add DWConvTranspose2d() module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add DWConvTranspose2d() module

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix

* Fix

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (3) hide show
  1. models/common.py +6 -0
  2. models/tf.py +36 -11
  3. models/yolo.py +1 -1
models/common.py CHANGED
@@ -56,6 +56,12 @@ class DWConv(Conv):
56
  super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
57
 
58
 
 
 
 
 
 
 
59
  class TransformerLayer(nn.Module):
60
  # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
61
  def __init__(self, c, num_heads):
 
56
  super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
57
 
58
 
59
+ class DWConvTranspose2d(nn.ConvTranspose2d):
60
+ # Depth-wise transpose convolution class
61
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
62
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
63
+
64
+
65
  class TransformerLayer(nn.Module):
66
  # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
67
  def __init__(self, c, num_heads):
models/tf.py CHANGED
@@ -27,7 +27,8 @@ import torch
27
  import torch.nn as nn
28
  from tensorflow import keras
29
 
30
- from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad
 
31
  from models.experimental import MixConv2d, attempt_load
32
  from models.yolo import Detect
33
  from utils.activations import SiLU
@@ -108,6 +109,29 @@ class TFDWConv(keras.layers.Layer):
108
  return self.act(self.bn(self.conv(inputs)))
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class TFFocus(keras.layers.Layer):
112
  # Focus wh information into c-space
113
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
@@ -152,15 +176,14 @@ class TFConv2d(keras.layers.Layer):
152
  def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
153
  super().__init__()
154
  assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
155
- self.conv = keras.layers.Conv2D(
156
- c2,
157
- k,
158
- s,
159
- 'VALID',
160
- use_bias=bias,
161
- kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
162
- bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
163
- )
164
 
165
  def call(self, inputs):
166
  return self.conv(inputs)
@@ -340,7 +363,9 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
340
  pass
341
 
342
  n = max(round(n * gd), 1) if n > 1 else n # depth gain
343
- if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]:
 
 
344
  c1, c2 = ch[f], args[0]
345
  c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
346
 
 
27
  import torch.nn as nn
28
  from tensorflow import keras
29
 
30
+ from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
31
+ DWConvTranspose2d, Focus, autopad)
32
  from models.experimental import MixConv2d, attempt_load
33
  from models.yolo import Detect
34
  from utils.activations import SiLU
 
109
  return self.act(self.bn(self.conv(inputs)))
110
 
111
 
112
+ class TFDWConvTranspose2d(keras.layers.Layer):
113
+ # Depthwise ConvTranspose2d
114
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
115
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
116
+ super().__init__()
117
+ assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
118
+ assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
119
+ weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
120
+ self.c1 = c1
121
+ self.conv = [
122
+ keras.layers.Conv2DTranspose(filters=1,
123
+ kernel_size=k,
124
+ strides=s,
125
+ padding='VALID',
126
+ output_padding=p2,
127
+ use_bias=True,
128
+ kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
129
+ bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
130
+
131
+ def call(self, inputs):
132
+ return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
133
+
134
+
135
  class TFFocus(keras.layers.Layer):
136
  # Focus wh information into c-space
137
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
 
176
  def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
177
  super().__init__()
178
  assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
179
+ self.conv = keras.layers.Conv2D(filters=c2,
180
+ kernel_size=k,
181
+ strides=s,
182
+ padding='VALID',
183
+ use_bias=bias,
184
+ kernel_initializer=keras.initializers.Constant(
185
+ w.weight.permute(2, 3, 1, 0).numpy()),
186
+ bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
 
187
 
188
  def call(self, inputs):
189
  return self.conv(inputs)
 
363
  pass
364
 
365
  n = max(round(n * gd), 1) if n > 1 else n # depth gain
366
+ if m in [
367
+ nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
368
+ BottleneckCSP, C3, C3x]:
369
  c1, c2 = ch[f], args[0]
370
  c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
371
 
models/yolo.py CHANGED
@@ -266,7 +266,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
266
 
267
  n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
268
  if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
269
- BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x):
270
  c1, c2 = ch[f], args[0]
271
  if c2 != no: # if not output
272
  c2 = make_divisible(c2 * gw, 8)
 
266
 
267
  n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
268
  if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
269
+ BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
270
  c1, c2 = ch[f], args[0]
271
  if c2 != no: # if not output
272
  c2 = make_divisible(c2 * gw, 8)