from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList import torch.cuda import io import base64 from PIL import Image from typing import Dict, Any from collections import defaultdict class RunningVarTorch: def __init__(self, L=15, norm=False): self.values = None self.L = L self.norm = norm def push(self, x: torch.Tensor): assert x.dim() == 1 if self.values is None: self.values = x[:, None] elif self.values.shape[1] < self.L: self.values = torch.cat((self.values, x[:, None]), 1) else: self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) def variance(self): if self.values is None: return if self.norm: return torch.var(self.values, 1) / self.values.shape[1] else: return torch.var(self.values, 1) class StoppingCriteriaScores(StoppingCriteria): def __init__(self, threshold: float = 0.015, window_size: int = 200): super().__init__() self.threshold = threshold self.vars = RunningVarTorch(norm=True) self.varvars = RunningVarTorch(L=window_size) self.stop_inds = defaultdict(int) self.stopped = defaultdict(bool) self.size = 0 self.window_size = window_size @torch.no_grad() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): last_scores = scores[-1] self.vars.push(last_scores.max(1)[0].float().cpu()) self.varvars.push(self.vars.variance()) self.size += 1 if self.size < self.window_size: return False varvar = self.varvars.variance() for b in range(len(last_scores)): if varvar[b] < self.threshold: if self.stop_inds[b] > 0 and not self.stopped[b]: self.stopped[b] = self.stop_inds[b] >= self.size else: self.stop_inds[b] = int( min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095) ) else: self.stop_inds[b] = 0 self.stopped[b] = False return all(self.stopped.values()) and len(self.stopped) > 0 class EndpointHandler(): def __init__(self, path="facebook/nougat-base"): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = NougatProcessor.from_pretrained(path) self.model = VisionEncoderDecoderModel.from_pretrained(path) self.model = self.model.to(self.device) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data (Dict): The payload with the text prompt and generation parameters. """ # Get inputs input = data.pop("inputs", None) parameters = data.pop("parameters", None) fix_markdown = data.pop("fix_markdown", None) if input is None: raise ValueError("Missing image.") # autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors) binary_data = base64.b64decode(input) image = Image.open(io.BytesIO(binary_data)) pixel_values = self.processor(images= image, return_tensors="pt").pixel_values outputs = self.model.generate( pixel_values=pixel_values.to(self.model.device), min_length=1, bad_words_ids=[[self.processor.tokenizer.unk_token_id]], return_dict_in_generate=True, output_scores=True, stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]), **parameters, ) generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0] prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown) return {"generated_text": prediction}