File size: 1,966 Bytes
1d305df
 
 
16cc5c7
 
 
 
1d305df
 
16cc5c7
1d305df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31e1d2c
1d305df
31e1d2c
1d305df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import gradio as gr

from transformers import ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor()
# or, to load one that corresponds to a checkpoint on the hub:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")


from transformers import VisionEncoderDecoderModel
# initialize a vit-bert from a pretrained ViT and a pretrained BERT model. Note that the cross-attention layers will be randomly initialized
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "bert-base-uncased"
)
# saving model after fine-tuning
model.save_pretrained("./vit-bert")
# load fine-tuned model
model = VisionEncoderDecoderModel.from_pretrained("./vit-bert")

#####################
from transformers import AutoTokenizer
repo_name = "ydshieh/vit-gpt2-coco-en"
feature_extractor = ViTFeatureExtractor.from_pretrained(repo_name)
tokenizer = AutoTokenizer.from_pretrained(repo_name)
model = VisionEncoderDecoderModel.from_pretrained(repo_name)

def get_quote(image_1):
    #reader = easyocr.Reader(['en'])
    image = Image.open(image_1,mode = 'r')
    
    
    
    ##############
    pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
    # autoregressively generate text (using beam search or other decoding strategy)
    generated_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True)
    
    ################
    # decode into text
    preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    #print(preds)

    return preds


#1: Text to Speech
#import gradio as gr
title = "Image to text generation"
  
demo = gr.Interface(fn=get_quote, inputs = "image", outputs=['text'],title = title, description = "Import an image file and get text from it" ,cache_examples=False).launch(debug = True)
if __name__ == "__main__":
    
    demo.launch()