import os import gradio as gr from PIL import Image import requests from transformers import ViTFeatureExtractor feature_extractor = ViTFeatureExtractor() # or, to load one that corresponds to a checkpoint on the hub: feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") from transformers import VisionEncoderDecoderModel # initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( "google/vit-base-patch16-224-in21k", "bert-base-uncased" ) # saving model after fine-tuning model.save_pretrained("./vit-bert") # load fine-tuned model model = VisionEncoderDecoderModel.from_pretrained("./vit-bert") ##################### from transformers import AutoTokenizer repo_name = "ydshieh/vit-gpt2-coco-en" feature_extractor = ViTFeatureExtractor.from_pretrained(repo_name) tokenizer = AutoTokenizer.from_pretrained(repo_name) model = VisionEncoderDecoderModel.from_pretrained(repo_name) def get_quote(image): #image = Image.open(image_1).raw #image = Image.open(image_1) #url = "http://images.cocodataset.org/val2017/000000039769.jpg" #with Image.open(requests.get(url, stream=True).raw) as image: #image.save("cats.png") ############## pixel_values = feature_extractor(image, return_tensors="pt").pixel_values # autoregressively generate text (using beam search or other decoding strategy) generated_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True) ################ # decode into text preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True) preds = [pred.strip() for pred in preds] #print(preds) return preds #1: Text to Speech #import gradio as gr title = "Image to text generation" demo = gr.Interface(fn=get_quote, gr.Image(type="pil"), "image", outputs=['text'],title = title, description = "Import an image file and get text from it" ,cache_examples=False, enable_queue=True).launch() #inputs = "image" #inputs=gr.inputs.Image(type="pil") if __name__ == "__main__": demo.launch(debug=True, cache_examples=True)