glenn-jocher commited on
Commit
1b8e70f
·
unverified ·
1 Parent(s): 91c82d8

Add TFDWConv() `depth_multiplier` (#7858)

Browse files

Enabled grouped non c1 == c2 convolutions in TF YOLOv5 models.

Files changed (1) hide show
  1. models/tf.py +2 -1
models/tf.py CHANGED
@@ -91,9 +91,10 @@ class TFDWConv(keras.layers.Layer):
91
  def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
92
  # ch_in, ch_out, weights, kernel, stride, padding, groups
93
  super().__init__()
94
- assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels'
95
  conv = keras.layers.DepthwiseConv2D(
96
  kernel_size=k,
 
97
  strides=s,
98
  padding='SAME' if s == 1 else 'VALID',
99
  use_bias=not hasattr(w, 'bn'),
 
91
  def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
92
  # ch_in, ch_out, weights, kernel, stride, padding, groups
93
  super().__init__()
94
+ assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
95
  conv = keras.layers.DepthwiseConv2D(
96
  kernel_size=k,
97
+ depth_multiplier=c2 // c1,
98
  strides=s,
99
  padding='SAME' if s == 1 else 'VALID',
100
  use_bias=not hasattr(w, 'bn'),