File size: 1,273 Bytes
1cd43eb
1679fe8
1cd43eb
1679fe8
1cd43eb
 
 
 
 
1679fe8
1cd43eb
1679fe8
 
 
 
7e70156
1679fe8
1cd43eb
1679fe8
 
 
1cd43eb
 
1679fe8
1cd43eb
 
1679fe8
 
 
1cd43eb
 
1679fe8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, PreTrainedTokenizerFast
import gradio as gr

# Load the model and preprocessing tools
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
vit_feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = PreTrainedTokenizerFast.from_pretrained("distilgpt2")

def vit2distilgpt2(img):
    # Preprocess the image
    pixel_values = vit_feature_extractor(images=img, return_tensors="pt").pixel_values
    
    # Generate a single caption
    encoder_outputs = model.generate(pixel_values.to('cpu'), num_beams=5, num_return_sequences=1)
    generated_sentence = tokenizer.decode(encoder_outputs[0], skip_special_tokens=True)

    return generated_sentence

# Gradio interface setup
inputs = gr.inputs.Image(type="pil", label="Original Image")
outputs = gr.outputs.Textbox(label="Caption")

title = "Image Captioning using ViT + GPT2"
description = "ViT and GPT2 are used to generate an image caption for the uploaded image. COCO dataset is used for training."

gr.Interface(
    fn=vit2distilgpt2,
    inputs=inputs,
    outputs=outputs,
    title=title,
    description=description,
).launch(debug=True, enable_queue=True)