nishantguvvada commited on
Commit
4c488c4
·
1 Parent(s): 96a1315

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -26
app.py CHANGED
@@ -1,17 +1,53 @@
1
  import streamlit as st
 
2
  import tensorflow as tf
 
3
  import numpy as np
4
- from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
5
- import torch
6
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
9
- feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
10
- tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
 
 
 
 
 
11
 
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model.to(device)
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
16
  image = Image.open('./title.jpg')
17
  st.image(image)
@@ -20,34 +56,75 @@ st.write("""
20
  """
21
  )
22
 
23
- file = st.file_uploader("Upload an image to generate captions!", type= ['png', 'jpg'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- max_length = 16
26
- num_beams = 4
27
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
28
- def predict_step(image_paths):
29
- images = []
30
- for image_path in image_paths:
31
- i_image = Image.open(image_path)
32
- if i_image.mode != "RGB":
33
- i_image = i_image.convert(mode="RGB")
34
 
35
- images.append(i_image)
36
 
37
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
38
- pixel_values = pixel_values.to(device)
39
 
40
- output_ids = model.generate(pixel_values, **gen_kwargs)
41
 
42
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
43
- preds = [pred.strip() for pred in preds]
44
- return preds
45
 
46
 
47
  def on_click():
48
  if file is None:
49
  st.text("Please upload an image file")
50
  else:
51
- predict_step(file)
 
 
 
 
 
52
 
53
  st.button('Generate', on_click=on_click)
 
1
  import streamlit as st
2
+ import pickle
3
  import tensorflow as tf
4
+ import cv2
5
  import numpy as np
6
+ from PIL import Image, ImageOps
7
+ import imageio.v3 as iio
8
+ import time
9
+ from textwrap import wrap
10
+ import matplotlib.pylab as plt
11
+ from tensorflow.keras import Input
12
+ from tensorflow.keras.layers import (
13
+ GRU,
14
+ Add,
15
+ AdditiveAttention,
16
+ Attention,
17
+ Concatenate,
18
+ Dense,
19
+ Embedding,
20
+ LayerNormalization,
21
+ Reshape,
22
+ StringLookup,
23
+ TextVectorization,
24
+ )
25
 
26
+ MAX_CAPTION_LEN = 64
27
+ MINIMUM_SENTENCE_LENGTH = 5
28
+ IMG_HEIGHT = 299
29
+ IMG_WIDTH = 299
30
+ IMG_CHANNELS = 3
31
+ ATTENTION_DIM = 512 # size of dense layer in Attention
32
+ VOCAB_SIZE = 20000
33
+ FEATURES_SHAPE = (8, 8, 1536)
34
 
35
+ @st.cache_resource()
36
+ def load_image_model():
37
+ image_model=tf.keras.models.load_model('./image_caption_model.h5')
38
+ return image_model
39
 
40
+ @st.cache_resource()
41
+ def load_decoder_model():
42
+ decoder_model=tf.keras.models.load_model('./decoder_pred_model.h5')
43
+ return decoder_model
44
+
45
+ @st.cache_resource()
46
+ def load_encoder_model():
47
+ encoder=tf.keras.models.load_model('./encoder_model.h5')
48
+ return encoder
49
+
50
+
51
  st.title(":blue[Nishant Guvvada's] :red[AI Journey] Image Caption Generation")
52
  image = Image.open('./title.jpg')
53
  st.image(image)
 
56
  """
57
  )
58
 
59
+ file = st.file_uploader("Upload any image and the model will try to provide a caption to it!", type= ['png', 'jpg'])
60
+
61
+
62
+
63
+ # We will override the default standardization of TextVectorization to preserve
64
+ # "<>" characters, so we preserve the tokens for the <start> and <end>.
65
+ def standardize(inputs):
66
+ inputs = tf.strings.lower(inputs)
67
+ return tf.strings.regex_replace(
68
+ inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
69
+ )
70
+
71
+ # Choose the most frequent words from the vocabulary & remove punctuation etc.
72
+ vocab = open('./tokenizer_vocab.txt', 'rb')
73
+ tokenizer = pickle.load(vocab)
74
+
75
+
76
+ # Lookup table: Word -> Index
77
+ word_to_index = StringLookup(
78
+ mask_token="", vocabulary=tokenizer
79
+ )
80
+
81
+
82
+ ## Probabilistic prediction using the trained model
83
+ def predict_caption(file):
84
+ filename = Image.open(file)
85
+ image = filename.convert('RGB')
86
+ image = np.array(image)
87
+ gru_state = tf.zeros((1, ATTENTION_DIM))
88
+
89
+ resize = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH))
90
+ img = resize/255
91
+
92
+ encoder = load_encoder_model()
93
+ features = encoder(tf.expand_dims(img, axis=0))
94
+ dec_input = tf.expand_dims([word_to_index("<start>")], 1)
95
+ result = []
96
+ decoder_pred_model = load_decoder_model()
97
+ for i in range(MAX_CAPTION_LEN):
98
+ predictions, gru_state = decoder_pred_model(
99
+ [dec_input, gru_state, features]
100
+ )
101
 
102
+ # draws from log distribution given by predictions
103
+ top_probs, top_idxs = tf.math.top_k(
104
+ input=predictions[0][0], k=10, sorted=False
105
+ )
106
+ chosen_id = tf.random.categorical([top_probs], 1)[0].numpy()
107
+ predicted_id = top_idxs.numpy()[chosen_id][0]
 
 
 
108
 
109
+ result.append(tokenizer[predicted_id])
110
 
111
+ if predicted_id == word_to_index("<end>"):
112
+ return img, result
113
 
114
+ dec_input = tf.expand_dims([predicted_id], 1)
115
 
116
+ return img, result
 
 
117
 
118
 
119
  def on_click():
120
  if file is None:
121
  st.text("Please upload an image file")
122
  else:
123
+ image = Image.open(file)
124
+ st.image(image, use_column_width=True)
125
+ for i in range(5):
126
+ image, caption = predict_caption(file)
127
+ #print(" ".join(caption[:-1]) + ".")
128
+ st.write(" ".join(caption[:-1]) + ".")
129
 
130
  st.button('Generate', on_click=on_click)