ImageProcessing / image_summary.py
ipvikas's picture
Update image_summary.py
f812ffa verified
import os
import gradio as gr
from PIL import Image
import requests
# from transformers import ViTFeatureExtractor
# feature_extractor = ViTFeatureExtractor()
from transformers import ViTImageProcessor
feature_extractor = ViTImageProcessor()
# or, to load one that corresponds to a checkpoint on the hub:
# feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
from transformers import VisionEncoderDecoderModel
# initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
"google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
# saving model after fine-tuning
model.save_pretrained("./vit-bert")
# load fine-tuned model
model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")
#####################
from transformers import AutoTokenizer
repo_name = "ydshieh/vit-gpt2-coco-en"
# feature_extractor = ViTFeatureExtractor.from_pretrained(repo_name)
feature_extractor = ViTImageProcessor.from_pretrained(repo_name)
tokenizer = AutoTokenizer.from_pretrained(repo_name)
model = VisionEncoderDecoderModel.from_pretrained(repo_name)
def get_quote(image):
##############
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
# autoregressively generate text (using beam search or other decoding strategy)
generated_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
################
# decode into text
preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
#1: Text to Speech
title = "Sentence, listing all the items present in the image file"
description = "Summary of items in the Image"
examples=[["english.png"],["Parag_Letter_j.jpg"]]
image_summary_demo = gr.Interface(fn=get_quote,
inputs=gr.inputs.Image(type="pil"),
outputs=['text'],
title = title,
description = "Upload an image file and get text from it" ,
cache_examples=False,
examples=examples,
enable_queue=True)
# if __name__ == "__main__":
# demo.launch(debug=True, cache_examples=True)