krishnv commited on
Commit
c36694a
·
verified ·
1 Parent(s): d204657

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -1,9 +1,8 @@
1
- #From
2
  import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
- device='cpu'
7
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
  decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
  model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
@@ -14,7 +13,7 @@ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
14
  def predict(image, max_length=64, num_beams=4):
15
  image = image.convert('RGB')
16
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
17
- clean_text = lambda x: x.replace('','').split('\n')[0]
18
  caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
19
  caption_text = clean_text(tokenizer.decode(caption_ids, skip_special_tokens=True))
20
  return caption_text
@@ -41,4 +40,4 @@ interface = gr.Interface(
41
  )
42
 
43
  # Launch the interface
44
- interface.launch(share=True)
 
 
1
  import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
4
 
5
+ device = 'cpu'
6
  encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
7
  decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
  model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
 
13
  def predict(image, max_length=64, num_beams=4):
14
  image = image.convert('RGB')
15
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
16
+ clean_text = lambda x: x.replace('', '').split('\n')[0]
17
  caption_ids = model.generate(image, max_length=max_length, num_beams=num_beams)[0]
18
  caption_text = clean_text(tokenizer.decode(caption_ids, skip_special_tokens=True))
19
  return caption_text
 
40
  )
41
 
42
  # Launch the interface
43
+ interface.launch(share=True)