nishantguvvada commited on
Commit
8a0a303
·
1 Parent(s): 8887ce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -47,8 +47,18 @@ def load_image_model():
47
  # encoder=tf.keras.models.load_model('./encoder_model.h5')
48
  # return encoder
49
 
 
 
 
 
 
 
50
  # **** ENCODER ****
51
  image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
 
 
 
 
52
  encoder_output = Dense(ATTENTION_DIM, activation="relu")(x)
53
  encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
54
  # **** ENCODER ****
 
47
  # encoder=tf.keras.models.load_model('./encoder_model.h5')
48
  # return encoder
49
 
50
+ # InceptionResNetV2 takes (299, 299, 3) image as inputs
51
+ # and return features in (8, 8, 1536) shape
52
+ FEATURE_EXTRACTOR = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(
53
+ include_top=False, weights="imagenet"
54
+ )
55
+
56
  # **** ENCODER ****
57
  image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
58
+ image_features = FEATURE_EXTRACTOR(image_input)
59
+ x = Reshape((FEATURES_SHAPE[0] * FEATURES_SHAPE[1], FEATURES_SHAPE[2]))(
60
+ image_features
61
+ )
62
  encoder_output = Dense(ATTENTION_DIM, activation="relu")(x)
63
  encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
64
  # **** ENCODER ****