from typing import Dict, Any from PIL import Image import torch from io import BytesIO from transformers import BlipForConditionalGeneration, BlipProcessor, AutoModelForSeq2SeqLM, AutoTokenizer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class EndpointHandler: def __init__(self, path=""): # load the Blip model and processor self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") self.blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device) self.blip_model.eval() # load the Flan model and tokenizer self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device) self.flan_tokenizer = AutoTokenizer.from_pretrained(path) def __call__(self, data: Any) -> Dict[str, Any]: # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) # preprocess image with Blip raw_images = [Image.open(BytesIO(_img)) for _img in inputs] processed_image = self.blip_processor(images=raw_images, return_tensors="pt") processed_image["pixel_values"] = processed_image["pixel_values"].to(device) processed_image = {**processed_image, **parameters} # generate caption with Blip with torch.no_grad(): out = self.blip_model.generate(**processed_image) captions = self.blip_processor.batch_decode(out, skip_special_tokens=True) # preprocess caption with Flan input_ids = self.flan_tokenizer(captions, return_tensors="pt").input_ids # generate text with Flan if parameters is not None: outputs = self.flan_model.generate(input_ids, **parameters) else: outputs = self.flan_model.generate(input_ids) # postprocess the prediction prediction = self.flan_tokenizer.decode(outputs[0], skip_special_tokens=True) return [{"generated_text": prediction}]