kevin-pek
commited on
Commit
•
4b9dac6
1
Parent(s):
82eb988
add stopping criteria for inference
Browse files- README.md +6 -2
- handler.py +66 -4
README.md
CHANGED
@@ -6,7 +6,11 @@ tags:
|
|
6 |
pipeline_tag: image-to-text
|
7 |
---
|
8 |
|
9 |
-
# Nougat
|
|
|
|
|
|
|
|
|
10 |
|
11 |
Nougat model trained on PDF-to-markdown. It was introduced in the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Blecher et al. and first released in [this repository](https://github.com/facebookresearch/nougat/tree/main).
|
12 |
|
@@ -45,4 +49,4 @@ We refer to the [docs](https://huggingface.co/docs/transformers/main/en/model_do
|
|
45 |
archivePrefix={arXiv},
|
46 |
primaryClass={cs.LG}
|
47 |
}
|
48 |
-
```
|
|
|
6 |
pipeline_tag: image-to-text
|
7 |
---
|
8 |
|
9 |
+
# Nougat Huggingface Api
|
10 |
+
|
11 |
+
This repo adds the necessary handlers to allow the nougat model to be used with the huggingface hosted inference api.
|
12 |
+
|
13 |
+
The inference code is adapted from the [example](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Nougat/Inference_with_Nougat_to_read_scientific_PDFs.ipynb) provided by huggingface.
|
14 |
|
15 |
Nougat model trained on PDF-to-markdown. It was introduced in the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Blecher et al. and first released in [this repository](https://github.com/facebookresearch/nougat/tree/main).
|
16 |
|
|
|
49 |
archivePrefix={arXiv},
|
50 |
primaryClass={cs.LG}
|
51 |
}
|
52 |
+
```
|
handler.py
CHANGED
@@ -1,10 +1,69 @@
|
|
1 |
from io import BytesIO
|
2 |
from typing import Dict, Any
|
3 |
-
from transformers import NougatProcessor, VisionEncoderDecoderModel
|
4 |
from transformers.image_utils import base64
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
class EndpointHandler():
|
9 |
def __init__(self, path="facebook/nougat-base") -> None:
|
10 |
self.processor = NougatProcessor.from_pretrained(path)
|
@@ -21,11 +80,14 @@ class EndpointHandler():
|
|
21 |
outputs = self.model.generate(
|
22 |
pixel_values.to(self.device),
|
23 |
min_length=1,
|
24 |
-
|
25 |
-
bad_words_ids=[[self.processor.tokenizer.unk_token_id]]
|
|
|
|
|
|
|
26 |
)
|
27 |
|
28 |
-
text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
29 |
text = self.processor.post_process_generation(text, fix_markdown=False)
|
30 |
|
31 |
return text
|
|
|
1 |
from io import BytesIO
|
2 |
from typing import Dict, Any
|
3 |
+
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
|
4 |
from transformers.image_utils import base64
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
|
8 |
+
|
9 |
+
class RunningVarTorch:
|
10 |
+
def __init__(self, L=15, norm=False):
|
11 |
+
self.values = None
|
12 |
+
self.L = L
|
13 |
+
self.norm = norm
|
14 |
+
|
15 |
+
def push(self, x: torch.Tensor):
|
16 |
+
assert x.dim() == 1
|
17 |
+
if self.values is None:
|
18 |
+
self.values = x[:, None]
|
19 |
+
elif self.values.shape[1] < self.L:
|
20 |
+
self.values = torch.cat((self.values, x[:, None]), 1)
|
21 |
+
else:
|
22 |
+
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
|
23 |
+
|
24 |
+
def variance(self):
|
25 |
+
if self.values is None:
|
26 |
+
return
|
27 |
+
if self.norm:
|
28 |
+
return torch.var(self.values, 1) / self.values.shape[1]
|
29 |
+
else:
|
30 |
+
return torch.var(self.values, 1)
|
31 |
+
|
32 |
+
|
33 |
+
class StoppingCriteriaScores(StoppingCriteria):
|
34 |
+
def __init__(self, threshold: float = 0.015, window_size: int = 200):
|
35 |
+
super().__init__()
|
36 |
+
self.threshold = threshold
|
37 |
+
self.vars = RunningVarTorch(norm=True)
|
38 |
+
self.varvars = RunningVarTorch(L=window_size)
|
39 |
+
self.stop_inds = defaultdict(int)
|
40 |
+
self.stopped = defaultdict(bool)
|
41 |
+
self.size = 0
|
42 |
+
self.window_size = window_size
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
46 |
+
last_scores = scores[-1]
|
47 |
+
self.vars.push(last_scores.max(1)[0].float().cpu())
|
48 |
+
self.varvars.push(self.vars.variance())
|
49 |
+
self.size += 1
|
50 |
+
if self.size < self.window_size:
|
51 |
+
return False
|
52 |
+
|
53 |
+
varvar = self.varvars.variance()
|
54 |
+
for b in range(len(last_scores)):
|
55 |
+
if varvar[b] < self.threshold:
|
56 |
+
if self.stop_inds[b] > 0 and not self.stopped[b]:
|
57 |
+
self.stopped[b] = self.stop_inds[b] >= self.size
|
58 |
+
else:
|
59 |
+
self.stop_inds[b] = int(
|
60 |
+
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
self.stop_inds[b] = 0
|
64 |
+
self.stopped[b] = False
|
65 |
+
return all(self.stopped.values()) and len(self.stopped) > 0
|
66 |
+
|
67 |
class EndpointHandler():
|
68 |
def __init__(self, path="facebook/nougat-base") -> None:
|
69 |
self.processor = NougatProcessor.from_pretrained(path)
|
|
|
80 |
outputs = self.model.generate(
|
81 |
pixel_values.to(self.device),
|
82 |
min_length=1,
|
83 |
+
max_length=3584,
|
84 |
+
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
|
85 |
+
return_dict_in_generate=True,
|
86 |
+
output_scores=True,
|
87 |
+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()])
|
88 |
)
|
89 |
|
90 |
+
text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
|
91 |
text = self.processor.post_process_generation(text, fix_markdown=False)
|
92 |
|
93 |
return text
|