DHEIVER commited on
Commit
671c67e
·
1 Parent(s): 8d05804

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -8,23 +8,24 @@ from tensorflow.keras import backend as K
8
 
9
  # Define the custom FixedDropout layer
10
  class FixedDropout(tf.keras.layers.Layer):
11
- def __init__(self, rate, noise_shape=None, **kwargs):
12
  super(FixedDropout, self).__init__(**kwargs)
13
  self.rate = rate
14
  self.noise_shape = noise_shape # Include the noise_shape argument
 
15
 
16
  def call(self, inputs, training=None):
17
  if training is None:
18
  training = K.learning_phase()
19
- return K.in_train_phase(K.dropout(inputs, self.rate, noise_shape=self.noise_shape), inputs, training=training)
20
 
21
  def get_config(self):
22
  config = super(FixedDropout, self).get_config()
23
  config['rate'] = self.rate # Serialize the rate argument
24
  config['noise_shape'] = self.noise_shape # Serialize the noise_shape argument
 
25
  return config
26
 
27
-
28
  class ImageClassifierApp:
29
  def __init__(self, model_path):
30
  self.model_path = model_path
 
8
 
9
  # Define the custom FixedDropout layer
10
  class FixedDropout(tf.keras.layers.Layer):
11
+ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
12
  super(FixedDropout, self).__init__(**kwargs)
13
  self.rate = rate
14
  self.noise_shape = noise_shape # Include the noise_shape argument
15
+ self.seed = seed # Include the seed argument
16
 
17
  def call(self, inputs, training=None):
18
  if training is None:
19
  training = K.learning_phase()
20
+ return K.in_train_phase(K.dropout(inputs, self.rate, noise_shape=self.noise_shape, seed=self.seed), inputs, training=training)
21
 
22
  def get_config(self):
23
  config = super(FixedDropout, self).get_config()
24
  config['rate'] = self.rate # Serialize the rate argument
25
  config['noise_shape'] = self.noise_shape # Serialize the noise_shape argument
26
+ config['seed'] = self.seed # Serialize the seed argument
27
  return config
28
 
 
29
  class ImageClassifierApp:
30
  def __init__(self, model_path):
31
  self.model_path = model_path