|
--- |
|
language: |
|
- he |
|
base_model: |
|
- naver-clova-ix/donut-base |
|
--- |
|
|
|
|
|
```python |
|
from transformers import VisionEncoderDecoderConfig |
|
from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
import torch |
|
import re |
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
|
|
|
|
url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCeH216oW6FXeTpN4ijvakW8_frP3vnCBIKQ&s" |
|
|
|
response = requests.get(url) |
|
img = Image.open(BytesIO(response.content)) |
|
img.show() |
|
|
|
config = VisionEncoderDecoderConfig.from_pretrained('jjjlangem/He-Donut') |
|
processor = DonutProcessor.from_pretrained('jjjlangem/He-Donut') |
|
model = VisionEncoderDecoderModel.from_pretrained('jjjlangem/He-Donut') |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
pixel_values = processor(img, random_padding=False, return_tensors="pt").pixel_values |
|
batch_size = pixel_values.shape[0] |
|
decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id, |
|
device=device) |
|
|
|
|
|
outputs = model.generate(pixel_values.to(device), |
|
decoder_input_ids=decoder_input_ids, |
|
max_length= 768, |
|
early_stopping=True, |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
eos_token_id=processor.tokenizer.eos_token_id, |
|
use_cache=True, |
|
num_beams=1, |
|
bad_words_ids=[[processor.tokenizer.unk_token_id]], |
|
return_dict_in_generate=True) |
|
|
|
predictions = [] |
|
for seq in processor.tokenizer.batch_decode(outputs.sequences): |
|
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").replace(processor.tokenizer.bos_token, "") |
|
seq = re.sub(r"<.*?>", "", seq, count=1).strip() |
|
predictions.append(seq) |
|
|
|
|
|
print(predictions) |
|
``` |