File size: 1,672 Bytes
d21a961 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from transformers import NougatProcessor, VisionEncoderDecoderModel
import torch.cuda
import io
import base64
from PIL import Image
from typing import Dict, Any
class EndpointHandler():
def __init__(self, path="facebook/nougat-base"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = NougatProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.model = model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data (Dict): The payload with the text prompt
and generation parameters.
"""
# Get inputs
input = data.pop("inputs", None)
parameters = data.pop("parameters", None)
fix_markdown = data.pop("fix_markdown", None)
if input is None:
raise ValueError("Missing image.")
# autoregressively generate tokens, with custom stopping criteria (as defined by the Nougat authors)
binary_data = base64.b64decode(input)
image = Image.open(io.BytesIO(binary_data))
pixel_values = self.processor(images= image, return_tensors="pt").pixel_values
outputs = self.model.generate(inputs = pixel_values.to(self.device),
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
**parameters)
generated = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
prediction = self.processor.post_process_generation(generated, fix_markdown=fix_markdown)
return {"generated_text": prediction} |