kevin-pek commited on
Commit
4b9dac6
1 Parent(s): 82eb988

add stopping criteria for inference

Browse files
Files changed (2) hide show
  1. README.md +6 -2
  2. handler.py +66 -4
README.md CHANGED
@@ -6,7 +6,11 @@ tags:
6
  pipeline_tag: image-to-text
7
  ---
8
 
9
- # Nougat model, base-sized version
 
 
 
 
10
 
11
  Nougat model trained on PDF-to-markdown. It was introduced in the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Blecher et al. and first released in [this repository](https://github.com/facebookresearch/nougat/tree/main).
12
 
@@ -45,4 +49,4 @@ We refer to the [docs](https://huggingface.co/docs/transformers/main/en/model_do
45
  archivePrefix={arXiv},
46
  primaryClass={cs.LG}
47
  }
48
- ```
 
6
  pipeline_tag: image-to-text
7
  ---
8
 
9
+ # Nougat Huggingface Api
10
+
11
+ This repo adds the necessary handlers to allow the nougat model to be used with the huggingface hosted inference api.
12
+
13
+ The inference code is adapted from the [example](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Nougat/Inference_with_Nougat_to_read_scientific_PDFs.ipynb) provided by huggingface.
14
 
15
  Nougat model trained on PDF-to-markdown. It was introduced in the paper [Nougat: Neural Optical Understanding for Academic Documents](https://arxiv.org/abs/2308.13418) by Blecher et al. and first released in [this repository](https://github.com/facebookresearch/nougat/tree/main).
16
 
 
49
  archivePrefix={arXiv},
50
  primaryClass={cs.LG}
51
  }
52
+ ```
handler.py CHANGED
@@ -1,10 +1,69 @@
1
  from io import BytesIO
2
  from typing import Dict, Any
3
- from transformers import NougatProcessor, VisionEncoderDecoderModel
4
  from transformers.image_utils import base64
5
  from PIL import Image
6
  import torch
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  class EndpointHandler():
9
  def __init__(self, path="facebook/nougat-base") -> None:
10
  self.processor = NougatProcessor.from_pretrained(path)
@@ -21,11 +80,14 @@ class EndpointHandler():
21
  outputs = self.model.generate(
22
  pixel_values.to(self.device),
23
  min_length=1,
24
- max_new_tokens=30,
25
- bad_words_ids=[[self.processor.tokenizer.unk_token_id]]
 
 
 
26
  )
27
 
28
- text = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
29
  text = self.processor.post_process_generation(text, fix_markdown=False)
30
 
31
  return text
 
1
  from io import BytesIO
2
  from typing import Dict, Any
3
+ from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
4
  from transformers.image_utils import base64
5
  from PIL import Image
6
  import torch
7
 
8
+
9
+ class RunningVarTorch:
10
+ def __init__(self, L=15, norm=False):
11
+ self.values = None
12
+ self.L = L
13
+ self.norm = norm
14
+
15
+ def push(self, x: torch.Tensor):
16
+ assert x.dim() == 1
17
+ if self.values is None:
18
+ self.values = x[:, None]
19
+ elif self.values.shape[1] < self.L:
20
+ self.values = torch.cat((self.values, x[:, None]), 1)
21
+ else:
22
+ self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
23
+
24
+ def variance(self):
25
+ if self.values is None:
26
+ return
27
+ if self.norm:
28
+ return torch.var(self.values, 1) / self.values.shape[1]
29
+ else:
30
+ return torch.var(self.values, 1)
31
+
32
+
33
+ class StoppingCriteriaScores(StoppingCriteria):
34
+ def __init__(self, threshold: float = 0.015, window_size: int = 200):
35
+ super().__init__()
36
+ self.threshold = threshold
37
+ self.vars = RunningVarTorch(norm=True)
38
+ self.varvars = RunningVarTorch(L=window_size)
39
+ self.stop_inds = defaultdict(int)
40
+ self.stopped = defaultdict(bool)
41
+ self.size = 0
42
+ self.window_size = window_size
43
+
44
+ @torch.no_grad()
45
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
46
+ last_scores = scores[-1]
47
+ self.vars.push(last_scores.max(1)[0].float().cpu())
48
+ self.varvars.push(self.vars.variance())
49
+ self.size += 1
50
+ if self.size < self.window_size:
51
+ return False
52
+
53
+ varvar = self.varvars.variance()
54
+ for b in range(len(last_scores)):
55
+ if varvar[b] < self.threshold:
56
+ if self.stop_inds[b] > 0 and not self.stopped[b]:
57
+ self.stopped[b] = self.stop_inds[b] >= self.size
58
+ else:
59
+ self.stop_inds[b] = int(
60
+ min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
61
+ )
62
+ else:
63
+ self.stop_inds[b] = 0
64
+ self.stopped[b] = False
65
+ return all(self.stopped.values()) and len(self.stopped) > 0
66
+
67
  class EndpointHandler():
68
  def __init__(self, path="facebook/nougat-base") -> None:
69
  self.processor = NougatProcessor.from_pretrained(path)
 
80
  outputs = self.model.generate(
81
  pixel_values.to(self.device),
82
  min_length=1,
83
+ max_length=3584,
84
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
85
+ return_dict_in_generate=True,
86
+ output_scores=True,
87
+ stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()])
88
  )
89
 
90
+ text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
91
  text = self.processor.post_process_generation(text, fix_markdown=False)
92
 
93
  return text