Molmo-7B-D-0924 / handler.py
davanstrien's picture
davanstrien HF staff
Update handler.py
d869c4e verified
raw
history blame
2.61 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import requests
import torch
import gc
import base64
import io
class EndpointHandler:
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map='auto'
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map='auto',
low_cpu_mem_usage=True
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
torch.cuda.empty_cache()
gc.collect()
inputs = data.get("inputs", {})
image_url = inputs.get("image_url")
image_data = inputs.get("image")
text_prompt = inputs.get("text_prompt", "Describe this image.")
if image_url:
try:
image = Image.open(requests.get(image_url, stream=True).raw)
except Exception as e:
return [{"error": f"Failed to load image from URL: {str(e)}"}]
elif image_data:
try:
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
except Exception as e:
return [{"error": f"Failed to decode image data: {str(e)}"}]
else:
return [{"error": "No image_url or image data provided in inputs"}]
if image.mode != "RGB":
image = image.convert("RGB")
try:
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
inputs = self.processor.process(
images=[image],
text=text_prompt
)
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
output = self.model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"),
tokenizer=self.processor.tokenizer
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
torch.cuda.empty_cache()
gc.collect()
return [{"generated_text": generated_text}]
except Exception as e:
return [{"error": f"Error during generation: {str(e)}"}]