File size: 2,067 Bytes
74e6743 f8f1873 74e6743 f8f1873 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
---
language:
- he
base_model:
- naver-clova-ix/donut-base
---
```python
from transformers import VisionEncoderDecoderConfig
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re
import requests
from PIL import Image
from io import BytesIO
url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCeH216oW6FXeTpN4ijvakW8_frP3vnCBIKQ&s"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img.show()
config = VisionEncoderDecoderConfig.from_pretrained('jjjlangem/He-Donut')
processor = DonutProcessor.from_pretrained('jjjlangem/He-Donut')
model = VisionEncoderDecoderModel.from_pretrained('jjjlangem/He-Donut')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
with torch.no_grad():
pixel_values = processor(img, random_padding=False, return_tensors="pt").pixel_values
batch_size = pixel_values.shape[0]
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id,
device=device)
outputs = model.generate(pixel_values.to(device),
decoder_input_ids=decoder_input_ids,
max_length= 768,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True)
predictions = []
for seq in processor.tokenizer.batch_decode(outputs.sequences):
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").replace(processor.tokenizer.bos_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip()
predictions.append(seq)
print(predictions)
``` |