File size: 1,711 Bytes
81dcf93
 
516996b
 
81dcf93
 
 
 
 
 
 
 
 
516996b
 
81dcf93
 
 
 
 
 
 
516996b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
---
library_name: transformers
language:
- ko
---

# Model Card for Model ID

<!-- Provide a quick summary of what the model is/does. -->



## Model Details
line ๋‹จ์œ„๋กœ ์ˆ˜์‹์ด ํฌํ•จ๋œ ๊ธ€์ž๋ฅผ ์ธ์‹ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.  
ํ•œ๊ตญ์–ด + latex ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ finetuning ํ–ˆ์Šต๋‹ˆ๋‹ค.

## Uses

<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->

### Direct Use

```python
from PIL import Image
import glob
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import IPython.display as ipd

## ์ด๋ฏธ์ง€ ์ค€๋น„
img_path_list = sorted(glob.glob('images/mathematical_expression_2-*.png'))
img_list = [Image.open(img_path).convert("RGB") for img_path in img_path_list]

## ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ์ค€๋น„
model_path = 'models/math_ocr'
processor = TrOCRProcessor.from_pretrained(model_path)
model = VisionEncoderDecoderModel.from_pretrained(model_path)
model.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

processor.feature_extractor.size = model.config.encoder.image_size

gc = model.generation_config
gc.max_length = 128
gc.early_stopping = True
gc.no_repeat_ngram_size = 3
gc.length_penalty = 2.0
gc.num_beams = 4
gc.eos_token_id = processor.tokenizer.sep_token_id

## TrOCR ์ถ”๋ก 
pixel_values = processor(img_list, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values.to(model.device), pad_token_id=processor.tokenizer.eos_token_id)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)

for img,text in zip(img_list, generated_text):
    ipd.display(img)
    print(text)

```