NoteDance commited on
Commit
2868def
·
verified ·
1 Parent(s): 668e031

Update Segformer.py

Browse files
Files changed (1) hide show
  1. Segformer.py +13 -3
Segformer.py CHANGED
@@ -25,11 +25,21 @@ class DsConv2d:
25
  def __call__(self, x):
26
  return self.net(x)
27
 
28
- class LayerNorm:
29
  def __init__(self, dim, eps = 1e-5):
30
  self.eps = eps
31
- self.g = tf.Variable(tf.ones((1, dim, 1, 1)))
32
- self.b = tf.Variable(tf.zeros((1, dim, 1, 1)))
 
 
 
 
 
 
 
 
 
 
33
 
34
  def __call__(self, x):
35
  std = tf.math.sqrt(tf.math.reduce_variance(x, axis=1, keepdims=True))
 
25
  def __call__(self, x):
26
  return self.net(x)
27
 
28
+ class LayerNorm(tf.keras.layers.Layer):
29
  def __init__(self, dim, eps = 1e-5):
30
  self.eps = eps
31
+ self.g = self.add_weight(
32
+ name='g',
33
+ shape=(1, dim, 1, 1),
34
+ initializer=tf.keras.initializers.Ones(),
35
+ trainable=True
36
+ )
37
+ self.b = self.add_weight(
38
+ name='b',
39
+ shape=(1, dim, 1, 1),
40
+ initializer=tf.keras.initializers.Zeros(),
41
+ trainable=True
42
+ )
43
 
44
  def __call__(self, x):
45
  std = tf.math.sqrt(tf.math.reduce_variance(x, axis=1, keepdims=True))