Vageesh1 commited on
Commit
389e486
1 Parent(s): 8af2f9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -4,10 +4,32 @@ import PIL.Image
4
  import skimage.io as io
5
  import streamlit as st
6
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
 
7
  from model import generate2,ClipCaptionModel
8
  from engine import inference
9
 
10
- #model loading code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  device = "cpu"
13
  clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
@@ -35,7 +57,7 @@ def ui():
35
  pil_image = PIL.Image.fromarray(image)
36
  image = preprocess(pil_image).unsqueeze(0).to(device)
37
 
38
- option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model'))
39
 
40
  if option=='Model':
41
  with torch.no_grad():
@@ -60,6 +82,12 @@ def ui():
60
  st.image(uploaded_file, width = 500, channels = 'RGB')
61
  st.markdown("**PREDICTION:** " + out)
62
 
 
 
 
 
 
 
63
 
64
  if __name__ == '__main__':
65
  ui()
 
4
  import skimage.io as io
5
  import streamlit as st
6
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
7
+ from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
8
  from model import generate2,ClipCaptionModel
9
  from engine import inference
10
 
11
+
12
+ model_trained = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
13
+ model_trained.load_state_dict(torch.load('model_trained.pth',map_location=torch.device('cpu')))
14
+ image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
15
+ tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
16
+
17
+ def show_n_generate(img, greedy = True, model = model_raw):
18
+ image = Image.open(img)
19
+ pixel_values = image_processor(image, return_tensors ="pt").pixel_values
20
+ plt.imshow(np.asarray(image))
21
+ plt.show()
22
+
23
+ if greedy:
24
+ generated_ids = model.generate(pixel_values, max_new_tokens = 30)
25
+ else:
26
+ generated_ids = model.generate(
27
+ pixel_values,
28
+ do_sample=True,
29
+ max_new_tokens = 30,
30
+ top_k=5)
31
+ generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
+ returned generated_text
33
 
34
  device = "cpu"
35
  clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
 
57
  pil_image = PIL.Image.fromarray(image)
58
  image = preprocess(pil_image).unsqueeze(0).to(device)
59
 
60
+ option = st.selectbox('Please select the Model',('Model', 'COCO Model','PreTrained Model','Fine Tuned Model'))
61
 
62
  if option=='Model':
63
  with torch.no_grad():
 
82
  st.image(uploaded_file, width = 500, channels = 'RGB')
83
  st.markdown("**PREDICTION:** " + out)
84
 
85
+ elif option=='Fine Tuned Model':
86
+ out=show_n_generate(uploaded_file, greedy = False, model = model_trained)
87
+ st.image(uploaded_file, width = 500, channels = 'RGB')
88
+ st.markdown("**PREDICTION:** " + out)
89
+
90
+
91
 
92
  if __name__ == '__main__':
93
  ui()