Qwen2-VL-2B-Instruct / handler.py
Gabriel's picture
Create handler.py
25c91e1 verified
raw
history blame
1.76 kB
from typing import Dict, Any
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from PIL import Image
import io
import base64
import requests
class EndpointHandler():
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(path)
self.model = Qwen2VLForConditionalGeneration.from_pretrained(path)
def __call__(self, data: Any) -> Dict[str, Any]:
image_input = data.get('image', None)
text_input = data.get('text', None)
if isinstance(data, dict):
if image_input.startswith('http'):
image = Image.open(requests.get(image_input, stream=True).raw).convert('RGB')
else:
image_data = base64.b64decode(image_input)
image = Image.open(io.BytesIO(image_data)).convert('RGB')
else:
return {"error": "Invalid input data. Expected binary image data or a dictionary with 'image' key."}
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_input},
],
}
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt",
)
generate_ids = self.model.generate(inputs.input_ids, max_length=30)
output_text = self.processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}