from transformers import VisionEncoderDecoderConfig
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch 
import re
import requests
from PIL import Image
from io import BytesIO


url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCeH216oW6FXeTpN4ijvakW8_frP3vnCBIKQ&s"

response = requests.get(url)
img = Image.open(BytesIO(response.content))
img.show()

config = VisionEncoderDecoderConfig.from_pretrained('jjjlangem/He-Donut')
processor = DonutProcessor.from_pretrained('jjjlangem/He-Donut')
model = VisionEncoderDecoderModel.from_pretrained('jjjlangem/He-Donut')


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)



with torch.no_grad():
    
    pixel_values = processor(img, random_padding=False, return_tensors="pt").pixel_values
    batch_size = pixel_values.shape[0]
    decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id,
                               device=device)


    outputs = model.generate(pixel_values.to(device),
                                     decoder_input_ids=decoder_input_ids,
                                     max_length= 768,
                                     early_stopping=True,
                                     pad_token_id=processor.tokenizer.pad_token_id,
                                     eos_token_id=processor.tokenizer.eos_token_id,
                                     use_cache=True,
                                     num_beams=1,
                                     bad_words_ids=[[processor.tokenizer.unk_token_id]],
                                     return_dict_in_generate=True)

    predictions = []
    for seq in processor.tokenizer.batch_decode(outputs.sequences):
        seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").replace(processor.tokenizer.bos_token, "")
        seq = re.sub(r"<.*?>", "", seq, count=1).strip()
        predictions.append(seq)


print(predictions)
Downloads last month
42
Safetensors
Model size
211M params
Tensor type
I64
·
F32
·
Inference API
Inference API (serverless) does not yet support transformers models for this pipeline type.

Model tree for jjjlangem/He-Donut

Finetuned
(367)
this model