Spaces:
Runtime error
Runtime error
File size: 1,154 Bytes
5dae26f 222b8b3 5dae26f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import os
import PIL.Image
import torch
from huggingface_hub import login
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import spaces
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)
class PaliGemmaModel:
def __init__(self):
self.model_id = "google/paligemma-3b-mix-448"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
@spaces.GPU
def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
inputs = self.processor(text=text, images=image, return_tensors="pt").to(self.device)
with torch.inference_mode():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n") |