Qwen-VL-7B-2 / handler.py
fredaddy's picture
Update handler.py
537f66a verified
import torch
from PIL import Image
import base64
from io import BytesIO
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and processor
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
).to(self.device)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data):
# Extract image and text from the input data
image_data = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_data or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Process the image data
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return {"error": f"Failed to process image data: {e}"}
# Prepare the input in the format expected by the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_prompt},
],
}
]
# Process the input
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the appropriate device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate output
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=2000, # Increased from 128 to 2000
num_return_sequences=1,
do_sample=True,
temperature=0.7,
top_p=0.95
)
# Decode the output
output_text = self.processor.batch_decode(
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}