pix2struct_embedding / handler.py
Jasper Lu
update req
4efd111
raw
history blame
1.31 kB
from typing import Dict, List, Any
import torch
from transformers import AutoProcessor, Pix2StructVisionModel
from PIL import Image
import pdb
import requests
MODEL = "google/pix2struct-screen2words-large"
class EndpointHandler():
def __init__(self, path=""):
#self.processor = AutoProcessor.from_pretrained("jasper-lu/pix2struct_embedding")
#self.model = MarkupLMModel.from_pretrained("jasper-lu/pix2struct_embedding")
self.processor = AutoProcessor.from_pretrained(MODEL)
self.processor.image_processor.is_vqa = False
self.model = Pix2StructVisionModel.from_pretrained(MODEL).cuda()
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
url = data.pop("inputs", data)
image = Image.open(requests.get(url, stream=True).raw)
inputs = self.processor(images=image, return_tensors="pt").cuda()
with torch.no_grad():
outputs = self.model(**inputs)
last_hidden_state = outputs['last_hidden_state']
embedding = torch.mean(last_hidden_state, dim=1).flatten().tolist()
return {"embedding": embedding}
"""
handler = EndpointHandler()
output = handler({"inputs": "https://figma-staging-api.s3.us-west-2.amazonaws.com/images/a8c6a0cc-c022-4f3a-9fc5-ac8582c964dd"})
print(output)
"""