jjjlangem commited on
Commit
74e6743
·
verified ·
1 Parent(s): 5183bd8

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -0
README.md CHANGED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - he
4
+ base_model:
5
+ - naver-clova-ix/donut-base
6
+ ---
7
+ from transformers import VisionEncoderDecoderConfig
8
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
9
+ import torch
10
+ import re
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+
15
+
16
+ url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRCeH216oW6FXeTpN4ijvakW8_frP3vnCBIKQ&s"
17
+
18
+ response = requests.get(url)
19
+ img = Image.open(BytesIO(response.content))
20
+ img.show()
21
+
22
+ config = VisionEncoderDecoderConfig.from_pretrained('jjjlangem/He-Donut')
23
+ processor = DonutProcessor.from_pretrained('jjjlangem/He-Donut')
24
+ model = VisionEncoderDecoderModel.from_pretrained('jjjlangem/He-Donut')
25
+
26
+
27
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+ model.to(device)
29
+
30
+
31
+
32
+ with torch.no_grad():
33
+
34
+ pixel_values = processor(img, random_padding=False, return_tensors="pt").pixel_values
35
+ batch_size = pixel_values.shape[0]
36
+ # In PyTorch DDP setup, .model is wrapped by DistributedDataParallel,
37
+ # so the actual model is accessible through .model.module.
38
+ decoder_input_ids = torch.full((batch_size, 1), model.config.decoder_start_token_id,
39
+ device=device)
40
+
41
+
42
+ outputs = model.generate(pixel_values.to(device),
43
+ decoder_input_ids=decoder_input_ids,
44
+ max_length= 768,
45
+ early_stopping=True,
46
+ pad_token_id=processor.tokenizer.pad_token_id,
47
+ eos_token_id=processor.tokenizer.eos_token_id,
48
+ use_cache=True,
49
+ num_beams=1,
50
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
51
+ return_dict_in_generate=True)
52
+
53
+ predictions = []
54
+ for seq in processor.tokenizer.batch_decode(outputs.sequences):
55
+ seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "").replace(processor.tokenizer.bos_token, "")
56
+ seq = re.sub(r"<.*?>", "", seq, count=1).strip()
57
+ predictions.append(seq)
58
+
59
+
60
+ print(predictions)