Sravanth's picture
Update app.py
2c99d88
raw
history blame
2.68 kB
import torch
import re
import gradio as gr
from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
from transformers import AutoProcessor, AutoTokenizer, BlipForConditionalGeneration
from huggingface_hub import hf_hub_download
device='cpu'
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
def predict(image,max_length=64, num_beams=4):
image = image.convert('RGB')
image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
caption_ids = model.generate(image, max_length = max_length)[0]
caption_text = clean_text(tokenizer.decode(caption_ids))
#caption_text2 = generate_captions(image)
return caption_text
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model_large.to(device)
def generate_caption(processor, model, image, tokenizer=None, use_float_16=False):
inputs = processor(images=image, return_tensors="pt").to(device)
if use_float_16:
inputs = inputs.to(torch.float16)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
if tokenizer is not None:
generated_caption = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
def generate_captions(image):
caption_blip_large = generate_caption(blip_processor_large, blip_model_large, image)
return caption_blip_large
input = gr.inputs.Image(label="Upload your Image", type = 'pil', optional=True)
#Two output boxes
output_1 = gr.outputs.Textbox(type="text",label="Caption - 1")
examples = [f"example{i}.png" for i in range(1,4)]
description= "Image caption Generator"
title = "Deep Learning and AI Intern Assignment for Listed Inc"
article = "Created By : Sravanth Kurmala"
interface = gr.Interface(
fn=predict,
inputs = input,
theme="grass",
outputs = output_1,
examples = examples,
title=title,
description=description,
article = article,
)
interface.launch(debug=True)