from io import BytesIO from typing import Dict, Any from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList from transformers.image_utils import base64 from PIL import Image import torch class RunningVarTorch: def __init__(self, L=15, norm=False): self.values = None self.L = L self.norm = norm def push(self, x: torch.Tensor): assert x.dim() == 1 if self.values is None: self.values = x[:, None] elif self.values.shape[1] < self.L: self.values = torch.cat((self.values, x[:, None]), 1) else: self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) def variance(self): if self.values is None: return if self.norm: return torch.var(self.values, 1) / self.values.shape[1] else: return torch.var(self.values, 1) class StoppingCriteriaScores(StoppingCriteria): def __init__(self, threshold: float = 0.015, window_size: int = 200): super().__init__() self.threshold = threshold self.vars = RunningVarTorch(norm=True) self.varvars = RunningVarTorch(L=window_size) self.stop_inds = defaultdict(int) self.stopped = defaultdict(bool) self.size = 0 self.window_size = window_size @torch.no_grad() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): last_scores = scores[-1] self.vars.push(last_scores.max(1)[0].float().cpu()) self.varvars.push(self.vars.variance()) self.size += 1 if self.size < self.window_size: return False varvar = self.varvars.variance() for b in range(len(last_scores)): if varvar[b] < self.threshold: if self.stop_inds[b] > 0 and not self.stopped[b]: self.stopped[b] = self.stop_inds[b] >= self.size else: self.stop_inds[b] = int( min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095) ) else: self.stop_inds[b] = 0 self.stopped[b] = False return all(self.stopped.values()) and len(self.stopped) > 0 class EndpointHandler(): def __init__(self, path="facebook/nougat-small") -> None: self.processor = NougatProcessor.from_pretrained(path) self.model = VisionEncoderDecoderModel.from_pretrained(path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) def __call__(self, data: Dict[str, Any]) -> str: image = data.pop("inputs", data) image_data = Image.open(BytesIO(base64.b64decode(image))) pixel_values = self.processor(image_data, return_tensors="pt").pixel_values outputs = self.model.generate( pixel_values.to(self.device), min_length=1, max_length=3584, bad_words_ids=[[self.processor.tokenizer.unk_token_id]], return_dict_in_generate=True, output_scores=True, stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]) ) text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0] text = self.processor.post_process_generation(text, fix_markdown=False) return outputs